gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2010 Python Software Foundation. All Rights Reserved.
2
+ # Adapted from Python's Lib/test/test_strtod.py (by Mark Dickinson)
3
+
4
+ # More test cases for deccheck.py.
5
+
6
+ import random
7
+
8
+ TEST_SIZE = 2
9
+
10
+
11
+ def test_short_halfway_cases():
12
+ # exact halfway cases with a small number of significant digits
13
+ for k in 0, 5, 10, 15, 20:
14
+ # upper = smallest integer >= 2**54/5**k
15
+ upper = -(-2**54//5**k)
16
+ # lower = smallest odd number >= 2**53/5**k
17
+ lower = -(-2**53//5**k)
18
+ if lower % 2 == 0:
19
+ lower += 1
20
+ for i in range(10 * TEST_SIZE):
21
+ # Select a random odd n in [2**53/5**k,
22
+ # 2**54/5**k). Then n * 10**k gives a halfway case
23
+ # with small number of significant digits.
24
+ n, e = random.randrange(lower, upper, 2), k
25
+
26
+ # Remove any additional powers of 5.
27
+ while n % 5 == 0:
28
+ n, e = n // 5, e + 1
29
+ assert n % 10 in (1, 3, 7, 9)
30
+
31
+ # Try numbers of the form n * 2**p2 * 10**e, p2 >= 0,
32
+ # until n * 2**p2 has more than 20 significant digits.
33
+ digits, exponent = n, e
34
+ while digits < 10**20:
35
+ s = '{}e{}'.format(digits, exponent)
36
+ yield s
37
+ # Same again, but with extra trailing zeros.
38
+ s = '{}e{}'.format(digits * 10**40, exponent - 40)
39
+ yield s
40
+ digits *= 2
41
+
42
+ # Try numbers of the form n * 5**p2 * 10**(e - p5), p5
43
+ # >= 0, with n * 5**p5 < 10**20.
44
+ digits, exponent = n, e
45
+ while digits < 10**20:
46
+ s = '{}e{}'.format(digits, exponent)
47
+ yield s
48
+ # Same again, but with extra trailing zeros.
49
+ s = '{}e{}'.format(digits * 10**40, exponent - 40)
50
+ yield s
51
+ digits *= 5
52
+ exponent -= 1
53
+
54
+ def test_halfway_cases():
55
+ # test halfway cases for the round-half-to-even rule
56
+ for i in range(1000):
57
+ for j in range(TEST_SIZE):
58
+ # bit pattern for a random finite positive (or +0.0) float
59
+ bits = random.randrange(2047*2**52)
60
+
61
+ # convert bit pattern to a number of the form m * 2**e
62
+ e, m = divmod(bits, 2**52)
63
+ if e:
64
+ m, e = m + 2**52, e - 1
65
+ e -= 1074
66
+
67
+ # add 0.5 ulps
68
+ m, e = 2*m + 1, e - 1
69
+
70
+ # convert to a decimal string
71
+ if e >= 0:
72
+ digits = m << e
73
+ exponent = 0
74
+ else:
75
+ # m * 2**e = (m * 5**-e) * 10**e
76
+ digits = m * 5**-e
77
+ exponent = e
78
+ s = '{}e{}'.format(digits, exponent)
79
+ yield s
80
+
81
+ def test_boundaries():
82
+ # boundaries expressed as triples (n, e, u), where
83
+ # n*10**e is an approximation to the boundary value and
84
+ # u*10**e is 1ulp
85
+ boundaries = [
86
+ (10000000000000000000, -19, 1110), # a power of 2 boundary (1.0)
87
+ (17976931348623159077, 289, 1995), # overflow boundary (2.**1024)
88
+ (22250738585072013831, -327, 4941), # normal/subnormal (2.**-1022)
89
+ (0, -327, 4941), # zero
90
+ ]
91
+ for n, e, u in boundaries:
92
+ for j in range(1000):
93
+ for i in range(TEST_SIZE):
94
+ digits = n + random.randrange(-3*u, 3*u)
95
+ exponent = e
96
+ s = '{}e{}'.format(digits, exponent)
97
+ yield s
98
+ n *= 10
99
+ u *= 10
100
+ e -= 1
101
+
102
+ def test_underflow_boundary():
103
+ # test values close to 2**-1075, the underflow boundary; similar
104
+ # to boundary_tests, except that the random error doesn't scale
105
+ # with n
106
+ for exponent in range(-400, -320):
107
+ base = 10**-exponent // 2**1075
108
+ for j in range(TEST_SIZE):
109
+ digits = base + random.randrange(-1000, 1000)
110
+ s = '{}e{}'.format(digits, exponent)
111
+ yield s
112
+
113
+ def test_bigcomp():
114
+ for ndigs in 5, 10, 14, 15, 16, 17, 18, 19, 20, 40, 41, 50:
115
+ dig10 = 10**ndigs
116
+ for i in range(100 * TEST_SIZE):
117
+ digits = random.randrange(dig10)
118
+ exponent = random.randrange(-400, 400)
119
+ s = '{}e{}'.format(digits, exponent)
120
+ yield s
121
+
122
+ def test_parsing():
123
+ # make '0' more likely to be chosen than other digits
124
+ digits = '000000123456789'
125
+ signs = ('+', '-', '')
126
+
127
+ # put together random short valid strings
128
+ # \d*[.\d*]?e
129
+ for i in range(1000):
130
+ for j in range(TEST_SIZE):
131
+ s = random.choice(signs)
132
+ intpart_len = random.randrange(5)
133
+ s += ''.join(random.choice(digits) for _ in range(intpart_len))
134
+ if random.choice([True, False]):
135
+ s += '.'
136
+ fracpart_len = random.randrange(5)
137
+ s += ''.join(random.choice(digits)
138
+ for _ in range(fracpart_len))
139
+ else:
140
+ fracpart_len = 0
141
+ if random.choice([True, False]):
142
+ s += random.choice(['e', 'E'])
143
+ s += random.choice(signs)
144
+ exponent_len = random.randrange(1, 4)
145
+ s += ''.join(random.choice(digits)
146
+ for _ in range(exponent_len))
147
+
148
+ if intpart_len + fracpart_len:
149
+ yield s
150
+
151
+
152
+ TESTCASES = [
153
+ [x for x in test_short_halfway_cases()],
154
+ [x for x in test_halfway_cases()],
155
+ [x for x in test_boundaries()],
156
+ [x for x in test_underflow_boundary()],
157
+ [x for x in test_bigcomp()],
158
+ [x for x in test_parsing()],
159
+ ]
160
+
161
+ def un_randfloat():
162
+ for i in range(2):
163
+ l = random.choice(TESTCASES[:6])
164
+ yield random.choice(l)
165
+
166
+ def bin_randfloat():
167
+ for i in range(2):
168
+ l1 = random.choice(TESTCASES)
169
+ l2 = random.choice(TESTCASES)
170
+ yield random.choice(l1), random.choice(l2)
171
+
172
+ def tern_randfloat():
173
+ for i in range(2):
174
+ l1 = random.choice(TESTCASES)
175
+ l2 = random.choice(TESTCASES)
176
+ l3 = random.choice(TESTCASES)
177
+ yield random.choice(l1), random.choice(l2), random.choice(l3)
@@ -35,37 +35,59 @@ import gumath.functions as fn
35
35
  import gumath.examples as ex
36
36
  from xnd import xnd
37
37
  from ndtypes import ndt
38
- from extending import Graph, bfloat16
38
+ from extending import Graph
39
39
  import sys, time
40
+ import platform
40
41
  import math
42
+ import cmath
41
43
  import unittest
42
44
  import argparse
45
+ from gumath_aux import *
46
+
47
+ try:
48
+ import gumath.cuda as cd
49
+ except ImportError:
50
+ cd = None
43
51
 
44
52
  try:
45
53
  import numpy as np
54
+ np.warnings.filterwarnings('ignore')
46
55
  except ImportError:
47
56
  np = None
48
57
 
58
+ SKIP_LONG = True
59
+ SKIP_BRUTE_FORCE = True
49
60
 
50
- TEST_CASES = [
51
- ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
61
+ ARCH = platform.architecture()[0]
52
62
 
53
- ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
54
- "2 * 1000 * float64", "float64"),
55
63
 
56
- (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
64
+ class TestAPI(unittest.TestCase):
57
65
 
58
- ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
66
+ def test_api(self):
59
67
 
60
- ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
61
- "2 * 1000 * float32", "float32"),
62
-
63
- (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
64
- ]
68
+ self.assertIsInstance(fn.add, gm.gufunc)
69
+ self.assertRaises(TypeError, gm.gufunc.__new__)
70
+ self.assertRaises(TypeError, gm.gufunc.__new__, 1)
65
71
 
66
72
 
67
73
  class TestCall(unittest.TestCase):
68
74
 
75
+ def test_subclass(self):
76
+
77
+ class X(xnd):
78
+ pass
79
+
80
+ x = X([1, 2, 3])
81
+ y = X([1, 2, 3])
82
+
83
+ z = fn.multiply(x, y)
84
+ self.assertEqual(z, [1, 4, 9])
85
+ self.assertEqual(type(z), xnd)
86
+
87
+ z = fn.multiply(x, y, cls=X)
88
+ self.assertEqual(z, [1, 4, 9])
89
+ self.assertEqual(type(z), X)
90
+
69
91
  def test_sin_scalar(self):
70
92
 
71
93
  x1 = xnd(1.2, type="float64")
@@ -213,6 +235,293 @@ class TestMissingValues(unittest.TestCase):
213
235
 
214
236
  self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
215
237
 
238
+ def test_unary(self):
239
+ a = [0, None, 2]
240
+ ans = xnd([math.sin(x) if x is not None else None for x in a])
241
+
242
+ x = xnd(a, dtype="?float64")
243
+ y = fn.sin(x)
244
+ self.assertEqual(y.value, ans)
245
+
246
+ def test_binary(self):
247
+ a = [3, None, 3]
248
+ b = [100, 1, None]
249
+ ans = xnd([t[0] * t[1] if t[0] is not None and t[1] is not None else None
250
+ for t in zip(a, b)])
251
+
252
+ x = xnd(a)
253
+ y = xnd(b)
254
+ z = fn.multiply(x, y)
255
+ self.assertEqual(z.value, ans)
256
+
257
+ def test_reduce(self):
258
+ a = [1, None, 2]
259
+ x = xnd(a)
260
+
261
+ y = gm.reduce(fn.add, x)
262
+ self.assertEqual(y, None)
263
+
264
+ y = gm.reduce(fn.multiply, x)
265
+ self.assertEqual(y, None)
266
+
267
+ y = gm.reduce(fn.subtract, x)
268
+ self.assertEqual(y, None)
269
+
270
+ x = xnd([], dtype="?int32")
271
+
272
+ y = gm.reduce(fn.add, x)
273
+ self.assertEqual(y, 0)
274
+
275
+ @unittest.skipIf(cd is None, "test requires cuda")
276
+ def test_reduce_cuda(self):
277
+ a = [1, None, 2]
278
+ x = xnd(a, device="cuda:managed")
279
+
280
+ y = gm.reduce(cd.add, x)
281
+ self.assertEqual(y, None)
282
+
283
+ y = gm.reduce(cd.multiply, x)
284
+ self.assertEqual(y, None)
285
+
286
+ x = xnd([], dtype="?int32", device="cuda:managed")
287
+ y = gm.reduce(fn.add, x)
288
+ self.assertEqual(y, 0)
289
+
290
+ def test_comparisons(self):
291
+ a = [1, None, 3, 5]
292
+ b = [2, None, 3, 4]
293
+
294
+ x = xnd(a)
295
+ y = xnd(b)
296
+
297
+ ans = fn.equal(x, y)
298
+ self.assertEqual(ans.value, [False, None, True, False])
299
+
300
+ ans = fn.not_equal(x, y)
301
+ self.assertEqual(ans.value, [True, None, False, True])
302
+
303
+ ans = fn.less(x, y)
304
+ self.assertEqual(ans.value, [True, None, False, False])
305
+
306
+ ans = fn.less_equal(x, y)
307
+ self.assertEqual(ans.value, [True, None, True, False])
308
+
309
+ ans = fn.greater_equal(x, y)
310
+ self.assertEqual(ans.value, [False, None, True, True])
311
+
312
+ ans = fn.greater(x, y)
313
+ self.assertEqual(ans.value, [False, None, False, True])
314
+
315
+ @unittest.skipIf(cd is None, "test requires cuda")
316
+ def test_comparisons_cuda(self):
317
+ a = [1, None, 3, 5]
318
+ b = [2, None, 3, 4]
319
+
320
+ x = xnd(a, device="cuda:managed")
321
+ y = xnd(b, device="cuda:managed")
322
+
323
+ ans = cd.equal(x, y)
324
+ self.assertEqual(ans.value, [False, None, True, False])
325
+
326
+ ans = cd.not_equal(x, y)
327
+ self.assertEqual(ans.value, [True, None, False, True])
328
+
329
+ ans = cd.less(x, y)
330
+ self.assertEqual(ans.value, [True, None, False, False])
331
+
332
+ ans = cd.less_equal(x, y)
333
+ self.assertEqual(ans.value, [True, None, True, False])
334
+
335
+ ans = cd.greater_equal(x, y)
336
+ self.assertEqual(ans.value, [False, None, True, True])
337
+
338
+ ans = cd.greater(x, y)
339
+ self.assertEqual(ans.value, [False, None, False, True])
340
+
341
+ def test_equaln(self):
342
+ a = [1, None, 3, 5]
343
+ b = [2, None, 3, 4]
344
+
345
+ x = xnd(a)
346
+ y = xnd(b)
347
+ z = fn.equaln(x, y)
348
+ self.assertEqual(z, [False, True, True, False])
349
+ self.assertEqual(z.dtype, ndt("bool"))
350
+
351
+ a = [1, None, 3, 5]
352
+ b = [2, 0, 3, 4]
353
+
354
+ x = xnd(a)
355
+ y = xnd(b)
356
+ z = fn.equaln(x, y)
357
+ self.assertEqual(z, [False, False, True, False])
358
+ self.assertEqual(z.dtype, ndt("bool"))
359
+
360
+ # NA eqn NA
361
+ a = [None]
362
+ b = [None]
363
+ x = xnd(a)
364
+ y = xnd(b)
365
+ z = fn.equaln(x, y)
366
+ self.assertEqual(z, [True])
367
+
368
+ # !(NA eqn 0)
369
+ a = [None]
370
+ b = [0.0]
371
+ x = xnd(a)
372
+ y = xnd(b)
373
+ z = fn.equaln(x, y)
374
+ self.assertEqual(z, [False])
375
+
376
+ # !(0 eqn NA)
377
+ a = [0.0]
378
+ b = [None]
379
+ x = xnd(a)
380
+ y = xnd(b)
381
+ z = fn.equaln(x, y)
382
+ self.assertEqual(z, [False])
383
+
384
+ # !(NA eqn NaN)
385
+ a = [None]
386
+ b = [float("nan")]
387
+ x = xnd(a)
388
+ y = xnd(b)
389
+ z = fn.equaln(x, y)
390
+ self.assertEqual(z, [False])
391
+
392
+ # !(NaN eqn NA)
393
+ a = [float("nan")]
394
+ b = [None]
395
+ x = xnd(a)
396
+ y = xnd(b)
397
+ z = fn.equaln(x, y)
398
+ self.assertEqual(z, [False])
399
+
400
+ @unittest.skipIf(cd is None, "test requires cuda")
401
+ def test_equaln_cuda(self):
402
+ a = [1, None, 3, 5]
403
+ b = [2, None, 3, 4]
404
+
405
+ x = xnd(a, device="cuda:managed")
406
+ y = xnd(b, device="cuda:managed")
407
+ z = cd.equaln(x, y)
408
+ self.assertEqual(z, [False, True, True, False])
409
+ self.assertEqual(z.dtype, ndt("bool"))
410
+
411
+ a = [1, None, 3, 5]
412
+ b = [2, 0, 3, 4]
413
+
414
+ x = xnd(a, device="cuda:managed")
415
+ y = xnd(b, device="cuda:managed")
416
+ z = cd.equaln(x, y)
417
+ self.assertEqual(z, [False, False, True, False])
418
+ self.assertEqual(z.dtype, ndt("bool"))
419
+
420
+ # NA eqn NA
421
+ a = [None]
422
+ b = [None]
423
+ x = xnd(a, device="cuda:managed")
424
+ y = xnd(b, device="cuda:managed")
425
+ z = cd.equaln(x, y)
426
+ self.assertEqual(z, [True])
427
+
428
+ # !(NA eqn 0)
429
+ a = [None]
430
+ b = [0.0]
431
+ x = xnd(a, device="cuda:managed")
432
+ y = xnd(b, device="cuda:managed")
433
+ z = cd.equaln(x, y)
434
+ self.assertEqual(z, [False])
435
+
436
+ # !(0 eqn NA)
437
+ a = [0.0]
438
+ b = [None]
439
+ x = xnd(a, device="cuda:managed")
440
+ y = xnd(b, device="cuda:managed")
441
+ z = cd.equaln(x, y)
442
+ self.assertEqual(z, [False])
443
+
444
+ # !(NA eqn NaN)
445
+ a = [None]
446
+ b = [float("nan")]
447
+ x = xnd(a, device="cuda:managed")
448
+ y = xnd(b, device="cuda:managed")
449
+ z = cd.equaln(x, y)
450
+ self.assertEqual(z, [False])
451
+
452
+ # !(NaN eqn NA)
453
+ a = [float("nan")]
454
+ b = [None]
455
+ x = xnd(a, device="cuda:managed")
456
+ y = xnd(b, device="cuda:managed")
457
+ z = cd.equaln(x, y)
458
+ self.assertEqual(z, [False])
459
+
460
+
461
+ class TestEqualN(unittest.TestCase):
462
+
463
+ def test_nan_float(self):
464
+ for dtype in "bfloat16", "float32", "float64":
465
+ x = xnd([0, float("nan"), 2], dtype=dtype)
466
+
467
+ y = xnd([0, float("nan"), 2], dtype=dtype)
468
+ z = fn.equaln(x, y)
469
+ self.assertEqual(z, [True, True, True])
470
+
471
+ y = xnd([0, 1, 2], dtype=dtype)
472
+ z = fn.equaln(x, y)
473
+ self.assertEqual(z, [True, False, True])
474
+
475
+ def test_nan_complex(self):
476
+ for dtype in "complex64", "complex128":
477
+ for a, b, ans in [
478
+ (complex(float("nan"), 1.2), complex(float("nan"), 1.2), True),
479
+ (complex(float("nan"), 1.2), complex(float("nan"), 1), False),
480
+ (complex(float("nan"), float("nan")), complex(float("nan"), 1.2), False),
481
+
482
+ (complex(1.2, float("nan")), complex(1.2, float("nan")), True),
483
+ (complex(1.2, float("nan")), complex(1, float("nan")), False),
484
+ (complex(float("nan"), float("nan")), complex(1.2, float("nan")), False),
485
+
486
+ (complex(float("nan"), float("nan")), complex(float("nan"), float("nan")), True)]:
487
+
488
+ x = xnd([0, a, 2], dtype=dtype)
489
+ y = xnd([0, b, 2], dtype=dtype)
490
+ z = fn.equaln(x, y)
491
+ self.assertEqual(z, [True, ans, True])
492
+
493
+ @unittest.skipIf(cd is None, "test requires cuda")
494
+ def test_nan_float_cuda(self):
495
+ for dtype in "bfloat16", "float16", "float32", "float64":
496
+ x = xnd([0, float("nan"), 2], dtype=dtype, device="cuda:managed")
497
+
498
+ y = xnd([0, float("nan"), 2], dtype=dtype, device="cuda:managed")
499
+ z = cd.equaln(x, y)
500
+ self.assertEqual(z, [True, True, True])
501
+
502
+ y = xnd([0, 1, 2], dtype=dtype, device="cuda:managed")
503
+ z = cd.equaln(x, y)
504
+ self.assertEqual(z, [True, False, True])
505
+
506
+ @unittest.skipIf(cd is None, "test requires cuda")
507
+ def test_nan_complex_cuda(self):
508
+ for dtype in "complex64", "complex128":
509
+ for a, b, ans in [
510
+ (complex(float("nan"), 1.2), complex(float("nan"), 1.2), True),
511
+ (complex(float("nan"), 1.2), complex(float("nan"), 1), False),
512
+ (complex(float("nan"), float("nan")), complex(float("nan"), 1.2), False),
513
+
514
+ (complex(1.2, float("nan")), complex(1.2, float("nan")), True),
515
+ (complex(1.2, float("nan")), complex(1, float("nan")), False),
516
+ (complex(float("nan"), float("nan")), complex(1.2, float("nan")), False),
517
+
518
+ (complex(float("nan"), float("nan")), complex(float("nan"), float("nan")), True)]:
519
+
520
+ x = xnd([0, a, 2], dtype=dtype, device="cuda:managed")
521
+ y = xnd([0, b, 2], dtype=dtype, device="cuda:managed")
522
+ z = cd.equaln(x, y)
523
+ self.assertEqual(z, [True, ans, True])
524
+
216
525
 
217
526
  class TestRaggedArrays(unittest.TestCase):
218
527
 
@@ -237,8 +546,59 @@ class TestRaggedArrays(unittest.TestCase):
237
546
  self.assertEqual(y.value, ans)
238
547
 
239
548
 
549
+ class TestFlexibleArrays(unittest.TestCase):
550
+
551
+ def test_sin_var_compatible(self):
552
+ s = math.sin
553
+ lst = [[[1.0],
554
+ [2.0, 3.0],
555
+ [4.0, 5.0, 6.0]],
556
+ [[7.0],
557
+ [8.0, 9.0],
558
+ [10.0, 11.0, 12.0]]]
559
+
560
+ ans = [[[s(1.0)],
561
+ [s(2.0), s(3.0)],
562
+ [s(4.0), s(5.0), s(6.0)]],
563
+ [[s(7.0)],
564
+ [s(8.0), s(9.0)],
565
+ [s(10.0), s(11.0), s(12.0)]]]
566
+
567
+ x = xnd(lst, type="array * array * array * float64")
568
+ y = fn.sin(x)
569
+ self.assertEqual(y.value, ans)
570
+
571
+ def test_add(self):
572
+ a = [[[1.0],
573
+ [2.0, 3.0],
574
+ [4.0, 5.0, 6.0]],
575
+ [[7.0],
576
+ [8.0, 9.0],
577
+ [10.0, 11.0, 12.0]]]
578
+
579
+ b = [[[2.0],
580
+ [3.0, 4.0],
581
+ [5.0, 6.0, 7.0]],
582
+ [[-8.0],
583
+ [-9.0, -10.0],
584
+ [111.1, 121.2, 25.3]]]
585
+
586
+ ans = [[[1.0+2.0],
587
+ [2.0+3.0, 3.0+4.0],
588
+ [4.0+5.0, 5.0+6.0, 6.0+7.0]],
589
+ [[7.0-8.0],
590
+ [8.0-9.0, 9.0-10.0],
591
+ [10.0+111.1, 11.0+121.2, 12.0+25.3]]]
592
+
593
+ x = xnd(a, type="array * array * array * float64")
594
+ y = xnd(b, type="array * array * array * float64")
595
+ z = fn.add(x, y)
596
+ self.assertEqual(z.value, ans)
597
+
598
+
240
599
  class TestGraphs(unittest.TestCase):
241
600
 
601
+ @unittest.skipIf(True, "abstract return types are temporarily disabled")
242
602
  def test_shortest_path(self):
243
603
  graphs = [[[(1, 1.2), (2, 4.4)],
244
604
  [(2, 2.2)],
@@ -274,17 +634,6 @@ class TestGraphs(unittest.TestCase):
274
634
  self.assertRaises(ValueError, Graph, lst)
275
635
 
276
636
 
277
- @unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
278
- class TestBFloat16(unittest.TestCase):
279
-
280
- def test_init(self):
281
- lst = [1.2e10, 2.1121, -3e20]
282
- ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
283
-
284
- x = bfloat16(lst)
285
- self.assertEqual(x.value, ans)
286
-
287
-
288
637
  class TestPdist(unittest.TestCase):
289
638
 
290
639
  def test_exceptions(self):
@@ -373,15 +722,1142 @@ class TestNumba(unittest.TestCase):
373
722
  np.testing.assert_equal(z, c)
374
723
 
375
724
 
725
+ class TestOut(unittest.TestCase):
726
+
727
+ def test_api_cpu(self):
728
+ # negative
729
+ x = xnd([1, 2, 3])
730
+ y = xnd.empty("3 * int64")
731
+ z = fn.negative(x, out=y)
732
+
733
+ self.assertIs(z, y)
734
+ self.assertEqual(y, xnd([-1, -2, -3]))
735
+
736
+ # divmod
737
+ x = xnd([10, 20, 30])
738
+ y = xnd([7, 8, 9])
739
+ a = xnd.empty("3 * int64")
740
+ b = xnd.empty("3 * int64")
741
+ q, r = fn.divmod(x, y, out=(a, b))
742
+
743
+ self.assertIs(q, a)
744
+ self.assertIs(r, b)
745
+
746
+ self.assertEqual(q, xnd([1, 2, 3]))
747
+ self.assertEqual(r, xnd([3, 4, 3]))
748
+
749
+ @unittest.skipIf(cd is None, "test requires cuda")
750
+ def test_api_cuda(self):
751
+ # negative
752
+ x = xnd([1, 2, 3], device="cuda:managed")
753
+ y = xnd.empty("3 * int64", device="cuda:managed")
754
+ z = cd.negative(x, out=y)
755
+
756
+ self.assertIs(z, y)
757
+ self.assertEqual(y, xnd([-1, -2, -3]))
758
+
759
+ # divmod
760
+ x = xnd([10, 20, 30], device="cuda:managed")
761
+ y = xnd([7, 8, 9], device="cuda:managed")
762
+ a = xnd.empty("3 * int64", device="cuda:managed")
763
+ b = xnd.empty("3 * int64", device="cuda:managed")
764
+ q, r = cd.divmod(x, y, out=(a, b))
765
+
766
+ self.assertIs(q, a)
767
+ self.assertIs(r, b)
768
+
769
+ self.assertEqual(q, xnd([1, 2, 3]))
770
+ self.assertEqual(r, xnd([3, 4, 3]))
771
+
772
+ def test_broadcast_cpu(self):
773
+ # multiply
774
+ x = xnd([1, 2, 3])
775
+ y = xnd([2])
776
+ z = xnd.empty("3 * int64")
777
+ ans = fn.multiply(x, y, out=z)
778
+
779
+ self.assertIs(ans, z)
780
+ self.assertEqual(ans, xnd([2, 4, 6]))
781
+
782
+ x = xnd([1, 2, 3])
783
+ y = xnd(2)
784
+ z = xnd.empty("3 * int64")
785
+ ans = fn.multiply(x, y, out=z)
786
+
787
+ self.assertIs(ans, z)
788
+ self.assertEqual(ans, xnd([2, 4, 6]))
789
+
790
+ # divmod
791
+ x = xnd([10, 20, 30])
792
+ y = xnd([3])
793
+ a = xnd.empty("3 * int64")
794
+ b = xnd.empty("3 * int64")
795
+ q, r = fn.divmod(x, y, out=(a, b))
796
+
797
+ self.assertIs(q, a)
798
+ self.assertIs(r, b)
799
+ self.assertEqual(q, xnd([3, 6, 10]))
800
+ self.assertEqual(r, xnd([1, 2, 0]))
801
+
802
+ x = xnd([10, 20, 30])
803
+ y = xnd(3)
804
+ a = xnd.empty("3 * int64")
805
+ b = xnd.empty("3 * int64")
806
+ q, r = fn.divmod(x, y, out=(a, b))
807
+
808
+ self.assertIs(q, a)
809
+ self.assertIs(r, b)
810
+ self.assertEqual(q, xnd([3, 6, 10]))
811
+ self.assertEqual(r, xnd([1, 2, 0]))
812
+
813
+ @unittest.skipIf(cd is None, "test requires cuda")
814
+ def test_broadcast_cuda(self):
815
+ # multiply
816
+ x = xnd([1, 2, 3], device="cuda:managed")
817
+ y = xnd([2], device="cuda:managed")
818
+ z = xnd.empty("3 * int64", device="cuda:managed")
819
+ ans = fn.multiply(x, y, out=z)
820
+
821
+ self.assertIs(ans, z)
822
+ self.assertEqual(ans, xnd([2, 4, 6]))
823
+
824
+
825
+ class TestUnaryCPU(unittest.TestCase):
826
+
827
+ def test_acos(self):
828
+ a = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
829
+ b = [math.acos(x) for x in a]
830
+
831
+ x = xnd(a, dtype="float64")
832
+ y = fn.acos(x)
833
+ self.assertEqual(y, b)
834
+
835
+ def test_acos_opt(self):
836
+ a = [0, 0.1, 0.2, None, 0.4, 0.5, 0.6, None]
837
+ b = [math.acos(x) if x is not None else None for x in a]
838
+
839
+ x = xnd(a, dtype="?float64")
840
+ y = fn.acos(x)
841
+ self.assertEqual(y, b)
842
+
843
+ def test_inexact_cast(self):
844
+ a = [0, 1, 2, 3, 4, 5, 6, 7]
845
+ x = xnd(a, dtype="int64")
846
+ self.assertRaises(ValueError, fn.sin, x)
847
+
848
+
849
+ @unittest.skipIf(cd is None, "test requires cuda")
850
+ class TestUnaryCUDA(unittest.TestCase):
851
+
852
+ def test_cos(self):
853
+ a = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
854
+ b = [math.cos(x) for x in a]
855
+
856
+ x = xnd(a, dtype="float64", device="cuda:managed")
857
+ y = cd.cos(x)
858
+ self.assertEqual(y, b)
859
+
860
+ def test_cos_opt(self):
861
+ a = [0, 0.1, 0.2, None, 0.4, 0.5, 0.6, None]
862
+ b = [math.cos(x) if x is not None else None for x in a]
863
+
864
+ x = xnd(a, dtype="?float64", device="cuda:managed")
865
+ y = cd.cos(x)
866
+ self.assertEqual(y, b)
867
+
868
+ def test_inexact_cast(self):
869
+ a = [0, 1, 2, 3, 4, 5, 6, 7]
870
+ x = xnd(a, dtype="int64", device="cuda:managed")
871
+ self.assertRaises(ValueError, cd.sin, x)
872
+
873
+
874
+ class TestBinaryCPU(unittest.TestCase):
875
+
876
+ def test_binary(self):
877
+ for t, u in implemented_sigs["binary"]["default"]:
878
+ w = implemented_sigs["binary"]["default"][(t, u)]
879
+
880
+ if t.cpu_noimpl() or u.cpu_noimpl():
881
+ continue
882
+
883
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type)
884
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
885
+ z = fn.add(x, y)
886
+ self.assertEqual(z, [1, 3, 5, 7, 9, 11, 13, 15])
887
+
888
+ def test_add_opt(self):
889
+ for t, u in implemented_sigs["binary"]["default"]:
890
+ w = implemented_sigs["binary"]["default"][(t, u)]
891
+
892
+ if t.cpu_noimpl() or u.cpu_noimpl():
893
+ continue
894
+
895
+ x = xnd([0, 1, None, 3, 4, 5, 6, 7], dtype="?" + t.type)
896
+ y = xnd([1, 2, 3, 4, 5, 6, None, 8], dtype="?" + u.type)
897
+ z = fn.add(x, y)
898
+ self.assertEqual(z, [1, 3, None, 7, 9, 11, None, 15])
899
+
900
+ def test_subtract(self):
901
+ for t, u in implemented_sigs["binary"]["default"]:
902
+ w = implemented_sigs["binary"]["default"][(t, u)]
903
+
904
+ if t.cpu_noimpl() or u.cpu_noimpl():
905
+ continue
906
+
907
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type)
908
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
909
+ z = fn.subtract(x, y)
910
+ self.assertEqual(z, [1, 1, 1, 1, 1, 1, 1, 1])
911
+
912
+ def test_multiply(self):
913
+ for t, u in implemented_sigs["binary"]["default"]:
914
+ w = implemented_sigs["binary"]["default"][(t, u)]
915
+
916
+ if t.cpu_noimpl() or u.cpu_noimpl():
917
+ continue
918
+
919
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type)
920
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type)
921
+ z = fn.multiply(x, y)
922
+ self.assertEqual(z, [2, 6, 12, 20, 30, 42, 56, 72])
923
+
924
+
925
+ @unittest.skipIf(cd is None, "test requires cuda")
926
+ class TestBinaryCUDA(unittest.TestCase):
927
+
928
+ def test_binary(self):
929
+ for t, u in implemented_sigs["binary"]["default"]:
930
+ w = implemented_sigs["binary"]["default"][(t, u)]
931
+
932
+ if t.cpu_noimpl() or u.cpu_noimpl():
933
+ continue
934
+
935
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type, device="cuda:managed")
936
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
937
+ z = cd.add(x, y)
938
+ self.assertEqual(z, [1, 3, 5, 7, 9, 11, 13, 15])
939
+
940
+ def test_add_opt(self):
941
+ for t, u in implemented_sigs["binary"]["default"]:
942
+ w = implemented_sigs["binary"]["default"][(t, u)]
943
+
944
+ if t.cpu_noimpl() or u.cpu_noimpl():
945
+ continue
946
+
947
+ x = xnd([0, 1, None, 3, 4, 5, 6, 7], dtype="?" + t.type, device="cuda:managed")
948
+ y = xnd([1, 2, 3, 4, 5, 6, None, 8], dtype="?" + u.type, device="cuda:managed")
949
+ z = cd.add(x, y)
950
+ self.assertEqual(z, [1, 3, None, 7, 9, 11, None, 15])
951
+
952
+ def test_subtract(self):
953
+ for t, u in implemented_sigs["binary"]["default"]:
954
+ w = implemented_sigs["binary"]["default"][(t, u)]
955
+
956
+ if t.cpu_noimpl() or u.cpu_noimpl():
957
+ continue
958
+
959
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type, device="cuda:managed")
960
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
961
+ z = cd.subtract(x, y)
962
+ self.assertEqual(z, [1, 1, 1, 1, 1, 1, 1, 1])
963
+
964
+ def test_multiply(self):
965
+ for t, u in implemented_sigs["binary"]["default"]:
966
+ w = implemented_sigs["binary"]["default"][(t, u)]
967
+
968
+ if t.cpu_noimpl() or u.cpu_noimpl():
969
+ continue
970
+
971
+ x = xnd([2, 3, 4, 5, 6, 7, 8, 9], dtype=t.type, device="cuda:managed")
972
+ y = xnd([1, 2, 3, 4, 5, 6, 7, 8], dtype=u.type, device="cuda:managed")
973
+ z = cd.multiply(x, y)
974
+ self.assertEqual(z, [2, 6, 12, 20, 30, 42, 56, 72])
975
+
976
+
977
+ class TestBitwiseCPU(unittest.TestCase):
978
+
979
+ def test_and(self):
980
+ for t, u in implemented_sigs["binary"]["bitwise"]:
981
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
982
+
983
+ if t.cpu_noimpl() or u.cpu_noimpl():
984
+ continue
985
+
986
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type)
987
+ x = xnd([0, 1, 0, 1, 1, 1, 1, 0], dtype=t.type)
988
+ y = xnd([1, 0, 0, 0, 1, 1, 1, 1], dtype=u.type)
989
+ z = fn.bitwise_and(x, y)
990
+ self.assertEqual(z, [0, 0, 0, 0, 1, 1, 1, 0])
991
+
992
+ def test_and_opt(self):
993
+ for t, u in implemented_sigs["binary"]["bitwise"]:
994
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
995
+
996
+ if t.cpu_noimpl() or u.cpu_noimpl():
997
+ continue
998
+
999
+ a = [0, 1, None, 1, 1, 1, 1, 0]
1000
+ b = [1, 1, 1, 1, 1, 1, None, 0]
1001
+ c = [0, 1, None, 1, 1, 1, None, 0]
1002
+
1003
+ x = xnd(a, dtype="?" + t.type)
1004
+ y = xnd(b, dtype="?" + u.type)
1005
+ z = fn.bitwise_and(x, y)
1006
+ self.assertEqual(z, c)
1007
+
1008
+
1009
+ @unittest.skipIf(cd is None, "test requires cuda")
1010
+ class TestBitwiseCUDA(unittest.TestCase):
1011
+
1012
+ def test_and(self):
1013
+ for t, u in implemented_sigs["binary"]["bitwise"]:
1014
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
1015
+
1016
+ if t.cuda_noimpl() or u.cuda_noimpl():
1017
+ continue
1018
+
1019
+ x = xnd([0, 1, 2, 3, 4, 5, 6, 7], dtype=t.type, device="cuda:managed")
1020
+ x = xnd([0, 1, 0, 1, 1, 1, 1, 0], dtype=t.type, device="cuda:managed")
1021
+ y = xnd([1, 0, 0, 0, 1, 1, 1, 1], dtype=u.type, device="cuda:managed")
1022
+ z = cd.bitwise_and(x, y)
1023
+ self.assertEqual(z, [0, 0, 0, 0, 1, 1, 1, 0])
1024
+
1025
+ def test_and_opt(self):
1026
+ for t, u in implemented_sigs["binary"]["bitwise"]:
1027
+ w = implemented_sigs["binary"]["bitwise"][(t, u)]
1028
+
1029
+ if t.cuda_noimpl() or u.cuda_noimpl():
1030
+ continue
1031
+
1032
+ a = [0, 1, None, 1, 1, 1, 1, 0]
1033
+ b = [1, 1, 1, 1, 1, 1, None, 0]
1034
+ c = [0, 1, None, 1, 1, 1, None, 0]
1035
+
1036
+ x = xnd(a, dtype="?" + t.type, device="cuda:managed")
1037
+ y = xnd(b, dtype="?" + u.type, device="cuda:managed")
1038
+ z = cd.bitwise_and(x, y)
1039
+ self.assertEqual(z, c)
1040
+
1041
+
1042
+ @unittest.skipIf(np is None, "test requires numpy")
1043
+ class TestFunctions(unittest.TestCase):
1044
+
1045
+ def assertRelErrorLess(self, calc, expected, maxerr, msg):
1046
+ if cmath.isnan(calc) or cmath.isnan(expected):
1047
+ return
1048
+ elif cmath.isinf(calc) or cmath.isinf(expected):
1049
+ return
1050
+ elif abs(expected) < 1e-5 or abs(calc) < 1e-5:
1051
+ self.assertLess(abs(calc), 1e-5, msg)
1052
+ self.assertLess(abs(expected), 1e-5, msg)
1053
+ else:
1054
+ err = abs((calc-expected) / expected)
1055
+ self.assertLess(err, maxerr, msg)
1056
+
1057
+ def equal(self, calc, expected, msg):
1058
+ if np.isnan(calc) and np.isnan(expected):
1059
+ return
1060
+ else:
1061
+ self.assertEqual(calc, expected, msg)
1062
+
1063
+ def assert_equal(self, f, z1, z2, w, msg, a=None, b=None):
1064
+ if w.type == "bfloat16":
1065
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1066
+ elif f == "power" and w.type in ("int8", "int16", "int32", "int64"):
1067
+ pass # equal mod INTN_MAX
1068
+ elif f == "power" and isinstance(z1, complex):
1069
+ # multivalued function, compare against Python
1070
+ try:
1071
+ ans = complex(a) ** complex(b)
1072
+ except ZeroDivisionError:
1073
+ pass
1074
+ except OverflowError:
1075
+ pass
1076
+ else:
1077
+ msg = "%s ans=%s" % (msg, ans)
1078
+ self.assertRelErrorLess(z1.real, ans.real, 1e-2, msg)
1079
+ self.assertRelErrorLess(z1.imag, ans.imag, 1e-2, msg)
1080
+ elif isinstance(z1, complex):
1081
+ if f not in ("add", "subtract"):
1082
+ self.assertRelErrorLess(z1.real, z2.real, 1e-2, msg)
1083
+ self.assertRelErrorLess(z1.imag, z2.imag, 1e-2, msg)
1084
+ else:
1085
+ self.equal(z1.real, z2.real, msg) and \
1086
+ self.equal(z1.imag, z2.imag, msg)
1087
+ elif f in functions["unary"]["real_math"] or \
1088
+ f in functions["unary"]["real_math_with_half"] or \
1089
+ f in functions["unary"]["complex_math"] or \
1090
+ f in functions["unary"]["complex_math_with_half"] or \
1091
+ f == "power":
1092
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1093
+ elif f == "divide" and w.type in ("float16", "float32"):
1094
+ self.assertRelErrorLess(z1, z2, 1e-2, msg)
1095
+ else:
1096
+ return self.equal(z1, z2, msg)
1097
+
1098
+ def create_xnd(self, a, t, dev=None):
1099
+
1100
+ # Check that struct.pack(a) overflows iff xnd(a) overflows.
1101
+ overflow = struct_overflow(a, t)
1102
+ xnd_overflow = False
1103
+ try:
1104
+ x = xnd([a], dtype=t.type, device=dev)
1105
+ except OverflowError:
1106
+ xnd_overflow = True
1107
+
1108
+ self.assertEqual(xnd_overflow, overflow)
1109
+
1110
+ return None if xnd_overflow else x
1111
+
1112
+ def check_unary_not_implemented(self, f, a, t, mod=fn, dev=None):
1113
+
1114
+ x = self.create_xnd(a, t, dev)
1115
+ if x is None:
1116
+ return
1117
+
1118
+ self.assertRaises(NotImplementedError, getattr(mod, f), x)
1119
+
1120
+ def check_unary_type_error(self, f, a, t, mod=fn, dev=None):
1121
+
1122
+ x = self.create_xnd(a, t, dev)
1123
+ if x is None:
1124
+ return
1125
+
1126
+ self.assertRaises(TypeError, getattr(mod, f), x)
1127
+
1128
+ def check_unary(self, f, a, t, u, mod=fn, dev=None):
1129
+
1130
+ x1 = self.create_xnd(a, t, dev)
1131
+ if x1 is None:
1132
+ return
1133
+
1134
+ y1 = getattr(mod, f)(x1)
1135
+ self.assertEqual(str(y1[0].type), u.type)
1136
+ v1 = y1[0].value
1137
+
1138
+ value = x1.value if t.type == "bfloat16" else a
1139
+ dtype = "float32" if t.type == "bfloat16" else t.type
1140
+
1141
+ x2 = np.array([value], dtype=dtype)
1142
+ y2 = getattr(np, np_function(f))(x2)
1143
+ v2 = y2[0]
1144
+
1145
+ msg = "%s(%s : %s) -> %s xnd: %s np: %s" % (f, a, t, u, y1, y2)
1146
+ self.assert_equal(f, v1, v2, u, msg)
1147
+
1148
+ def check_binary_not_implemented(self, f, a, t, b, u, mod=fn, dev=None):
1149
+
1150
+ x1 = self.create_xnd(a, t, dev)
1151
+ if x1 is None:
1152
+ return
1153
+
1154
+ y1 = self.create_xnd(b, u, dev)
1155
+ if y1 is None:
1156
+ return
1157
+
1158
+ self.assertRaises(NotImplementedError, getattr(mod, f), x1, y1)
1159
+
1160
+ def check_binary_type_error(self, f, a, t, b, u, mod=fn, dev=None):
1161
+
1162
+ x1 = self.create_xnd(a, t, dev)
1163
+ if x1 is None:
1164
+ return
1165
+
1166
+ y1 = self.create_xnd(b, u, dev)
1167
+ if y1 is None:
1168
+ return
1169
+
1170
+ self.assertRaises(TypeError, getattr(mod, f), x1, y1)
1171
+
1172
+ def check_binary(self, f, a, t, b, u, w, mod=fn, dev=None):
1173
+
1174
+ x1 = self.create_xnd(a, t, dev)
1175
+ if x1 is None:
1176
+ return
1177
+
1178
+ y1 = self.create_xnd(b, u, dev)
1179
+ if y1 is None:
1180
+ return
1181
+
1182
+ xnd_exc = z1 = None
1183
+ try:
1184
+ z1 = getattr(mod, f)(x1, y1)
1185
+ self.assertEqual(str(z1[0].type), w.type)
1186
+ v1 = z1[0].value
1187
+ except Exception as e:
1188
+ xnd_exc = e.__class__
1189
+
1190
+ dtype1 = "float32" if t.type == "bfloat16" else t.type
1191
+ dtype2 = "float32" if u.type == "bfloat16" else u.type
1192
+ value1 = x1.value if t.type == "bfloat16" else a
1193
+ value2 = y1.value if u.type == "bfloat16" else b
1194
+
1195
+ x2 = np.array([value1], dtype=dtype1)
1196
+ y2 = np.array([value2], dtype=dtype2)
1197
+
1198
+ np_exc = z2 = None
1199
+ try:
1200
+ z2 = getattr(np, f)(x2, y2)
1201
+ v2 = z2[0]
1202
+ except Exception as e:
1203
+ np_exc = e.__class__
1204
+
1205
+ if xnd_exc or np_exc:
1206
+ if xnd_exc != NotImplementedError:
1207
+ self.assertEqual(xnd_exc, np_exc)
1208
+ else:
1209
+ msg = "%s(%s : %s, %s : %s) -> %s xnd: %s np: %s" % \
1210
+ (f, a, t, b, u, w, z1, z2)
1211
+ self.assert_equal(f, v1, v2, w, msg, a=x1[0].value, b=y1[0].value)
1212
+
1213
+ def check_binary_mv(self, f, a, t, b, u, v, w, mod=fn, dev=None):
1214
+
1215
+ x1 = self.create_xnd(a, t, dev)
1216
+ if x1 is None:
1217
+ return
1218
+
1219
+ y1 = self.create_xnd(b, u, dev)
1220
+ if y1 is None:
1221
+ return
1222
+
1223
+ c1, d1 = getattr(mod, f)(x1, y1)
1224
+ self.assertEqual(str(c1[0].type), v.type)
1225
+ self.assertEqual(str(d1[0].type), w.type)
1226
+ cv1 = c1[0].value
1227
+ dv1 = d1[0].value
1228
+
1229
+ x2 = np.array([a], dtype=t.type)
1230
+ y2 = np.array([b], dtype=u.type)
1231
+ c2, d2 = getattr(np, f)(x2, y2)
1232
+ cv2 = c2[0]
1233
+ dv2 = d2[0]
1234
+
1235
+ msg = "%s(%s : %s, %s : %s) -> %s, %s xnd: %s np: %s" % \
1236
+ (f, a, t, b, u, v, w, (cv1, dv1), (cv2, dv2))
1237
+ self.assert_equal(f, cv1, cv2, v, msg)
1238
+ self.assert_equal(f, dv2, dv2, v, msg)
1239
+
1240
+ @unittest.skipIf(sys.platform == "darwin", "complex trigonometry errors too large")
1241
+ @unittest.skipIf(sys.platform == "win32" and ARCH == "32bit", "complex trigonometry errors too large")
1242
+ def test_unary_cpu(self):
1243
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1244
+
1245
+ print("\n", flush=True)
1246
+
1247
+ for pattern, return_type in [
1248
+ ("default", "default"),
1249
+ ("complex_math", "float_result"),
1250
+ ("real_math", "float_result")]:
1251
+
1252
+ for f in functions["unary"][pattern]:
1253
+ if np_noimpl(f):
1254
+ continue
1255
+
1256
+ print("testing %s ..." % f, flush=True)
1257
+
1258
+ for t, in implemented_sigs["unary"][return_type]:
1259
+ u = implemented_sigs["unary"][return_type][(t,)]
1260
+
1261
+ print(" %s -> %s" % (t, u), flush=True)
1262
+
1263
+ for a in t.testcases():
1264
+ if t.cpu_noimpl(f) or u.cpu_noimpl(f):
1265
+ self.check_unary_not_implemented(f, a, t)
1266
+ else:
1267
+ self.check_unary(f, a, t, u)
1268
+
1269
+ def test_binary_cpu(self):
1270
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1271
+
1272
+ print("\n", flush=True)
1273
+
1274
+ for pattern in "default", "float_result", "bool_result":
1275
+ for f in functions["binary"][pattern]:
1276
+ print("testing %s ..." % f, flush=True)
1277
+
1278
+ for t, u in implemented_sigs["binary"][pattern]:
1279
+ w = implemented_sigs["binary"][pattern][(t, u)]
1280
+
1281
+ print(" %s, %s -> %s" % (t, u, w), flush=True)
1282
+
1283
+ for a in t.testcases():
1284
+ for b in u.testcases():
1285
+ if t.cpu_nokern(f) or u.cpu_nokern(f) or w.cpu_nokern(f):
1286
+ self.check_binary_type_error(f, a, t, b, u)
1287
+ elif t.cpu_noimpl(f) or u.cpu_noimpl(f) or w.cpu_noimpl(f):
1288
+ self.check_binary_not_implemented(f, a, t, b, u)
1289
+ else:
1290
+ self.check_binary(f, a, t, b, u, w)
1291
+
1292
+ def test_binary_mv_cpu(self):
1293
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1294
+
1295
+ print("\n", flush=True)
1296
+
1297
+ for f in functions["binary_mv"]["default"]:
1298
+ print("testing %s ..." % f, flush=True)
1299
+
1300
+ for t, u in implemented_sigs["binary_mv"]["default"]:
1301
+ v, w = implemented_sigs["binary_mv"]["default"][(t, u)]
1302
+
1303
+ print(" %s, %s -> %s, %s" % (t, u, v, w), flush=True)
1304
+
1305
+ for a in t.testcases():
1306
+ for b in u.testcases():
1307
+ self.check_binary_mv(f, a, t, b, u, v, w)
1308
+
1309
+ @unittest.skipIf(cd is None, "test requires cuda")
1310
+ def test_unary_cuda(self):
1311
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1312
+
1313
+ print("\n", flush=True)
1314
+
1315
+ for pattern, return_type in [
1316
+ ("default", "default"),
1317
+ ("complex_math_with_half", "float_result"),
1318
+ ("complex_math", "float_result"),
1319
+ ("real_math_with_half", "float_result"),
1320
+ ("real_math", "float_result")]:
1321
+
1322
+ for f in functions["unary"][pattern]:
1323
+ if np_noimpl(f):
1324
+ continue
1325
+
1326
+ print("testing %s ..." % f, flush=True)
1327
+
1328
+ for t, in implemented_sigs["unary"][return_type]:
1329
+ u = implemented_sigs["unary"][return_type][(t,)]
1330
+
1331
+ print(" %s -> %s" % (t, u), flush=True)
1332
+
1333
+ for a in t.testcases():
1334
+ if t.cuda_noimpl(f) or u.cuda_noimpl(f):
1335
+ self.check_unary_not_implemented(
1336
+ f, a, t, mod=cd, dev="cuda:managed")
1337
+ else:
1338
+ self.check_unary(f, a, t, u,
1339
+ mod=cd, dev="cuda:managed")
1340
+
1341
+ @unittest.skipIf(cd is None, "test requires cuda")
1342
+ def test_binary_cuda(self):
1343
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1344
+
1345
+ print("\n", flush=True)
1346
+
1347
+ for pattern in "default", "float_result", "bool_result":
1348
+ for f in functions["binary"][pattern]:
1349
+ print("testing %s ..." % f, flush=True)
1350
+
1351
+ for t, u in implemented_sigs["binary"][pattern]:
1352
+ w = implemented_sigs["binary"][pattern][(t, u)]
1353
+
1354
+ print(" %s, %s -> %s" % (t, u, w), flush=True)
1355
+
1356
+ for a in t.testcases():
1357
+ for b in u.testcases():
1358
+ if t.cuda_nokern(f) or u.cuda_nokern(f) or w.cuda_nokern(f):
1359
+ self.check_binary_type_error(f, a, t, b, u,
1360
+ mod=cd, dev="cuda:managed")
1361
+ elif t.type == "complex32" or u.type == "complex32" or w.cuda_noimpl(f):
1362
+ self.check_binary_not_implemented(f, a, t, b, u,
1363
+ mod=cd, dev="cuda:managed")
1364
+ else:
1365
+ self.check_binary(f, a, t, b, u, w,
1366
+ mod=cd, dev="cuda:managed")
1367
+
1368
+ @unittest.skipIf(cd is None, "test requires cuda")
1369
+ def test_binary_mv_cuda(self):
1370
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1371
+
1372
+ print("\n", flush=True)
1373
+
1374
+ for f in functions["binary_mv"]["default"]:
1375
+ print("testing %s ..." % f, flush=True)
1376
+
1377
+ for t, u in implemented_sigs["binary_mv"]["default"]:
1378
+ v, w = implemented_sigs["binary_mv"]["default"][(t, u)]
1379
+
1380
+ print(" %s, %s -> %s, %s" % (t, u, v, w), flush=True)
1381
+
1382
+ for a in t.testcases():
1383
+ for b in u.testcases():
1384
+ self.check_binary_mv(f, a, t, b, u, v, w, mod=cd,
1385
+ dev="cuda:managed")
1386
+
1387
+ def test_divide_inexact_cpu(self):
1388
+
1389
+ t = Tint("uint8")
1390
+ u = Tint("uint64")
1391
+
1392
+ a = next(t.testcases())
1393
+ b = next(u.testcases())
1394
+ self.check_binary_type_error("divide", a, t, b, u)
1395
+
1396
+ @unittest.skipIf(cd is None, "test requires cuda")
1397
+ def test_divide_inexact_cuda(self):
1398
+
1399
+ t = Tint("uint8")
1400
+ u = Tint("uint64")
1401
+
1402
+ a = next(t.testcases())
1403
+ b = next(u.testcases())
1404
+ self.check_binary_type_error("divide", a, t, b, u,
1405
+ mod=cd, dev="cuda:managed")
1406
+
1407
+ def test_divmod_type_error_cpu(self):
1408
+
1409
+ t = Tint("uint8")
1410
+ u = Tint("uint64")
1411
+
1412
+ a = next(t.testcases())
1413
+ b = next(u.testcases())
1414
+ self.check_binary_type_error("divmod", a, t, b, u)
1415
+
1416
+ @unittest.skipIf(cd is None, "test requires cuda")
1417
+ def test_divmod_type_error_cuda(self):
1418
+
1419
+ t = Tint("uint8")
1420
+ u = Tint("uint64")
1421
+
1422
+ a = next(t.testcases())
1423
+ b = next(u.testcases())
1424
+ self.check_binary_type_error("divmod", a, t, b, u)
1425
+
1426
+
1427
+ @unittest.skipIf(cd is None, "test requires cuda")
1428
+ class TestCudaManaged(unittest.TestCase):
1429
+
1430
+ def test_mixed_functions(self):
1431
+
1432
+ x = xnd([1,2,3])
1433
+ y = xnd([1,2,3])
1434
+
1435
+ a = xnd([1,2,3], device="cuda:managed")
1436
+ b = xnd([1,2,3], device="cuda:managed")
1437
+
1438
+ z = fn.multiply(x, y)
1439
+ c = cd.multiply(a, b)
1440
+ self.assertEqual(z, c)
1441
+
1442
+ z = fn.multiply(a, b)
1443
+ self.assertEqual(z, c)
1444
+
1445
+ z = fn.multiply(x, b)
1446
+ self.assertEqual(z, c)
1447
+
1448
+ z = fn.multiply(a, y)
1449
+ self.assertEqual(z, c)
1450
+
1451
+ self.assertRaises(ValueError, cd.multiply, x, y)
1452
+ self.assertRaises(ValueError, cd.multiply, x, b)
1453
+ self.assertRaises(ValueError, cd.multiply, a, y)
1454
+
1455
+
1456
+ class TestSpec(unittest.TestCase):
1457
+
1458
+ def __init__(self, *, constr, ndarray, mod,
1459
+ values, value_generator,
1460
+ indices_generator, indices_generator_args):
1461
+ super().__init__()
1462
+ self.constr = constr
1463
+ self.ndarray = ndarray
1464
+ self.mod = mod
1465
+ self.values = values
1466
+ self.value_generator = value_generator
1467
+ self.indices_generator = indices_generator
1468
+ self.indices_generator_args = indices_generator_args
1469
+ self.indices_stack = [None] * 8
1470
+
1471
+ def log_err(self, value, depth):
1472
+ """Dump an error as a Python script for debugging."""
1473
+ dtype = "?int32" if have_none(value) else "int32"
1474
+
1475
+ sys.stderr.write("\n\nfrom xnd import *\n")
1476
+ sys.stderr.write("import gumath.functions as fn\n")
1477
+ sys.stderr.write("from test_gumath import NDArray\n")
1478
+ sys.stderr.write("lst = %s\n\n" % value)
1479
+ sys.stderr.write("x0 = xnd(lst, dtype=\"%s\")\n" % dtype)
1480
+ sys.stderr.write("y0 = NDArray(lst)\n" % value)
1481
+
1482
+ for i in range(depth+1):
1483
+ sys.stderr.write("x%d = x%d[%s]\n" % (i+1, i, itos(self.indices_stack[i])))
1484
+ sys.stderr.write("y%d = y%d[%s]\n" % (i+1, i, itos(self.indices_stack[i])))
1485
+
1486
+ sys.stderr.write("\n")
1487
+
1488
+ def run_reduce(self, nd, d):
1489
+ if not isinstance(nd, xnd) or not isinstance(d, np.ndarray):
1490
+ return
1491
+
1492
+ for attr in ["add", "subtract", "multiply"]:
1493
+ f = getattr(fn, attr)
1494
+ g = getattr(np, attr)
1495
+
1496
+ x = nd_exception = None
1497
+ try:
1498
+ x = gm.reduce(f, nd, dtype=nd.dtype)
1499
+ except Exception as e:
1500
+ nd_exception = e
1501
+
1502
+ y = np_exception = None
1503
+ try:
1504
+ y = g.reduce(d, dtype=d.dtype)
1505
+ except Exception as e:
1506
+ np_exception = e
1507
+
1508
+ if nd_exception or np_exception:
1509
+ self.assertIs(nd_exception.__class__, np_exception.__class__,
1510
+ "f: %r nd: %r np: %r x: %r y: %r" % (attr, nd, d, x, y))
1511
+ else:
1512
+ self.assertEqual(x.value, y.tolist(),
1513
+ "f: %r nd: %r np: %r x: %r y: %r" % (attr, nd, d, x, y))
1514
+
1515
+ for axes in gen_axes(d.ndim):
1516
+ nd_exception = None
1517
+ try:
1518
+ x = gm.reduce(f, nd, axes=axes, dtype=nd.dtype)
1519
+ except Exception as e:
1520
+ nd_exception = e
1521
+
1522
+ np_exception = None
1523
+ try:
1524
+ y = g.reduce(d, axis=axes, dtype=d.dtype)
1525
+ except Exception as e:
1526
+ np_exception = e
1527
+
1528
+ if nd_exception or np_exception:
1529
+ self.assertIs(nd_exception.__class__, np_exception.__class__,
1530
+ "f: %r axes: %r nd: %r np: %r x: %r y: %r" % (attr, axes, nd, d, x, y))
1531
+ else:
1532
+ self.assertEqual(x.value, y.tolist(),
1533
+ "f: %r axes: %r nd: %r np: %r x: %r y: %r" % (attr, axes, nd, d, x, y))
1534
+
1535
+ def run_single(self, nd, d, indices):
1536
+ """Run a single test case."""
1537
+
1538
+ self.assertEqual(len(nd), len(d))
1539
+
1540
+ nd_exception = None
1541
+ try:
1542
+ nd_result = nd[indices]
1543
+ except Exception as e:
1544
+ nd_exception = e
1545
+
1546
+ def_exception = None
1547
+ try:
1548
+ def_result = d[indices]
1549
+ except Exception as e:
1550
+ def_exception = e
1551
+
1552
+ if nd_exception or def_exception:
1553
+ if nd_exception is None and def_exception.__class__ is IndexError:
1554
+ # Example: type = 0 * 0 * int64
1555
+ if len(indices) <= nd.ndim:
1556
+ return None, None
1557
+
1558
+ self.assertIs(nd_exception.__class__, def_exception.__class__)
1559
+ return None, None
1560
+
1561
+ assert(isinstance(nd_result, xnd))
1562
+
1563
+ x = self.mod.sin(nd_result)
1564
+ y = self.mod.multiply(nd_result, nd_result)
1565
+
1566
+ if isinstance(def_result, NDArray):
1567
+ aa = a = def_result.sin()
1568
+ b = def_result * def_result
1569
+ elif isinstance(def_result, int):
1570
+ aa = a = math.sin(def_result)
1571
+ b = def_result * def_result
1572
+ elif def_result is None:
1573
+ aa = a = None
1574
+ aa = b = None
1575
+ elif isinstance(def_result, np.ndarray):
1576
+ aa = np.sin(def_result)
1577
+ a = aa.tolist()
1578
+ bb = np.multiply(def_result, def_result)
1579
+ b = bb.tolist()
1580
+ elif isinstance(def_result, np.int32):
1581
+ aa = np.sin(def_result)
1582
+ a = aa.tolist()
1583
+ bb = np.multiply(def_result, def_result)
1584
+ b = bb.tolist()
1585
+ else:
1586
+ raise TypeError("unexpected def_result: %s : %s" % (def_result, type(def_result)))
1587
+
1588
+ if self.mod == cd:
1589
+ np.testing.assert_allclose(x, aa, 1e-6)
1590
+ np.testing.assert_allclose(y, bb, 1e-6)
1591
+ else:
1592
+ self.assertEqual(x, a)
1593
+ self.assertEqual(y, b)
1594
+
1595
+ if self.mod == fn:
1596
+ self.run_reduce(nd_result, def_result)
1597
+
1598
+ return nd_result, def_result
1599
+
1600
+ def run(self):
1601
+ def check(nd, d, value, depth):
1602
+ if depth > 3: # adjust for longer tests
1603
+ return
1604
+
1605
+ g = self.indices_generator(*self.indices_generator_args)
1606
+
1607
+ for indices in g:
1608
+ self.indices_stack[depth] = indices
1609
+
1610
+ try:
1611
+ next_nd, next_d = self.run_single(nd, d, indices)
1612
+ except Exception as e:
1613
+ self.log_err(value, depth)
1614
+ raise e
1615
+
1616
+ if isinstance(next_d, list): # possibly None or scalar
1617
+ check(next_nd, next_d, value, depth+1)
1618
+
1619
+ def check_buffer(nd, d, value, depth):
1620
+ if depth > 3: # adjust for longer tests
1621
+ return
1622
+ if not isinstance(nd, xnd) or nd.device == "cuda:managed" or \
1623
+ not isinstance(d, np.ndarray):
1624
+ return
1625
+
1626
+ nd = xnd.from_buffer(d)
1627
+ d = np.array(nd, copy=False)
1628
+
1629
+ g = self.indices_generator(*self.indices_generator_args)
1630
+
1631
+ for indices in g:
1632
+ self.indices_stack[depth] = indices
1633
+
1634
+ try:
1635
+ next_nd, next_d = self.run_single(nd, d, indices)
1636
+ except Exception as e:
1637
+ self.log_err(value, depth)
1638
+ raise e
1639
+
1640
+ if isinstance(next_d, list): # possibly None or scalar
1641
+ check_buffer(next_nd, next_d, value, depth+1)
1642
+
1643
+ for value in self.values:
1644
+ dtype = "?int32" if have_none(value) else "int32"
1645
+ if self.constr == xnd:
1646
+ nd = xnd(value, dtype=dtype, device=None if self.mod==fn else "cuda:managed")
1647
+ else:
1648
+ nd = self.constr(value, dtype=dtype)
1649
+ # NumPy does not support "?int32", NDArray does not need the dtype.
1650
+ d = self.ndarray(value, dtype="int32")
1651
+ check(nd, d, value, 0)
1652
+
1653
+ for max_ndim in range(1, 5):
1654
+ for min_shape in (0, 1):
1655
+ for max_shape in range(1, 8):
1656
+ for value in self.value_generator(max_ndim, min_shape, max_shape):
1657
+ dtype = "?int32" if have_none(value) else "int32"
1658
+ if self.constr == xnd:
1659
+ nd = xnd(value, dtype=dtype, device=None if self.mod==fn else "cuda:managed")
1660
+ else:
1661
+ nd = self.constr(value, dtype=dtype)
1662
+ # See above.
1663
+ d = self.ndarray(value, dtype="int32")
1664
+ check(nd, d, value, 0)
1665
+ check_buffer(nd, d, value, 0)
1666
+
1667
+
1668
+ class LongIndexSliceTest(unittest.TestCase):
1669
+
1670
+ def test_subarray(self):
1671
+ # Multidimensional indexing
1672
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1673
+
1674
+ t = TestSpec(constr=xnd,
1675
+ ndarray=NDArray,
1676
+ mod=fn,
1677
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1678
+ value_generator=gen_fixed,
1679
+ indices_generator=genindices,
1680
+ indices_generator_args=())
1681
+ t.run()
1682
+
1683
+ t = TestSpec(constr=xnd,
1684
+ ndarray=NDArray,
1685
+ mod=fn,
1686
+ values=SUBSCRIPT_VAR_TEST_CASES,
1687
+ value_generator=gen_var,
1688
+ indices_generator=genindices,
1689
+ indices_generator_args=())
1690
+ t.run()
1691
+
1692
+ @unittest.skipIf(cd is None or np is None, "cuda or numpy not found")
1693
+ def test_subarray_cuda(self):
1694
+ # Multidimensional indexing
1695
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1696
+
1697
+ t = TestSpec(constr=xnd,
1698
+ ndarray=np.array,
1699
+ mod=cd,
1700
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1701
+ value_generator=gen_fixed,
1702
+ indices_generator=genindices,
1703
+ indices_generator_args=())
1704
+ t.run()
1705
+
1706
+ def test_slices(self):
1707
+ # Multidimensional slicing
1708
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1709
+
1710
+ t = TestSpec(constr=xnd,
1711
+ ndarray=NDArray,
1712
+ mod=fn,
1713
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1714
+ value_generator=gen_fixed,
1715
+ indices_generator=randslices,
1716
+ indices_generator_args=(3,))
1717
+ t.run()
1718
+
1719
+ t = TestSpec(constr=xnd,
1720
+ ndarray=NDArray,
1721
+ mod=fn,
1722
+ values=SUBSCRIPT_VAR_TEST_CASES,
1723
+ value_generator=gen_var,
1724
+ indices_generator=randslices,
1725
+ indices_generator_args=(3,))
1726
+ t.run()
1727
+
1728
+ @unittest.skipIf(cd is None or np is None, "cuda or numpy not found")
1729
+ def test_slices_cuda(self):
1730
+ # Multidimensional slicing
1731
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1732
+
1733
+ t = TestSpec(constr=xnd,
1734
+ ndarray=np.array,
1735
+ mod=cd,
1736
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1737
+ value_generator=gen_fixed,
1738
+ indices_generator=randslices,
1739
+ indices_generator_args=(3,))
1740
+ t.run()
1741
+
1742
+ def test_chained_indices_slices(self):
1743
+ # Multidimensional indexing and slicing, chained
1744
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1745
+
1746
+ t = TestSpec(constr=xnd,
1747
+ ndarray=NDArray,
1748
+ mod=fn,
1749
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1750
+ value_generator=gen_fixed,
1751
+ indices_generator=gen_indices_or_slices,
1752
+ indices_generator_args=())
1753
+ t.run()
1754
+
1755
+
1756
+ t = TestSpec(constr=xnd,
1757
+ ndarray=NDArray,
1758
+ mod=fn,
1759
+ values=SUBSCRIPT_VAR_TEST_CASES,
1760
+ value_generator=gen_var,
1761
+ indices_generator=gen_indices_or_slices,
1762
+ indices_generator_args=())
1763
+ t.run()
1764
+
1765
+ def test_fixed_mixed_indices_slices(self):
1766
+ # Multidimensional indexing and slicing, mixed
1767
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1768
+
1769
+ t = TestSpec(constr=xnd,
1770
+ ndarray=NDArray,
1771
+ mod=fn,
1772
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1773
+ value_generator=gen_fixed,
1774
+ indices_generator=mixed_indices,
1775
+ indices_generator_args=(3,))
1776
+ t.run()
1777
+
1778
+ def test_var_mixed_indices_slices(self):
1779
+ # Multidimensional indexing and slicing, mixed
1780
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1781
+
1782
+ t = TestSpec(constr=xnd,
1783
+ ndarray=NDArray,
1784
+ mod=fn,
1785
+ values=SUBSCRIPT_VAR_TEST_CASES,
1786
+ value_generator=gen_var,
1787
+ indices_generator=mixed_indices,
1788
+ indices_generator_args=(5,))
1789
+ t.run()
1790
+
1791
+ def test_slices_brute_force(self):
1792
+ # Test all possible slices for the given ndim and shape
1793
+ skip_if(SKIP_BRUTE_FORCE, "use --all argument to enable these tests")
1794
+
1795
+ t = TestSpec(constr=xnd,
1796
+ ndarray=NDArray,
1797
+ mod=fn,
1798
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1799
+ value_generator=gen_fixed,
1800
+ indices_generator=genslices_ndim,
1801
+ indices_generator_args=(3, [3,3,3]))
1802
+ t.run()
1803
+
1804
+ t = TestSpec(constr=xnd,
1805
+ ndarray=NDArray,
1806
+ mod=fn,
1807
+ values=SUBSCRIPT_VAR_TEST_CASES,
1808
+ value_generator=gen_var,
1809
+ indices_generator=genslices_ndim,
1810
+ indices_generator_args=(3, [3,3,3]))
1811
+ t.run()
1812
+
1813
+ @unittest.skipIf(np is None, "numpy not found")
1814
+ def test_fixed_mixed_indices_slices_np(self):
1815
+ # Multidimensional indexing and slicing, mixed
1816
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1817
+
1818
+ t = TestSpec(constr=xnd,
1819
+ ndarray=np.array,
1820
+ mod=fn,
1821
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1822
+ value_generator=gen_fixed,
1823
+ indices_generator=mixed_indices,
1824
+ indices_generator_args=(3,))
1825
+ t.run()
1826
+
1827
+ @unittest.skipIf(np is None, "numpy not found")
1828
+ def test_reduce(self):
1829
+ skip_if(SKIP_LONG, "use --long argument to enable these tests")
1830
+
1831
+ t = TestSpec(constr=xnd,
1832
+ ndarray=np.array,
1833
+ mod=fn,
1834
+ values=SUBSCRIPT_FIXED_TEST_CASES,
1835
+ value_generator=gen_fixed,
1836
+ indices_generator=mixed_indices,
1837
+ indices_generator_args=(3,))
1838
+ t.run()
1839
+
376
1840
 
