gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2010 Python Software Foundation. All Rights Reserved.
2
+ # Adapted from Python's Lib/test/test_strtod.py (by Mark Dickinson)
3
+
4
+ # More test cases for deccheck.py.
5
+
6
+ import random
7
+
8
+ TEST_SIZE = 2
9
+
10
+
11
+ def test_short_halfway_cases():
12
+ # exact halfway cases with a small number of significant digits
13
+ for k in 0, 5, 10, 15, 20:
14
+ # upper = smallest integer >= 2**54/5**k
15
+ upper = -(-2**54//5**k)
16
+ # lower = smallest odd number >= 2**53/5**k
17
+ lower = -(-2**53//5**k)
18
+ if lower % 2 == 0:
19
+ lower += 1
20
+ for i in range(10 * TEST_SIZE):
21
+ # Select a random odd n in [2**53/5**k,
22
+ # 2**54/5**k). Then n * 10**k gives a halfway case
23
+ # with small number of significant digits.
24
+ n, e = random.randrange(lower, upper, 2), k
25
+
26
+ # Remove any additional powers of 5.
27
+ while n % 5 == 0:
28
+ n, e = n // 5, e + 1
29
+ assert n % 10 in (1, 3, 7, 9)
30
+
31
+ # Try numbers of the form n * 2**p2 * 10**e, p2 >= 0,
32
+ # until n * 2**p2 has more than 20 significant digits.
33
+ digits, exponent = n, e
34
+ while digits < 10**20:
35
+ s = '{}e{}'.format(digits, exponent)
36
+ yield s
37
+ # Same again, but with extra trailing zeros.
38
+ s = '{}e{}'.format(digits * 10**40, exponent - 40)
39
+ yield s
40
+ digits *= 2
41
+
42
+ # Try numbers of the form n * 5**p2 * 10**(e - p5), p5
43
+ # >= 0, with n * 5**p5 < 10**20.
44
+ digits, exponent = n, e
45
+ while digits < 10**20:
46
+ s = '{}e{}'.format(digits, exponent)
47
+ yield s
48
+ # Same again, but with extra trailing zeros.
49
+ s = '{}e{}'.format(digits * 10**40, exponent - 40)
50
+ yield s
51
+ digits *= 5
52
+ exponent -= 1
53
+
54
+ def test_halfway_cases():
55
+ # test halfway cases for the round-half-to-even rule
56
+ for i in range(1000):
57
+ for j in range(TEST_SIZE):
58
+ # bit pattern for a random finite positive (or +0.0) float
59
+ bits = random.randrange(2047*2**52)
60
+
61
+ # convert bit pattern to a number of the form m * 2**e
62
+ e, m = divmod(bits, 2**52)
63
+ if e:
64
+ m, e = m + 2**52, e - 1
65
+ e -= 1074
66
+
67
+ # add 0.5 ulps
68
+ m, e = 2*m + 1, e - 1
69
+
70
+ # convert to a decimal string
71
+ if e >= 0:
72
+ digits = m << e
73
+ exponent = 0
74
+ else:
75
+ # m * 2**e = (m * 5**-e) * 10**e
76
+ digits = m * 5**-e
77
+ exponent = e
78
+ s = '{}e{}'.format(digits, exponent)
79
+ yield s
80
+
81
+ def test_boundaries():
82
+ # boundaries expressed as triples (n, e, u), where
83
+ # n*10**e is an approximation to the boundary value and
84
+ # u*10**e is 1ulp
85
+ boundaries = [
86
+ (10000000000000000000, -19, 1110), # a power of 2 boundary (1.0)
87
+ (17976931348623159077, 289, 1995), # overflow boundary (2.**1024)
88
+ (22250738585072013831, -327, 4941), # normal/subnormal (2.**-1022)
89
+ (0, -327, 4941), # zero
90
+ ]
91
+ for n, e, u in boundaries:
92
+ for j in range(1000):
93
+ for i in range(TEST_SIZE):
94
+ digits = n + random.randrange(-3*u, 3*u)
95
+ exponent = e
96
+ s = '{}e{}'.format(digits, exponent)
97
+ yield s
98
+ n *= 10
99
+ u *= 10
100
+ e -= 1
101
+
102
+ def test_underflow_boundary():
103
+ # test values close to 2**-1075, the underflow boundary; similar
104
+ # to boundary_tests, except that the random error doesn't scale
105
+ # with n
106
+ for exponent in range(-400, -320):
107
+ base = 10**-exponent // 2**1075
108
+ for j in range(TEST_SIZE):
109
+ digits = base + random.randrange(-1000, 1000)
110
+ s = '{}e{}'.format(digits, exponent)
111
+ yield s
112
+
113
+ def test_bigcomp():
114
+ for ndigs in 5, 10, 14, 15, 16, 17, 18, 19, 20, 40, 41, 50:
115
+ dig10 = 10**ndigs
116
+ for i in range(100 * TEST_SIZE):
117
+ digits = random.randrange(dig10)
118
+ exponent = random.randrange(-400, 400)
119
+ s = '{}e{}'.format(digits, exponent)
120
+ yield s
121
+
122
+ def test_parsing():
123
+ # make '0' more likely to be chosen than other digits
124
+ digits = '000000123456789'
125
+ signs = ('+', '-', '')
126
+
127
+ # put together random short valid strings
128
+ # \d*[.\d*]?e
129
+ for i in range(1000):
130
+ for j in range(TEST_SIZE):
131
+ s = random.choice(signs)
132
+ intpart_len = random.randrange(5)
133
+ s += ''.join(random.choice(digits) for _ in range(intpart_len))
134
+ if random.choice([True, False]):
135
+ s += '.'
136
+ fracpart_len = random.randrange(5)
137
+ s += ''.join(random.choice(digits)
138
+ for _ in range(fracpart_len))
139
+ else:
140
+ fracpart_len = 0
141
+ if random.choice([True, False]):
142
+ s += random.choice(['e', 'E'])
143
+ s += random.choice(signs)
144
+ exponent_len = random.randrange(1, 4)
145
+ s += ''.join(random.choice(digits)
146
+ for _ in range(exponent_len))
147
+
148
+ if intpart_len + fracpart_len:
149
+ yield s
150
+
151
+
152
+ TESTCASES = [
153
+ [x for x in test_short_halfway_cases()],
154
+ [x for x in test_halfway_cases()],
155
+ [x for x in test_boundaries()],
156
+ [x for x in test_underflow_boundary()],
157
+ [x for x in test_bigcomp()],
158
+ [x for x in test_parsing()],
159
+ ]
160
+
161
+ def un_randfloat():
162
+ for i in range(2):
163
+ l = random.choice(TESTCASES[:6])
164
+ yield random.choice(l)
165
+
166
+ def bin_randfloat():
167
+ for i in range(2):
168
+ l1 = random.choice(TESTCASES)
169
+ l2 = random.choice(TESTCASES)
170
+ yield random.choice(l1), random.choice(l2)
171
+
172
+ def tern_randfloat():
173
+ for i in range(2):
174
+ l1 = random.choice(TESTCASES)
175
+ l2 = random.choice(TESTCASES)
176
+ l3 = random.choice(TESTCASES)
177
+ yield random.choice(l1), random.choice(l2), random.choice(l3)
@@ -35,37 +35,59 @@ import gumath.functions as fn
35
35
  import gumath.examples as ex
36
36
  from xnd import xnd
37
37
  from ndtypes import ndt
38
- from extending import Graph, bfloat16
38
+ from extending import Graph
39
39
  import sys, time
40
+ import platform
40
41
  import math
42
+ import cmath
41
43
  import unittest
42
44
  import argparse
45
+ from gumath_aux import *
46
+
47
+ try:
48
+ import gumath.cuda as cd
49
+ except ImportError:
50
+ cd = None
43
51
 
44
52
  try:
45
53
  import numpy as np
54
+ np.warnings.filterwarnings('ignore')
46
55
  except ImportError:
47
56
  np = None
48
57
 
58
+ SKIP_LONG = True
59
+ SKIP_BRUTE_FORCE = True
49
60
 
50
- TEST_CASES = [
51
- ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
61
+ ARCH = platform.architecture()[0]
52
62
 
53
- ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
54
- "2 * 1000 * float64", "float64"),
55
63
 
56
- (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
64
+ class TestAPI(unittest.TestCase):
57
65
 
58
- ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
66
+ def test_api(self):
59
67
 
60
- ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
61
- "2 * 1000 * float32", "float32"),
62
-
63
- (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
64
- ]
68
+ self.assertIsInstance(fn.add, gm.gufunc)
69
+ self.assertRaises(TypeError, gm.gufunc.__new__)
70
+ self.assertRaises(TypeError, gm.gufunc.__new__, 1)
65
71
 
66
72
 
67
73
  class TestCall(unittest.TestCase):
68
74
 
75
+ def test_subclass(self):
76
+
77
+ class X(xnd):
78
+ pass
79
+
80
+ x = X([1, 2, 3])
81
+ y = X([1, 2, 3])
82
+
83
+ z = fn.multiply(x, y)
84
+ self.assertEqual(z, [1, 4, 9])
85
+ self.assertEqual(type(z), xnd)
86
+
87
+ z = fn.multiply(x, y, cls=X)
88
+ self.assertEqual(z, [1, 4, 9])
89
+ self.assertEqual(type(z), X)
90
+
69
91
  def test_sin_scalar(self):
70
92
 
71
93
  x1 = xnd(1.2, type="float64")
@@ -213,6 +235,293 @@ class TestMissingValues(unittest.TestCase):
213
235
 
214
236
  self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
215
237
 
238
+ def test_unary(self):
239
+ a = [0, None, 2]
240
+ ans = xnd([math.sin(x) if x is not None else None for x in a])
241
+
242
+ x = xnd(a, dtype="?float64")
243
+ y = fn.sin(x)
244
+ self.assertEqual(y.value, ans)
245
+
246
+ def test_binary(self):
247
+ a = [3, None, 3]
248
+ b = [100, 1, None]
249
+ ans = xnd([t[0] * t[1] if t[0] is not None and t[1] is not None else None
250
+ for t in zip(a, b)])
251
+
252
+ x = xnd(a)
253
+ y = xnd(b)
254
+ z = fn.multiply(x, y)
255
+ self.assertEqual(z.value, ans)
256
+
257
+ def test_reduce(self):
258
+ a = [1, None, 2]
259
+ x = xnd(a)
260
+
261
+ y = gm.reduce(fn.add, x)
262
+ self.assertEqual(y, None)
263
+
264
+ y = gm.reduce(fn.multiply, x)
265
+ self.assertEqual(y, None)
266
+
267
+ y = gm.reduce(fn.subtract, x)
268
+ self.assertEqual(y, None)
269
+
270
+ x = xnd([], dtype="?int32")
271
+
272
+ y = gm.reduce(fn.add, x)
273
+ self.assertEqual(y, 0)
274
+
275
+ @unittest.skipIf(cd is None, "test requires cuda")
276
+ def test_reduce_cuda(self):
277
+ a = [1, None, 2]
278
+ x = xnd(a, device="cuda:managed")
279
+
280
+ y = gm.reduce(cd.add, x)
281
+ self.assertEqual(y, None)
282
+
283
+ y = gm.reduce(cd.multiply, x)
284
+ self.assertEqual(y, None)
285
+
286
+ x = xnd([], dtype="?int32", device="cuda:managed")
287
+ y = gm.reduce(fn.add, x)
288
+ self.assertEqual(y, 0)
289
+
290
+ def test_comparisons(self):
291
+ a = [1, None, 3, 5]
292
+ b = [2, None, 3, 4]
293
+
294
+ x = xnd(a)
295
+ y = xnd(b)
296
+
297
+ ans = fn.equal(x, y)
298
+ self.assertEqual(ans.value, [False, None, True, False])
299
+
300
+ ans = fn.not_equal(x, y)
301
+ self.assertEqual(ans.value, [True, None, False, True])
302
+
303
+ ans = fn.less(x, y)
304
+ self.assertEqual(ans.value, [True, None, False, False])
305
+
306
+ ans = fn.less_equal(x, y)
307
+ self.assertEqual(ans.value, [True, None, True, False])
308
+
309
+ ans = fn.greater_equal(x, y)
310
+ self.assertEqual(ans.value, [False, None, True, True])
311
+
312
+ ans = fn.greater(x, y)
313
+ self.assertEqual(ans.value, [False, None, False, True])
314
+
315
+ @unittest.skipIf(cd is None, "test requires cuda")
316
+ def test_comparisons_cuda(self):
317
+ a = [1, None, 3, 5]
318
+ b = [2, None, 3, 4]
319
+
320
+ x = xnd(a, device="cuda:managed")
321
+ y = xnd(b, device="cuda:managed")
322
+
323
+ ans = cd.equal(x, y)
324
+ self.assertEqual(ans.value, [False, None, True, False])
325
+
326
+ ans = cd.not_equal(x, y)
327
+ self.assertEqual(ans.value, [True, None, False, True])
328
+
329
+ ans = cd.less(x, y)
330
+ self.assertEqual(ans.value, [True, None, False, False])
331
+
332
+ ans = cd.less_equal(x, y)
333
+ self.assertEqual(ans.value, [True, None, True, False])
334
+
335
+ ans = cd.greater_equal(x, y)
336
+ self.assertEqual(ans.value, [False, None, True, True])
337
+
338
+ ans = cd.greater(x, y)
339
+ self.assertEqual(ans.value, [False, None, False, True])
340
+
341
+ def test_equaln(self):
342
+ a = [1, None, 3, 5]
343
+ b = [2, None, 3, 4]
344
+
345
+ x = xnd(a)
346
+ y = xnd(b)
347
+ z = fn.equaln(x, y)
348
+ self.assertEqual(z, [False, True, True, False])
349
+ self.assertEqual(z.dtype, ndt("bool"))
350
+
351
+ a = [1, None, 3, 5]
352
+ b = [2, 0, 3, 4]
353
+
354
+ x = xnd(a)
355
+ y = xnd(b)
356
+ z = fn.equaln(x, y)
357
+ self.assertEqual(z, [False, False, True, False])
358
+ self.assertEqual(z.dtype, ndt("bool"))
359
+
360
+ # NA eqn NA
361
+ a = [None]
362
+ b = [None]
363
+ x = xnd(a)
364
+ y = xnd(b)
365
+ z = fn.equaln(x, y)
366
+ self.assertEqual(z, [True])
367
+
368
+ # !(NA eqn 0)
369
+ a = [None]
370
+ b = [0.0]
371
+ x = xnd(a)
372
+ y = xnd(b)
373
+ z = fn.equaln(x, y)
374
+ self.assertEqual(z, [False])
375
+
376
+ # !(0 eqn NA)
377
+ a = [0.0]
378
+ b = [None]
379
+ x = xnd(a)
380
+ y = xnd(b)
381
+ z = fn.equaln(x, y)
382
+ self.assertEqual(z, [False])
383
+
384
+ # !(NA eqn NaN)
385
+ a = [None]
386
+ b = [float("nan")]
387
+ x = xnd(a)
388
+ y = xnd(b)
389
+ z = fn.equaln(x, y)
390
+ self.assertEqual(z, [False])
391
+
392
+ # !(NaN eqn NA)
393
+ a = [float("nan")]
394
+ b = [None]
395
+ x = xnd(a)
396
+ y = xnd(b)
397
+ z = fn.equaln(x, y)
398
+ self.assertEqual(z, [False])
399
+
400
+ @unittest.skipIf(cd is None, "test requires cuda")
401
+ def test_equaln_cuda(self):
402
+ a = [1, None, 3, 5]
403
+ b = [2, None, 3, 4]
404
+
405
+ x = xnd(a, device="cuda:managed")
406
+ y = xnd(b, device="cuda:managed")
407
+ z = cd.equaln(x, y)
408
+ self.assertEqual(z, [False, True, True, False])
409
+ self.assertEqual(z.dtype, ndt("bool"))
410
+
411
+ a = [1, None, 3, 5]
412
+ b = [2, 0, 3, 4]
413
+
414
+ x = xnd(a, device="cuda:managed")
415
+ y = xnd(b, device="cuda:managed")
416
+ z = cd.equaln(x, y)
417
+ self.assertEqual(z, [False, False, True, False])
418
+ self.assertEqual(z.dtype, ndt("bool"))
419
+
420
+ # NA eqn NA
421
+ a = [None]
422
+ b = [None]
423
+ x = xnd(a, device="cuda:managed")
424
+ y = xnd(b, device="cuda:managed")
425
+ z = cd.equaln(x, y)
426
+ self.assertEqual(z, [True])
427
+
428
+ # !(NA eqn 0)
429
+ a = [None]
430
+ b = [0.0]
431
+ x = xnd(a, device="cuda:managed")
432
+ y = xnd(b, device="cuda:managed")
433
+ z = cd.equaln(x, y)
434
+ self.assertEqual(z, [False])
435
+
436
+ # !(0 eqn NA)
437
+ a = [0.0]
438
+ b = [None]
439
+ x = xnd(a, device="cuda:managed")
440
+ y = xnd(b, device="cuda:managed")
441
+ z = cd.equaln(x, y)
442
+ self.assertEqual(z, [False])
443
+
444
+ # !(NA eqn NaN)
445
+ a = [None]
446
+ b = [float("nan")]
447
+ x = xnd(a, device="cuda:managed")
448
+ y = xnd(b, device="cuda:managed")
449
+ z = cd.equaln(x, y)
450
+ self.assertEqual(z, [False])
451
+
452
+ # !(NaN eqn NA)
453
+ a = [float("nan")]
454
+ b = [None]
455
+ x = xnd(a, device="cuda:managed")
456
+ y = xnd(b, device="cuda:managed")
457
+ z = cd.equaln(x, y)
458
+ self.assertEqual(z, [False])
459
+
460
+
461
+ class TestEqualN(unittest.TestCase):
462
+
463
+ def test_nan_float(self):
464
+ for dtype in "bfloat16", "float32", "float64":
465
+ x = xnd([0, float("nan"), 2], dtype=dtype)
466
+
467
+ y = xnd([0, float("nan"), 2], dtype=dtype)
468
+ z = fn.equaln(x, y)
469
+ self.assertEqual(z, [True, True, True])
470
+
471
+ y = xnd([0, 1, 2], dtype=dtype)
472
+ z = fn.equaln(x, y)
473
+ self.assertEqual(z, [True, False, True])
474
+
475
+ def test_nan_complex(self):
476
+ for dtype in "complex64", "complex128":
477
+ for a, b, ans in [
478
+ (complex(float("nan"), 1.2), complex(float("nan"), 1.2), True),
479
+ (complex(float("nan"), 1.2), complex(float("nan"), 1), False),
480
+ (complex(float("nan"), float("nan")), complex(float("nan"), 1.2), False),
481
+
482
+ (complex(1.2, float("nan")), complex(1.2, float("nan")), True),
483
+ (complex(1.2, float("nan")), complex(1, float("nan")), False),
484
+ (complex(float("nan"), float("nan")), complex(1.2, float("nan")), False),
485
+
486
+ (complex(float("nan"), float("nan")), complex(float("nan"), float("nan")), True)]:
487
+
488
+ x = xnd([0, a, 2], dtype=dtype)
489
+ y = xnd([0, b, 2], dtype=dtype)
490
+ z = fn.equaln(x, y)
491
+ self.assertEqual(z, [True, ans, True])
492
+
493
+ @unittest.skipIf(cd is None, "test requires cuda")
494
+ def test_nan_float_cuda(self):
495
+ for dtype in "bfloat16", "float16", "float32", "float64":
496
+ x = xnd([0, float("nan"), 2], dtype=dtype, device="cuda:managed")
497
+
498
+ y = xnd([0, float("nan"), 2], dtype=dtype, device="cuda:managed")
499
+ z = cd.equaln(x, y)
500
+ self.assertEqual(z, [True, True, True])
501
+
502
+ y = xnd([0, 1, 2], dtype=dtype, device="cuda:managed")
503
+ z = cd.equaln(x, y)
504
+ self.assertEqual(z, [True, False, True])
505
+
506
+ @unittest.skipIf(cd is None, "test requires cuda")
507
+ def test_nan_complex_cuda(self):
508
+ for dtype in "complex64", "complex128":
509
+ for a, b, ans in [
510
+ (complex(float("nan"), 1.2), complex(float("nan"), 1.2), True),
511
+ (complex(float("nan"), 1.2), complex(float("nan"), 1), False),
512
+ (complex(float("nan"), float("nan")), complex(float("nan"), 1.2), False),
513
+
514
+ (complex(1.2, float("nan")), complex(1.2, float("nan")), True),
515
+ (complex(1.2, float("nan")), complex(1, float("nan")), False),
516
+ (complex(float("nan"), float("nan")), complex(1.2, float("nan")), False),
517
+
518
+ (complex(float("nan"), float("nan")), complex(float("nan"), float("nan")), True)]:
519
+
520
+ x = xnd([0, a, 2], dtype=dtype, device="cuda:managed")
521
+ y = xnd([0, b, 2], dtype=dtype, device="cuda:managed")
522
+ z = cd.equaln(x, y)
523
+ self.assertEqual(z, [True, ans, True])
524
+
216
525
 
217
526
  class TestRaggedArrays(unittest.TestCase):
218
527
 
@@ -237,8 +546,59 @@ class TestRaggedArrays(unittest.TestCase):
237
546
  self.assertEqual(y.value, ans)
238
547
 
239
548
 
549
+ class TestFlexibleArrays(unittest.TestCase):
550
+
551
+ def test_sin_var_compatible(self):
552
+ s = math.sin
553
+ lst = [[[1.0],
554
+ [2.0, 3.0],
555
+ [4.0, 5.0, 6.0]],
556
+ [[7.0],
557
+ [8.0, 9.0],
558
+ [10.0, 11.0, 12.0]]]
559
+
560
+ ans = [[[s(1.0)],
561
+ [s(2.0), s(3.0)],
562
+ [s(4.0), s(5.0), s(6.0)]],
563
+ [[s(7.0)],
564
+ [s(8.0), s(9.0)],
565
+ [s(10.0), s(11.0), s(12.0)]]]
566
+
567
+ x = xnd(lst, type="array * array * array * float64")
568
+ y = fn.sin(x)
569
+ self.assertEqual(y.value, ans)
570
+
571
+ def test_add(self):
572
+ a = [[[1.0],
573
+ [2.0, 3.0],
574
+ [4.0, 5.0, 6.0]],
575
+ [[7.0],
576
+ [8.0, 9.0],
577
+ [10.0, 11.0, 12.0]]]
578
+
579
+ b = [[[2.0],
580
+ [3.0, 4.0],
581
+ [5.0, 6.0, 7.0]],
582
+ [[-8.0],
583
+ [-9.0, -10.0],
584
+ [111.1, 121.2, 25.3]]]
585
+
586
+ ans = [[[1.0+2.0],
587
+ [2.0+3.0, 3.0+4.0],
588
+ [4.0+5.0, 5.0+6.0, 6.0+7.0]],
589
+ [[7.0-8.0],
590
+ [8.0-9.0, 9.0-10.0],
591
+ [10.0+111.1, 11.0+121.2, 12.0+25.3]]]
592
+
593
+ x = xnd(a, type="array * array * array * float64")
594
+ y = xnd(b, type="array * array * array * float64")
595
+ z = fn.add(x, y)
596
+ self.assertEqual(z.value, ans)
597
+
598
+
240
599
  class TestGraphs(unittest.TestCase):
241
600
 
601
+ @unittest.skipIf(True, "abstract return types are temporarily disabled")
242
602
  def test_shortest_path(self):
243
603
  graphs = [[[(1, 1.2), (2, 4.4)],
244
604
  [(2, 2.2)],
@@ -274,17 +634,6 @@ class TestGraphs(unittest.TestCase):
274
634
  self.assertRaises(ValueError, Graph, lst)
275
635
 
276
636
 
277
- @unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
278
- class TestBFloat16(unittest.TestCase):
279
-
280
- def test_init(self):
281
- lst = [1.2e10, 2.1121, -3e20]
282
- ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
283
-
284
- x = bfloat16(lst)
285
- self.assertEqual(x.value, ans)
286
-
287
-
288
637
  class TestPdist(unittest.TestCase):
289
638
 
290
639
  def test_exceptions(self):
@@ -373,15 +722,1142 @@ class TestNumba(unittest.TestCase):
373
722
  np.testing.assert_equal(z, c)
374
723
 
375
724
 
725
+ class TestOut(unittest.TestCase):
726
+
727
+ def test_api_cpu(self):
728
+ # negative
729
+ x = xnd([1, 2, 3])
730
+ y = xnd.empty("3 * int64")
731
+ z = fn.negative(x, out=y)
732
+
733
+ self.assertIs(z, y)
734
+ self.assertEqual(y, xnd([-1, -2, -3]))
735
+
736
+ # divmod
737
+ x = xnd([10, 20, 30])
738
+ y = xnd([7, 8, 9])
739
+ a = xnd.empty("3 * int64")
740
+ b = xnd.empty("3 * int64")
741
+ q, r = fn.divmod(x, y, out=(a, b))
742
+
743
+ self.assertIs(q, a)
744
+ self.assertIs(r, b)
745
+
746
+ self.assertEqual(q, xnd([1, 2, 3]))
747
+ self.assertEqual(r, xnd([3, 4, 3]))
748
+
749
+ @unittest.skipIf(cd is None, "test requires cuda")
750
+ def test_api_cuda(self):
751
+ # negative
752
+ x = xnd([1, 2, 3], device="cuda:managed")
753
+ y = xnd.empty("3 * int64", device="cuda:managed")
754
+ z = cd.negative(x, out=y)
755
+
756
+ self.assertIs(z, y)
757
+ self.assertEqual(y, xnd([-1, -2, -3]))
758
+
759
+ # divmod
760
+ x = xnd([10, 20, 30], device="cuda:managed")
761
+ y = xnd([7, 8, 9], device="cuda:managed")
762
+ a = xnd.empty("3 * int64", device="cuda:managed")
763
+ b = xnd.empty("3 * int64", device="cuda:managed")
764
+ q, r = cd.divmod(x, y, out=(a, b))
765
+
766
+ self.assertIs(q, a)
767
+ self.assertIs(r, b)
768
+
769
+ self.assertEqual(q, xnd([1, 2, 3]))
770
+ self.assertEqual(r, xnd([3, 4, 3]))
771
+
772
+ def test_broadcast_cpu(self):
773
+ # multiply
774
+ x = xnd([1, 2, 3])
775
+ y = xnd([2])
776
+ z = xnd.empty("3 * int64")
777
+ ans = fn.multiply(x, y, out=z)
778
+
779
+ self.assertIs(ans, z)
780
+ self.assertEqual(ans, xnd([2, 4, 6]))
781
+
782
+ x = xnd([1, 2, 3])
783
+ y = xnd(2)
784
+ z = xnd.empty("3 * int64")
785
+ ans = fn.multiply(x, y, out=z)
786
+
787
+ self.assertIs(ans, z)
788
+ self.assertEqual(ans, xnd([2, 4, 6]))
789
+
790
+ # divmod
791
+ x = xnd([10, 20, 30])
792
+ y = xnd([3])
793
+ a = xnd.empty("3 * int64")
794
+ b = xnd.empty("3 * int64")
795
+ q, r = fn.divmod(x, y, out=(a, b))
796
+
797
+ self.assertIs(q, a)
798
+ self.assertIs(r, b)
799
+ self.assertEqual(q, xnd([3, 6, 10]))
800
+ self.assertEqual(r, xnd([1, 2, 0]))
801
+
802
+ x = xnd([10, 20, 30])
803
+ y = xnd(3)
804
+ a = xnd.empty("3 * int64")
805
+ b = xnd.empty("3 * int64")
806
+ q, r = fn.divmod(x, y, out=(a, b))
807
+
808
+ self.assertIs(q, a)
809
+ self.assertIs(r, b)
810
+ self.assertEqual(q, xnd([3, 6, 10]))
811
+ self.assertEqual(r, xnd([1, 2, 0]))
812
+
813
+ @unittest.skipIf(cd is None, "test requires cuda")
814
+ def test_broadcast_cuda(self):
815
+ # multiply
816
+ x = xnd([1, 2, 3], device="cuda:managed")
817
+ y = xnd([2], device="cuda:managed")
818
+ z = xnd.empty("3 * int64", device="cuda:managed")
819
+ ans = fn.multiply(x, y, out=z)
820
+
821
+ self.assertIs(ans, z)
822
+ self.assertEqual(ans, xnd([2, 4, 6]))
823
+
824
+
825
+ class TestUnaryCPU(unittest.TestCase):
826
+
827
+ def test_acos(self):
828
+ a = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
829
+ b = [math.acos(x) for x in a]
830
+
831
+ x = xnd(a, dtype="float64")
832
+ y = fn.acos(x)
833
+ self.assertEqual(y, b)
834
+
835
+ def test_acos_opt(self):
836
+ a = [0, 0.1, 0.2, None, 0.4, 0.5, 0.6, None]
837
+ b = [math.acos(x) if x is not None else None for x in a]
838
+
839
+ x = xnd(a, dtype="?float64")
840
+ y = fn.acos(x)
841
+ self.assertEqual(y, b)
842
+
843
+ def test_inexact_cast(self):
844
+ a = [0, 1, 2, 3, 4, 5, 6, 7]
845
+ x = xnd(a, dtype="int64")
846
+ self.assertRaises(ValueError, fn.sin, x)
847
+
848
+
849
+ @unittest.skipIf(cd is None, "test requires cuda")
850
+ class TestUnaryCUDA(unittest.TestCase):
851
+
852
+ def test_cos(self):
853
+ a = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
854
+ b = [math.cos(x) for x in a]
855
+
856
+ x = xnd(a, dtype="float64", device="cuda:managed")
857
+ y = cd.cos(x)
858
+ self.assertEqual(y, b)
859
+
860
+ def test_cos_opt(self):
861
+ a = [0, 0.1, 0.2, None, 0.4, 0.5, 0.6, None]
862
+ b = [math.cos(x) if x is not None else None for x in a]
863
+
864
+ x = xnd(a, dtype="?float64", device="cuda:managed")
865
+ y = cd.cos(x)
866
+ self.assertEqual(y, b)
867
+
868
+ def test_inexact_cast(self):
869
+ a = [0, 1, 2, 3, 4, 5, 6, 7]
870
+ x = xnd(a, dtype="int64", device="cuda:managed")
871
+ self.assertRaises(ValueError, cd.sin, x)
872
+
873
+
874
+ class TestBinaryCPU(unittest.TestCase):
875
+
876
+ def test_binary(self):
877
+ for t, u in implemented_sigs["binary"]["default"]:
878
+ w = implemented_sigs["binary"]["default"][(t, u)]
879
+
880
+ if t.cpu_noimpl() or u.cpu_noimpl():
881
+ continue
882
+
883
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type)
884
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
885
+ z = fn.add(x, y)
886
+ self.assertEqual(z, [1, 3, 5, 7, 9, 11, 13, 15])
887
+
888
+ def test_add_opt(self):
889
+ for t, u in implemented_sigs["binary"]["default"]:
890
+ w = implemented_sigs["binary"]["default"][(t, u)]
891
+
892
+ if t.cpu_noimpl() or u.cpu_noimpl():
893
+ continue
894
+
895
+ x = xnd([0, 1, None, 3, 4, 5, 6, 7], dtype="?" + t.type)
896
+ y = xnd([1, 2, 3, 4, 5, 6, None, 8], dtype="?" + u.type)
897
+ z = fn.add(x, y)
898
+ self.assertEqual(z, [1, 3, None, 7, 9, 11, None, 15])
899
+
900
+ def test_subtract(self):
901
+ for t, u in implemented_sigs["binary"]["default"]:
902
+ w = implemented_sigs["binary"]["default"][(t, u)]
903
+
904
+ if t.cpu_noimpl() or u.cpu_noimpl():
905
+ continue
906
+
907
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type)
908
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
909
+ z = fn.subtract(x, y)
910
+ self.assertEqual(z, [1, 1, 1, 1, 1, 1, 1, 1])
911
+
912
+ def test_multiply(self):
913
+ for t, u in implemented_sigs["binary"]["default"]:
914
+ w = implemented_sigs["binary"]["default"][(t, u)]
915
+
916
+ if t.cpu_noimpl() or u.cpu_noimpl():
917
+ continue
918
+
919
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type)
920
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
921
+ z = fn.multiply(x, y)
922
+ self.assertEqual(z, [2, 6, 12, 20, 30, 42, 56, 72])
923
+
924
+
925
+ @unittest.skipIf(cd is None, "test requires cuda")
926
+ class TestBinaryCUDA(unittest.TestCase):
927
+
928
+ def test_binary(self):
929
+ for t, u in implemented_sigs["binary"]["default"]:
930
+ w = implemented_sigs["binary"]["default"][(t, u)]
931
+
932
+ if t.cpu_noimpl() or u.cpu_noimpl():
933
+ continue
934
+
935
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type, device="cuda:managed")
936
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
937
+ z = cd.add(x, y)
938
+ self.assertEqual(z, [1, 3, 5, 7, 9, 11, 13, 15])
939
+
940
+ def test_add_opt(self):
941
+ for t, u in implemented_sigs["binary"]["default"]:
942
+ w = implemented_sigs["binary"]["default"][(t, u)]
943
+
944
+ if t.cpu_noimpl() or u.cpu_noimpl():
945
+ continue
946
+
947
+ x = xnd([0, 1, None, 3, 4, 5, 6, 7], dtype="?" + t.type, device="cuda:managed")
948
+ y = xnd([1, 2, 3, 4, 5, 6, None, 8], dtype="?" + u.type, device="cuda:managed")
949
+ z = cd.add(x, y)
950
+ self.assertEqual(z, [1, 3, None, 7, 9, 11, None, 15])
951
+
952
+ def test_subtract(self):
953
+ for t, u in implemented_sigs["binary"]["default"]:
954
+ w = implemented_sigs["binary"]["default"][(t, u)]
955
+
956
+ if t.cpu_noimpl() or u.cpu_noimpl():
957
+ continue
958
+
959
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type, device="cuda:managed")
960
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
961
+ z = cd.subtract(x, y)
962
+ self.assertEqual(z, [1, 1, 1, 1, 1, 1, 1, 1])
963
+
964
+ def test_multiply(self):
965
+ for t, u in implemented_sigs["binary"]["default"]:
966
+ w = implemented_sigs["binary"]["default"][(t, u)]
967
+
968
+ if t.cpu_noimpl() or u.cpu_noimpl():
969
+ continue
970
+
971
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type, device="cuda:managed")
972
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
973
+ z = cd.multiply(x, y)
974
+ self.assertEqual(z, [2, 6, 12, 20, 30, 42, 56, 72])
975
+
976
+
977
+ class TestBitwiseCPU(unittest.TestCase):
978
+
979
+ def test_and(self):
980
+ for t, u in implemented_sigs["binary"]["bitwise"]:
981
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
982
+
983
+ if t.cpu_noimpl() or u.cpu_noimpl():
984
+ continue
985
+
986
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type)
987
+ x = xnd([0, 1, 0, 1, 1, 1, 1, 0], dtype=t.type)
988
+ y = xnd([1, 0, 0, 0, 1, 1, 1, 1], dtype=u.type)
989
+ z = fn.bitwise_and(x, y)
990
+ self.assertEqual(z, [0, 0, 0, 0, 1, 1, 1, 0])
991
+
992
+ def test_and_opt(self):
993
+ for t, u in implemented_sigs["binary"]["bitwise"]:
994
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
995
+
996
+ if t.cpu_noimpl() or u.cpu_noimpl():
997
+ continue
998
+
999
+ a = [0, 1, None, 1, 1, 1, 1, 0]
1000
+ b = [1, 1, 1, 1, 1, 1, None, 0]
1001
+ c = [0, 1, None, 1, 1, 1, None, 0]
1002
+
1003
+ x = xnd(a, dtype="?" + t.type)
1004
+ y = xnd(b, dtype="?" + u.type)
1005
+ z = fn.bitwise_and(x, y)
1006
+ self.assertEqual(z, c)
1007
+
1008
+
1009
+ @unittest.skipIf(cd is None, "test requires cuda")
1010
+ class TestBitwiseCUDA(unittest.TestCase):
1011
+
1012
+ def test_and(self):
1013
+ for t, u in implemented_sigs["binary"]["bitwise"]:
1014
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
1015
+
1016
+ if t.cuda_noimpl() or u.cuda_noimpl():
1017
+ continue
1018
+
1019
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type, device="cuda:managed")
1020
+ x = xnd([0, 1, 0, 1, 1, 1, 1, 0], dtype=t.type, device="cuda:managed")
1021
+ y = xnd([1, 0, 0, 0, 1, 1, 1, 1], dtype=u.type, device="cuda:managed")
1022
+ z = cd.bitwise_and(x, y)
1023
+ self.assertEqual(z, [0, 0, 0, 0, 1, 1, 1, 0])
1024
+
1025
+ def test_and_opt(self):
1026
+ for t, u in implemented_sigs["binary"]["bitwise"]:
1027
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
1028
+
1029
+ if t.cuda_noimpl() or u.cuda_noimpl():
1030
+ continue
1031
+
1032
+ a = [0, 1, None, 1, 1, 1, 1, 0]
1033
+ b = [1, 1, 1, 1, 1, 1, None, 0]
1034
+ c = [0, 1, None, 1, 1, 1, None, 0]
1035
+
1036
+ x = xnd(a, dtype="?" + t.type, device="cuda:managed")
1037
+ y = xnd(b, dtype="?" + u.type, device="cuda:managed")
1038
+ z = cd.bitwise_and(x, y)
1039
+ self.assertEqual(z, c)
1040
+
1041
+
1042
+ @unittest.skipIf(np is None, "test requires numpy")
1043
+ class TestFunctions(unittest.TestCase):
1044
+
1045
+ def assertRelErrorLess(self, calc, expected, maxerr, msg):
1046
+ if cmath.isnan(calc) or cmath.isnan(expected):
1047
+ return
1048
+ elif cmath.isinf(calc) or cmath.isinf(expected):
1049
+ return
1050
+ elif abs(expected) < 1e-5 or abs(calc) < 1e-5:
1051
+ self.assertLess(abs(calc), 1e-5, msg)
1052
+ self.assertLess(abs(expected), 1e-5, msg)
1053
+ else:
1054
+ err = abs((calc-expected) / expected)
1055
+ self.assertLess(err, maxerr, msg)
1056
+
1057
+ def equal(self, calc, expected, msg):
1058
+ if np.isnan(calc) and np.isnan(expected):
1059
+ return
1060
+ else:
1061
+ self.assertEqual(calc, expected, msg)
1062
+
1063
+ def assert_equal(self, f, z1, z2, w, msg, a=None, b=None):
1064
+ if w.type == "bfloat16":
1065
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1066
+ elif f == "power" and w.type in ("int8", "int16", "int32", "int64"):
1067
+ pass # equal mod INTN_MAX
1068
+ elif f == "power" and isinstance(z1, complex):
1069
+ # multivalued function, compare against Python
1070
+ try:
1071
+ ans = complex(a) ** complex(b)
1072
+ except ZeroDivisionError:
1073
+ pass
1074
+ except OverflowError:
1075
+ pass
1076
+ else:
1077
+ msg = "%s ans=%s" % (msg, ans)
1078
+ self.assertRelErrorLess(z1.real, ans.real, 1e-2, msg)
1079
+ self.assertRelErrorLess(z1.imag, ans.imag, 1e-2, msg)
1080
+ elif isinstance(z1, complex):
1081
+ if f not in ("add", "subtract"):
1082
+ self.assertRelErrorLess(z1.real, z2.real, 1e-2, msg)
1083
+ self.assertRelErrorLess(z1.imag, z2.imag, 1e-2, msg)
1084
+ else:
1085
+ self.equal(z1.real, z2.real, msg) and \
1086
+ self.equal(z1.imag, z2.imag, msg)
1087
+ elif f in functions["unary"]["real_math"] or \
1088
+ f in functions["unary"]["real_math_with_half"] or \
1089
+ f in functions["unary"]["complex_math"] or \
1090
+ f in functions["unary"]["complex_math_with_half"] or \
1091
+ f == "power":
1092
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1093
+ elif f == "divide" and w.type in ("float16", "float32"):
1094
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1095
+ else:
1096
+ return self.equal(z1, z2, msg)
1097
+
1098
+ def create_xnd(self, a, t, dev=None):
1099
+
1100
+ # Check that struct.pack(a) overflows iff xnd(a) overflows.
1101
+ overflow = struct_overflow(a, t)
1102
+ xnd_overflow = False
1103
+ try:
1104
+ x = xnd([a], dtype=t.type, device=dev)
1105
+ except OverflowError:
1106
+ xnd_overflow = True
1107
+
1108
+ self.assertEqual(xnd_overflow, overflow)
1109
+
1110
+ return None if xnd_overflow else x
1111
+
1112
+ def check_unary_not_implemented(self, f, a, t, mod=fn, dev=None):
1113
+
1114
+ x = self.create_xnd(a, t, dev)
1115
+ if x is None:
1116
+ return
1117
+
1118
+ self.assertRaises(NotImplementedError, getattr(mod, f), x)
1119
+
1120
+ def check_unary_type_error(self, f, a, t, mod=fn, dev=None):
1121
+
1122
+ x = self.create_xnd(a, t, dev)
1123
+ if x is None:
1124
+ return
1125
+
1126
+ self.assertRaises(TypeError, getattr(mod, f), x)
1127
+
1128
+ def check_unary(self, f, a, t, u, mod=fn, dev=None):
1129
+
1130
+ x1 = self.create_xnd(a, t, dev)
1131
+ if x1 is None:
1132
+ return
1133
+
1134
+ y1 = getattr(mod, f)(x1)
1135
+ self.assertEqual(str(y1[0].type), u.type)
1136
+ v1 = y1[0].value
1137
+
1138
+ value = x1.value if t.type == "bfloat16" else a
1139
+ dtype = "float32" if t.type == "bfloat16" else t.type
1140
+
1141
+ x2 = np.array([value], dtype=dtype)
1142
+ y2 = getattr(np, np_function(f))(x2)
1143
+ v2 = y2[0]
1144
+
1145
+ msg = "%s(%s : %s) -> %s xnd: %s np: %s" % (f, a, t, u, y1, y2)
1146
+ self.assert_equal(f, v1, v2, u, msg)
1147
+
1148
+ def check_binary_not_implemented(self, f, a, t, b, u, mod=fn, dev=None):
1149
+
1150
+ x1 = self.create_xnd(a, t, dev)
1151
+ if x1 is None:
1152
+ return
1153
+
1154
+ y1 = self.create_xnd(b, u, dev)
1155
+ if y1 is None:
1156
+ return
1157
+
1158
+ self.assertRaises(NotImplementedError, getattr(mod, f), x1, y1)
1159
+
1160
+ def check_binary_type_error(self, f, a, t, b, u, mod=fn, dev=None):
1161
+
1162
+ x1 = self.create_xnd(a, t, dev)
1163
+ if x1 is None:
1164
+ return
1165
+
1166
+ y1 = self.create_xnd(b, u, dev)
1167
+ if y1 is None:
1168
+ return
1169
+
1170
+ self.assertRaises(TypeError, getattr(mod, f), x1, y1)
1171
+
1172
+ def check_binary(self, f, a, t, b, u, w, mod=fn, dev=None):
1173
+
1174
+ x1 = self.create_xnd(a, t, dev)
1175
+ if x1 is None:
1176
+ return
1177
+
1178
+ y1 = self.create_xnd(b, u, dev)
1179
+ if y1 is None:
1180
+ return
1181
+
1182
+ xnd_exc = z1 = None
1183
+ try:
1184
+ z1 = getattr(mod, f)(x1, y1)
1185
+ self.assertEqual(str(z1[0].type), w.type)
1186
+ v1 = z1[0].value
1187
+ except Exception as e:
1188
+ xnd_exc = e.__class__
1189
+
1190
+ dtype1 = "float32" if t.type == "bfloat16" else t.type
1191
+ dtype2 = "float32" if u.type == "bfloat16" else u.type
1192
+ value1 = x1.value if t.type == "bfloat16" else a
1193
+ value2 = y1.value if u.type == "bfloat16" else b
1194
+
1195
+ x2 = np.array([value1], dtype=dtype1)
1196
+ y2 = np.array([value2], dtype=dtype2)
1197
+
1198
+ np_exc = z2 = None
1199
+ try:
1200
+ z2 = getattr(np, f)(x2, y2)
1201
+ v2 = z2[0]
1202
+ except Exception as e:
1203
+ np_exc = e.__class__
1204
+
1205
+ if xnd_exc or np_exc:
1206
+ if xnd_exc != NotImplementedError:
1207
+ self.assertEqual(xnd_exc, np_exc)
1208
+ else:
1209
+ msg = "%s(%s : %s, %s : %s) -> %s xnd: %s np: %s" % \
1210
+ (f, a, t, b, u, w, z1, z2)
1211
+ self.assert_equal(f, v1, v2, w, msg, a=x1[0].value, b=y1[0].value)
1212
+
1213
+ def check_binary_mv(self, f, a, t, b, u, v, w, mod=fn, dev=None):
1214
+
1215
+ x1 = self.create_xnd(a, t, dev)
1216
+ if x1 is None:
1217
+ return
1218
+
1219
+ y1 = self.create_xnd(b, u, dev)
1220
+ if y1 is None:
1221
+ return
1222
+
1223
+ c1, d1 = getattr(mod, f)(x1, y1)
1224
+ self.assertEqual(str(c1[0].type), v.type)
1225
+ self.assertEqual(str(d1[0].type), w.type)
1226
+ cv1 = c1[0].value
1227
+ dv1 = d1[0].value
1228
+
1229
+ x2 = np.array([a], dtype=t.type)
1230
+ y2 = np.array([b], dtype=u.type)
1231
+ c2, d2 = getattr(np, f)(x2, y2)
1232
+ cv2 = c2[0]
1233
+ dv2 = d2[0]
1234
+
1235
+ msg = "%s(%s : %s, %s : %s) -> %s, %s xnd: %s np: %s" % \
1236
+ (f, a, t, b, u, v, w, (cv1, dv1), (cv2, dv2))
1237
+ self.assert_equal(f, cv1, cv2, v, msg)
1238
+ self.assert_equal(f, dv2, dv2, v, msg)
1239
+
1240
+ @unittest.skipIf(sys.platform == "darwin", "complex trigonometry errors too large")
1241
+ @unittest.skipIf(sys.platform == "win32" and ARCH == "32bit", "complex trigonometry errors too large")
1242
+ def test_unary_cpu(self):
1243
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1244
+
1245
+ print("\n", flush=True)
1246
+
1247
+ for pattern, return_type in [
1248
+ ("default", "default"),
1249
+ ("complex_math", "float_result"),
1250
+ ("real_math", "float_result")]:
1251
+
1252
+ for f in functions["unary"][pattern]:
1253
+ if np_noimpl(f):
1254
+ continue
1255
+
1256
+ print("testing %s ..." % f, flush=True)
1257
+
1258
+ for t, in implemented_sigs["unary"][return_type]:
1259
+ u = implemented_sigs["unary"][return_type][(t,)]
1260
+
1261
+ print(" %s -> %s" % (t, u), flush=True)
1262
+
1263
+ for a in t.testcases():
1264
+ if t.cpu_noimpl(f) or u.cpu_noimpl(f):
1265
+ self.check_unary_not_implemented(f, a, t)
1266
+ else:
1267
+ self.check_unary(f, a, t, u)
1268
+
1269
+ def test_binary_cpu(self):
1270
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1271
+
1272
+ print("\n", flush=True)
1273
+
1274
+ for pattern in "default", "float_result", "bool_result":
1275
+ for f in functions["binary"][pattern]:
1276
+ print("testing %s ..." % f, flush=True)
1277
+
1278
+ for t, u in implemented_sigs["binary"][pattern]:
1279
+ w = implemented_sigs["binary"][pattern][(t, u)]
1280
+
1281
+ print(" %s, %s -> %s" % (t, u, w), flush=True)
1282
+
1283
+ for a in t.testcases():
1284
+ for b in u.testcases():
1285
+ if t.cpu_nokern(f) or u.cpu_nokern(f) or w.cpu_nokern(f):
1286
+ self.check_binary_type_error(f, a, t, b, u)
1287
+ elif t.cpu_noimpl(f) or u.cpu_noimpl(f) or w.cpu_noimpl(f):
1288
+ self.check_binary_not_implemented(f, a, t, b, u)
1289
+ else:
1290
+ self.check_binary(f, a, t, b, u, w)
1291
+
1292
+ def test_binary_mv_cpu(self):
1293
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1294
+
1295
+ print("\n", flush=True)
1296
+
1297
+ for f in functions["binary_mv"]["default"]:
1298
+ print("testing %s ..." % f, flush=True)
1299
+
1300
+ for t, u in implemented_sigs["binary_mv"]["default"]:
1301
+ v, w = implemented_sigs["binary_mv"]["default"][(t, u)]
1302
+
1303
+ print(" %s, %s -> %s, %s" % (t, u, v, w), flush=True)
1304
+
1305
+ for a in t.testcases():
1306
+ for b in u.testcases():
1307
+ self.check_binary_mv(f, a, t, b, u, v, w)
1308
+
1309
+ @unittest.skipIf(cd is None, "test requires cuda")
1310
+ def test_unary_cuda(self):
1311
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1312
+
1313
+ print("\n", flush=True)
1314
+
1315
+ for pattern, return_type in [
1316
+ ("default", "default"),
1317
+ ("complex_math_with_half", "float_result"),
1318
+ ("complex_math", "float_result"),
1319
+ ("real_math_with_half", "float_result"),
1320
+ ("real_math", "float_result")]:
1321
+
1322
+ for f in functions["unary"][pattern]:
1323
+ if np_noimpl(f):
1324
+ continue
1325
+
1326
+ print("testing %s ..." % f, flush=True)
1327
+
1328
+ for t, in implemented_sigs["unary"][return_type]:
1329
+ u = implemented_sigs["unary"][return_type][(t,)]
1330
+
1331
+ print(" %s -> %s" % (t, u), flush=True)
1332
+
1333
+ for a in t.testcases():
1334
+ if t.cuda_noimpl(f) or u.cuda_noimpl(f):
1335
+ self.check_unary_not_implemented(
1336
+ f, a, t, mod=cd, dev="cuda:managed")
1337
+ else:
1338
+ self.check_unary(f, a, t, u,
1339
+ mod=cd, dev="cuda:managed")
1340
+
1341
+ @unittest.skipIf(cd is None, "test requires cuda")
1342
+ def test_binary_cuda(self):
1343
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1344
+
1345
+ print("\n", flush=True)
1346
+
1347
+ for pattern in "default", "float_result", "bool_result":
1348
+ for f in functions["binary"][pattern]:
1349
+ print("testing %s ..." % f, flush=True)
1350
+
1351
+ for t, u in implemented_sigs["binary"][pattern]:
1352
+ w = implemented_sigs["binary"][pattern][(t, u)]
1353
+
1354
+ print(" %s, %s -> %s" % (t, u, w), flush=True)
1355
+
1356
+ for a in t.testcases():
1357
+ for b in u.testcases():
1358
+ if t.cuda_nokern(f) or u.cuda_nokern(f) or w.cuda_nokern(f):
1359
+ self.check_binary_type_error(f, a, t, b, u,
1360
+ mod=cd, dev="cuda:managed")
1361
+ elif t.type == "complex32" or u.type == "complex32" or w.cuda_noimpl(f):
1362
+ self.check_binary_not_implemented(f, a, t, b, u,
1363
+ mod=cd, dev="cuda:managed")
1364
+ else:
1365
+ self.check_binary(f, a, t, b, u, w,
1366
+ mod=cd, dev="cuda:managed")
1367
+
1368
+ @unittest.skipIf(cd is None, "test requires cuda")
1369
+ def test_binary_mv_cuda(self):
1370
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1371
+
1372
+ print("\n", flush=True)
1373
+
1374
+ for f in functions["binary_mv"]["default"]:
1375
+ print("testing %s ..." % f, flush=True)
1376
+
1377
+ for t, u in implemented_sigs["binary_mv"]["default"]:
1378
+ v, w = implemented_sigs["binary_mv"]["default"][(t, u)]
1379
+
1380
+ print(" %s, %s -> %s, %s" % (t, u, v, w), flush=True)
1381
+
1382
+ for a in t.testcases():
1383
+ for b in u.testcases():
1384
+ self.check_binary_mv(f, a, t, b, u, v, w, mod=cd,
1385
+ dev="cuda:managed")
1386
+
1387
+ def test_divide_inexact_cpu(self):
1388
+
1389
+ t = Tint("uint8")
1390
+ u = Tint("uint64")
1391
+
1392
+ a = next(t.testcases())
1393
+ b = next(u.testcases())
1394
+ self.check_binary_type_error("divide", a, t, b, u)
1395
+
1396
+ @unittest.skipIf(cd is None, "test requires cuda")
1397
+ def test_divide_inexact_cuda(self):
1398
+
1399
+ t = Tint("uint8")
1400
+ u = Tint("uint64")
1401
+
1402
+ a = next(t.testcases())
1403
+ b = next(u.testcases())
1404
+ self.check_binary_type_error("divide", a, t, b, u,
1405
+ mod=cd, dev="cuda:managed")
1406
+
1407
+ def test_divmod_type_error_cpu(self):
1408
+
1409
+ t = Tint("uint8")
1410
+ u = Tint("uint64")
1411
+
1412
+ a = next(t.testcases())
1413
+ b = next(u.testcases())
1414
+ self.check_binary_type_error("divmod", a, t, b, u)
1415
+
1416
+ @unittest.skipIf(cd is None, "test requires cuda")
1417
+ def test_divmod_type_error_cuda(self):
1418
+
1419
+ t = Tint("uint8")
1420
+ u = Tint("uint64")
1421
+
1422
+ a = next(t.testcases())
1423
+ b = next(u.testcases())
1424
+ self.check_binary_type_error("divmod", a, t, b, u)
1425
+
1426
+
1427
+ @unittest.skipIf(cd is None, "test requires cuda")
1428
+ class TestCudaManaged(unittest.TestCase):
1429
+
1430
+ def test_mixed_functions(self):
1431
+
1432
+ x = xnd([1,2,3])
1433
+ y = xnd([1,2,3])
1434
+
1435
+ a = xnd([1,2,3], device="cuda:managed")
1436
+ b = xnd([1,2,3], device="cuda:managed")
1437
+
1438
+ z = fn.multiply(x, y)
1439
+ c = cd.multiply(a, b)
1440
+ self.assertEqual(z, c)
1441
+
1442
+ z = fn.multiply(a, b)
1443
+ self.assertEqual(z, c)
1444
+
1445
+ z = fn.multiply(x, b)
1446
+ self.assertEqual(z, c)
1447
+
1448
+ z = fn.multiply(a, y)
1449
+ self.assertEqual(z, c)
1450
+
1451
+ self.assertRaises(ValueError, cd.multiply, x, y)
1452
+ self.assertRaises(ValueError, cd.multiply, x, b)
1453
+ self.assertRaises(ValueError, cd.multiply, a, y)
1454
+
1455
+
1456
+ class TestSpec(unittest.TestCase):
1457
+
1458
+ def __init__(self, *, constr, ndarray, mod,
1459
+ values, value_generator,
1460
+ indices_generator, indices_generator_args):
1461
+ super().__init__()
1462
+ self.constr = constr
1463
+ self.ndarray = ndarray
1464
+ self.mod = mod
1465
+ self.values = values
1466
+ self.value_generator = value_generator
1467
+ self.indices_generator = indices_generator
1468
+ self.indices_generator_args = indices_generator_args
1469
+ self.indices_stack = [None] * 8
1470
+
1471
+ def log_err(self, value, depth):
1472
+ """Dump an error as a Python script for debugging."""
1473
+ dtype = "?int32" if have_none(value) else "int32"
1474
+
1475
+ sys.stderr.write("\n\nfrom xnd import *\n")
1476
+ sys.stderr.write("import gumath.functions as fn\n")
1477
+ sys.stderr.write("from test_gumath import NDArray\n")
1478
+ sys.stderr.write("lst = %s\n\n" % value)
1479
+ sys.stderr.write("x0 = xnd(lst, dtype=\"%s\")\n" % dtype)
1480
+ sys.stderr.write("y0 = NDArray(lst)\n" % value)
1481
+
1482
+ for i in range(depth+1):
1483
+ sys.stderr.write("x%d = x%d[%s]\n" % (i+1, i, itos(self.indices_stack[i])))
1484
+ sys.stderr.write("y%d = y%d[%s]\n" % (i+1, i, itos(self.indices_stack[i])))
1485
+
1486
+ sys.stderr.write("\n")
1487
+
1488
+ def run_reduce(self, nd, d):
1489
+ if not isinstance(nd, xnd) or not isinstance(d, np.ndarray):
1490
+ return
1491
+
1492
+ for attr in ["add", "subtract", "multiply"]:
1493
+ f = getattr(fn, attr)
1494
+ g = getattr(np, attr)
1495
+
1496
+ x = nd_exception = None
1497
+ try:
1498
+ x = gm.reduce(f, nd, dtype=nd.dtype)
1499
+ except Exception as e:
1500
+ nd_exception = e
1501
+
1502
+ y = np_exception = None
1503
+ try:
1504
+ y = g.reduce(d, dtype=d.dtype)
1505
+ except Exception as e:
1506
+ np_exception = e
1507
+
1508
+ if nd_exception or np_exception:
1509
+ self.assertIs(nd_exception.__class__, np_exception.__class__,
1510
+ "f: %r nd: %r np: %r x: %r y: %r" % (attr, nd, d, x, y))
1511
+ else:
1512
+ self.assertEqual(x.value, y.tolist(),
1513
+ "f: %r nd: %r np: %r x: %r y: %r" % (attr, nd, d, x, y))
1514
+
1515
+ for axes in gen_axes(d.ndim):
1516
+ nd_exception = None
1517
+ try:
1518
+ x = gm.reduce(f, nd, axes=axes, dtype=nd.dtype)
1519
+ except Exception as e:
1520
+ nd_exception = e
1521
+
1522
+ np_exception = None
1523
+ try:
1524
+ y = g.reduce(d, axis=axes, dtype=d.dtype)
1525
+ except Exception as e:
1526
+ np_exception = e
1527
+
1528
+ if nd_exception or np_exception:
1529
+ self.assertIs(nd_exception.__class__, np_exception.__class__,
1530
+ "f: %r axes: %r nd: %r np: %r x: %r y: %r" % (attr, axes, nd, d, x, y))
1531
+ else:
1532
+ self.assertEqual(x.value, y.tolist(),
1533
+ "f: %r axes: %r nd: %r np: %r x: %r y: %r" % (attr, axes, nd, d, x, y))
1534
+
1535
+ def run_single(self, nd, d, indices):
1536
+ """Run a single test case."""
1537
+
1538
+ self.assertEqual(len(nd), len(d))
1539
+
1540
+ nd_exception = None
1541
+ try:
1542
+ nd_result = nd[indices]
1543
+ except Exception as e:
1544
+ nd_exception = e
1545
+
1546
+ def_exception = None
1547
+ try:
1548
+ def_result = d[indices]
1549
+ except Exception as e:
1550
+ def_exception = e
1551
+
1552
+ if nd_exception or def_exception:
1553
+ if nd_exception is None and def_exception.__class__ is IndexError:
1554
+ # Example: type = 0 * 0 * int64
1555
+ if len(indices) <= nd.ndim:
1556
+ return None, None
1557
+
1558
+ self.assertIs(nd_exception.__class__, def_exception.__class__)
1559
+ return None, None
1560
+
1561
+ assert(isinstance(nd_result, xnd))
1562
+
1563
+ x = self.mod.sin(nd_result)
1564
+ y = self.mod.multiply(nd_result, nd_result)
1565
+
1566
+ if isinstance(def_result, NDArray):
1567
+ aa = a = def_result.sin()
1568
+ b = def_result * def_result
1569
+ elif isinstance(def_result, int):
1570
+ aa = a = math.sin(def_result)
1571
+ b = def_result * def_result
1572
+ elif def_result is None:
1573
+ aa = a = None
1574
+ aa = b = None
1575
+ elif isinstance(def_result, np.ndarray):
1576
+ aa = np.sin(def_result)
1577
+ a = aa.tolist()
1578
+ bb = np.multiply(def_result, def_result)
1579
+ b = bb.tolist()
1580
+ elif isinstance(def_result, np.int32):
1581
+ aa = np.sin(def_result)
1582
+ a = aa.tolist()
1583
+ bb = np.multiply(def_result, def_result)
1584
+ b = bb.tolist()
1585
+ else:
1586
+ raise TypeError("unexpected def_result: %s : %s" % (def_result, type(def_result)))
1587
+
1588
+ if self.mod == cd:
1589
+ np.testing.assert_allclose(x, aa, 1e-6)
1590
+ np.testing.assert_allclose(y, bb, 1e-6)
1591
+ else:
1592
+ self.assertEqual(x, a)
1593
+ self.assertEqual(y, b)
1594
+
1595
+ if self.mod == fn:
1596
+ self.run_reduce(nd_result, def_result)
1597
+
1598
+ return nd_result, def_result
1599
+
1600
+ def run(self):
1601
+ def check(nd, d, value, depth):
1602
+ if depth > 3: # adjust for longer tests
1603
+ return
1604
+
1605
+ g = self.indices_generator(*self.indices_generator_args)
1606
+
1607
+ for indices in g:
1608
+ self.indices_stack[depth] = indices
1609
+
1610
+ try:
1611
+ next_nd, next_d = self.run_single(nd, d, indices)
1612
+ except Exception as e:
1613
+ self.log_err(value, depth)
1614
+ raise e
1615
+
1616
+ if isinstance(next_d, list): # possibly None or scalar
1617
+ check(next_nd, next_d, value, depth+1)
1618
+
1619
+ def check_buffer(nd, d, value, depth):
1620
+ if depth > 3: # adjust for longer tests
1621
+ return
1622
+ if not isinstance(nd, xnd) or nd.device == "cuda:managed" or \
1623
+ not isinstance(d, np.ndarray):
1624
+ return
1625
+
1626
+ nd = xnd.from_buffer(d)
1627
+ d = np.array(nd, copy=False)
1628
+
1629
+ g = self.indices_generator(*self.indices_generator_args)
1630
+
1631
+ for indices in g:
1632
+ self.indices_stack[depth] = indices
1633
+
1634
+ try:
1635
+ next_nd, next_d = self.run_single(nd, d, indices)
1636
+ except Exception as e:
1637
+ self.log_err(value, depth)
1638
+ raise e
1639
+
1640
+ if isinstance(next_d, list): # possibly None or scalar
1641
+ check_buffer(next_nd, next_d, value, depth+1)
1642
+
1643
+ for value in self.values:
1644
+ dtype = "?int32" if have_none(value) else "int32"
1645
+ if self.constr == xnd:
1646
+ nd = xnd(value, dtype=dtype, device=None if self.mod==fn else "cuda:managed")
1647
+ else:
1648
+ nd = self.constr(value, dtype=dtype)
1649
+ # NumPy does not support "?int32", NDArray does not need the dtype.
1650
+ d = self.ndarray(value, dtype="int32")
1651
+ check(nd, d, value, 0)
1652
+
1653
+ for max_ndim in range(1, 5):
1654
+ for min_shape in (0, 1):
1655
+ for max_shape in range(1, 8):
1656
+ for value in self.value_generator(max_ndim, min_shape, max_shape):
1657
+ dtype = "?int32" if have_none(value) else "int32"
1658
+ if self.constr == xnd:
1659
+ nd = xnd(value, dtype=dtype, device=None if self.mod==fn else "cuda:managed")
1660
+ else:
1661
+ nd = self.constr(value, dtype=dtype)
1662
+ # See above.
1663
+ d = self.ndarray(value, dtype="int32")
1664
+ check(nd, d, value, 0)
1665
+ check_buffer(nd, d, value, 0)
1666
+
1667
+
1668
+ class LongIndexSliceTest(unittest.TestCase):
1669
+
1670
+ def test_subarray(self):
1671
+ # Multidimensional indexing
1672
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1673
+
1674
+ t = TestSpec(constr=xnd,
1675
+ ndarray=NDArray,
1676
+ mod=fn,
1677
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1678
+ value_generator=gen_fixed,
1679
+ indices_generator=genindices,
1680
+ indices_generator_args=())
1681
+ t.run()
1682
+
1683
+ t = TestSpec(constr=xnd,
1684
+ ndarray=NDArray,
1685
+ mod=fn,
1686
+ values=SUBSCRIPT_VAR_TEST_CASES,
1687
+ value_generator=gen_var,
1688
+ indices_generator=genindices,
1689
+ indices_generator_args=())
1690
+ t.run()
1691
+
1692
+ @unittest.skipIf(cd is None or np is None, "cuda or numpy not found")
1693
+ def test_subarray_cuda(self):
1694
+ # Multidimensional indexing
1695
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1696
+
1697
+ t = TestSpec(constr=xnd,
1698
+ ndarray=np.array,
1699
+ mod=cd,
1700
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1701
+ value_generator=gen_fixed,
1702
+ indices_generator=genindices,
1703
+ indices_generator_args=())
1704
+ t.run()
1705
+
1706
+ def test_slices(self):
1707
+ # Multidimensional slicing
1708
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1709
+
1710
+ t = TestSpec(constr=xnd,
1711
+ ndarray=NDArray,
1712
+ mod=fn,
1713
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1714
+ value_generator=gen_fixed,
1715
+ indices_generator=randslices,
1716
+ indices_generator_args=(3,))
1717
+ t.run()
1718
+
1719
+ t = TestSpec(constr=xnd,
1720
+ ndarray=NDArray,
1721
+ mod=fn,
1722
+ values=SUBSCRIPT_VAR_TEST_CASES,
1723
+ value_generator=gen_var,
1724
+ indices_generator=randslices,
1725
+ indices_generator_args=(3,))
1726
+ t.run()
1727
+
1728
+ @unittest.skipIf(cd is None or np is None, "cuda or numpy not found")
1729
+ def test_slices_cuda(self):
1730
+ # Multidimensional slicing
1731
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1732
+
1733
+ t = TestSpec(constr=xnd,
1734
+ ndarray=np.array,
1735
+ mod=cd,
1736
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1737
+ value_generator=gen_fixed,
1738
+ indices_generator=randslices,
1739
+ indices_generator_args=(3,))
1740
+ t.run()
1741
+
1742
+ def test_chained_indices_slices(self):
1743
+ # Multidimensional indexing and slicing, chained
1744
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1745
+
1746
+ t = TestSpec(constr=xnd,
1747
+ ndarray=NDArray,
1748
+ mod=fn,
1749
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1750
+ value_generator=gen_fixed,
1751
+ indices_generator=gen_indices_or_slices,
1752
+ indices_generator_args=())
1753
+ t.run()
1754
+
1755
+
1756
+ t = TestSpec(constr=xnd,
1757
+ ndarray=NDArray,
1758
+ mod=fn,
1759
+ values=SUBSCRIPT_VAR_TEST_CASES,
1760
+ value_generator=gen_var,
1761
+ indices_generator=gen_indices_or_slices,
1762
+ indices_generator_args=())
1763
+ t.run()
1764
+
1765
+ def test_fixed_mixed_indices_slices(self):
1766
+ # Multidimensional indexing and slicing, mixed
1767
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1768
+
1769
+ t = TestSpec(constr=xnd,
1770
+ ndarray=NDArray,
1771
+ mod=fn,
1772
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1773
+ value_generator=gen_fixed,
1774
+ indices_generator=mixed_indices,
1775
+ indices_generator_args=(3,))
1776
+ t.run()
1777
+
1778
+ def test_var_mixed_indices_slices(self):
1779
+ # Multidimensional indexing and slicing, mixed
1780
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1781
+
1782
+ t = TestSpec(constr=xnd,
1783
+ ndarray=NDArray,
1784
+ mod=fn,
1785
+ values=SUBSCRIPT_VAR_TEST_CASES,
1786
+ value_generator=gen_var,
1787
+ indices_generator=mixed_indices,
1788
+ indices_generator_args=(5,))
1789
+ t.run()
1790
+
1791
+ def test_slices_brute_force(self):
1792
+ # Test all possible slices for the given ndim and shape
1793
+ skip_if(SKIP_BRUTE_FORCE, "use --all argument to enable these tests")
1794
+
1795
+ t = TestSpec(constr=xnd,
1796
+ ndarray=NDArray,
1797
+ mod=fn,
1798
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1799
+ value_generator=gen_fixed,
1800
+ indices_generator=genslices_ndim,
1801
+ indices_generator_args=(3, [3,3,3]))
1802
+ t.run()
1803
+
1804
+ t = TestSpec(constr=xnd,
1805
+ ndarray=NDArray,
1806
+ mod=fn,
1807
+ values=SUBSCRIPT_VAR_TEST_CASES,
1808
+ value_generator=gen_var,
1809
+ indices_generator=genslices_ndim,
1810
+ indices_generator_args=(3, [3,3,3]))
1811
+ t.run()
1812
+
1813
+ @unittest.skipIf(np is None, "numpy not found")
1814
+ def test_fixed_mixed_indices_slices_np(self):
1815
+ # Multidimensional indexing and slicing, mixed
1816
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1817
+
1818
+ t = TestSpec(constr=xnd,
1819
+ ndarray=np.array,
1820
+ mod=fn,
1821
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1822
+ value_generator=gen_fixed,
1823
+ indices_generator=mixed_indices,
1824
+ indices_generator_args=(3,))
1825
+ t.run()
1826
+
1827
+ @unittest.skipIf(np is None, "numpy not found")
1828
+ def test_reduce(self):
1829
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1830
+
1831
+ t = TestSpec(constr=xnd,
1832
+ ndarray=np.array,
1833
+ mod=fn,
1834
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1835
+ value_generator=gen_fixed,
1836
+ indices_generator=mixed_indices,
1837
+ indices_generator_args=(3,))
1838
+ t.run()
1839
+
376
1840
 
