dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/linalg/_qr.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
"""QR decomposition for array-expr.
|
|
2
|
+
|
|
3
|
+
Implements TSQR (Tall-and-Skinny QR) and SFQR (Short-and-Fat QR).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import operator
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from dask_array._new_collection import new_collection
|
|
14
|
+
from dask._task_spec import List, Task, TaskRef
|
|
15
|
+
from dask_array._expr import ArrayExpr
|
|
16
|
+
from dask_array.linalg._utils import (
|
|
17
|
+
_cumsum_blocks,
|
|
18
|
+
_cumsum_part,
|
|
19
|
+
_get_block_size,
|
|
20
|
+
_get_n,
|
|
21
|
+
_getitem_with_slice,
|
|
22
|
+
_has_uncertain_chunks,
|
|
23
|
+
_make_slice,
|
|
24
|
+
_nanmin,
|
|
25
|
+
)
|
|
26
|
+
from dask_array._utils import meta_from_array
|
|
27
|
+
from dask.utils import derived_from
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _wrapped_qr(a):
|
|
31
|
+
"""A wrapper for np.linalg.qr that handles arrays with 0 rows."""
|
|
32
|
+
if a.shape[0] == 0:
|
|
33
|
+
return np.zeros_like(a, shape=(0, 0)), np.zeros_like(a, shape=(0, a.shape[1]))
|
|
34
|
+
else:
|
|
35
|
+
return np.linalg.qr(a)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class QRBlock(ArrayExpr):
|
|
39
|
+
"""Block-wise QR decomposition of input blocks.
|
|
40
|
+
|
|
41
|
+
This is the first step of TSQR - apply QR to each block independently.
|
|
42
|
+
Returns tuple of (Q, R) for each input block.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_parameters = ["array"]
|
|
46
|
+
|
|
47
|
+
@functools.cached_property
|
|
48
|
+
def _meta(self):
|
|
49
|
+
qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.array.dtype))
|
|
50
|
+
return (
|
|
51
|
+
meta_from_array(self.array._meta, ndim=2, dtype=qq.dtype),
|
|
52
|
+
meta_from_array(self.array._meta, ndim=2, dtype=rr.dtype),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@functools.cached_property
|
|
56
|
+
def chunks(self):
|
|
57
|
+
return self.array.chunks
|
|
58
|
+
|
|
59
|
+
@functools.cached_property
|
|
60
|
+
def _name(self):
|
|
61
|
+
return f"qr-block-{self.deterministic_token}"
|
|
62
|
+
|
|
63
|
+
def _layer(self):
|
|
64
|
+
dsk = {}
|
|
65
|
+
numblocks = self.array.numblocks
|
|
66
|
+
for i in range(numblocks[0]):
|
|
67
|
+
out_key = (self._name, i, 0)
|
|
68
|
+
in_key = (self.array.name, i, 0)
|
|
69
|
+
dsk[out_key] = Task(out_key, _wrapped_qr, TaskRef(in_key))
|
|
70
|
+
return dsk
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class QRBlockQ(ArrayExpr):
|
|
74
|
+
"""Extract Q matrices from block-wise QR."""
|
|
75
|
+
|
|
76
|
+
_parameters = ["qr_block"]
|
|
77
|
+
|
|
78
|
+
@functools.cached_property
|
|
79
|
+
def _meta(self):
|
|
80
|
+
return self.qr_block._meta[0]
|
|
81
|
+
|
|
82
|
+
@functools.cached_property
|
|
83
|
+
def chunks(self):
|
|
84
|
+
data = self.qr_block.array
|
|
85
|
+
n = data.shape[1]
|
|
86
|
+
if _has_uncertain_chunks(data.chunks):
|
|
87
|
+
return (data.chunks[0], (np.nan,))
|
|
88
|
+
q_widths = tuple(_nanmin(m, n) for m in data.chunks[0])
|
|
89
|
+
q_width = min(q_widths)
|
|
90
|
+
return (data.chunks[0], (q_width,))
|
|
91
|
+
|
|
92
|
+
@functools.cached_property
|
|
93
|
+
def _name(self):
|
|
94
|
+
return f"qr-q-{self.deterministic_token}"
|
|
95
|
+
|
|
96
|
+
def _layer(self):
|
|
97
|
+
dsk = {}
|
|
98
|
+
numblocks = self.qr_block.array.numblocks
|
|
99
|
+
for i in range(numblocks[0]):
|
|
100
|
+
out_key = (self._name, i, 0)
|
|
101
|
+
in_key = (self.qr_block._name, i, 0)
|
|
102
|
+
dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), 0)
|
|
103
|
+
return dsk
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class QRBlockR(ArrayExpr):
|
|
107
|
+
"""Extract R matrices from block-wise QR.
|
|
108
|
+
|
|
109
|
+
Each R has shape (min(m_i, n), n) where m_i is the block height.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
_parameters = ["qr_block"]
|
|
113
|
+
|
|
114
|
+
@functools.cached_property
|
|
115
|
+
def _meta(self):
|
|
116
|
+
return self.qr_block._meta[1]
|
|
117
|
+
|
|
118
|
+
@functools.cached_property
|
|
119
|
+
def chunks(self):
|
|
120
|
+
data = self.qr_block.array
|
|
121
|
+
n = data.shape[1]
|
|
122
|
+
if _has_uncertain_chunks(data.chunks):
|
|
123
|
+
return (tuple(np.nan for _ in data.chunks[0]), (n,))
|
|
124
|
+
r_heights = tuple(_nanmin(m, n) for m in data.chunks[0])
|
|
125
|
+
return (r_heights, (n,))
|
|
126
|
+
|
|
127
|
+
@functools.cached_property
|
|
128
|
+
def _name(self):
|
|
129
|
+
return f"qr-r-{self.deterministic_token}"
|
|
130
|
+
|
|
131
|
+
def _layer(self):
|
|
132
|
+
dsk = {}
|
|
133
|
+
numblocks = self.qr_block.array.numblocks
|
|
134
|
+
for i in range(numblocks[0]):
|
|
135
|
+
out_key = (self._name, i, 0)
|
|
136
|
+
in_key = (self.qr_block._name, i, 0)
|
|
137
|
+
dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), 1)
|
|
138
|
+
return dsk
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class StackRFactors(ArrayExpr):
|
|
142
|
+
"""Stack R factors from block QR into a single tall matrix.
|
|
143
|
+
|
|
144
|
+
Used for in-core QR when recursion is not needed.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
_parameters = ["r_blocks"]
|
|
148
|
+
|
|
149
|
+
@functools.cached_property
|
|
150
|
+
def _meta(self):
|
|
151
|
+
return self.r_blocks._meta
|
|
152
|
+
|
|
153
|
+
@functools.cached_property
|
|
154
|
+
def chunks(self):
|
|
155
|
+
n = self.r_blocks.chunks[1][0]
|
|
156
|
+
if _has_uncertain_chunks(self.r_blocks.chunks):
|
|
157
|
+
return ((np.nan,), (n,))
|
|
158
|
+
total_height = sum(self.r_blocks.chunks[0])
|
|
159
|
+
return ((total_height,), (n,))
|
|
160
|
+
|
|
161
|
+
@functools.cached_property
|
|
162
|
+
def _name(self):
|
|
163
|
+
return f"stack-r-{self.deterministic_token}"
|
|
164
|
+
|
|
165
|
+
def _layer(self):
|
|
166
|
+
numblocks = self.r_blocks.numblocks[0]
|
|
167
|
+
out_key = (self._name, 0, 0)
|
|
168
|
+
refs = List(*(TaskRef((self.r_blocks._name, i, 0)) for i in range(numblocks)))
|
|
169
|
+
dsk = {out_key: Task(out_key, np.vstack, refs)}
|
|
170
|
+
return dsk
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class InCoreQR(ArrayExpr):
|
|
174
|
+
"""In-core QR decomposition of stacked R factors.
|
|
175
|
+
|
|
176
|
+
Returns (Q, R) tuple as a single task.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
_parameters = ["stacked_r"]
|
|
180
|
+
|
|
181
|
+
@functools.cached_property
|
|
182
|
+
def _meta(self):
|
|
183
|
+
qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.stacked_r.dtype))
|
|
184
|
+
return (
|
|
185
|
+
meta_from_array(self.stacked_r._meta, ndim=2, dtype=qq.dtype),
|
|
186
|
+
meta_from_array(self.stacked_r._meta, ndim=2, dtype=rr.dtype),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@functools.cached_property
|
|
190
|
+
def chunks(self):
|
|
191
|
+
return self.stacked_r.chunks
|
|
192
|
+
|
|
193
|
+
@functools.cached_property
|
|
194
|
+
def _name(self):
|
|
195
|
+
return f"qr-core-{self.deterministic_token}"
|
|
196
|
+
|
|
197
|
+
def _layer(self):
|
|
198
|
+
out_key = (self._name, 0, 0)
|
|
199
|
+
in_key = (self.stacked_r._name, 0, 0)
|
|
200
|
+
dsk = {out_key: Task(out_key, np.linalg.qr, TaskRef(in_key))}
|
|
201
|
+
return dsk
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class InCoreQRQ(ArrayExpr):
|
|
205
|
+
"""Extract Q from in-core QR."""
|
|
206
|
+
|
|
207
|
+
_parameters = ["incore_qr", "original_r_chunks"]
|
|
208
|
+
|
|
209
|
+
@functools.cached_property
|
|
210
|
+
def _meta(self):
|
|
211
|
+
return self.incore_qr._meta[0]
|
|
212
|
+
|
|
213
|
+
@functools.cached_property
|
|
214
|
+
def chunks(self):
|
|
215
|
+
n = self.incore_qr.stacked_r.shape[1]
|
|
216
|
+
if any(np.isnan(c) for c in self.original_r_chunks):
|
|
217
|
+
return ((np.nan,), (np.nan,))
|
|
218
|
+
total_height = sum(self.original_r_chunks)
|
|
219
|
+
q_width = _nanmin(total_height, n)
|
|
220
|
+
return ((total_height,), (q_width,))
|
|
221
|
+
|
|
222
|
+
@functools.cached_property
|
|
223
|
+
def _name(self):
|
|
224
|
+
return f"qr-core-q-{self.deterministic_token}"
|
|
225
|
+
|
|
226
|
+
def _layer(self):
|
|
227
|
+
out_key = (self._name, 0, 0)
|
|
228
|
+
in_key = (self.incore_qr._name, 0, 0)
|
|
229
|
+
dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 0)}
|
|
230
|
+
return dsk
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class InCoreQRR(ArrayExpr):
|
|
234
|
+
"""Extract R from in-core QR."""
|
|
235
|
+
|
|
236
|
+
_parameters = ["incore_qr"]
|
|
237
|
+
|
|
238
|
+
@functools.cached_property
|
|
239
|
+
def _meta(self):
|
|
240
|
+
return self.incore_qr._meta[1]
|
|
241
|
+
|
|
242
|
+
@functools.cached_property
|
|
243
|
+
def chunks(self):
|
|
244
|
+
m = self.incore_qr.stacked_r.shape[0]
|
|
245
|
+
n = self.incore_qr.stacked_r.shape[1]
|
|
246
|
+
r_height = _nanmin(m, n)
|
|
247
|
+
return ((r_height,), (n,))
|
|
248
|
+
|
|
249
|
+
@functools.cached_property
|
|
250
|
+
def _name(self):
|
|
251
|
+
return f"qr-core-r-{self.deterministic_token}"
|
|
252
|
+
|
|
253
|
+
def _layer(self):
|
|
254
|
+
out_key = (self._name, 0, 0)
|
|
255
|
+
in_key = (self.incore_qr._name, 0, 0)
|
|
256
|
+
dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 1)}
|
|
257
|
+
return dsk
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class UnstackQInner(ArrayExpr):
|
|
261
|
+
"""Unstack Q from in-core QR back to blocks matching original R block sizes.
|
|
262
|
+
|
|
263
|
+
When chunk sizes are uncertain (NaN), we compute slice indices at runtime
|
|
264
|
+
by querying the actual shapes of the input data blocks.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
_parameters = ["q_inner", "r_chunks", "data_expr"]
|
|
268
|
+
_defaults = {"data_expr": None}
|
|
269
|
+
|
|
270
|
+
@functools.cached_property
|
|
271
|
+
def _meta(self):
|
|
272
|
+
return self.q_inner._meta
|
|
273
|
+
|
|
274
|
+
@functools.cached_property
|
|
275
|
+
def chunks(self):
|
|
276
|
+
n = self.q_inner.chunks[1][0]
|
|
277
|
+
return (self.r_chunks, (n,))
|
|
278
|
+
|
|
279
|
+
@functools.cached_property
|
|
280
|
+
def _name(self):
|
|
281
|
+
return f"unstack-q-{self.deterministic_token}"
|
|
282
|
+
|
|
283
|
+
def _layer(self):
|
|
284
|
+
dsk = {}
|
|
285
|
+
block_sizes = self.r_chunks
|
|
286
|
+
numblocks = len(block_sizes)
|
|
287
|
+
has_uncertain = any(np.isnan(c) for c in block_sizes)
|
|
288
|
+
|
|
289
|
+
if not has_uncertain:
|
|
290
|
+
n = self.q_inner.chunks[1][0]
|
|
291
|
+
block_slices = list(_cumsum_blocks(block_sizes))
|
|
292
|
+
|
|
293
|
+
for i, (start, end) in enumerate(block_slices):
|
|
294
|
+
out_key = (self._name, i, 0)
|
|
295
|
+
in_key = (self.q_inner._name, 0, 0)
|
|
296
|
+
slc = (slice(start, end), slice(0, n))
|
|
297
|
+
dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), slc)
|
|
298
|
+
else:
|
|
299
|
+
data_name = self.data_expr._name
|
|
300
|
+
|
|
301
|
+
n_key = (self._name + "-n",)
|
|
302
|
+
dsk[n_key] = Task(n_key, _get_n, TaskRef((data_name, 0, 0)))
|
|
303
|
+
|
|
304
|
+
for i in range(numblocks):
|
|
305
|
+
bs_key = (self._name + "-bs", i)
|
|
306
|
+
dsk[bs_key] = Task(bs_key, _get_block_size, TaskRef((data_name, i, 0)))
|
|
307
|
+
|
|
308
|
+
cs_key_0 = (self._name + "-cs", 0)
|
|
309
|
+
bs_key_0 = (self._name + "-bs", 0)
|
|
310
|
+
dsk[cs_key_0] = List(0, TaskRef(bs_key_0))
|
|
311
|
+
|
|
312
|
+
for i in range(1, numblocks):
|
|
313
|
+
cs_key = (self._name + "-cs", i)
|
|
314
|
+
cs_key_prev = (self._name + "-cs", i - 1)
|
|
315
|
+
bs_key = (self._name + "-bs", i)
|
|
316
|
+
dsk[cs_key] = Task(cs_key, _cumsum_part, TaskRef(cs_key_prev), TaskRef(bs_key))
|
|
317
|
+
|
|
318
|
+
for i in range(numblocks):
|
|
319
|
+
cs_key = (self._name + "-cs", i)
|
|
320
|
+
slice_key = (self._name + "-slice", i)
|
|
321
|
+
dsk[slice_key] = Task(slice_key, _make_slice, TaskRef(cs_key), TaskRef(n_key))
|
|
322
|
+
|
|
323
|
+
out_key = (self._name, i, 0)
|
|
324
|
+
in_key = (self.q_inner._name, 0, 0)
|
|
325
|
+
dsk[out_key] = Task(out_key, _getitem_with_slice, TaskRef(in_key), TaskRef(slice_key))
|
|
326
|
+
|
|
327
|
+
return dsk
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class BlockDot(ArrayExpr):
|
|
331
|
+
"""Block-wise dot product: Q_block @ Q_inner_block.
|
|
332
|
+
|
|
333
|
+
Final Q = Q1 @ Q_inner (block-wise multiplication)
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
_parameters = ["q_blocks", "q_inner_unstacked"]
|
|
337
|
+
|
|
338
|
+
@functools.cached_property
|
|
339
|
+
def _meta(self):
|
|
340
|
+
return self.q_blocks._meta
|
|
341
|
+
|
|
342
|
+
@functools.cached_property
|
|
343
|
+
def chunks(self):
|
|
344
|
+
n = self.q_inner_unstacked.chunks[1][0]
|
|
345
|
+
return (self.q_blocks.chunks[0], (n,))
|
|
346
|
+
|
|
347
|
+
@functools.cached_property
|
|
348
|
+
def _name(self):
|
|
349
|
+
return f"block-dot-{self.deterministic_token}"
|
|
350
|
+
|
|
351
|
+
def _layer(self):
|
|
352
|
+
dsk = {}
|
|
353
|
+
numblocks = len(self.q_blocks.chunks[0])
|
|
354
|
+
for i in range(numblocks):
|
|
355
|
+
out_key = (self._name, i, 0)
|
|
356
|
+
q_key = (self.q_blocks._name, i, 0)
|
|
357
|
+
qi_key = (self.q_inner_unstacked._name, i, 0)
|
|
358
|
+
dsk[out_key] = Task(out_key, np.dot, TaskRef(q_key), TaskRef(qi_key))
|
|
359
|
+
return dsk
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def tsqr(data, compute_svd=False, _max_vchunk_size=None):
|
|
363
|
+
"""Direct Tall-and-Skinny QR algorithm.
|
|
364
|
+
|
|
365
|
+
As presented in:
|
|
366
|
+
A. Benson, D. Gleich, and J. Demmel.
|
|
367
|
+
Direct QR factorizations for tall-and-skinny matrices in
|
|
368
|
+
MapReduce architectures.
|
|
369
|
+
IEEE International Conference on Big Data, 2013.
|
|
370
|
+
https://arxiv.org/abs/1301.1071
|
|
371
|
+
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
data: Array
|
|
375
|
+
compute_svd: bool
|
|
376
|
+
Whether to compute the SVD rather than the QR decomposition
|
|
377
|
+
_max_vchunk_size: Integer
|
|
378
|
+
Used internally in recursion to set the maximum row dimension
|
|
379
|
+
of chunks in subsequent recursive calls.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
q, r : Array, Array
|
|
384
|
+
Q and R factors if compute_svd=False
|
|
385
|
+
u, s, vh : Array, Array, Array
|
|
386
|
+
SVD factors if compute_svd=True
|
|
387
|
+
"""
|
|
388
|
+
from dask_array.core import asanyarray
|
|
389
|
+
from dask_array.linalg._svd import _tsqr_svd
|
|
390
|
+
|
|
391
|
+
data = asanyarray(data)
|
|
392
|
+
expr = data.expr
|
|
393
|
+
|
|
394
|
+
_nr, nc = len(expr.chunks[0]), len(expr.chunks[1])
|
|
395
|
+
_cr_max, _cc = max(expr.chunks[0]), expr.chunks[1][0]
|
|
396
|
+
|
|
397
|
+
if not (expr.ndim == 2 and nc == 1):
|
|
398
|
+
raise ValueError(
|
|
399
|
+
"Input must have the following properties:\n"
|
|
400
|
+
" 1. Have two dimensions\n"
|
|
401
|
+
" 2. Have only one column of blocks\n\n"
|
|
402
|
+
"Note: This function (tsqr) supports QR decomposition in the case of\n"
|
|
403
|
+
"tall-and-skinny matrices (single column chunk/block; see qr)\n"
|
|
404
|
+
f"Current shape: {data.shape},\nCurrent chunksize: {data.chunksize}"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
m, n = data.shape
|
|
408
|
+
|
|
409
|
+
qr_block = QRBlock(expr)
|
|
410
|
+
q_blocks = QRBlockQ(qr_block)
|
|
411
|
+
r_blocks = QRBlockR(qr_block)
|
|
412
|
+
|
|
413
|
+
stacked_r = StackRFactors(r_blocks)
|
|
414
|
+
incore_qr = InCoreQR(stacked_r)
|
|
415
|
+
q_inner = InCoreQRQ(incore_qr, r_blocks.chunks[0])
|
|
416
|
+
r_final = InCoreQRR(incore_qr)
|
|
417
|
+
|
|
418
|
+
q_inner_unstacked = UnstackQInner(q_inner, r_blocks.chunks[0], expr)
|
|
419
|
+
|
|
420
|
+
q_final = BlockDot(q_blocks, q_inner_unstacked)
|
|
421
|
+
|
|
422
|
+
if not compute_svd:
|
|
423
|
+
return new_collection(q_final), new_collection(r_final)
|
|
424
|
+
else:
|
|
425
|
+
return _tsqr_svd(q_final, r_final, expr)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _qt_dot(qr_tuple, a):
|
|
429
|
+
"""Compute Q.T @ A where qr_tuple = (Q, R)."""
|
|
430
|
+
q, _ = qr_tuple
|
|
431
|
+
return np.dot(q.T, a)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class SFQR(ArrayExpr):
|
|
435
|
+
"""Short-and-Fat QR decomposition."""
|
|
436
|
+
|
|
437
|
+
_parameters = ["array"]
|
|
438
|
+
|
|
439
|
+
@functools.cached_property
|
|
440
|
+
def _meta(self):
|
|
441
|
+
qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.array.dtype))
|
|
442
|
+
return (
|
|
443
|
+
meta_from_array(self.array._meta, ndim=2, dtype=qq.dtype),
|
|
444
|
+
meta_from_array(self.array._meta, ndim=2, dtype=rr.dtype),
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
@functools.cached_property
|
|
448
|
+
def chunks(self):
|
|
449
|
+
return self.array.chunks
|
|
450
|
+
|
|
451
|
+
@functools.cached_property
|
|
452
|
+
def _name(self):
|
|
453
|
+
return f"sfqr-{self.deterministic_token}"
|
|
454
|
+
|
|
455
|
+
def _layer(self):
|
|
456
|
+
dsk = {}
|
|
457
|
+
out_key = (self._name, 0, 0)
|
|
458
|
+
in_key = (self.array._name, 0, 0)
|
|
459
|
+
dsk[out_key] = Task(out_key, np.linalg.qr, TaskRef(in_key))
|
|
460
|
+
return dsk
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class SFQRGetQ(ArrayExpr):
|
|
464
|
+
"""Extract Q from SFQR."""
|
|
465
|
+
|
|
466
|
+
_parameters = ["sfqr"]
|
|
467
|
+
|
|
468
|
+
@functools.cached_property
|
|
469
|
+
def _meta(self):
|
|
470
|
+
return self.sfqr._meta[0]
|
|
471
|
+
|
|
472
|
+
@functools.cached_property
|
|
473
|
+
def chunks(self):
|
|
474
|
+
m = self.sfqr.array.shape[0]
|
|
475
|
+
n = self.sfqr.array.shape[1]
|
|
476
|
+
q_width = min(m, n)
|
|
477
|
+
return ((m,), (q_width,))
|
|
478
|
+
|
|
479
|
+
@functools.cached_property
|
|
480
|
+
def _name(self):
|
|
481
|
+
return f"sfqr-q-{self.deterministic_token}"
|
|
482
|
+
|
|
483
|
+
def _layer(self):
|
|
484
|
+
out_key = (self._name, 0, 0)
|
|
485
|
+
in_key = (self.sfqr._name, 0, 0)
|
|
486
|
+
dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 0)}
|
|
487
|
+
return dsk
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
class SFQRGetR(ArrayExpr):
|
|
491
|
+
"""Extract R from SFQR and compute R for remaining blocks.
|
|
492
|
+
|
|
493
|
+
For SFQR: R_k = Q.T @ A_k for each block k
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
_parameters = ["sfqr"]
|
|
497
|
+
|
|
498
|
+
@functools.cached_property
|
|
499
|
+
def _meta(self):
|
|
500
|
+
return self.sfqr._meta[1]
|
|
501
|
+
|
|
502
|
+
@functools.cached_property
|
|
503
|
+
def chunks(self):
|
|
504
|
+
m = self.sfqr.array.shape[0]
|
|
505
|
+
n = self.sfqr.array.shape[1]
|
|
506
|
+
r_height = min(m, n)
|
|
507
|
+
return ((r_height,), self.sfqr.array.chunks[1])
|
|
508
|
+
|
|
509
|
+
@functools.cached_property
|
|
510
|
+
def _name(self):
|
|
511
|
+
return f"sfqr-r-{self.deterministic_token}"
|
|
512
|
+
|
|
513
|
+
def _layer(self):
|
|
514
|
+
dsk = {}
|
|
515
|
+
nc = len(self.sfqr.array.chunks[1])
|
|
516
|
+
qr_key = (self.sfqr._name, 0, 0)
|
|
517
|
+
|
|
518
|
+
out_key_0 = (self._name, 0, 0)
|
|
519
|
+
dsk[out_key_0] = Task(out_key_0, operator.getitem, TaskRef(qr_key), 1)
|
|
520
|
+
|
|
521
|
+
for j in range(1, nc):
|
|
522
|
+
out_key = (self._name, 0, j)
|
|
523
|
+
a_key = (self.sfqr.array._name, 0, j)
|
|
524
|
+
dsk[out_key] = Task(out_key, _qt_dot, TaskRef(qr_key), TaskRef(a_key))
|
|
525
|
+
|
|
526
|
+
return dsk
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def sfqr(data):
|
|
530
|
+
"""Direct Short-and-Fat QR.
|
|
531
|
+
|
|
532
|
+
For matrices that are one chunk tall and wider than they are tall.
|
|
533
|
+
|
|
534
|
+
Q [R_1 R_2 ...] = [A_1 A_2 ...]
|
|
535
|
+
"""
|
|
536
|
+
from dask_array.core import asanyarray
|
|
537
|
+
|
|
538
|
+
data = asanyarray(data)
|
|
539
|
+
expr = data.expr
|
|
540
|
+
|
|
541
|
+
nr, nc = len(expr.chunks[0]), len(expr.chunks[1])
|
|
542
|
+
cr, cc = expr.chunks[0][0], expr.chunks[1][0]
|
|
543
|
+
|
|
544
|
+
if not ((expr.ndim == 2) and (nr == 1) and ((cr <= cc) or (nc == 1))):
|
|
545
|
+
raise ValueError(
|
|
546
|
+
"Input must have the following properties:\n"
|
|
547
|
+
" 1. Have two dimensions\n"
|
|
548
|
+
" 2. Have only one row of blocks\n"
|
|
549
|
+
" 3. Either one column of blocks or chunk size on cols >= rows"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
sfqr_expr = SFQR(expr)
|
|
553
|
+
q = SFQRGetQ(sfqr_expr)
|
|
554
|
+
r = SFQRGetR(sfqr_expr)
|
|
555
|
+
|
|
556
|
+
return new_collection(q), new_collection(r)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
@derived_from(np.linalg)
|
|
560
|
+
def qr(a, mode="reduced"):
|
|
561
|
+
"""Compute the qr factorization of a matrix.
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
----------
|
|
565
|
+
a : Array
|
|
566
|
+
Input array
|
|
567
|
+
mode : {'reduced', 'r', 'raw', 'complete'}
|
|
568
|
+
Mode of factorization. Only 'reduced' is currently supported.
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
q, r : Array, Array
|
|
573
|
+
Q and R factors
|
|
574
|
+
"""
|
|
575
|
+
from dask_array.core import asanyarray
|
|
576
|
+
|
|
577
|
+
a = asanyarray(a)
|
|
578
|
+
|
|
579
|
+
if mode != "reduced":
|
|
580
|
+
raise NotImplementedError(f"qr mode '{mode}' is not implemented")
|
|
581
|
+
|
|
582
|
+
if a.ndim != 2:
|
|
583
|
+
raise ValueError("qr requires 2-D array")
|
|
584
|
+
|
|
585
|
+
m, n = a.shape
|
|
586
|
+
|
|
587
|
+
nr, nc = len(a.chunks[0]), len(a.chunks[1])
|
|
588
|
+
|
|
589
|
+
if nc == 1 and nr > 1:
|
|
590
|
+
return tsqr(a, compute_svd=False)
|
|
591
|
+
elif nr == 1:
|
|
592
|
+
return sfqr(a)
|
|
593
|
+
else:
|
|
594
|
+
raise NotImplementedError(
|
|
595
|
+
"qr currently supports only tall-and-skinny (single column chunk/block; see tsqr)\n"
|
|
596
|
+
"and short-and-fat (single row chunk/block; see sfqr) matrices\n\n"
|
|
597
|
+
"Consider use of the rechunk method. For example,\n\n"
|
|
598
|
+
"x.rechunk({0: -1, 1: 'auto'}) or x.rechunk({0: 'auto', 1: -1})\n\n"
|
|
599
|
+
"which rechunk one shorter axis to a single chunk, while allowing\n"
|
|
600
|
+
"the other axis to automatically grow/shrink appropriately."
|
|
601
|
+
)
|