dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,601 @@
1
+ """QR decomposition for array-expr.
2
+
3
+ Implements TSQR (Tall-and-Skinny QR) and SFQR (Short-and-Fat QR).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import functools
9
+ import operator
10
+
11
+ import numpy as np
12
+
13
+ from dask_array._new_collection import new_collection
14
+ from dask._task_spec import List, Task, TaskRef
15
+ from dask_array._expr import ArrayExpr
16
+ from dask_array.linalg._utils import (
17
+ _cumsum_blocks,
18
+ _cumsum_part,
19
+ _get_block_size,
20
+ _get_n,
21
+ _getitem_with_slice,
22
+ _has_uncertain_chunks,
23
+ _make_slice,
24
+ _nanmin,
25
+ )
26
+ from dask_array._utils import meta_from_array
27
+ from dask.utils import derived_from
28
+
29
+
30
+ def _wrapped_qr(a):
31
+ """A wrapper for np.linalg.qr that handles arrays with 0 rows."""
32
+ if a.shape[0] == 0:
33
+ return np.zeros_like(a, shape=(0, 0)), np.zeros_like(a, shape=(0, a.shape[1]))
34
+ else:
35
+ return np.linalg.qr(a)
36
+
37
+
38
+ class QRBlock(ArrayExpr):
39
+ """Block-wise QR decomposition of input blocks.
40
+
41
+ This is the first step of TSQR - apply QR to each block independently.
42
+ Returns tuple of (Q, R) for each input block.
43
+ """
44
+
45
+ _parameters = ["array"]
46
+
47
+ @functools.cached_property
48
+ def _meta(self):
49
+ qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.array.dtype))
50
+ return (
51
+ meta_from_array(self.array._meta, ndim=2, dtype=qq.dtype),
52
+ meta_from_array(self.array._meta, ndim=2, dtype=rr.dtype),
53
+ )
54
+
55
+ @functools.cached_property
56
+ def chunks(self):
57
+ return self.array.chunks
58
+
59
+ @functools.cached_property
60
+ def _name(self):
61
+ return f"qr-block-{self.deterministic_token}"
62
+
63
+ def _layer(self):
64
+ dsk = {}
65
+ numblocks = self.array.numblocks
66
+ for i in range(numblocks[0]):
67
+ out_key = (self._name, i, 0)
68
+ in_key = (self.array.name, i, 0)
69
+ dsk[out_key] = Task(out_key, _wrapped_qr, TaskRef(in_key))
70
+ return dsk
71
+
72
+
73
+ class QRBlockQ(ArrayExpr):
74
+ """Extract Q matrices from block-wise QR."""
75
+
76
+ _parameters = ["qr_block"]
77
+
78
+ @functools.cached_property
79
+ def _meta(self):
80
+ return self.qr_block._meta[0]
81
+
82
+ @functools.cached_property
83
+ def chunks(self):
84
+ data = self.qr_block.array
85
+ n = data.shape[1]
86
+ if _has_uncertain_chunks(data.chunks):
87
+ return (data.chunks[0], (np.nan,))
88
+ q_widths = tuple(_nanmin(m, n) for m in data.chunks[0])
89
+ q_width = min(q_widths)
90
+ return (data.chunks[0], (q_width,))
91
+
92
+ @functools.cached_property
93
+ def _name(self):
94
+ return f"qr-q-{self.deterministic_token}"
95
+
96
+ def _layer(self):
97
+ dsk = {}
98
+ numblocks = self.qr_block.array.numblocks
99
+ for i in range(numblocks[0]):
100
+ out_key = (self._name, i, 0)
101
+ in_key = (self.qr_block._name, i, 0)
102
+ dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), 0)
103
+ return dsk
104
+
105
+
106
+ class QRBlockR(ArrayExpr):
107
+ """Extract R matrices from block-wise QR.
108
+
109
+ Each R has shape (min(m_i, n), n) where m_i is the block height.
110
+ """
111
+
112
+ _parameters = ["qr_block"]
113
+
114
+ @functools.cached_property
115
+ def _meta(self):
116
+ return self.qr_block._meta[1]
117
+
118
+ @functools.cached_property
119
+ def chunks(self):
120
+ data = self.qr_block.array
121
+ n = data.shape[1]
122
+ if _has_uncertain_chunks(data.chunks):
123
+ return (tuple(np.nan for _ in data.chunks[0]), (n,))
124
+ r_heights = tuple(_nanmin(m, n) for m in data.chunks[0])
125
+ return (r_heights, (n,))
126
+
127
+ @functools.cached_property
128
+ def _name(self):
129
+ return f"qr-r-{self.deterministic_token}"
130
+
131
+ def _layer(self):
132
+ dsk = {}
133
+ numblocks = self.qr_block.array.numblocks
134
+ for i in range(numblocks[0]):
135
+ out_key = (self._name, i, 0)
136
+ in_key = (self.qr_block._name, i, 0)
137
+ dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), 1)
138
+ return dsk
139
+
140
+
141
+ class StackRFactors(ArrayExpr):
142
+ """Stack R factors from block QR into a single tall matrix.
143
+
144
+ Used for in-core QR when recursion is not needed.
145
+ """
146
+
147
+ _parameters = ["r_blocks"]
148
+
149
+ @functools.cached_property
150
+ def _meta(self):
151
+ return self.r_blocks._meta
152
+
153
+ @functools.cached_property
154
+ def chunks(self):
155
+ n = self.r_blocks.chunks[1][0]
156
+ if _has_uncertain_chunks(self.r_blocks.chunks):
157
+ return ((np.nan,), (n,))
158
+ total_height = sum(self.r_blocks.chunks[0])
159
+ return ((total_height,), (n,))
160
+
161
+ @functools.cached_property
162
+ def _name(self):
163
+ return f"stack-r-{self.deterministic_token}"
164
+
165
+ def _layer(self):
166
+ numblocks = self.r_blocks.numblocks[0]
167
+ out_key = (self._name, 0, 0)
168
+ refs = List(*(TaskRef((self.r_blocks._name, i, 0)) for i in range(numblocks)))
169
+ dsk = {out_key: Task(out_key, np.vstack, refs)}
170
+ return dsk
171
+
172
+
173
+ class InCoreQR(ArrayExpr):
174
+ """In-core QR decomposition of stacked R factors.
175
+
176
+ Returns (Q, R) tuple as a single task.
177
+ """
178
+
179
+ _parameters = ["stacked_r"]
180
+
181
+ @functools.cached_property
182
+ def _meta(self):
183
+ qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.stacked_r.dtype))
184
+ return (
185
+ meta_from_array(self.stacked_r._meta, ndim=2, dtype=qq.dtype),
186
+ meta_from_array(self.stacked_r._meta, ndim=2, dtype=rr.dtype),
187
+ )
188
+
189
+ @functools.cached_property
190
+ def chunks(self):
191
+ return self.stacked_r.chunks
192
+
193
+ @functools.cached_property
194
+ def _name(self):
195
+ return f"qr-core-{self.deterministic_token}"
196
+
197
+ def _layer(self):
198
+ out_key = (self._name, 0, 0)
199
+ in_key = (self.stacked_r._name, 0, 0)
200
+ dsk = {out_key: Task(out_key, np.linalg.qr, TaskRef(in_key))}
201
+ return dsk
202
+
203
+
204
+ class InCoreQRQ(ArrayExpr):
205
+ """Extract Q from in-core QR."""
206
+
207
+ _parameters = ["incore_qr", "original_r_chunks"]
208
+
209
+ @functools.cached_property
210
+ def _meta(self):
211
+ return self.incore_qr._meta[0]
212
+
213
+ @functools.cached_property
214
+ def chunks(self):
215
+ n = self.incore_qr.stacked_r.shape[1]
216
+ if any(np.isnan(c) for c in self.original_r_chunks):
217
+ return ((np.nan,), (np.nan,))
218
+ total_height = sum(self.original_r_chunks)
219
+ q_width = _nanmin(total_height, n)
220
+ return ((total_height,), (q_width,))
221
+
222
+ @functools.cached_property
223
+ def _name(self):
224
+ return f"qr-core-q-{self.deterministic_token}"
225
+
226
+ def _layer(self):
227
+ out_key = (self._name, 0, 0)
228
+ in_key = (self.incore_qr._name, 0, 0)
229
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 0)}
230
+ return dsk
231
+
232
+
233
+ class InCoreQRR(ArrayExpr):
234
+ """Extract R from in-core QR."""
235
+
236
+ _parameters = ["incore_qr"]
237
+
238
+ @functools.cached_property
239
+ def _meta(self):
240
+ return self.incore_qr._meta[1]
241
+
242
+ @functools.cached_property
243
+ def chunks(self):
244
+ m = self.incore_qr.stacked_r.shape[0]
245
+ n = self.incore_qr.stacked_r.shape[1]
246
+ r_height = _nanmin(m, n)
247
+ return ((r_height,), (n,))
248
+
249
+ @functools.cached_property
250
+ def _name(self):
251
+ return f"qr-core-r-{self.deterministic_token}"
252
+
253
+ def _layer(self):
254
+ out_key = (self._name, 0, 0)
255
+ in_key = (self.incore_qr._name, 0, 0)
256
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 1)}
257
+ return dsk
258
+
259
+
260
+ class UnstackQInner(ArrayExpr):
261
+ """Unstack Q from in-core QR back to blocks matching original R block sizes.
262
+
263
+ When chunk sizes are uncertain (NaN), we compute slice indices at runtime
264
+ by querying the actual shapes of the input data blocks.
265
+ """
266
+
267
+ _parameters = ["q_inner", "r_chunks", "data_expr"]
268
+ _defaults = {"data_expr": None}
269
+
270
+ @functools.cached_property
271
+ def _meta(self):
272
+ return self.q_inner._meta
273
+
274
+ @functools.cached_property
275
+ def chunks(self):
276
+ n = self.q_inner.chunks[1][0]
277
+ return (self.r_chunks, (n,))
278
+
279
+ @functools.cached_property
280
+ def _name(self):
281
+ return f"unstack-q-{self.deterministic_token}"
282
+
283
+ def _layer(self):
284
+ dsk = {}
285
+ block_sizes = self.r_chunks
286
+ numblocks = len(block_sizes)
287
+ has_uncertain = any(np.isnan(c) for c in block_sizes)
288
+
289
+ if not has_uncertain:
290
+ n = self.q_inner.chunks[1][0]
291
+ block_slices = list(_cumsum_blocks(block_sizes))
292
+
293
+ for i, (start, end) in enumerate(block_slices):
294
+ out_key = (self._name, i, 0)
295
+ in_key = (self.q_inner._name, 0, 0)
296
+ slc = (slice(start, end), slice(0, n))
297
+ dsk[out_key] = Task(out_key, operator.getitem, TaskRef(in_key), slc)
298
+ else:
299
+ data_name = self.data_expr._name
300
+
301
+ n_key = (self._name + "-n",)
302
+ dsk[n_key] = Task(n_key, _get_n, TaskRef((data_name, 0, 0)))
303
+
304
+ for i in range(numblocks):
305
+ bs_key = (self._name + "-bs", i)
306
+ dsk[bs_key] = Task(bs_key, _get_block_size, TaskRef((data_name, i, 0)))
307
+
308
+ cs_key_0 = (self._name + "-cs", 0)
309
+ bs_key_0 = (self._name + "-bs", 0)
310
+ dsk[cs_key_0] = List(0, TaskRef(bs_key_0))
311
+
312
+ for i in range(1, numblocks):
313
+ cs_key = (self._name + "-cs", i)
314
+ cs_key_prev = (self._name + "-cs", i - 1)
315
+ bs_key = (self._name + "-bs", i)
316
+ dsk[cs_key] = Task(cs_key, _cumsum_part, TaskRef(cs_key_prev), TaskRef(bs_key))
317
+
318
+ for i in range(numblocks):
319
+ cs_key = (self._name + "-cs", i)
320
+ slice_key = (self._name + "-slice", i)
321
+ dsk[slice_key] = Task(slice_key, _make_slice, TaskRef(cs_key), TaskRef(n_key))
322
+
323
+ out_key = (self._name, i, 0)
324
+ in_key = (self.q_inner._name, 0, 0)
325
+ dsk[out_key] = Task(out_key, _getitem_with_slice, TaskRef(in_key), TaskRef(slice_key))
326
+
327
+ return dsk
328
+
329
+
330
+ class BlockDot(ArrayExpr):
331
+ """Block-wise dot product: Q_block @ Q_inner_block.
332
+
333
+ Final Q = Q1 @ Q_inner (block-wise multiplication)
334
+ """
335
+
336
+ _parameters = ["q_blocks", "q_inner_unstacked"]
337
+
338
+ @functools.cached_property
339
+ def _meta(self):
340
+ return self.q_blocks._meta
341
+
342
+ @functools.cached_property
343
+ def chunks(self):
344
+ n = self.q_inner_unstacked.chunks[1][0]
345
+ return (self.q_blocks.chunks[0], (n,))
346
+
347
+ @functools.cached_property
348
+ def _name(self):
349
+ return f"block-dot-{self.deterministic_token}"
350
+
351
+ def _layer(self):
352
+ dsk = {}
353
+ numblocks = len(self.q_blocks.chunks[0])
354
+ for i in range(numblocks):
355
+ out_key = (self._name, i, 0)
356
+ q_key = (self.q_blocks._name, i, 0)
357
+ qi_key = (self.q_inner_unstacked._name, i, 0)
358
+ dsk[out_key] = Task(out_key, np.dot, TaskRef(q_key), TaskRef(qi_key))
359
+ return dsk
360
+
361
+
362
+ def tsqr(data, compute_svd=False, _max_vchunk_size=None):
363
+ """Direct Tall-and-Skinny QR algorithm.
364
+
365
+ As presented in:
366
+ A. Benson, D. Gleich, and J. Demmel.
367
+ Direct QR factorizations for tall-and-skinny matrices in
368
+ MapReduce architectures.
369
+ IEEE International Conference on Big Data, 2013.
370
+ https://arxiv.org/abs/1301.1071
371
+
372
+ Parameters
373
+ ----------
374
+ data: Array
375
+ compute_svd: bool
376
+ Whether to compute the SVD rather than the QR decomposition
377
+ _max_vchunk_size: Integer
378
+ Used internally in recursion to set the maximum row dimension
379
+ of chunks in subsequent recursive calls.
380
+
381
+ Returns
382
+ -------
383
+ q, r : Array, Array
384
+ Q and R factors if compute_svd=False
385
+ u, s, vh : Array, Array, Array
386
+ SVD factors if compute_svd=True
387
+ """
388
+ from dask_array.core import asanyarray
389
+ from dask_array.linalg._svd import _tsqr_svd
390
+
391
+ data = asanyarray(data)
392
+ expr = data.expr
393
+
394
+ _nr, nc = len(expr.chunks[0]), len(expr.chunks[1])
395
+ _cr_max, _cc = max(expr.chunks[0]), expr.chunks[1][0]
396
+
397
+ if not (expr.ndim == 2 and nc == 1):
398
+ raise ValueError(
399
+ "Input must have the following properties:\n"
400
+ " 1. Have two dimensions\n"
401
+ " 2. Have only one column of blocks\n\n"
402
+ "Note: This function (tsqr) supports QR decomposition in the case of\n"
403
+ "tall-and-skinny matrices (single column chunk/block; see qr)\n"
404
+ f"Current shape: {data.shape},\nCurrent chunksize: {data.chunksize}"
405
+ )
406
+
407
+ m, n = data.shape
408
+
409
+ qr_block = QRBlock(expr)
410
+ q_blocks = QRBlockQ(qr_block)
411
+ r_blocks = QRBlockR(qr_block)
412
+
413
+ stacked_r = StackRFactors(r_blocks)
414
+ incore_qr = InCoreQR(stacked_r)
415
+ q_inner = InCoreQRQ(incore_qr, r_blocks.chunks[0])
416
+ r_final = InCoreQRR(incore_qr)
417
+
418
+ q_inner_unstacked = UnstackQInner(q_inner, r_blocks.chunks[0], expr)
419
+
420
+ q_final = BlockDot(q_blocks, q_inner_unstacked)
421
+
422
+ if not compute_svd:
423
+ return new_collection(q_final), new_collection(r_final)
424
+ else:
425
+ return _tsqr_svd(q_final, r_final, expr)
426
+
427
+
428
+ def _qt_dot(qr_tuple, a):
429
+ """Compute Q.T @ A where qr_tuple = (Q, R)."""
430
+ q, _ = qr_tuple
431
+ return np.dot(q.T, a)
432
+
433
+
434
+ class SFQR(ArrayExpr):
435
+ """Short-and-Fat QR decomposition."""
436
+
437
+ _parameters = ["array"]
438
+
439
+ @functools.cached_property
440
+ def _meta(self):
441
+ qq, rr = np.linalg.qr(np.ones(shape=(1, 1), dtype=self.array.dtype))
442
+ return (
443
+ meta_from_array(self.array._meta, ndim=2, dtype=qq.dtype),
444
+ meta_from_array(self.array._meta, ndim=2, dtype=rr.dtype),
445
+ )
446
+
447
+ @functools.cached_property
448
+ def chunks(self):
449
+ return self.array.chunks
450
+
451
+ @functools.cached_property
452
+ def _name(self):
453
+ return f"sfqr-{self.deterministic_token}"
454
+
455
+ def _layer(self):
456
+ dsk = {}
457
+ out_key = (self._name, 0, 0)
458
+ in_key = (self.array._name, 0, 0)
459
+ dsk[out_key] = Task(out_key, np.linalg.qr, TaskRef(in_key))
460
+ return dsk
461
+
462
+
463
+ class SFQRGetQ(ArrayExpr):
464
+ """Extract Q from SFQR."""
465
+
466
+ _parameters = ["sfqr"]
467
+
468
+ @functools.cached_property
469
+ def _meta(self):
470
+ return self.sfqr._meta[0]
471
+
472
+ @functools.cached_property
473
+ def chunks(self):
474
+ m = self.sfqr.array.shape[0]
475
+ n = self.sfqr.array.shape[1]
476
+ q_width = min(m, n)
477
+ return ((m,), (q_width,))
478
+
479
+ @functools.cached_property
480
+ def _name(self):
481
+ return f"sfqr-q-{self.deterministic_token}"
482
+
483
+ def _layer(self):
484
+ out_key = (self._name, 0, 0)
485
+ in_key = (self.sfqr._name, 0, 0)
486
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 0)}
487
+ return dsk
488
+
489
+
490
+ class SFQRGetR(ArrayExpr):
491
+ """Extract R from SFQR and compute R for remaining blocks.
492
+
493
+ For SFQR: R_k = Q.T @ A_k for each block k
494
+ """
495
+
496
+ _parameters = ["sfqr"]
497
+
498
+ @functools.cached_property
499
+ def _meta(self):
500
+ return self.sfqr._meta[1]
501
+
502
+ @functools.cached_property
503
+ def chunks(self):
504
+ m = self.sfqr.array.shape[0]
505
+ n = self.sfqr.array.shape[1]
506
+ r_height = min(m, n)
507
+ return ((r_height,), self.sfqr.array.chunks[1])
508
+
509
+ @functools.cached_property
510
+ def _name(self):
511
+ return f"sfqr-r-{self.deterministic_token}"
512
+
513
+ def _layer(self):
514
+ dsk = {}
515
+ nc = len(self.sfqr.array.chunks[1])
516
+ qr_key = (self.sfqr._name, 0, 0)
517
+
518
+ out_key_0 = (self._name, 0, 0)
519
+ dsk[out_key_0] = Task(out_key_0, operator.getitem, TaskRef(qr_key), 1)
520
+
521
+ for j in range(1, nc):
522
+ out_key = (self._name, 0, j)
523
+ a_key = (self.sfqr.array._name, 0, j)
524
+ dsk[out_key] = Task(out_key, _qt_dot, TaskRef(qr_key), TaskRef(a_key))
525
+
526
+ return dsk
527
+
528
+
529
+ def sfqr(data):
530
+ """Direct Short-and-Fat QR.
531
+
532
+ For matrices that are one chunk tall and wider than they are tall.
533
+
534
+ Q [R_1 R_2 ...] = [A_1 A_2 ...]
535
+ """
536
+ from dask_array.core import asanyarray
537
+
538
+ data = asanyarray(data)
539
+ expr = data.expr
540
+
541
+ nr, nc = len(expr.chunks[0]), len(expr.chunks[1])
542
+ cr, cc = expr.chunks[0][0], expr.chunks[1][0]
543
+
544
+ if not ((expr.ndim == 2) and (nr == 1) and ((cr <= cc) or (nc == 1))):
545
+ raise ValueError(
546
+ "Input must have the following properties:\n"
547
+ " 1. Have two dimensions\n"
548
+ " 2. Have only one row of blocks\n"
549
+ " 3. Either one column of blocks or chunk size on cols >= rows"
550
+ )
551
+
552
+ sfqr_expr = SFQR(expr)
553
+ q = SFQRGetQ(sfqr_expr)
554
+ r = SFQRGetR(sfqr_expr)
555
+
556
+ return new_collection(q), new_collection(r)
557
+
558
+
559
+ @derived_from(np.linalg)
560
+ def qr(a, mode="reduced"):
561
+ """Compute the qr factorization of a matrix.
562
+
563
+ Parameters
564
+ ----------
565
+ a : Array
566
+ Input array
567
+ mode : {'reduced', 'r', 'raw', 'complete'}
568
+ Mode of factorization. Only 'reduced' is currently supported.
569
+
570
+ Returns
571
+ -------
572
+ q, r : Array, Array
573
+ Q and R factors
574
+ """
575
+ from dask_array.core import asanyarray
576
+
577
+ a = asanyarray(a)
578
+
579
+ if mode != "reduced":
580
+ raise NotImplementedError(f"qr mode '{mode}' is not implemented")
581
+
582
+ if a.ndim != 2:
583
+ raise ValueError("qr requires 2-D array")
584
+
585
+ m, n = a.shape
586
+
587
+ nr, nc = len(a.chunks[0]), len(a.chunks[1])
588
+
589
+ if nc == 1 and nr > 1:
590
+ return tsqr(a, compute_svd=False)
591
+ elif nr == 1:
592
+ return sfqr(a)
593
+ else:
594
+ raise NotImplementedError(
595
+ "qr currently supports only tall-and-skinny (single column chunk/block; see tsqr)\n"
596
+ "and short-and-fat (single row chunk/block; see sfqr) matrices\n\n"
597
+ "Consider use of the rechunk method. For example,\n\n"
598
+ "x.rechunk({0: -1, 1: 'auto'}) or x.rechunk({0: 'auto', 1: -1})\n\n"
599
+ "which rechunk one shorter axis to a single chunk, while allowing\n"
600
+ "the other axis to automatically grow/shrink appropriately."
601
+ )