dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,1130 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+ import pytest
6
+
7
+ pytest.importorskip("numpy")
8
+ pytest.importorskip("scipy")
9
+
10
+ import numpy as np
11
+ import scipy.linalg
12
+ from packaging.version import Version
13
+
14
+ import dask_array as da
15
+ from dask_array.linalg import qr, sfqr, svd, svd_compressed, tsqr
16
+ from dask_array._numpy_compat import _np_version
17
+ from dask_array._test_utils import assert_eq, same_keys
18
+ from dask_array._utils import svd_flip
19
+
20
+
21
+ @pytest.mark.parametrize(
22
+ "m,n,chunks,error_type",
23
+ [
24
+ (20, 10, 10, None), # tall-skinny regular blocks
25
+ (20, 10, (3, 10), None), # tall-skinny regular fat layers
26
+ (20, 10, ((8, 4, 8), 10), None), # tall-skinny irregular fat layers
27
+ (40, 10, ((15, 5, 5, 8, 7), 10), None), # tall-skinny non-uniform chunks (why?)
28
+ (128, 2, (16, 2), None), # tall-skinny regular thin layers; recursion_depth=1
29
+ (
30
+ 129,
31
+ 2,
32
+ (16, 2),
33
+ None,
34
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
35
+ (
36
+ 130,
37
+ 2,
38
+ (16, 2),
39
+ None,
40
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
41
+ (
42
+ 131,
43
+ 2,
44
+ (16, 2),
45
+ None,
46
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
47
+ (300, 10, (40, 10), None), # tall-skinny regular thin layers; recursion_depth=2
48
+ (300, 10, (30, 10), None), # tall-skinny regular thin layers; recursion_depth=3
49
+ (300, 10, (20, 10), None), # tall-skinny regular thin layers; recursion_depth=4
50
+ (10, 5, 10, None), # single block tall
51
+ (5, 10, 10, None), # single block short
52
+ (10, 10, 10, None), # single block square
53
+ (10, 40, (10, 10), ValueError), # short-fat regular blocks
54
+ (10, 40, (10, 15), ValueError), # short-fat irregular blocks
55
+ (
56
+ 10,
57
+ 40,
58
+ (10, (15, 5, 5, 8, 7)),
59
+ ValueError,
60
+ ), # short-fat non-uniform chunks (why?)
61
+ (20, 20, 10, ValueError), # 2x2 regular blocks
62
+ ],
63
+ )
64
+ def test_tsqr(m, n, chunks, error_type):
65
+ mat = np.random.default_rng().random((m, n))
66
+ data = da.from_array(mat, chunks=chunks, name="A")
67
+
68
+ # qr
69
+ m_q = m
70
+ n_q = min(m, n)
71
+ m_r = n_q
72
+ n_r = n
73
+
74
+ # svd
75
+ m_u = m
76
+ n_u = min(m, n)
77
+ n_s = n_q
78
+ m_vh = n_q
79
+ n_vh = n
80
+ d_vh = max(m_vh, n_vh) # full matrix returned
81
+
82
+ if error_type is None:
83
+ # test QR
84
+ q, r = tsqr(data)
85
+ assert_eq((m_q, n_q), q.shape) # shape check
86
+ assert_eq((m_r, n_r), r.shape) # shape check
87
+ assert_eq(mat, da.dot(q, r)) # accuracy check
88
+ assert_eq(np.eye(n_q, n_q), da.dot(q.T, q)) # q must be orthonormal
89
+ assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
90
+
91
+ # test SVD
92
+ u, s, vh = tsqr(data, compute_svd=True)
93
+ s_exact = np.linalg.svd(mat)[1]
94
+ assert_eq(s, s_exact) # s must contain the singular values
95
+ assert_eq((m_u, n_u), u.shape) # shape check
96
+ assert_eq((n_s,), s.shape) # shape check
97
+ assert_eq((d_vh, d_vh), vh.shape) # shape check
98
+ assert_eq(np.eye(n_u, n_u), da.dot(u.T, u)) # u must be orthonormal
99
+ assert_eq(np.eye(d_vh, d_vh), da.dot(vh, vh.T)) # vh must be orthonormal
100
+ assert_eq(mat, da.dot(da.dot(u, da.diag(s)), vh[:n_q])) # accuracy check
101
+ else:
102
+ with pytest.raises(error_type):
103
+ q, r = tsqr(data)
104
+ with pytest.raises(error_type):
105
+ u, s, vh = tsqr(data, compute_svd=True)
106
+
107
+
108
+ @pytest.mark.parametrize(
109
+ "m_min,n_max,chunks,vary_rows,vary_cols,error_type",
110
+ [
111
+ (10, 5, (10, 5), True, False, None), # single block tall
112
+ (10, 5, (10, 5), False, True, None), # single block tall
113
+ (10, 5, (10, 5), True, True, None), # single block tall
114
+ (40, 5, (10, 5), True, False, None), # multiple blocks tall
115
+ (40, 5, (10, 5), False, True, None), # multiple blocks tall
116
+ (40, 5, (10, 5), True, True, None), # multiple blocks tall
117
+ (
118
+ 300,
119
+ 10,
120
+ (40, 10),
121
+ True,
122
+ False,
123
+ None,
124
+ ), # tall-skinny regular thin layers; recursion_depth=2
125
+ (
126
+ 300,
127
+ 10,
128
+ (30, 10),
129
+ True,
130
+ False,
131
+ None,
132
+ ), # tall-skinny regular thin layers; recursion_depth=3
133
+ (
134
+ 300,
135
+ 10,
136
+ (20, 10),
137
+ True,
138
+ False,
139
+ None,
140
+ ), # tall-skinny regular thin layers; recursion_depth=4
141
+ (
142
+ 300,
143
+ 10,
144
+ (40, 10),
145
+ False,
146
+ True,
147
+ None,
148
+ ), # tall-skinny regular thin layers; recursion_depth=2
149
+ (
150
+ 300,
151
+ 10,
152
+ (30, 10),
153
+ False,
154
+ True,
155
+ None,
156
+ ), # tall-skinny regular thin layers; recursion_depth=3
157
+ (
158
+ 300,
159
+ 10,
160
+ (20, 10),
161
+ False,
162
+ True,
163
+ None,
164
+ ), # tall-skinny regular thin layers; recursion_depth=4
165
+ (
166
+ 300,
167
+ 10,
168
+ (40, 10),
169
+ True,
170
+ True,
171
+ None,
172
+ ), # tall-skinny regular thin layers; recursion_depth=2
173
+ (
174
+ 300,
175
+ 10,
176
+ (30, 10),
177
+ True,
178
+ True,
179
+ None,
180
+ ), # tall-skinny regular thin layers; recursion_depth=3
181
+ (
182
+ 300,
183
+ 10,
184
+ (20, 10),
185
+ True,
186
+ True,
187
+ None,
188
+ ), # tall-skinny regular thin layers; recursion_depth=4
189
+ ],
190
+ )
191
+ def test_tsqr_uncertain(m_min, n_max, chunks, vary_rows, vary_cols, error_type):
192
+ mat = np.random.default_rng().random((m_min * 2, n_max))
193
+ m, n = m_min * 2, n_max
194
+ mat[0:m_min, 0] += 1
195
+ _c0 = mat[:, 0]
196
+ _r0 = mat[0, :]
197
+ c0 = da.from_array(_c0, chunks=m_min, name="c")
198
+ r0 = da.from_array(_r0, chunks=n_max, name="r")
199
+ data = da.from_array(mat, chunks=chunks, name="A")
200
+ if vary_rows:
201
+ data = data[c0 > 0.5, :]
202
+ mat = mat[_c0 > 0.5, :]
203
+ m = mat.shape[0]
204
+ if vary_cols:
205
+ data = data[:, r0 > 0.5]
206
+ mat = mat[:, _r0 > 0.5]
207
+ n = mat.shape[1]
208
+
209
+ # qr
210
+ m_q = m
211
+ n_q = min(m, n)
212
+ m_r = n_q
213
+ n_r = n
214
+
215
+ # svd
216
+ m_u = m
217
+ n_u = min(m, n)
218
+ n_s = n_q
219
+ m_vh = n_q
220
+ n_vh = n
221
+ d_vh = max(m_vh, n_vh) # full matrix returned
222
+
223
+ if error_type is None:
224
+ # test QR
225
+ q, r = tsqr(data)
226
+ q = q.compute() # because uncertainty
227
+ r = r.compute()
228
+ assert_eq((m_q, n_q), q.shape) # shape check
229
+ assert_eq((m_r, n_r), r.shape) # shape check
230
+ assert_eq(mat, np.dot(q, r)) # accuracy check
231
+ assert_eq(np.eye(n_q, n_q), np.dot(q.T, q)) # q must be orthonormal
232
+ assert_eq(r, np.triu(r)) # r must be upper triangular
233
+
234
+ # test SVD
235
+ u, s, vh = tsqr(data, compute_svd=True)
236
+ u = u.compute() # because uncertainty
237
+ s = s.compute()
238
+ vh = vh.compute()
239
+ s_exact = np.linalg.svd(mat)[1]
240
+ assert_eq(s, s_exact) # s must contain the singular values
241
+ assert_eq((m_u, n_u), u.shape) # shape check
242
+ assert_eq((n_s,), s.shape) # shape check
243
+ assert_eq((d_vh, d_vh), vh.shape) # shape check
244
+ assert_eq(np.eye(n_u, n_u), np.dot(u.T, u)) # u must be orthonormal
245
+ assert_eq(np.eye(d_vh, d_vh), np.dot(vh, vh.T)) # vh must be orthonormal
246
+ assert_eq(mat, np.dot(np.dot(u, np.diag(s)), vh[:n_q])) # accuracy check
247
+ else:
248
+ with pytest.raises(error_type):
249
+ q, r = tsqr(data)
250
+ with pytest.raises(error_type):
251
+ u, s, vh = tsqr(data, compute_svd=True)
252
+
253
+
254
+ def test_tsqr_zero_height_chunks():
255
+ m_q = 10
256
+ n_q = 5
257
+ m_r = 5
258
+ n_r = 5
259
+
260
+ # certainty
261
+ mat = np.random.default_rng().random((10, 5))
262
+ x = da.from_array(mat, chunks=((4, 0, 1, 0, 5), (5,)))
263
+ q, r = da.linalg.qr(x)
264
+ assert_eq((m_q, n_q), q.shape) # shape check
265
+ assert_eq((m_r, n_r), r.shape) # shape check
266
+ assert_eq(mat, da.dot(q, r)) # accuracy check
267
+ assert_eq(np.eye(n_q, n_q), da.dot(q.T, q)) # q must be orthonormal
268
+ assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
269
+
270
+ # uncertainty
271
+ mat2 = np.vstack([mat, -(np.ones((10, 5)))])
272
+ v2 = mat2[:, 0]
273
+ x2 = da.from_array(mat2, chunks=5)
274
+ c = da.from_array(v2, chunks=5)
275
+ x = x2[c >= 0, :] # remove the ones added above to yield mat
276
+ q, r = da.linalg.qr(x)
277
+ q = q.compute() # because uncertainty
278
+ r = r.compute()
279
+ assert_eq((m_q, n_q), q.shape) # shape check
280
+ assert_eq((m_r, n_r), r.shape) # shape check
281
+ assert_eq(mat, np.dot(q, r)) # accuracy check
282
+ assert_eq(np.eye(n_q, n_q), np.dot(q.T, q)) # q must be orthonormal
283
+ assert_eq(r, np.triu(r)) # r must be upper triangular
284
+
285
+
286
+ @pytest.mark.parametrize(
287
+ "m,n,chunks,error_type",
288
+ [
289
+ (20, 10, 10, ValueError), # tall-skinny regular blocks
290
+ (20, 10, (3, 10), ValueError), # tall-skinny regular fat layers
291
+ (20, 10, ((8, 4, 8), 10), ValueError), # tall-skinny irregular fat layers
292
+ (
293
+ 40,
294
+ 10,
295
+ ((15, 5, 5, 8, 7), 10),
296
+ ValueError,
297
+ ), # tall-skinny non-uniform chunks (why?)
298
+ (
299
+ 128,
300
+ 2,
301
+ (16, 2),
302
+ ValueError,
303
+ ), # tall-skinny regular thin layers; recursion_depth=1
304
+ (
305
+ 129,
306
+ 2,
307
+ (16, 2),
308
+ ValueError,
309
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
310
+ (
311
+ 130,
312
+ 2,
313
+ (16, 2),
314
+ ValueError,
315
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
316
+ (
317
+ 131,
318
+ 2,
319
+ (16, 2),
320
+ ValueError,
321
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
322
+ (
323
+ 300,
324
+ 10,
325
+ (40, 10),
326
+ ValueError,
327
+ ), # tall-skinny regular thin layers; recursion_depth=2
328
+ (
329
+ 300,
330
+ 10,
331
+ (30, 10),
332
+ ValueError,
333
+ ), # tall-skinny regular thin layers; recursion_depth=3
334
+ (
335
+ 300,
336
+ 10,
337
+ (20, 10),
338
+ ValueError,
339
+ ), # tall-skinny regular thin layers; recursion_depth=4
340
+ (10, 5, 10, None), # single block tall
341
+ (5, 10, 10, None), # single block short
342
+ (10, 10, 10, None), # single block square
343
+ (10, 40, (10, 10), None), # short-fat regular blocks
344
+ (10, 40, (10, 15), None), # short-fat irregular blocks
345
+ (10, 40, (10, (15, 5, 5, 8, 7)), None), # short-fat non-uniform chunks (why?)
346
+ (20, 20, 10, ValueError), # 2x2 regular blocks
347
+ ],
348
+ )
349
+ def test_sfqr(m, n, chunks, error_type):
350
+ mat = np.random.default_rng().random((m, n))
351
+ data = da.from_array(mat, chunks=chunks, name="A")
352
+ m_q = m
353
+ n_q = min(m, n)
354
+ m_r = n_q
355
+ n_r = n
356
+ m_qtq = n_q
357
+
358
+ if error_type is None:
359
+ q, r = sfqr(data)
360
+ assert_eq((m_q, n_q), q.shape) # shape check
361
+ assert_eq((m_r, n_r), r.shape) # shape check
362
+ assert_eq(mat, da.dot(q, r)) # accuracy check
363
+ assert_eq(np.eye(m_qtq, m_qtq), da.dot(q.T, q)) # q must be orthonormal
364
+ assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
365
+ else:
366
+ with pytest.raises(error_type):
367
+ q, r = sfqr(data)
368
+
369
+
370
+ @pytest.mark.parametrize(
371
+ "m,n,chunks,error_type",
372
+ [
373
+ (20, 10, 10, None), # tall-skinny regular blocks
374
+ (20, 10, (3, 10), None), # tall-skinny regular fat layers
375
+ (20, 10, ((8, 4, 8), 10), None), # tall-skinny irregular fat layers
376
+ (40, 10, ((15, 5, 5, 8, 7), 10), None), # tall-skinny non-uniform chunks (why?)
377
+ (128, 2, (16, 2), None), # tall-skinny regular thin layers; recursion_depth=1
378
+ (
379
+ 129,
380
+ 2,
381
+ (16, 2),
382
+ None,
383
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
384
+ (
385
+ 130,
386
+ 2,
387
+ (16, 2),
388
+ None,
389
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
390
+ (
391
+ 131,
392
+ 2,
393
+ (16, 2),
394
+ None,
395
+ ), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
396
+ (300, 10, (40, 10), None), # tall-skinny regular thin layers; recursion_depth=2
397
+ (300, 10, (30, 10), None), # tall-skinny regular thin layers; recursion_depth=3
398
+ (300, 10, (20, 10), None), # tall-skinny regular thin layers; recursion_depth=4
399
+ (10, 5, 10, None), # single block tall
400
+ (5, 10, 10, None), # single block short
401
+ (10, 10, 10, None), # single block square
402
+ (10, 40, (10, 10), None), # short-fat regular blocks
403
+ (10, 40, (10, 15), None), # short-fat irregular blocks
404
+ (10, 40, (10, (15, 5, 5, 8, 7)), None), # short-fat non-uniform chunks (why?)
405
+ (20, 20, 10, NotImplementedError), # 2x2 regular blocks
406
+ ],
407
+ )
408
+ def test_qr(m, n, chunks, error_type):
409
+ mat = np.random.default_rng().random((m, n))
410
+ data = da.from_array(mat, chunks=chunks, name="A")
411
+ m_q = m
412
+ n_q = min(m, n)
413
+ m_r = n_q
414
+ n_r = n
415
+ m_qtq = n_q
416
+
417
+ if error_type is None:
418
+ q, r = qr(data)
419
+ assert_eq((m_q, n_q), q.shape) # shape check
420
+ assert_eq((m_r, n_r), r.shape) # shape check
421
+ assert_eq(mat, da.dot(q, r)) # accuracy check
422
+ assert_eq(np.eye(m_qtq, m_qtq), da.dot(q.T, q)) # q must be orthonormal
423
+ assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
424
+ else:
425
+ with pytest.raises(error_type):
426
+ q, r = qr(data)
427
+
428
+
429
+ def test_linalg_consistent_names():
430
+ m, n = 20, 10
431
+ mat = np.random.default_rng().random((m, n))
432
+ data = da.from_array(mat, chunks=(10, n), name="A")
433
+
434
+ q1, r1 = qr(data)
435
+ q2, r2 = qr(data)
436
+ assert same_keys(q1, q2)
437
+ assert same_keys(r1, r2)
438
+
439
+ u1, s1, v1 = svd(data)
440
+ u2, s2, v2 = svd(data)
441
+ assert same_keys(u1, u2)
442
+ assert same_keys(s1, s2)
443
+ assert same_keys(v1, v2)
444
+
445
+
446
+ @pytest.mark.parametrize("m,n", [(10, 20), (15, 15), (20, 10)])
447
+ def test_dask_svd_self_consistent(m, n):
448
+ a = np.random.default_rng().random((m, n))
449
+ d_a = da.from_array(a, chunks=(3, n), name="A")
450
+
451
+ d_u, d_s, d_vt = da.linalg.svd(d_a)
452
+ u, s, vt = da.compute(d_u, d_s, d_vt)
453
+
454
+ for d_e, e in zip([d_u, d_s, d_vt], [u, s, vt]):
455
+ assert d_e.shape == e.shape
456
+ assert d_e.dtype == e.dtype
457
+
458
+
459
+ def test_svd_compute_uv_false_returns_singular_values():
460
+ a = np.random.default_rng().random((20, 10))
461
+ d_a = da.from_array(a, chunks=(5, 10))
462
+
463
+ d_s = da.linalg.svd(d_a, compute_uv=False)
464
+
465
+ assert_eq(d_s, np.linalg.svd(a, full_matrices=False, compute_uv=False))
466
+
467
+
468
+ def test_svd_compute_uv_false_ignores_full_matrices():
469
+ a = np.random.default_rng().random((20, 10))
470
+ d_a = da.from_array(a, chunks=(5, 10))
471
+
472
+ d_s = da.linalg.svd(d_a, full_matrices=True, compute_uv=False)
473
+
474
+ assert_eq(d_s, np.linalg.svd(a, full_matrices=True, compute_uv=False))
475
+
476
+
477
+ def test_svd_full_matrices_not_supported():
478
+ a = da.ones((20, 10), chunks=(5, 10))
479
+
480
+ with pytest.raises(NotImplementedError, match="full_matrices=True"):
481
+ da.linalg.svd(a, full_matrices=True)
482
+
483
+
484
+ @pytest.mark.parametrize("iterator", ["power", "QR"])
485
+ def test_svd_compressed_compute(iterator):
486
+ x = da.ones((100, 100), chunks=(10, 10))
487
+ u, s, v = da.linalg.svd_compressed(x, k=2, iterator=iterator, n_power_iter=1, compute=True, seed=123)
488
+ uu, ss, vv = da.linalg.svd_compressed(x, k=2, iterator=iterator, n_power_iter=1, seed=123)
489
+
490
+ assert len(v.dask) < len(vv.dask)
491
+ assert_eq(v, vv)
492
+
493
+
494
+ @pytest.mark.parametrize("iterator", [("power", 2), ("QR", 2)])
495
+ def test_svd_compressed(iterator):
496
+ m, n = 100, 50
497
+ r = 5
498
+ a = da.random.default_rng().random((m, n), chunks=(m, n))
499
+
500
+ # calculate approximation and true singular values
501
+ u, s, vt = svd_compressed(a, 2 * r, iterator=iterator[0], n_power_iter=iterator[1], seed=4321) # worst case
502
+ s_true = scipy.linalg.svd(a.compute(), compute_uv=False)
503
+
504
+ # compute the difference with original matrix
505
+ norm = scipy.linalg.norm((a - (u[:, :r] * s[:r]) @ vt[:r, :]).compute(), 2)
506
+
507
+ # ||a-a_hat||_2 <= (1+tol)s_{k+1}: based on eq. 1.10/1.11:
508
+ # Halko, Nathan, Per-Gunnar Martinsson, and Joel A. Tropp.
509
+ # "Finding structure with randomness: Probabilistic algorithms for constructing
510
+ # approximate matrix decompositions." SIAM review 53.2 (2011): 217-288.
511
+ frac = norm / s_true[r + 1] - 1
512
+ # Tolerance determined via simulation to be slightly above max norm of difference matrix in 10k samples.
513
+ # See https://github.com/dask/dask/pull/6799#issuecomment-726631175 for more details.
514
+ tol = 0.4
515
+ assert frac < tol
516
+
517
+ assert_eq(np.eye(r, r), da.dot(u[:, :r].T, u[:, :r])) # u must be orthonormal
518
+ assert_eq(np.eye(r, r), da.dot(vt[:r, :], vt[:r, :].T)) # v must be orthonormal
519
+
520
+
521
+ @pytest.mark.parametrize("input_dtype, output_dtype", [(np.float32, np.float32), (np.float64, np.float64)])
522
+ def test_svd_compressed_dtype_preservation(input_dtype, output_dtype):
523
+ x = da.random.default_rng().random((50, 50), chunks=(50, 50)).astype(input_dtype)
524
+ u, s, vt = svd_compressed(x, 1, seed=4321)
525
+ assert u.dtype == s.dtype == vt.dtype == output_dtype
526
+
527
+
528
+ @pytest.mark.parametrize("chunks", [(10, 50), (50, 10), (-1, -1)])
529
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
530
+ def test_svd_dtype_preservation(chunks, dtype):
531
+ x = da.random.default_rng().random((50, 50), chunks=chunks).astype(dtype)
532
+ u, s, v = svd(x)
533
+ assert u.dtype == s.dtype == v.dtype == dtype
534
+
535
+
536
+ def test_svd_compressed_deterministic():
537
+ m, n = 30, 25
538
+ x = da.random.default_rng(1234).random(size=(m, n), chunks=(5, 5))
539
+ u, s, vt = svd_compressed(x, 3, seed=1234)
540
+ u2, s2, vt2 = svd_compressed(x, 3, seed=1234)
541
+
542
+ assert all(da.compute((u == u2).all(), (s == s2).all(), (vt == vt2).all()))
543
+
544
+
545
+ @pytest.mark.parametrize("m", [5, 10, 15, 20])
546
+ @pytest.mark.parametrize("n", [5, 10, 15, 20])
547
+ @pytest.mark.parametrize("k", [5])
548
+ @pytest.mark.parametrize("chunks", [(5, 10), (10, 5)])
549
+ def test_svd_compressed_shapes(m, n, k, chunks):
550
+ x = da.random.default_rng().random(size=(m, n), chunks=chunks)
551
+ u, s, v = svd_compressed(x, k, n_power_iter=1, compute=True, seed=1)
552
+ u, s, v = da.compute(u, s, v)
553
+ r = min(m, n, k)
554
+ assert u.shape == (m, r)
555
+ assert s.shape == (r,)
556
+ assert v.shape == (r, n)
557
+
558
+
559
+ def _check_lu_result(p, l, u, A):
560
+ assert np.allclose(p.dot(l).dot(u), A)
561
+
562
+ # check triangulars
563
+ assert_eq(l, da.tril(l), check_graph=False)
564
+ assert_eq(u, da.triu(u), check_graph=False)
565
+
566
+
567
+ def test_lu_1():
568
+ A1 = np.array([[7, 3, -1, 2], [3, 8, 1, -4], [-1, 1, 4, -1], [2, -4, -1, 6]])
569
+
570
+ A2 = np.array(
571
+ [
572
+ [7, 0, 0, 0, 0, 0],
573
+ [0, 8, 0, 0, 0, 0],
574
+ [0, 0, 4, 0, 0, 0],
575
+ [0, 0, 0, 6, 0, 0],
576
+ [0, 0, 0, 0, 3, 0],
577
+ [0, 0, 0, 0, 0, 5],
578
+ ]
579
+ )
580
+ # without shuffle
581
+ for A, chunk in zip([A1, A2], [2, 2]):
582
+ dA = da.from_array(A, chunks=(chunk, chunk))
583
+ p, l, u = scipy.linalg.lu(A)
584
+ dp, dl, du = da.linalg.lu(dA)
585
+ assert_eq(p, dp, check_graph=False)
586
+ assert_eq(l, dl, check_graph=False)
587
+ assert_eq(u, du, check_graph=False)
588
+ _check_lu_result(dp, dl, du, A)
589
+
590
+ A3 = np.array(
591
+ [
592
+ [7, 3, 2, 1, 4, 1],
593
+ [7, 11, 5, 2, 5, 2],
594
+ [21, 25, 16, 10, 16, 5],
595
+ [21, 41, 18, 13, 16, 11],
596
+ [14, 46, 23, 24, 21, 22],
597
+ [0, 56, 29, 17, 14, 8],
598
+ ]
599
+ )
600
+
601
+ # with shuffle
602
+ for A, chunk in zip([A3], [2]):
603
+ dA = da.from_array(A, chunks=(chunk, chunk))
604
+ p, l, u = scipy.linalg.lu(A)
605
+ dp, dl, du = da.linalg.lu(dA)
606
+ _check_lu_result(dp, dl, du, A)
607
+
608
+
609
+ @pytest.mark.slow
610
+ @pytest.mark.parametrize("size", [10, 20, 30, 50])
611
+ @pytest.mark.filterwarnings("ignore:Increasing:dask.array.core.PerformanceWarning")
612
+ def test_lu_2(size):
613
+ rng = np.random.default_rng(10)
614
+ A = rng.integers(0, 10, (size, size))
615
+
616
+ dA = da.from_array(A, chunks=(5, 5))
617
+ dp, dl, du = da.linalg.lu(dA)
618
+ _check_lu_result(dp, dl, du, A)
619
+
620
+
621
+ @pytest.mark.slow
622
+ @pytest.mark.parametrize("size", [50, 100, 200])
623
+ def test_lu_3(size):
624
+ rng = np.random.default_rng(10)
625
+ A = rng.integers(0, 10, (size, size))
626
+
627
+ dA = da.from_array(A, chunks=(25, 25))
628
+ dp, dl, du = da.linalg.lu(dA)
629
+ _check_lu_result(dp, dl, du, A)
630
+
631
+
632
+ def test_lu_errors():
633
+ rng = np.random.default_rng()
634
+
635
+ A = rng.integers(0, 11, (10, 10, 10))
636
+ dA = da.from_array(A, chunks=(5, 5, 5))
637
+ pytest.raises(ValueError, lambda: da.linalg.lu(dA))
638
+
639
+ A = rng.integers(0, 11, (10, 8))
640
+ dA = da.from_array(A, chunks=(5, 4))
641
+ pytest.raises(ValueError, lambda: da.linalg.lu(dA))
642
+
643
+ A = rng.integers(0, 11, (20, 20))
644
+ dA = da.from_array(A, chunks=(5, 4))
645
+ pytest.raises(ValueError, lambda: da.linalg.lu(dA))
646
+
647
+
648
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (70, 20)])
649
+ def test_solve_triangular_vector(shape, chunk):
650
+ rng = np.random.default_rng(1)
651
+
652
+ A = rng.integers(1, 11, (shape, shape))
653
+ b = rng.integers(1, 11, shape)
654
+
655
+ # upper
656
+ Au = np.triu(A)
657
+ dAu = da.from_array(Au, (chunk, chunk))
658
+ db = da.from_array(b, chunk)
659
+ res = da.linalg.solve_triangular(dAu, db)
660
+ assert_eq(res, scipy.linalg.solve_triangular(Au, b))
661
+ assert_eq(dAu.dot(res), b.astype(float), rtol=1e-4)
662
+
663
+ # lower
664
+ Al = np.tril(A)
665
+ dAl = da.from_array(Al, (chunk, chunk))
666
+ db = da.from_array(b, chunk)
667
+ res = da.linalg.solve_triangular(dAl, db, lower=True)
668
+ assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
669
+ assert_eq(dAl.dot(res), b.astype(float))
670
+
671
+
672
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (50, 20)])
673
+ def test_solve_triangular_matrix(shape, chunk):
674
+ rng = np.random.default_rng(1)
675
+
676
+ A = rng.integers(1, 10, (shape, shape))
677
+ b = rng.integers(1, 10, (shape, 5))
678
+
679
+ # upper
680
+ Au = np.triu(A)
681
+ dAu = da.from_array(Au, (chunk, chunk))
682
+ db = da.from_array(b, (chunk, 5))
683
+ res = da.linalg.solve_triangular(dAu, db)
684
+ assert_eq(res, scipy.linalg.solve_triangular(Au, b))
685
+ assert_eq(dAu.dot(res), b.astype(float))
686
+
687
+ # lower
688
+ Al = np.tril(A)
689
+ dAl = da.from_array(Al, (chunk, chunk))
690
+ db = da.from_array(b, (chunk, 5))
691
+ res = da.linalg.solve_triangular(dAl, db, lower=True)
692
+ assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
693
+ assert_eq(dAl.dot(res), b.astype(float))
694
+
695
+
696
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (50, 20)])
697
+ def test_solve_triangular_matrix2(shape, chunk):
698
+ rng = np.random.default_rng(1)
699
+
700
+ A = rng.integers(1, 10, (shape, shape))
701
+ b = rng.integers(1, 10, (shape, shape))
702
+
703
+ # upper
704
+ Au = np.triu(A)
705
+ dAu = da.from_array(Au, (chunk, chunk))
706
+ db = da.from_array(b, (chunk, chunk))
707
+ res = da.linalg.solve_triangular(dAu, db)
708
+ assert_eq(res, scipy.linalg.solve_triangular(Au, b))
709
+ assert_eq(dAu.dot(res), b.astype(float))
710
+
711
+ # lower
712
+ Al = np.tril(A)
713
+ dAl = da.from_array(Al, (chunk, chunk))
714
+ db = da.from_array(b, (chunk, chunk))
715
+ res = da.linalg.solve_triangular(dAl, db, lower=True)
716
+ assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
717
+ assert_eq(dAl.dot(res), b.astype(float))
718
+
719
+
720
+ def test_solve_triangular_errors():
721
+ A = np.random.default_rng().integers(0, 10, (10, 10, 10))
722
+ b = np.random.default_rng().integers(1, 10, 10)
723
+ dA = da.from_array(A, chunks=(5, 5, 5))
724
+ db = da.from_array(b, chunks=5)
725
+ pytest.raises(ValueError, lambda: da.linalg.solve_triangular(dA, db))
726
+
727
+ A = np.random.default_rng().integers(0, 10, (10, 10))
728
+ b = np.random.default_rng().integers(1, 10, 10)
729
+ dA = da.from_array(A, chunks=(3, 3))
730
+ db = da.from_array(b, chunks=5)
731
+ pytest.raises(ValueError, lambda: da.linalg.solve_triangular(dA, db))
732
+
733
+
734
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10)])
735
+ def test_solve(shape, chunk):
736
+ rng = np.random.default_rng(1)
737
+
738
+ A = rng.integers(1, 10, (shape, shape))
739
+ dA = da.from_array(A, (chunk, chunk))
740
+
741
+ # vector
742
+ b = rng.integers(1, 10, shape)
743
+ db = da.from_array(b, chunk)
744
+
745
+ res = da.linalg.solve(dA, db)
746
+ assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
747
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
748
+
749
+ # tall-and-skinny matrix
750
+ b = rng.integers(1, 10, (shape, 5))
751
+ db = da.from_array(b, (chunk, 5))
752
+
753
+ res = da.linalg.solve(dA, db)
754
+ assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
755
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
756
+
757
+ # matrix
758
+ b = rng.integers(1, 10, (shape, shape))
759
+ db = da.from_array(b, (chunk, chunk))
760
+
761
+ res = da.linalg.solve(dA, db)
762
+ assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
763
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
764
+
765
+
766
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10)])
767
+ def test_inv(shape, chunk):
768
+ rng = np.random.default_rng(1)
769
+
770
+ A = rng.integers(1, 10, (shape, shape))
771
+ dA = da.from_array(A, (chunk, chunk))
772
+
773
+ res = da.linalg.inv(dA)
774
+ assert_eq(res, scipy.linalg.inv(A), check_graph=False)
775
+ assert_eq(dA.dot(res), np.eye(shape, dtype=float), check_graph=False)
776
+
777
+
778
+ def _get_symmat(size):
779
+ rng = np.random.default_rng(1)
780
+ A = rng.integers(1, 21, (size, size))
781
+ lA = np.tril(A)
782
+ return lA.dot(lA.T)
783
+
784
+
785
+ # `sym_pos` kwarg was deprecated in scipy 1.9.0
786
+ # ref: https://github.com/dask/dask/issues/9335
787
+ def _scipy_linalg_solve(a, b, assume_a):
788
+ if Version(scipy.__version__) >= Version("1.9.0"):
789
+ return scipy.linalg.solve(a=a, b=b, assume_a=assume_a)
790
+ elif assume_a == "pos":
791
+ return scipy.linalg.solve(a=a, b=b, sym_pos=True)
792
+ else:
793
+ return scipy.linalg.solve(a=a, b=b)
794
+
795
+
796
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (30, 6)])
797
+ def test_solve_assume_a(shape, chunk):
798
+ rng = np.random.default_rng(1)
799
+
800
+ A = _get_symmat(shape)
801
+ dA = da.from_array(A, (chunk, chunk))
802
+
803
+ # vector
804
+ b = rng.integers(1, 10, shape)
805
+ db = da.from_array(b, chunk)
806
+
807
+ res = da.linalg.solve(dA, db, assume_a="pos")
808
+ assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
809
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
810
+
811
+ # tall-and-skinny matrix
812
+ b = rng.integers(1, 10, (shape, 5))
813
+ db = da.from_array(b, (chunk, 5))
814
+
815
+ res = da.linalg.solve(dA, db, assume_a="pos")
816
+ assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
817
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
818
+
819
+ # matrix
820
+ b = rng.integers(1, 10, (shape, shape))
821
+ db = da.from_array(b, (chunk, chunk))
822
+
823
+ res = da.linalg.solve(dA, db, assume_a="pos")
824
+ assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
825
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
826
+
827
+ with pytest.warns(FutureWarning, match="sym_pos keyword is deprecated"):
828
+ res = da.linalg.solve(dA, db, sym_pos=True)
829
+ assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
830
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
831
+
832
+ with pytest.warns(FutureWarning, match="sym_pos keyword is deprecated"):
833
+ res = da.linalg.solve(dA, db, sym_pos=False)
834
+ assert_eq(res, _scipy_linalg_solve(A, b, assume_a="gen"), check_graph=False)
835
+ assert_eq(dA.dot(res), b.astype(float), check_graph=False)
836
+
837
+
838
+ @pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (12, 3), (30, 3), (30, 6)])
839
+ def test_cholesky(shape, chunk):
840
+ A = _get_symmat(shape)
841
+ dA = da.from_array(A, (chunk, chunk))
842
+ assert_eq(
843
+ da.linalg.cholesky(dA).compute(),
844
+ scipy.linalg.cholesky(A),
845
+ check_graph=False,
846
+ check_chunks=False,
847
+ )
848
+ assert_eq(
849
+ da.linalg.cholesky(dA, lower=True),
850
+ scipy.linalg.cholesky(A, lower=True),
851
+ check_graph=False,
852
+ check_chunks=False,
853
+ )
854
+
855
+
856
+ @pytest.mark.parametrize("iscomplex", [False, True])
857
+ @pytest.mark.parametrize(("nrow", "ncol", "chunk"), [(20, 10, 5), (100, 10, 10)])
858
+ def test_lstsq(nrow, ncol, chunk, iscomplex):
859
+ rng = np.random.default_rng(1)
860
+ A = rng.integers(1, 20, (nrow, ncol))
861
+ b = rng.integers(1, 20, nrow)
862
+ if iscomplex:
863
+ A = A + 1.0j * rng.integers(1, 20, A.shape)
864
+ b = b + 1.0j * rng.integers(1, 20, b.shape)
865
+
866
+ dA = da.from_array(A, (chunk, ncol))
867
+ db = da.from_array(b, chunk)
868
+
869
+ x, r, rank, s = np.linalg.lstsq(A, b, rcond=-1)
870
+ dx, dr, drank, ds = da.linalg.lstsq(dA, db)
871
+
872
+ assert_eq(dx, x)
873
+ assert_eq(dr, r)
874
+ assert drank.compute() == rank
875
+ assert_eq(ds, s)
876
+
877
+ # reduce rank causes multicollinearity, only compare rank
878
+ A[:, 1] = A[:, 2]
879
+ dA = da.from_array(A, (chunk, ncol))
880
+ db = da.from_array(b, chunk)
881
+ x, r, rank, s = np.linalg.lstsq(A, b, rcond=np.finfo(np.double).eps * max(nrow, ncol))
882
+ assert rank == ncol - 1
883
+ dx, dr, drank, ds = da.linalg.lstsq(dA, db)
884
+ assert drank.compute() == rank
885
+
886
+ # 2D case
887
+ A = rng.integers(1, 20, (nrow, ncol))
888
+ b2D = rng.integers(1, 20, (nrow, ncol // 2))
889
+ if iscomplex:
890
+ A = A + 1.0j * rng.integers(1, 20, A.shape)
891
+ b2D = b2D + 1.0j * rng.integers(1, 20, b2D.shape)
892
+ dA = da.from_array(A, (chunk, ncol))
893
+ db2D = da.from_array(b2D, (chunk, ncol // 2))
894
+ x, r, rank, s = np.linalg.lstsq(A, b2D, rcond=-1)
895
+ dx, dr, drank, ds = da.linalg.lstsq(dA, db2D)
896
+
897
+ assert_eq(dx, x)
898
+ assert_eq(dr, r)
899
+ assert drank.compute() == rank
900
+ assert_eq(ds, s)
901
+
902
+
903
+ def test_no_chunks_svd():
904
+ x = np.random.default_rng().random((100, 10))
905
+ u, s, v = np.linalg.svd(x, full_matrices=False)
906
+
907
+ for chunks in [((np.nan,) * 10, (10,)), ((np.nan,) * 10, (np.nan,))]:
908
+ dx = da.from_array(x, chunks=(10, 10))
909
+ dx._chunks = chunks
910
+
911
+ du, ds, dv = da.linalg.svd(dx)
912
+
913
+ assert_eq(s, ds)
914
+ assert_eq(u.dot(np.diag(s)).dot(v), du.dot(da.diag(ds)).dot(dv))
915
+ assert_eq(du.T.dot(du), np.eye(10))
916
+ assert_eq(dv.T.dot(dv), np.eye(10))
917
+
918
+ dx = da.from_array(x, chunks=(10, 10))
919
+ dx._chunks = ((np.nan,) * 10, (np.nan,))
920
+ assert_eq(abs(v), abs(dv))
921
+ assert_eq(abs(u), abs(du))
922
+
923
+
924
+ @pytest.mark.parametrize("shape", [(10, 20), (10, 10), (20, 10)])
925
+ @pytest.mark.parametrize("chunks", [(-1, -1), (10, -1), (-1, 10)])
926
+ @pytest.mark.parametrize("dtype", ["f4", "f8"])
927
+ def test_svd_flip_correction(shape, chunks, dtype):
928
+ # Verify that sign-corrected SVD results can still
929
+ # be used to reconstruct inputs
930
+ x = da.random.default_rng().random(size=shape, chunks=chunks).astype(dtype)
931
+ u, s, v = da.linalg.svd(x)
932
+
933
+ # Choose precision in evaluation based on float precision
934
+ decimal = 9 if np.dtype(dtype).itemsize > 4 else 6
935
+
936
+ # Validate w/ dask inputs
937
+ uf, vf = svd_flip(u, v)
938
+ assert uf.dtype == u.dtype
939
+ assert vf.dtype == v.dtype
940
+ np.testing.assert_almost_equal(np.asarray(np.dot(uf * s, vf)), x, decimal=decimal)
941
+
942
+ # Validate w/ numpy inputs
943
+ uc, vc = svd_flip(*da.compute(u, v))
944
+ assert uc.dtype == u.dtype
945
+ assert vc.dtype == v.dtype
946
+ np.testing.assert_almost_equal(np.asarray(np.dot(uc * s, vc)), x, decimal=decimal)
947
+
948
+
949
+ @pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "f16", "c8", "c16", "c32"])
950
+ @pytest.mark.parametrize("u_based", [True, False])
951
+ def test_svd_flip_sign(dtype, u_based):
952
+ try:
953
+ x = np.array(
954
+ [[1, -1, 1, -1], [1, -1, 1, -1], [-1, 1, 1, -1], [-1, 1, 1, -1]],
955
+ dtype=dtype,
956
+ )
957
+ except TypeError:
958
+ pytest.skip("128-bit floats not supported by NumPy")
959
+ u, v = svd_flip(x, x.T, u_based_decision=u_based)
960
+ assert u.dtype == x.dtype
961
+ assert v.dtype == x.dtype
962
+ # Verify that all singular vectors have same
963
+ # sign except for the last one (i.e. last column)
964
+ y = x.copy()
965
+ y[:, -1] *= y.dtype.type(-1)
966
+ assert_eq(u, y)
967
+ assert_eq(v, y.T)
968
+
969
+
970
+ @pytest.mark.parametrize("chunks", [(10, -1), (-1, 10), (9, -1), (-1, 9)])
971
+ @pytest.mark.parametrize("shape", [(10, 100), (100, 10), (10, 10)])
972
+ def test_svd_supported_array_shapes(chunks, shape):
973
+ # Test the following cases for tall-skinny, short-fat and square arrays:
974
+ # - no chunking
975
+ # - chunking that contradicts shape (e.g. a 10x100 array with 9x100 chunks)
976
+ # - chunking that aligns with shape (e.g. a 10x100 array with 10x9 chunks)
977
+ x = np.random.default_rng().random(shape)
978
+ dx = da.from_array(x, chunks=chunks)
979
+
980
+ du, ds, dv = da.linalg.svd(dx)
981
+ du, dv = da.compute(du, dv)
982
+
983
+ nu, ns, nv = np.linalg.svd(x, full_matrices=False)
984
+
985
+ # Correct signs before comparison
986
+ du, dv = svd_flip(du, dv)
987
+ nu, nv = svd_flip(nu, nv)
988
+
989
+ assert_eq(du, nu)
990
+ assert_eq(ds, ns)
991
+ assert_eq(dv, nv)
992
+
993
+
994
+ def test_svd_incompatible_chunking():
995
+ with pytest.raises(NotImplementedError, match="Array must be chunked in one dimension only"):
996
+ x = da.random.default_rng().random((10, 10), chunks=(5, 5))
997
+ da.linalg.svd(x)
998
+
999
+
1000
+ @pytest.mark.parametrize("ndim", [0, 1, 3])
1001
+ def test_svd_incompatible_dimensions(ndim):
1002
+ with pytest.raises(ValueError, match="Array must be 2D"):
1003
+ x = da.random.default_rng().random((10,) * ndim, chunks=(-1,) * ndim)
1004
+ da.linalg.svd(x)
1005
+
1006
+
1007
+ @pytest.mark.xfail(
1008
+ sys.platform == "darwin" and _np_version < (1, 22),
1009
+ reason="https://github.com/dask/dask/issues/7189",
1010
+ strict=False,
1011
+ )
1012
+ @pytest.mark.parametrize(
1013
+ "shape, chunks, axis",
1014
+ [[(5,), (2,), None], [(5,), (2,), 0], [(5,), (2,), (0,)], [(5, 6), (2, 2), None]],
1015
+ )
1016
+ @pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
1017
+ @pytest.mark.parametrize("keepdims", [False, True])
1018
+ def test_norm_any_ndim(shape, chunks, axis, norm, keepdims):
1019
+ a = np.random.default_rng().random(shape)
1020
+ d = da.from_array(a, chunks=chunks)
1021
+
1022
+ a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
1023
+ d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
1024
+
1025
+ assert_eq(a_r, d_r)
1026
+
1027
+
1028
+ @pytest.mark.xfail(
1029
+ _np_version < (1, 23),
1030
+ reason="https://github.com/numpy/numpy/pull/17709",
1031
+ strict=False,
1032
+ )
1033
+ @pytest.mark.parametrize("precision", ["single", "double"])
1034
+ @pytest.mark.parametrize("isreal", [True, False])
1035
+ @pytest.mark.parametrize("keepdims", [False, True])
1036
+ @pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
1037
+ def test_norm_any_prec(norm, keepdims, precision, isreal):
1038
+ shape, chunks, axis = (5,), (2,), None
1039
+
1040
+ precs_r = {"single": "float32", "double": "float64"}
1041
+ precs_c = {"single": "complex64", "double": "complex128"}
1042
+
1043
+ dtype = precs_r[precision] if isreal else precs_c[precision]
1044
+
1045
+ a = np.random.default_rng().random(shape).astype(dtype)
1046
+ d = da.from_array(a, chunks=chunks)
1047
+ d_a = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
1048
+ d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
1049
+
1050
+ assert d_r.dtype == precs_r[precision]
1051
+ assert d_r.dtype == d_a.dtype
1052
+
1053
+
1054
+ @pytest.mark.slow
1055
+ @pytest.mark.xfail(
1056
+ sys.platform == "darwin" and _np_version < (1, 22),
1057
+ reason="https://github.com/dask/dask/issues/7189",
1058
+ strict=False,
1059
+ )
1060
+ @pytest.mark.parametrize(
1061
+ "shape, chunks",
1062
+ [
1063
+ [(5,), (2,)],
1064
+ [(5, 3), (2, 2)],
1065
+ [(4, 5, 3), (2, 2, 2)],
1066
+ [(4, 5, 2, 3), (2, 2, 2, 2)],
1067
+ [(2, 5, 2, 4, 3), (2, 2, 2, 2, 2)],
1068
+ ],
1069
+ )
1070
+ @pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
1071
+ @pytest.mark.parametrize("keepdims", [False, True])
1072
+ def test_norm_any_slice(shape, chunks, norm, keepdims):
1073
+ a = np.random.default_rng().random(shape)
1074
+ d = da.from_array(a, chunks=chunks)
1075
+
1076
+ for firstaxis in range(len(shape)):
1077
+ for secondaxis in range(len(shape)):
1078
+ if firstaxis != secondaxis:
1079
+ axis = (firstaxis, secondaxis)
1080
+ else:
1081
+ axis = firstaxis
1082
+ a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
1083
+ d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
1084
+ assert_eq(a_r, d_r)
1085
+
1086
+
1087
+ @pytest.mark.parametrize("shape, chunks, axis", [[(5,), (2,), None], [(5,), (2,), 0], [(5,), (2,), (0,)]])
1088
+ @pytest.mark.parametrize("norm", [0, 2, -2, 0.5])
1089
+ @pytest.mark.parametrize("keepdims", [False, True])
1090
+ def test_norm_1dim(shape, chunks, axis, norm, keepdims):
1091
+ a = np.random.default_rng().random(shape)
1092
+ d = da.from_array(a, chunks=chunks)
1093
+
1094
+ a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
1095
+ d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
1096
+ assert_eq(a_r, d_r)
1097
+
1098
+
1099
+ @pytest.mark.parametrize(
1100
+ "shape, chunks, axis",
1101
+ [[(5, 6), (2, 2), None], [(5, 6), (2, 2), (0, 1)], [(5, 6), (2, 2), (1, 0)]],
1102
+ )
1103
+ @pytest.mark.parametrize("norm", ["fro", "nuc", 2, -2])
1104
+ @pytest.mark.parametrize("keepdims", [False, True])
1105
+ def test_norm_2dim(shape, chunks, axis, norm, keepdims):
1106
+ a = np.random.default_rng().random(shape)
1107
+ d = da.from_array(a, chunks=chunks)
1108
+
1109
+ # Need one chunk on last dimension for svd.
1110
+ if norm == "nuc" or norm == 2 or norm == -2:
1111
+ d = d.rechunk({-1: -1})
1112
+
1113
+ a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
1114
+ d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
1115
+
1116
+ assert_eq(a_r, d_r)
1117
+
1118
+
1119
+ @pytest.mark.parametrize(
1120
+ "shape, chunks, axis",
1121
+ [[(3, 2, 4), (2, 2, 2), (1, 2)], [(2, 3, 4, 5), (2, 2, 2, 2), (-1, -2)]],
1122
+ )
1123
+ @pytest.mark.parametrize("norm", ["nuc", 2, -2])
1124
+ @pytest.mark.parametrize("keepdims", [False, True])
1125
+ def test_norm_implemented_errors(shape, chunks, axis, norm, keepdims):
1126
+ a = np.random.default_rng().random(shape)
1127
+ d = da.from_array(a, chunks=chunks)
1128
+ if len(shape) > 2 and len(axis) == 2:
1129
+ with pytest.raises(NotImplementedError):
1130
+ da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)