377
1841
  ALL_TESTS = [
1842
+ TestAPI,
378
1843
  TestCall,
379
1844
  TestRaggedArrays,
1845
+ TestFlexibleArrays,
380
1846
  TestMissingValues,
1847
+ TestEqualN,
381
1848
  TestGraphs,
382
- TestBFloat16,
383
1849
  TestPdist,
384
1850
  TestNumba,
1851
+ TestOut,
1852
+ TestUnaryCPU,
1853
+ TestUnaryCUDA,
1854
+ TestBinaryCPU,
1855
+ TestBinaryCUDA,
1856
+ TestBitwiseCPU,
1857
+ TestBitwiseCUDA,
1858
+ TestFunctions,
1859
+ TestCudaManaged,
1860
+ LongIndexSliceTest,
385
1861
  ]
386
1862
 
387
1863
 
@@ -389,7 +1865,11 @@ if __name__ == '__main__':
389
1865
  parser = argparse.ArgumentParser()
390
1866
  parser.add_argument("-f", "--failfast", action="store_true",
391
1867
  help="stop the test run on first error")
1868
+ parser.add_argument('--long', action="store_true", help="run long slice tests")
1869
+ parser.add_argument('--all', action="store_true", help="run brute force tests")
392
1870
  args = parser.parse_args()
1871
+ SKIP_LONG = not (args.long or args.all)
1872
+ SKIP_BRUTE_FORCE = not args.all
393
1873
 
394
1874
  suite = unittest.TestSuite()
395
1875
  loader = unittest.TestLoader()