dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,1130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
pytest.importorskip("numpy")
|
|
8
|
+
pytest.importorskip("scipy")
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import scipy.linalg
|
|
12
|
+
from packaging.version import Version
|
|
13
|
+
|
|
14
|
+
import dask_array as da
|
|
15
|
+
from dask_array.linalg import qr, sfqr, svd, svd_compressed, tsqr
|
|
16
|
+
from dask_array._numpy_compat import _np_version
|
|
17
|
+
from dask_array._test_utils import assert_eq, same_keys
|
|
18
|
+
from dask_array._utils import svd_flip
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.mark.parametrize(
|
|
22
|
+
"m,n,chunks,error_type",
|
|
23
|
+
[
|
|
24
|
+
(20, 10, 10, None), # tall-skinny regular blocks
|
|
25
|
+
(20, 10, (3, 10), None), # tall-skinny regular fat layers
|
|
26
|
+
(20, 10, ((8, 4, 8), 10), None), # tall-skinny irregular fat layers
|
|
27
|
+
(40, 10, ((15, 5, 5, 8, 7), 10), None), # tall-skinny non-uniform chunks (why?)
|
|
28
|
+
(128, 2, (16, 2), None), # tall-skinny regular thin layers; recursion_depth=1
|
|
29
|
+
(
|
|
30
|
+
129,
|
|
31
|
+
2,
|
|
32
|
+
(16, 2),
|
|
33
|
+
None,
|
|
34
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
|
|
35
|
+
(
|
|
36
|
+
130,
|
|
37
|
+
2,
|
|
38
|
+
(16, 2),
|
|
39
|
+
None,
|
|
40
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
41
|
+
(
|
|
42
|
+
131,
|
|
43
|
+
2,
|
|
44
|
+
(16, 2),
|
|
45
|
+
None,
|
|
46
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
47
|
+
(300, 10, (40, 10), None), # tall-skinny regular thin layers; recursion_depth=2
|
|
48
|
+
(300, 10, (30, 10), None), # tall-skinny regular thin layers; recursion_depth=3
|
|
49
|
+
(300, 10, (20, 10), None), # tall-skinny regular thin layers; recursion_depth=4
|
|
50
|
+
(10, 5, 10, None), # single block tall
|
|
51
|
+
(5, 10, 10, None), # single block short
|
|
52
|
+
(10, 10, 10, None), # single block square
|
|
53
|
+
(10, 40, (10, 10), ValueError), # short-fat regular blocks
|
|
54
|
+
(10, 40, (10, 15), ValueError), # short-fat irregular blocks
|
|
55
|
+
(
|
|
56
|
+
10,
|
|
57
|
+
40,
|
|
58
|
+
(10, (15, 5, 5, 8, 7)),
|
|
59
|
+
ValueError,
|
|
60
|
+
), # short-fat non-uniform chunks (why?)
|
|
61
|
+
(20, 20, 10, ValueError), # 2x2 regular blocks
|
|
62
|
+
],
|
|
63
|
+
)
|
|
64
|
+
def test_tsqr(m, n, chunks, error_type):
|
|
65
|
+
mat = np.random.default_rng().random((m, n))
|
|
66
|
+
data = da.from_array(mat, chunks=chunks, name="A")
|
|
67
|
+
|
|
68
|
+
# qr
|
|
69
|
+
m_q = m
|
|
70
|
+
n_q = min(m, n)
|
|
71
|
+
m_r = n_q
|
|
72
|
+
n_r = n
|
|
73
|
+
|
|
74
|
+
# svd
|
|
75
|
+
m_u = m
|
|
76
|
+
n_u = min(m, n)
|
|
77
|
+
n_s = n_q
|
|
78
|
+
m_vh = n_q
|
|
79
|
+
n_vh = n
|
|
80
|
+
d_vh = max(m_vh, n_vh) # full matrix returned
|
|
81
|
+
|
|
82
|
+
if error_type is None:
|
|
83
|
+
# test QR
|
|
84
|
+
q, r = tsqr(data)
|
|
85
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
86
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
87
|
+
assert_eq(mat, da.dot(q, r)) # accuracy check
|
|
88
|
+
assert_eq(np.eye(n_q, n_q), da.dot(q.T, q)) # q must be orthonormal
|
|
89
|
+
assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
|
|
90
|
+
|
|
91
|
+
# test SVD
|
|
92
|
+
u, s, vh = tsqr(data, compute_svd=True)
|
|
93
|
+
s_exact = np.linalg.svd(mat)[1]
|
|
94
|
+
assert_eq(s, s_exact) # s must contain the singular values
|
|
95
|
+
assert_eq((m_u, n_u), u.shape) # shape check
|
|
96
|
+
assert_eq((n_s,), s.shape) # shape check
|
|
97
|
+
assert_eq((d_vh, d_vh), vh.shape) # shape check
|
|
98
|
+
assert_eq(np.eye(n_u, n_u), da.dot(u.T, u)) # u must be orthonormal
|
|
99
|
+
assert_eq(np.eye(d_vh, d_vh), da.dot(vh, vh.T)) # vh must be orthonormal
|
|
100
|
+
assert_eq(mat, da.dot(da.dot(u, da.diag(s)), vh[:n_q])) # accuracy check
|
|
101
|
+
else:
|
|
102
|
+
with pytest.raises(error_type):
|
|
103
|
+
q, r = tsqr(data)
|
|
104
|
+
with pytest.raises(error_type):
|
|
105
|
+
u, s, vh = tsqr(data, compute_svd=True)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.mark.parametrize(
|
|
109
|
+
"m_min,n_max,chunks,vary_rows,vary_cols,error_type",
|
|
110
|
+
[
|
|
111
|
+
(10, 5, (10, 5), True, False, None), # single block tall
|
|
112
|
+
(10, 5, (10, 5), False, True, None), # single block tall
|
|
113
|
+
(10, 5, (10, 5), True, True, None), # single block tall
|
|
114
|
+
(40, 5, (10, 5), True, False, None), # multiple blocks tall
|
|
115
|
+
(40, 5, (10, 5), False, True, None), # multiple blocks tall
|
|
116
|
+
(40, 5, (10, 5), True, True, None), # multiple blocks tall
|
|
117
|
+
(
|
|
118
|
+
300,
|
|
119
|
+
10,
|
|
120
|
+
(40, 10),
|
|
121
|
+
True,
|
|
122
|
+
False,
|
|
123
|
+
None,
|
|
124
|
+
), # tall-skinny regular thin layers; recursion_depth=2
|
|
125
|
+
(
|
|
126
|
+
300,
|
|
127
|
+
10,
|
|
128
|
+
(30, 10),
|
|
129
|
+
True,
|
|
130
|
+
False,
|
|
131
|
+
None,
|
|
132
|
+
), # tall-skinny regular thin layers; recursion_depth=3
|
|
133
|
+
(
|
|
134
|
+
300,
|
|
135
|
+
10,
|
|
136
|
+
(20, 10),
|
|
137
|
+
True,
|
|
138
|
+
False,
|
|
139
|
+
None,
|
|
140
|
+
), # tall-skinny regular thin layers; recursion_depth=4
|
|
141
|
+
(
|
|
142
|
+
300,
|
|
143
|
+
10,
|
|
144
|
+
(40, 10),
|
|
145
|
+
False,
|
|
146
|
+
True,
|
|
147
|
+
None,
|
|
148
|
+
), # tall-skinny regular thin layers; recursion_depth=2
|
|
149
|
+
(
|
|
150
|
+
300,
|
|
151
|
+
10,
|
|
152
|
+
(30, 10),
|
|
153
|
+
False,
|
|
154
|
+
True,
|
|
155
|
+
None,
|
|
156
|
+
), # tall-skinny regular thin layers; recursion_depth=3
|
|
157
|
+
(
|
|
158
|
+
300,
|
|
159
|
+
10,
|
|
160
|
+
(20, 10),
|
|
161
|
+
False,
|
|
162
|
+
True,
|
|
163
|
+
None,
|
|
164
|
+
), # tall-skinny regular thin layers; recursion_depth=4
|
|
165
|
+
(
|
|
166
|
+
300,
|
|
167
|
+
10,
|
|
168
|
+
(40, 10),
|
|
169
|
+
True,
|
|
170
|
+
True,
|
|
171
|
+
None,
|
|
172
|
+
), # tall-skinny regular thin layers; recursion_depth=2
|
|
173
|
+
(
|
|
174
|
+
300,
|
|
175
|
+
10,
|
|
176
|
+
(30, 10),
|
|
177
|
+
True,
|
|
178
|
+
True,
|
|
179
|
+
None,
|
|
180
|
+
), # tall-skinny regular thin layers; recursion_depth=3
|
|
181
|
+
(
|
|
182
|
+
300,
|
|
183
|
+
10,
|
|
184
|
+
(20, 10),
|
|
185
|
+
True,
|
|
186
|
+
True,
|
|
187
|
+
None,
|
|
188
|
+
), # tall-skinny regular thin layers; recursion_depth=4
|
|
189
|
+
],
|
|
190
|
+
)
|
|
191
|
+
def test_tsqr_uncertain(m_min, n_max, chunks, vary_rows, vary_cols, error_type):
|
|
192
|
+
mat = np.random.default_rng().random((m_min * 2, n_max))
|
|
193
|
+
m, n = m_min * 2, n_max
|
|
194
|
+
mat[0:m_min, 0] += 1
|
|
195
|
+
_c0 = mat[:, 0]
|
|
196
|
+
_r0 = mat[0, :]
|
|
197
|
+
c0 = da.from_array(_c0, chunks=m_min, name="c")
|
|
198
|
+
r0 = da.from_array(_r0, chunks=n_max, name="r")
|
|
199
|
+
data = da.from_array(mat, chunks=chunks, name="A")
|
|
200
|
+
if vary_rows:
|
|
201
|
+
data = data[c0 > 0.5, :]
|
|
202
|
+
mat = mat[_c0 > 0.5, :]
|
|
203
|
+
m = mat.shape[0]
|
|
204
|
+
if vary_cols:
|
|
205
|
+
data = data[:, r0 > 0.5]
|
|
206
|
+
mat = mat[:, _r0 > 0.5]
|
|
207
|
+
n = mat.shape[1]
|
|
208
|
+
|
|
209
|
+
# qr
|
|
210
|
+
m_q = m
|
|
211
|
+
n_q = min(m, n)
|
|
212
|
+
m_r = n_q
|
|
213
|
+
n_r = n
|
|
214
|
+
|
|
215
|
+
# svd
|
|
216
|
+
m_u = m
|
|
217
|
+
n_u = min(m, n)
|
|
218
|
+
n_s = n_q
|
|
219
|
+
m_vh = n_q
|
|
220
|
+
n_vh = n
|
|
221
|
+
d_vh = max(m_vh, n_vh) # full matrix returned
|
|
222
|
+
|
|
223
|
+
if error_type is None:
|
|
224
|
+
# test QR
|
|
225
|
+
q, r = tsqr(data)
|
|
226
|
+
q = q.compute() # because uncertainty
|
|
227
|
+
r = r.compute()
|
|
228
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
229
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
230
|
+
assert_eq(mat, np.dot(q, r)) # accuracy check
|
|
231
|
+
assert_eq(np.eye(n_q, n_q), np.dot(q.T, q)) # q must be orthonormal
|
|
232
|
+
assert_eq(r, np.triu(r)) # r must be upper triangular
|
|
233
|
+
|
|
234
|
+
# test SVD
|
|
235
|
+
u, s, vh = tsqr(data, compute_svd=True)
|
|
236
|
+
u = u.compute() # because uncertainty
|
|
237
|
+
s = s.compute()
|
|
238
|
+
vh = vh.compute()
|
|
239
|
+
s_exact = np.linalg.svd(mat)[1]
|
|
240
|
+
assert_eq(s, s_exact) # s must contain the singular values
|
|
241
|
+
assert_eq((m_u, n_u), u.shape) # shape check
|
|
242
|
+
assert_eq((n_s,), s.shape) # shape check
|
|
243
|
+
assert_eq((d_vh, d_vh), vh.shape) # shape check
|
|
244
|
+
assert_eq(np.eye(n_u, n_u), np.dot(u.T, u)) # u must be orthonormal
|
|
245
|
+
assert_eq(np.eye(d_vh, d_vh), np.dot(vh, vh.T)) # vh must be orthonormal
|
|
246
|
+
assert_eq(mat, np.dot(np.dot(u, np.diag(s)), vh[:n_q])) # accuracy check
|
|
247
|
+
else:
|
|
248
|
+
with pytest.raises(error_type):
|
|
249
|
+
q, r = tsqr(data)
|
|
250
|
+
with pytest.raises(error_type):
|
|
251
|
+
u, s, vh = tsqr(data, compute_svd=True)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def test_tsqr_zero_height_chunks():
|
|
255
|
+
m_q = 10
|
|
256
|
+
n_q = 5
|
|
257
|
+
m_r = 5
|
|
258
|
+
n_r = 5
|
|
259
|
+
|
|
260
|
+
# certainty
|
|
261
|
+
mat = np.random.default_rng().random((10, 5))
|
|
262
|
+
x = da.from_array(mat, chunks=((4, 0, 1, 0, 5), (5,)))
|
|
263
|
+
q, r = da.linalg.qr(x)
|
|
264
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
265
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
266
|
+
assert_eq(mat, da.dot(q, r)) # accuracy check
|
|
267
|
+
assert_eq(np.eye(n_q, n_q), da.dot(q.T, q)) # q must be orthonormal
|
|
268
|
+
assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
|
|
269
|
+
|
|
270
|
+
# uncertainty
|
|
271
|
+
mat2 = np.vstack([mat, -(np.ones((10, 5)))])
|
|
272
|
+
v2 = mat2[:, 0]
|
|
273
|
+
x2 = da.from_array(mat2, chunks=5)
|
|
274
|
+
c = da.from_array(v2, chunks=5)
|
|
275
|
+
x = x2[c >= 0, :] # remove the ones added above to yield mat
|
|
276
|
+
q, r = da.linalg.qr(x)
|
|
277
|
+
q = q.compute() # because uncertainty
|
|
278
|
+
r = r.compute()
|
|
279
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
280
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
281
|
+
assert_eq(mat, np.dot(q, r)) # accuracy check
|
|
282
|
+
assert_eq(np.eye(n_q, n_q), np.dot(q.T, q)) # q must be orthonormal
|
|
283
|
+
assert_eq(r, np.triu(r)) # r must be upper triangular
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@pytest.mark.parametrize(
|
|
287
|
+
"m,n,chunks,error_type",
|
|
288
|
+
[
|
|
289
|
+
(20, 10, 10, ValueError), # tall-skinny regular blocks
|
|
290
|
+
(20, 10, (3, 10), ValueError), # tall-skinny regular fat layers
|
|
291
|
+
(20, 10, ((8, 4, 8), 10), ValueError), # tall-skinny irregular fat layers
|
|
292
|
+
(
|
|
293
|
+
40,
|
|
294
|
+
10,
|
|
295
|
+
((15, 5, 5, 8, 7), 10),
|
|
296
|
+
ValueError,
|
|
297
|
+
), # tall-skinny non-uniform chunks (why?)
|
|
298
|
+
(
|
|
299
|
+
128,
|
|
300
|
+
2,
|
|
301
|
+
(16, 2),
|
|
302
|
+
ValueError,
|
|
303
|
+
), # tall-skinny regular thin layers; recursion_depth=1
|
|
304
|
+
(
|
|
305
|
+
129,
|
|
306
|
+
2,
|
|
307
|
+
(16, 2),
|
|
308
|
+
ValueError,
|
|
309
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
|
|
310
|
+
(
|
|
311
|
+
130,
|
|
312
|
+
2,
|
|
313
|
+
(16, 2),
|
|
314
|
+
ValueError,
|
|
315
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
316
|
+
(
|
|
317
|
+
131,
|
|
318
|
+
2,
|
|
319
|
+
(16, 2),
|
|
320
|
+
ValueError,
|
|
321
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
322
|
+
(
|
|
323
|
+
300,
|
|
324
|
+
10,
|
|
325
|
+
(40, 10),
|
|
326
|
+
ValueError,
|
|
327
|
+
), # tall-skinny regular thin layers; recursion_depth=2
|
|
328
|
+
(
|
|
329
|
+
300,
|
|
330
|
+
10,
|
|
331
|
+
(30, 10),
|
|
332
|
+
ValueError,
|
|
333
|
+
), # tall-skinny regular thin layers; recursion_depth=3
|
|
334
|
+
(
|
|
335
|
+
300,
|
|
336
|
+
10,
|
|
337
|
+
(20, 10),
|
|
338
|
+
ValueError,
|
|
339
|
+
), # tall-skinny regular thin layers; recursion_depth=4
|
|
340
|
+
(10, 5, 10, None), # single block tall
|
|
341
|
+
(5, 10, 10, None), # single block short
|
|
342
|
+
(10, 10, 10, None), # single block square
|
|
343
|
+
(10, 40, (10, 10), None), # short-fat regular blocks
|
|
344
|
+
(10, 40, (10, 15), None), # short-fat irregular blocks
|
|
345
|
+
(10, 40, (10, (15, 5, 5, 8, 7)), None), # short-fat non-uniform chunks (why?)
|
|
346
|
+
(20, 20, 10, ValueError), # 2x2 regular blocks
|
|
347
|
+
],
|
|
348
|
+
)
|
|
349
|
+
def test_sfqr(m, n, chunks, error_type):
|
|
350
|
+
mat = np.random.default_rng().random((m, n))
|
|
351
|
+
data = da.from_array(mat, chunks=chunks, name="A")
|
|
352
|
+
m_q = m
|
|
353
|
+
n_q = min(m, n)
|
|
354
|
+
m_r = n_q
|
|
355
|
+
n_r = n
|
|
356
|
+
m_qtq = n_q
|
|
357
|
+
|
|
358
|
+
if error_type is None:
|
|
359
|
+
q, r = sfqr(data)
|
|
360
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
361
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
362
|
+
assert_eq(mat, da.dot(q, r)) # accuracy check
|
|
363
|
+
assert_eq(np.eye(m_qtq, m_qtq), da.dot(q.T, q)) # q must be orthonormal
|
|
364
|
+
assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
|
|
365
|
+
else:
|
|
366
|
+
with pytest.raises(error_type):
|
|
367
|
+
q, r = sfqr(data)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@pytest.mark.parametrize(
|
|
371
|
+
"m,n,chunks,error_type",
|
|
372
|
+
[
|
|
373
|
+
(20, 10, 10, None), # tall-skinny regular blocks
|
|
374
|
+
(20, 10, (3, 10), None), # tall-skinny regular fat layers
|
|
375
|
+
(20, 10, ((8, 4, 8), 10), None), # tall-skinny irregular fat layers
|
|
376
|
+
(40, 10, ((15, 5, 5, 8, 7), 10), None), # tall-skinny non-uniform chunks (why?)
|
|
377
|
+
(128, 2, (16, 2), None), # tall-skinny regular thin layers; recursion_depth=1
|
|
378
|
+
(
|
|
379
|
+
129,
|
|
380
|
+
2,
|
|
381
|
+
(16, 2),
|
|
382
|
+
None,
|
|
383
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 17x2
|
|
384
|
+
(
|
|
385
|
+
130,
|
|
386
|
+
2,
|
|
387
|
+
(16, 2),
|
|
388
|
+
None,
|
|
389
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
390
|
+
(
|
|
391
|
+
131,
|
|
392
|
+
2,
|
|
393
|
+
(16, 2),
|
|
394
|
+
None,
|
|
395
|
+
), # tall-skinny regular thin layers; recursion_depth=2 --> 18x2 next
|
|
396
|
+
(300, 10, (40, 10), None), # tall-skinny regular thin layers; recursion_depth=2
|
|
397
|
+
(300, 10, (30, 10), None), # tall-skinny regular thin layers; recursion_depth=3
|
|
398
|
+
(300, 10, (20, 10), None), # tall-skinny regular thin layers; recursion_depth=4
|
|
399
|
+
(10, 5, 10, None), # single block tall
|
|
400
|
+
(5, 10, 10, None), # single block short
|
|
401
|
+
(10, 10, 10, None), # single block square
|
|
402
|
+
(10, 40, (10, 10), None), # short-fat regular blocks
|
|
403
|
+
(10, 40, (10, 15), None), # short-fat irregular blocks
|
|
404
|
+
(10, 40, (10, (15, 5, 5, 8, 7)), None), # short-fat non-uniform chunks (why?)
|
|
405
|
+
(20, 20, 10, NotImplementedError), # 2x2 regular blocks
|
|
406
|
+
],
|
|
407
|
+
)
|
|
408
|
+
def test_qr(m, n, chunks, error_type):
|
|
409
|
+
mat = np.random.default_rng().random((m, n))
|
|
410
|
+
data = da.from_array(mat, chunks=chunks, name="A")
|
|
411
|
+
m_q = m
|
|
412
|
+
n_q = min(m, n)
|
|
413
|
+
m_r = n_q
|
|
414
|
+
n_r = n
|
|
415
|
+
m_qtq = n_q
|
|
416
|
+
|
|
417
|
+
if error_type is None:
|
|
418
|
+
q, r = qr(data)
|
|
419
|
+
assert_eq((m_q, n_q), q.shape) # shape check
|
|
420
|
+
assert_eq((m_r, n_r), r.shape) # shape check
|
|
421
|
+
assert_eq(mat, da.dot(q, r)) # accuracy check
|
|
422
|
+
assert_eq(np.eye(m_qtq, m_qtq), da.dot(q.T, q)) # q must be orthonormal
|
|
423
|
+
assert_eq(r, da.triu(r.rechunk(r.shape[0]))) # r must be upper triangular
|
|
424
|
+
else:
|
|
425
|
+
with pytest.raises(error_type):
|
|
426
|
+
q, r = qr(data)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def test_linalg_consistent_names():
|
|
430
|
+
m, n = 20, 10
|
|
431
|
+
mat = np.random.default_rng().random((m, n))
|
|
432
|
+
data = da.from_array(mat, chunks=(10, n), name="A")
|
|
433
|
+
|
|
434
|
+
q1, r1 = qr(data)
|
|
435
|
+
q2, r2 = qr(data)
|
|
436
|
+
assert same_keys(q1, q2)
|
|
437
|
+
assert same_keys(r1, r2)
|
|
438
|
+
|
|
439
|
+
u1, s1, v1 = svd(data)
|
|
440
|
+
u2, s2, v2 = svd(data)
|
|
441
|
+
assert same_keys(u1, u2)
|
|
442
|
+
assert same_keys(s1, s2)
|
|
443
|
+
assert same_keys(v1, v2)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@pytest.mark.parametrize("m,n", [(10, 20), (15, 15), (20, 10)])
|
|
447
|
+
def test_dask_svd_self_consistent(m, n):
|
|
448
|
+
a = np.random.default_rng().random((m, n))
|
|
449
|
+
d_a = da.from_array(a, chunks=(3, n), name="A")
|
|
450
|
+
|
|
451
|
+
d_u, d_s, d_vt = da.linalg.svd(d_a)
|
|
452
|
+
u, s, vt = da.compute(d_u, d_s, d_vt)
|
|
453
|
+
|
|
454
|
+
for d_e, e in zip([d_u, d_s, d_vt], [u, s, vt]):
|
|
455
|
+
assert d_e.shape == e.shape
|
|
456
|
+
assert d_e.dtype == e.dtype
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def test_svd_compute_uv_false_returns_singular_values():
|
|
460
|
+
a = np.random.default_rng().random((20, 10))
|
|
461
|
+
d_a = da.from_array(a, chunks=(5, 10))
|
|
462
|
+
|
|
463
|
+
d_s = da.linalg.svd(d_a, compute_uv=False)
|
|
464
|
+
|
|
465
|
+
assert_eq(d_s, np.linalg.svd(a, full_matrices=False, compute_uv=False))
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def test_svd_compute_uv_false_ignores_full_matrices():
|
|
469
|
+
a = np.random.default_rng().random((20, 10))
|
|
470
|
+
d_a = da.from_array(a, chunks=(5, 10))
|
|
471
|
+
|
|
472
|
+
d_s = da.linalg.svd(d_a, full_matrices=True, compute_uv=False)
|
|
473
|
+
|
|
474
|
+
assert_eq(d_s, np.linalg.svd(a, full_matrices=True, compute_uv=False))
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def test_svd_full_matrices_not_supported():
|
|
478
|
+
a = da.ones((20, 10), chunks=(5, 10))
|
|
479
|
+
|
|
480
|
+
with pytest.raises(NotImplementedError, match="full_matrices=True"):
|
|
481
|
+
da.linalg.svd(a, full_matrices=True)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
@pytest.mark.parametrize("iterator", ["power", "QR"])
|
|
485
|
+
def test_svd_compressed_compute(iterator):
|
|
486
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
487
|
+
u, s, v = da.linalg.svd_compressed(x, k=2, iterator=iterator, n_power_iter=1, compute=True, seed=123)
|
|
488
|
+
uu, ss, vv = da.linalg.svd_compressed(x, k=2, iterator=iterator, n_power_iter=1, seed=123)
|
|
489
|
+
|
|
490
|
+
assert len(v.dask) < len(vv.dask)
|
|
491
|
+
assert_eq(v, vv)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
@pytest.mark.parametrize("iterator", [("power", 2), ("QR", 2)])
|
|
495
|
+
def test_svd_compressed(iterator):
|
|
496
|
+
m, n = 100, 50
|
|
497
|
+
r = 5
|
|
498
|
+
a = da.random.default_rng().random((m, n), chunks=(m, n))
|
|
499
|
+
|
|
500
|
+
# calculate approximation and true singular values
|
|
501
|
+
u, s, vt = svd_compressed(a, 2 * r, iterator=iterator[0], n_power_iter=iterator[1], seed=4321) # worst case
|
|
502
|
+
s_true = scipy.linalg.svd(a.compute(), compute_uv=False)
|
|
503
|
+
|
|
504
|
+
# compute the difference with original matrix
|
|
505
|
+
norm = scipy.linalg.norm((a - (u[:, :r] * s[:r]) @ vt[:r, :]).compute(), 2)
|
|
506
|
+
|
|
507
|
+
# ||a-a_hat||_2 <= (1+tol)s_{k+1}: based on eq. 1.10/1.11:
|
|
508
|
+
# Halko, Nathan, Per-Gunnar Martinsson, and Joel A. Tropp.
|
|
509
|
+
# "Finding structure with randomness: Probabilistic algorithms for constructing
|
|
510
|
+
# approximate matrix decompositions." SIAM review 53.2 (2011): 217-288.
|
|
511
|
+
frac = norm / s_true[r + 1] - 1
|
|
512
|
+
# Tolerance determined via simulation to be slightly above max norm of difference matrix in 10k samples.
|
|
513
|
+
# See https://github.com/dask/dask/pull/6799#issuecomment-726631175 for more details.
|
|
514
|
+
tol = 0.4
|
|
515
|
+
assert frac < tol
|
|
516
|
+
|
|
517
|
+
assert_eq(np.eye(r, r), da.dot(u[:, :r].T, u[:, :r])) # u must be orthonormal
|
|
518
|
+
assert_eq(np.eye(r, r), da.dot(vt[:r, :], vt[:r, :].T)) # v must be orthonormal
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@pytest.mark.parametrize("input_dtype, output_dtype", [(np.float32, np.float32), (np.float64, np.float64)])
|
|
522
|
+
def test_svd_compressed_dtype_preservation(input_dtype, output_dtype):
|
|
523
|
+
x = da.random.default_rng().random((50, 50), chunks=(50, 50)).astype(input_dtype)
|
|
524
|
+
u, s, vt = svd_compressed(x, 1, seed=4321)
|
|
525
|
+
assert u.dtype == s.dtype == vt.dtype == output_dtype
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@pytest.mark.parametrize("chunks", [(10, 50), (50, 10), (-1, -1)])
|
|
529
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
530
|
+
def test_svd_dtype_preservation(chunks, dtype):
|
|
531
|
+
x = da.random.default_rng().random((50, 50), chunks=chunks).astype(dtype)
|
|
532
|
+
u, s, v = svd(x)
|
|
533
|
+
assert u.dtype == s.dtype == v.dtype == dtype
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def test_svd_compressed_deterministic():
|
|
537
|
+
m, n = 30, 25
|
|
538
|
+
x = da.random.default_rng(1234).random(size=(m, n), chunks=(5, 5))
|
|
539
|
+
u, s, vt = svd_compressed(x, 3, seed=1234)
|
|
540
|
+
u2, s2, vt2 = svd_compressed(x, 3, seed=1234)
|
|
541
|
+
|
|
542
|
+
assert all(da.compute((u == u2).all(), (s == s2).all(), (vt == vt2).all()))
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
@pytest.mark.parametrize("m", [5, 10, 15, 20])
|
|
546
|
+
@pytest.mark.parametrize("n", [5, 10, 15, 20])
|
|
547
|
+
@pytest.mark.parametrize("k", [5])
|
|
548
|
+
@pytest.mark.parametrize("chunks", [(5, 10), (10, 5)])
|
|
549
|
+
def test_svd_compressed_shapes(m, n, k, chunks):
|
|
550
|
+
x = da.random.default_rng().random(size=(m, n), chunks=chunks)
|
|
551
|
+
u, s, v = svd_compressed(x, k, n_power_iter=1, compute=True, seed=1)
|
|
552
|
+
u, s, v = da.compute(u, s, v)
|
|
553
|
+
r = min(m, n, k)
|
|
554
|
+
assert u.shape == (m, r)
|
|
555
|
+
assert s.shape == (r,)
|
|
556
|
+
assert v.shape == (r, n)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _check_lu_result(p, l, u, A):
|
|
560
|
+
assert np.allclose(p.dot(l).dot(u), A)
|
|
561
|
+
|
|
562
|
+
# check triangulars
|
|
563
|
+
assert_eq(l, da.tril(l), check_graph=False)
|
|
564
|
+
assert_eq(u, da.triu(u), check_graph=False)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def test_lu_1():
|
|
568
|
+
A1 = np.array([[7, 3, -1, 2], [3, 8, 1, -4], [-1, 1, 4, -1], [2, -4, -1, 6]])
|
|
569
|
+
|
|
570
|
+
A2 = np.array(
|
|
571
|
+
[
|
|
572
|
+
[7, 0, 0, 0, 0, 0],
|
|
573
|
+
[0, 8, 0, 0, 0, 0],
|
|
574
|
+
[0, 0, 4, 0, 0, 0],
|
|
575
|
+
[0, 0, 0, 6, 0, 0],
|
|
576
|
+
[0, 0, 0, 0, 3, 0],
|
|
577
|
+
[0, 0, 0, 0, 0, 5],
|
|
578
|
+
]
|
|
579
|
+
)
|
|
580
|
+
# without shuffle
|
|
581
|
+
for A, chunk in zip([A1, A2], [2, 2]):
|
|
582
|
+
dA = da.from_array(A, chunks=(chunk, chunk))
|
|
583
|
+
p, l, u = scipy.linalg.lu(A)
|
|
584
|
+
dp, dl, du = da.linalg.lu(dA)
|
|
585
|
+
assert_eq(p, dp, check_graph=False)
|
|
586
|
+
assert_eq(l, dl, check_graph=False)
|
|
587
|
+
assert_eq(u, du, check_graph=False)
|
|
588
|
+
_check_lu_result(dp, dl, du, A)
|
|
589
|
+
|
|
590
|
+
A3 = np.array(
|
|
591
|
+
[
|
|
592
|
+
[7, 3, 2, 1, 4, 1],
|
|
593
|
+
[7, 11, 5, 2, 5, 2],
|
|
594
|
+
[21, 25, 16, 10, 16, 5],
|
|
595
|
+
[21, 41, 18, 13, 16, 11],
|
|
596
|
+
[14, 46, 23, 24, 21, 22],
|
|
597
|
+
[0, 56, 29, 17, 14, 8],
|
|
598
|
+
]
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# with shuffle
|
|
602
|
+
for A, chunk in zip([A3], [2]):
|
|
603
|
+
dA = da.from_array(A, chunks=(chunk, chunk))
|
|
604
|
+
p, l, u = scipy.linalg.lu(A)
|
|
605
|
+
dp, dl, du = da.linalg.lu(dA)
|
|
606
|
+
_check_lu_result(dp, dl, du, A)
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@pytest.mark.slow
|
|
610
|
+
@pytest.mark.parametrize("size", [10, 20, 30, 50])
|
|
611
|
+
@pytest.mark.filterwarnings("ignore:Increasing:dask.array.core.PerformanceWarning")
|
|
612
|
+
def test_lu_2(size):
|
|
613
|
+
rng = np.random.default_rng(10)
|
|
614
|
+
A = rng.integers(0, 10, (size, size))
|
|
615
|
+
|
|
616
|
+
dA = da.from_array(A, chunks=(5, 5))
|
|
617
|
+
dp, dl, du = da.linalg.lu(dA)
|
|
618
|
+
_check_lu_result(dp, dl, du, A)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
@pytest.mark.slow
|
|
622
|
+
@pytest.mark.parametrize("size", [50, 100, 200])
|
|
623
|
+
def test_lu_3(size):
|
|
624
|
+
rng = np.random.default_rng(10)
|
|
625
|
+
A = rng.integers(0, 10, (size, size))
|
|
626
|
+
|
|
627
|
+
dA = da.from_array(A, chunks=(25, 25))
|
|
628
|
+
dp, dl, du = da.linalg.lu(dA)
|
|
629
|
+
_check_lu_result(dp, dl, du, A)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def test_lu_errors():
|
|
633
|
+
rng = np.random.default_rng()
|
|
634
|
+
|
|
635
|
+
A = rng.integers(0, 11, (10, 10, 10))
|
|
636
|
+
dA = da.from_array(A, chunks=(5, 5, 5))
|
|
637
|
+
pytest.raises(ValueError, lambda: da.linalg.lu(dA))
|
|
638
|
+
|
|
639
|
+
A = rng.integers(0, 11, (10, 8))
|
|
640
|
+
dA = da.from_array(A, chunks=(5, 4))
|
|
641
|
+
pytest.raises(ValueError, lambda: da.linalg.lu(dA))
|
|
642
|
+
|
|
643
|
+
A = rng.integers(0, 11, (20, 20))
|
|
644
|
+
dA = da.from_array(A, chunks=(5, 4))
|
|
645
|
+
pytest.raises(ValueError, lambda: da.linalg.lu(dA))
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (70, 20)])
|
|
649
|
+
def test_solve_triangular_vector(shape, chunk):
|
|
650
|
+
rng = np.random.default_rng(1)
|
|
651
|
+
|
|
652
|
+
A = rng.integers(1, 11, (shape, shape))
|
|
653
|
+
b = rng.integers(1, 11, shape)
|
|
654
|
+
|
|
655
|
+
# upper
|
|
656
|
+
Au = np.triu(A)
|
|
657
|
+
dAu = da.from_array(Au, (chunk, chunk))
|
|
658
|
+
db = da.from_array(b, chunk)
|
|
659
|
+
res = da.linalg.solve_triangular(dAu, db)
|
|
660
|
+
assert_eq(res, scipy.linalg.solve_triangular(Au, b))
|
|
661
|
+
assert_eq(dAu.dot(res), b.astype(float), rtol=1e-4)
|
|
662
|
+
|
|
663
|
+
# lower
|
|
664
|
+
Al = np.tril(A)
|
|
665
|
+
dAl = da.from_array(Al, (chunk, chunk))
|
|
666
|
+
db = da.from_array(b, chunk)
|
|
667
|
+
res = da.linalg.solve_triangular(dAl, db, lower=True)
|
|
668
|
+
assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
|
|
669
|
+
assert_eq(dAl.dot(res), b.astype(float))
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (50, 20)])
|
|
673
|
+
def test_solve_triangular_matrix(shape, chunk):
|
|
674
|
+
rng = np.random.default_rng(1)
|
|
675
|
+
|
|
676
|
+
A = rng.integers(1, 10, (shape, shape))
|
|
677
|
+
b = rng.integers(1, 10, (shape, 5))
|
|
678
|
+
|
|
679
|
+
# upper
|
|
680
|
+
Au = np.triu(A)
|
|
681
|
+
dAu = da.from_array(Au, (chunk, chunk))
|
|
682
|
+
db = da.from_array(b, (chunk, 5))
|
|
683
|
+
res = da.linalg.solve_triangular(dAu, db)
|
|
684
|
+
assert_eq(res, scipy.linalg.solve_triangular(Au, b))
|
|
685
|
+
assert_eq(dAu.dot(res), b.astype(float))
|
|
686
|
+
|
|
687
|
+
# lower
|
|
688
|
+
Al = np.tril(A)
|
|
689
|
+
dAl = da.from_array(Al, (chunk, chunk))
|
|
690
|
+
db = da.from_array(b, (chunk, 5))
|
|
691
|
+
res = da.linalg.solve_triangular(dAl, db, lower=True)
|
|
692
|
+
assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
|
|
693
|
+
assert_eq(dAl.dot(res), b.astype(float))
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10), (50, 20)])
|
|
697
|
+
def test_solve_triangular_matrix2(shape, chunk):
|
|
698
|
+
rng = np.random.default_rng(1)
|
|
699
|
+
|
|
700
|
+
A = rng.integers(1, 10, (shape, shape))
|
|
701
|
+
b = rng.integers(1, 10, (shape, shape))
|
|
702
|
+
|
|
703
|
+
# upper
|
|
704
|
+
Au = np.triu(A)
|
|
705
|
+
dAu = da.from_array(Au, (chunk, chunk))
|
|
706
|
+
db = da.from_array(b, (chunk, chunk))
|
|
707
|
+
res = da.linalg.solve_triangular(dAu, db)
|
|
708
|
+
assert_eq(res, scipy.linalg.solve_triangular(Au, b))
|
|
709
|
+
assert_eq(dAu.dot(res), b.astype(float))
|
|
710
|
+
|
|
711
|
+
# lower
|
|
712
|
+
Al = np.tril(A)
|
|
713
|
+
dAl = da.from_array(Al, (chunk, chunk))
|
|
714
|
+
db = da.from_array(b, (chunk, chunk))
|
|
715
|
+
res = da.linalg.solve_triangular(dAl, db, lower=True)
|
|
716
|
+
assert_eq(res, scipy.linalg.solve_triangular(Al, b, lower=True))
|
|
717
|
+
assert_eq(dAl.dot(res), b.astype(float))
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def test_solve_triangular_errors():
|
|
721
|
+
A = np.random.default_rng().integers(0, 10, (10, 10, 10))
|
|
722
|
+
b = np.random.default_rng().integers(1, 10, 10)
|
|
723
|
+
dA = da.from_array(A, chunks=(5, 5, 5))
|
|
724
|
+
db = da.from_array(b, chunks=5)
|
|
725
|
+
pytest.raises(ValueError, lambda: da.linalg.solve_triangular(dA, db))
|
|
726
|
+
|
|
727
|
+
A = np.random.default_rng().integers(0, 10, (10, 10))
|
|
728
|
+
b = np.random.default_rng().integers(1, 10, 10)
|
|
729
|
+
dA = da.from_array(A, chunks=(3, 3))
|
|
730
|
+
db = da.from_array(b, chunks=5)
|
|
731
|
+
pytest.raises(ValueError, lambda: da.linalg.solve_triangular(dA, db))
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10)])
|
|
735
|
+
def test_solve(shape, chunk):
|
|
736
|
+
rng = np.random.default_rng(1)
|
|
737
|
+
|
|
738
|
+
A = rng.integers(1, 10, (shape, shape))
|
|
739
|
+
dA = da.from_array(A, (chunk, chunk))
|
|
740
|
+
|
|
741
|
+
# vector
|
|
742
|
+
b = rng.integers(1, 10, shape)
|
|
743
|
+
db = da.from_array(b, chunk)
|
|
744
|
+
|
|
745
|
+
res = da.linalg.solve(dA, db)
|
|
746
|
+
assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
|
|
747
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
748
|
+
|
|
749
|
+
# tall-and-skinny matrix
|
|
750
|
+
b = rng.integers(1, 10, (shape, 5))
|
|
751
|
+
db = da.from_array(b, (chunk, 5))
|
|
752
|
+
|
|
753
|
+
res = da.linalg.solve(dA, db)
|
|
754
|
+
assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
|
|
755
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
756
|
+
|
|
757
|
+
# matrix
|
|
758
|
+
b = rng.integers(1, 10, (shape, shape))
|
|
759
|
+
db = da.from_array(b, (chunk, chunk))
|
|
760
|
+
|
|
761
|
+
res = da.linalg.solve(dA, db)
|
|
762
|
+
assert_eq(res, scipy.linalg.solve(A, b), check_graph=False)
|
|
763
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (50, 10)])
|
|
767
|
+
def test_inv(shape, chunk):
|
|
768
|
+
rng = np.random.default_rng(1)
|
|
769
|
+
|
|
770
|
+
A = rng.integers(1, 10, (shape, shape))
|
|
771
|
+
dA = da.from_array(A, (chunk, chunk))
|
|
772
|
+
|
|
773
|
+
res = da.linalg.inv(dA)
|
|
774
|
+
assert_eq(res, scipy.linalg.inv(A), check_graph=False)
|
|
775
|
+
assert_eq(dA.dot(res), np.eye(shape, dtype=float), check_graph=False)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
def _get_symmat(size):
|
|
779
|
+
rng = np.random.default_rng(1)
|
|
780
|
+
A = rng.integers(1, 21, (size, size))
|
|
781
|
+
lA = np.tril(A)
|
|
782
|
+
return lA.dot(lA.T)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
# `sym_pos` kwarg was deprecated in scipy 1.9.0
|
|
786
|
+
# ref: https://github.com/dask/dask/issues/9335
|
|
787
|
+
def _scipy_linalg_solve(a, b, assume_a):
|
|
788
|
+
if Version(scipy.__version__) >= Version("1.9.0"):
|
|
789
|
+
return scipy.linalg.solve(a=a, b=b, assume_a=assume_a)
|
|
790
|
+
elif assume_a == "pos":
|
|
791
|
+
return scipy.linalg.solve(a=a, b=b, sym_pos=True)
|
|
792
|
+
else:
|
|
793
|
+
return scipy.linalg.solve(a=a, b=b)
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (30, 6)])
|
|
797
|
+
def test_solve_assume_a(shape, chunk):
|
|
798
|
+
rng = np.random.default_rng(1)
|
|
799
|
+
|
|
800
|
+
A = _get_symmat(shape)
|
|
801
|
+
dA = da.from_array(A, (chunk, chunk))
|
|
802
|
+
|
|
803
|
+
# vector
|
|
804
|
+
b = rng.integers(1, 10, shape)
|
|
805
|
+
db = da.from_array(b, chunk)
|
|
806
|
+
|
|
807
|
+
res = da.linalg.solve(dA, db, assume_a="pos")
|
|
808
|
+
assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
|
|
809
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
810
|
+
|
|
811
|
+
# tall-and-skinny matrix
|
|
812
|
+
b = rng.integers(1, 10, (shape, 5))
|
|
813
|
+
db = da.from_array(b, (chunk, 5))
|
|
814
|
+
|
|
815
|
+
res = da.linalg.solve(dA, db, assume_a="pos")
|
|
816
|
+
assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
|
|
817
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
818
|
+
|
|
819
|
+
# matrix
|
|
820
|
+
b = rng.integers(1, 10, (shape, shape))
|
|
821
|
+
db = da.from_array(b, (chunk, chunk))
|
|
822
|
+
|
|
823
|
+
res = da.linalg.solve(dA, db, assume_a="pos")
|
|
824
|
+
assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
|
|
825
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
826
|
+
|
|
827
|
+
with pytest.warns(FutureWarning, match="sym_pos keyword is deprecated"):
|
|
828
|
+
res = da.linalg.solve(dA, db, sym_pos=True)
|
|
829
|
+
assert_eq(res, _scipy_linalg_solve(A, b, assume_a="pos"), check_graph=False)
|
|
830
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
831
|
+
|
|
832
|
+
with pytest.warns(FutureWarning, match="sym_pos keyword is deprecated"):
|
|
833
|
+
res = da.linalg.solve(dA, db, sym_pos=False)
|
|
834
|
+
assert_eq(res, _scipy_linalg_solve(A, b, assume_a="gen"), check_graph=False)
|
|
835
|
+
assert_eq(dA.dot(res), b.astype(float), check_graph=False)
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
@pytest.mark.parametrize(("shape", "chunk"), [(20, 10), (12, 3), (30, 3), (30, 6)])
|
|
839
|
+
def test_cholesky(shape, chunk):
|
|
840
|
+
A = _get_symmat(shape)
|
|
841
|
+
dA = da.from_array(A, (chunk, chunk))
|
|
842
|
+
assert_eq(
|
|
843
|
+
da.linalg.cholesky(dA).compute(),
|
|
844
|
+
scipy.linalg.cholesky(A),
|
|
845
|
+
check_graph=False,
|
|
846
|
+
check_chunks=False,
|
|
847
|
+
)
|
|
848
|
+
assert_eq(
|
|
849
|
+
da.linalg.cholesky(dA, lower=True),
|
|
850
|
+
scipy.linalg.cholesky(A, lower=True),
|
|
851
|
+
check_graph=False,
|
|
852
|
+
check_chunks=False,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
@pytest.mark.parametrize("iscomplex", [False, True])
|
|
857
|
+
@pytest.mark.parametrize(("nrow", "ncol", "chunk"), [(20, 10, 5), (100, 10, 10)])
|
|
858
|
+
def test_lstsq(nrow, ncol, chunk, iscomplex):
|
|
859
|
+
rng = np.random.default_rng(1)
|
|
860
|
+
A = rng.integers(1, 20, (nrow, ncol))
|
|
861
|
+
b = rng.integers(1, 20, nrow)
|
|
862
|
+
if iscomplex:
|
|
863
|
+
A = A + 1.0j * rng.integers(1, 20, A.shape)
|
|
864
|
+
b = b + 1.0j * rng.integers(1, 20, b.shape)
|
|
865
|
+
|
|
866
|
+
dA = da.from_array(A, (chunk, ncol))
|
|
867
|
+
db = da.from_array(b, chunk)
|
|
868
|
+
|
|
869
|
+
x, r, rank, s = np.linalg.lstsq(A, b, rcond=-1)
|
|
870
|
+
dx, dr, drank, ds = da.linalg.lstsq(dA, db)
|
|
871
|
+
|
|
872
|
+
assert_eq(dx, x)
|
|
873
|
+
assert_eq(dr, r)
|
|
874
|
+
assert drank.compute() == rank
|
|
875
|
+
assert_eq(ds, s)
|
|
876
|
+
|
|
877
|
+
# reduce rank causes multicollinearity, only compare rank
|
|
878
|
+
A[:, 1] = A[:, 2]
|
|
879
|
+
dA = da.from_array(A, (chunk, ncol))
|
|
880
|
+
db = da.from_array(b, chunk)
|
|
881
|
+
x, r, rank, s = np.linalg.lstsq(A, b, rcond=np.finfo(np.double).eps * max(nrow, ncol))
|
|
882
|
+
assert rank == ncol - 1
|
|
883
|
+
dx, dr, drank, ds = da.linalg.lstsq(dA, db)
|
|
884
|
+
assert drank.compute() == rank
|
|
885
|
+
|
|
886
|
+
# 2D case
|
|
887
|
+
A = rng.integers(1, 20, (nrow, ncol))
|
|
888
|
+
b2D = rng.integers(1, 20, (nrow, ncol // 2))
|
|
889
|
+
if iscomplex:
|
|
890
|
+
A = A + 1.0j * rng.integers(1, 20, A.shape)
|
|
891
|
+
b2D = b2D + 1.0j * rng.integers(1, 20, b2D.shape)
|
|
892
|
+
dA = da.from_array(A, (chunk, ncol))
|
|
893
|
+
db2D = da.from_array(b2D, (chunk, ncol // 2))
|
|
894
|
+
x, r, rank, s = np.linalg.lstsq(A, b2D, rcond=-1)
|
|
895
|
+
dx, dr, drank, ds = da.linalg.lstsq(dA, db2D)
|
|
896
|
+
|
|
897
|
+
assert_eq(dx, x)
|
|
898
|
+
assert_eq(dr, r)
|
|
899
|
+
assert drank.compute() == rank
|
|
900
|
+
assert_eq(ds, s)
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def test_no_chunks_svd():
|
|
904
|
+
x = np.random.default_rng().random((100, 10))
|
|
905
|
+
u, s, v = np.linalg.svd(x, full_matrices=False)
|
|
906
|
+
|
|
907
|
+
for chunks in [((np.nan,) * 10, (10,)), ((np.nan,) * 10, (np.nan,))]:
|
|
908
|
+
dx = da.from_array(x, chunks=(10, 10))
|
|
909
|
+
dx._chunks = chunks
|
|
910
|
+
|
|
911
|
+
du, ds, dv = da.linalg.svd(dx)
|
|
912
|
+
|
|
913
|
+
assert_eq(s, ds)
|
|
914
|
+
assert_eq(u.dot(np.diag(s)).dot(v), du.dot(da.diag(ds)).dot(dv))
|
|
915
|
+
assert_eq(du.T.dot(du), np.eye(10))
|
|
916
|
+
assert_eq(dv.T.dot(dv), np.eye(10))
|
|
917
|
+
|
|
918
|
+
dx = da.from_array(x, chunks=(10, 10))
|
|
919
|
+
dx._chunks = ((np.nan,) * 10, (np.nan,))
|
|
920
|
+
assert_eq(abs(v), abs(dv))
|
|
921
|
+
assert_eq(abs(u), abs(du))
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
@pytest.mark.parametrize("shape", [(10, 20), (10, 10), (20, 10)])
|
|
925
|
+
@pytest.mark.parametrize("chunks", [(-1, -1), (10, -1), (-1, 10)])
|
|
926
|
+
@pytest.mark.parametrize("dtype", ["f4", "f8"])
|
|
927
|
+
def test_svd_flip_correction(shape, chunks, dtype):
|
|
928
|
+
# Verify that sign-corrected SVD results can still
|
|
929
|
+
# be used to reconstruct inputs
|
|
930
|
+
x = da.random.default_rng().random(size=shape, chunks=chunks).astype(dtype)
|
|
931
|
+
u, s, v = da.linalg.svd(x)
|
|
932
|
+
|
|
933
|
+
# Choose precision in evaluation based on float precision
|
|
934
|
+
decimal = 9 if np.dtype(dtype).itemsize > 4 else 6
|
|
935
|
+
|
|
936
|
+
# Validate w/ dask inputs
|
|
937
|
+
uf, vf = svd_flip(u, v)
|
|
938
|
+
assert uf.dtype == u.dtype
|
|
939
|
+
assert vf.dtype == v.dtype
|
|
940
|
+
np.testing.assert_almost_equal(np.asarray(np.dot(uf * s, vf)), x, decimal=decimal)
|
|
941
|
+
|
|
942
|
+
# Validate w/ numpy inputs
|
|
943
|
+
uc, vc = svd_flip(*da.compute(u, v))
|
|
944
|
+
assert uc.dtype == u.dtype
|
|
945
|
+
assert vc.dtype == v.dtype
|
|
946
|
+
np.testing.assert_almost_equal(np.asarray(np.dot(uc * s, vc)), x, decimal=decimal)
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "f16", "c8", "c16", "c32"])
|
|
950
|
+
@pytest.mark.parametrize("u_based", [True, False])
|
|
951
|
+
def test_svd_flip_sign(dtype, u_based):
|
|
952
|
+
try:
|
|
953
|
+
x = np.array(
|
|
954
|
+
[[1, -1, 1, -1], [1, -1, 1, -1], [-1, 1, 1, -1], [-1, 1, 1, -1]],
|
|
955
|
+
dtype=dtype,
|
|
956
|
+
)
|
|
957
|
+
except TypeError:
|
|
958
|
+
pytest.skip("128-bit floats not supported by NumPy")
|
|
959
|
+
u, v = svd_flip(x, x.T, u_based_decision=u_based)
|
|
960
|
+
assert u.dtype == x.dtype
|
|
961
|
+
assert v.dtype == x.dtype
|
|
962
|
+
# Verify that all singular vectors have same
|
|
963
|
+
# sign except for the last one (i.e. last column)
|
|
964
|
+
y = x.copy()
|
|
965
|
+
y[:, -1] *= y.dtype.type(-1)
|
|
966
|
+
assert_eq(u, y)
|
|
967
|
+
assert_eq(v, y.T)
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
@pytest.mark.parametrize("chunks", [(10, -1), (-1, 10), (9, -1), (-1, 9)])
|
|
971
|
+
@pytest.mark.parametrize("shape", [(10, 100), (100, 10), (10, 10)])
|
|
972
|
+
def test_svd_supported_array_shapes(chunks, shape):
|
|
973
|
+
# Test the following cases for tall-skinny, short-fat and square arrays:
|
|
974
|
+
# - no chunking
|
|
975
|
+
# - chunking that contradicts shape (e.g. a 10x100 array with 9x100 chunks)
|
|
976
|
+
# - chunking that aligns with shape (e.g. a 10x100 array with 10x9 chunks)
|
|
977
|
+
x = np.random.default_rng().random(shape)
|
|
978
|
+
dx = da.from_array(x, chunks=chunks)
|
|
979
|
+
|
|
980
|
+
du, ds, dv = da.linalg.svd(dx)
|
|
981
|
+
du, dv = da.compute(du, dv)
|
|
982
|
+
|
|
983
|
+
nu, ns, nv = np.linalg.svd(x, full_matrices=False)
|
|
984
|
+
|
|
985
|
+
# Correct signs before comparison
|
|
986
|
+
du, dv = svd_flip(du, dv)
|
|
987
|
+
nu, nv = svd_flip(nu, nv)
|
|
988
|
+
|
|
989
|
+
assert_eq(du, nu)
|
|
990
|
+
assert_eq(ds, ns)
|
|
991
|
+
assert_eq(dv, nv)
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
def test_svd_incompatible_chunking():
|
|
995
|
+
with pytest.raises(NotImplementedError, match="Array must be chunked in one dimension only"):
|
|
996
|
+
x = da.random.default_rng().random((10, 10), chunks=(5, 5))
|
|
997
|
+
da.linalg.svd(x)
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
@pytest.mark.parametrize("ndim", [0, 1, 3])
|
|
1001
|
+
def test_svd_incompatible_dimensions(ndim):
|
|
1002
|
+
with pytest.raises(ValueError, match="Array must be 2D"):
|
|
1003
|
+
x = da.random.default_rng().random((10,) * ndim, chunks=(-1,) * ndim)
|
|
1004
|
+
da.linalg.svd(x)
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
@pytest.mark.xfail(
|
|
1008
|
+
sys.platform == "darwin" and _np_version < (1, 22),
|
|
1009
|
+
reason="https://github.com/dask/dask/issues/7189",
|
|
1010
|
+
strict=False,
|
|
1011
|
+
)
|
|
1012
|
+
@pytest.mark.parametrize(
|
|
1013
|
+
"shape, chunks, axis",
|
|
1014
|
+
[[(5,), (2,), None], [(5,), (2,), 0], [(5,), (2,), (0,)], [(5, 6), (2, 2), None]],
|
|
1015
|
+
)
|
|
1016
|
+
@pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
|
|
1017
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1018
|
+
def test_norm_any_ndim(shape, chunks, axis, norm, keepdims):
|
|
1019
|
+
a = np.random.default_rng().random(shape)
|
|
1020
|
+
d = da.from_array(a, chunks=chunks)
|
|
1021
|
+
|
|
1022
|
+
a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
|
|
1023
|
+
d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|
|
1024
|
+
|
|
1025
|
+
assert_eq(a_r, d_r)
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
@pytest.mark.xfail(
|
|
1029
|
+
_np_version < (1, 23),
|
|
1030
|
+
reason="https://github.com/numpy/numpy/pull/17709",
|
|
1031
|
+
strict=False,
|
|
1032
|
+
)
|
|
1033
|
+
@pytest.mark.parametrize("precision", ["single", "double"])
|
|
1034
|
+
@pytest.mark.parametrize("isreal", [True, False])
|
|
1035
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1036
|
+
@pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
|
|
1037
|
+
def test_norm_any_prec(norm, keepdims, precision, isreal):
|
|
1038
|
+
shape, chunks, axis = (5,), (2,), None
|
|
1039
|
+
|
|
1040
|
+
precs_r = {"single": "float32", "double": "float64"}
|
|
1041
|
+
precs_c = {"single": "complex64", "double": "complex128"}
|
|
1042
|
+
|
|
1043
|
+
dtype = precs_r[precision] if isreal else precs_c[precision]
|
|
1044
|
+
|
|
1045
|
+
a = np.random.default_rng().random(shape).astype(dtype)
|
|
1046
|
+
d = da.from_array(a, chunks=chunks)
|
|
1047
|
+
d_a = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
|
|
1048
|
+
d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|
|
1049
|
+
|
|
1050
|
+
assert d_r.dtype == precs_r[precision]
|
|
1051
|
+
assert d_r.dtype == d_a.dtype
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
@pytest.mark.slow
|
|
1055
|
+
@pytest.mark.xfail(
|
|
1056
|
+
sys.platform == "darwin" and _np_version < (1, 22),
|
|
1057
|
+
reason="https://github.com/dask/dask/issues/7189",
|
|
1058
|
+
strict=False,
|
|
1059
|
+
)
|
|
1060
|
+
@pytest.mark.parametrize(
|
|
1061
|
+
"shape, chunks",
|
|
1062
|
+
[
|
|
1063
|
+
[(5,), (2,)],
|
|
1064
|
+
[(5, 3), (2, 2)],
|
|
1065
|
+
[(4, 5, 3), (2, 2, 2)],
|
|
1066
|
+
[(4, 5, 2, 3), (2, 2, 2, 2)],
|
|
1067
|
+
[(2, 5, 2, 4, 3), (2, 2, 2, 2, 2)],
|
|
1068
|
+
],
|
|
1069
|
+
)
|
|
1070
|
+
@pytest.mark.parametrize("norm", [None, 1, -1, np.inf, -np.inf])
|
|
1071
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1072
|
+
def test_norm_any_slice(shape, chunks, norm, keepdims):
|
|
1073
|
+
a = np.random.default_rng().random(shape)
|
|
1074
|
+
d = da.from_array(a, chunks=chunks)
|
|
1075
|
+
|
|
1076
|
+
for firstaxis in range(len(shape)):
|
|
1077
|
+
for secondaxis in range(len(shape)):
|
|
1078
|
+
if firstaxis != secondaxis:
|
|
1079
|
+
axis = (firstaxis, secondaxis)
|
|
1080
|
+
else:
|
|
1081
|
+
axis = firstaxis
|
|
1082
|
+
a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
|
|
1083
|
+
d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|
|
1084
|
+
assert_eq(a_r, d_r)
|
|
1085
|
+
|
|
1086
|
+
|
|
1087
|
+
@pytest.mark.parametrize("shape, chunks, axis", [[(5,), (2,), None], [(5,), (2,), 0], [(5,), (2,), (0,)]])
|
|
1088
|
+
@pytest.mark.parametrize("norm", [0, 2, -2, 0.5])
|
|
1089
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1090
|
+
def test_norm_1dim(shape, chunks, axis, norm, keepdims):
|
|
1091
|
+
a = np.random.default_rng().random(shape)
|
|
1092
|
+
d = da.from_array(a, chunks=chunks)
|
|
1093
|
+
|
|
1094
|
+
a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
|
|
1095
|
+
d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|
|
1096
|
+
assert_eq(a_r, d_r)
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
@pytest.mark.parametrize(
|
|
1100
|
+
"shape, chunks, axis",
|
|
1101
|
+
[[(5, 6), (2, 2), None], [(5, 6), (2, 2), (0, 1)], [(5, 6), (2, 2), (1, 0)]],
|
|
1102
|
+
)
|
|
1103
|
+
@pytest.mark.parametrize("norm", ["fro", "nuc", 2, -2])
|
|
1104
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1105
|
+
def test_norm_2dim(shape, chunks, axis, norm, keepdims):
|
|
1106
|
+
a = np.random.default_rng().random(shape)
|
|
1107
|
+
d = da.from_array(a, chunks=chunks)
|
|
1108
|
+
|
|
1109
|
+
# Need one chunk on last dimension for svd.
|
|
1110
|
+
if norm == "nuc" or norm == 2 or norm == -2:
|
|
1111
|
+
d = d.rechunk({-1: -1})
|
|
1112
|
+
|
|
1113
|
+
a_r = np.linalg.norm(a, ord=norm, axis=axis, keepdims=keepdims)
|
|
1114
|
+
d_r = da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|
|
1115
|
+
|
|
1116
|
+
assert_eq(a_r, d_r)
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
@pytest.mark.parametrize(
|
|
1120
|
+
"shape, chunks, axis",
|
|
1121
|
+
[[(3, 2, 4), (2, 2, 2), (1, 2)], [(2, 3, 4, 5), (2, 2, 2, 2), (-1, -2)]],
|
|
1122
|
+
)
|
|
1123
|
+
@pytest.mark.parametrize("norm", ["nuc", 2, -2])
|
|
1124
|
+
@pytest.mark.parametrize("keepdims", [False, True])
|
|
1125
|
+
def test_norm_implemented_errors(shape, chunks, axis, norm, keepdims):
|
|
1126
|
+
a = np.random.default_rng().random(shape)
|
|
1127
|
+
d = da.from_array(a, chunks=chunks)
|
|
1128
|
+
if len(shape) > 2 and len(axis) == 2:
|
|
1129
|
+
with pytest.raises(NotImplementedError):
|
|
1130
|
+
da.linalg.norm(d, ord=norm, axis=axis, keepdims=keepdims)
|