377
1841
  ALL_TESTS = [
1842
+ TestAPI,
378
1843
  TestCall,
379
1844
  TestRaggedArrays,
1845
+ TestFlexibleArrays,
380
1846
  TestMissingValues,
1847
+ TestEqualN,
381
1848
  TestGraphs,
382
- TestBFloat16,
383
1849
  TestPdist,
384
1850
  TestNumba,
1851
+ TestOut,
1852
+ TestUnaryCPU,
1853
+ TestUnaryCUDA,
1854
+ TestBinaryCPU,
1855
+ TestBinaryCUDA,
1856
+ TestBitwiseCPU,
1857
+ TestBitwiseCUDA,
1858
+ TestFunctions,
1859
+ TestCudaManaged,
1860
+ LongIndexSliceTest,
385
1861
  ]
386
1862
 
387
1863
 
@@ -389,7 +1865,11 @@ if __name__ == '__main__':
389
1865
  parser = argparse.ArgumentParser()
390
1866
  parser.add_argument("-f", "--failfast", action="store_true",
391
1867
  help="stop the test run on first error")
1868
+ parser.add_argument('--long', action="store_true", help="run long slice tests")
1869
+ parser.add_argument('--all', action="store_true", help="run brute force tests")
392
1870
  args = parser.parse_args()
1871
+ SKIP_LONG = not (args.long or args.all)
1872
+ SKIP_BRUTE_FORCE = not args.all
393
1873
 
394
1874
  suite = unittest.TestSuite()
395
1875
  loader = unittest.TestLoader()