dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,349 @@
1
+ """Solve linear systems for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import operator
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+ from dask_array._new_collection import new_collection
12
+ from dask._task_spec import List, Task, TaskRef
13
+ from dask_array._expr import ArrayExpr
14
+ from dask_array.linalg._utils import (
15
+ _solve_triangular_lower,
16
+ _solve_triangular_upper,
17
+ )
18
+
19
+
20
+ class SolveTriangular(ArrayExpr):
21
+ """Blocked triangular solve: solve a @ x = b for x.
22
+
23
+ Uses forward or backward substitution depending on `lower`.
24
+ """
25
+
26
+ _parameters = ["a", "b", "lower"]
27
+ _defaults = {"lower": False}
28
+
29
+ @functools.cached_property
30
+ def _meta(self):
31
+ from dask_array._utils import array_safe, meta_from_array
32
+
33
+ a_meta = meta_from_array(self.a._meta)
34
+ b_meta = meta_from_array(self.b._meta)
35
+ res = _solve_triangular_lower(
36
+ array_safe([[1, 0], [1, 2]], dtype=self.a.dtype, like=a_meta),
37
+ array_safe([0, 1], dtype=self.b.dtype, like=b_meta),
38
+ )
39
+ return meta_from_array(self.a._meta, self.b.ndim, dtype=res.dtype)
40
+
41
+ @functools.cached_property
42
+ def chunks(self):
43
+ return self.b.chunks
44
+
45
+ @functools.cached_property
46
+ def _name(self):
47
+ return f"solve-triangular-{self.deterministic_token}"
48
+
49
+ def _layer(self):
50
+ vchunks = len(self.a.chunks[1])
51
+ hchunks = 1 if self.b.ndim == 1 else len(self.b.chunks[1])
52
+ name_mdot = f"solve-tri-dot-{self.deterministic_token}"
53
+
54
+ def _b_init(i, j):
55
+ if self.b.ndim == 1:
56
+ return (self.b._name, i)
57
+ else:
58
+ return (self.b._name, i, j)
59
+
60
+ def _key(i, j):
61
+ if self.b.ndim == 1:
62
+ return (self._name, i)
63
+ else:
64
+ return (self._name, i, j)
65
+
66
+ dsk = {}
67
+ if self.lower:
68
+ for i in range(vchunks):
69
+ for j in range(hchunks):
70
+ target = TaskRef(_b_init(i, j))
71
+ if i > 0:
72
+ prevs = []
73
+ for k in range(i):
74
+ prev_key = (name_mdot, i, k, k, j)
75
+ dsk[prev_key] = Task(
76
+ prev_key,
77
+ np.dot,
78
+ TaskRef((self.a._name, i, k)),
79
+ TaskRef(_key(k, j)),
80
+ )
81
+ prevs.append(TaskRef(prev_key))
82
+ target = Task(
83
+ None,
84
+ operator.sub,
85
+ target,
86
+ Task(None, sum, List(*prevs)),
87
+ )
88
+ dsk[_key(i, j)] = Task(
89
+ _key(i, j),
90
+ _solve_triangular_lower,
91
+ TaskRef((self.a._name, i, i)),
92
+ target,
93
+ )
94
+ else:
95
+ for i in range(vchunks - 1, -1, -1):
96
+ for j in range(hchunks):
97
+ target = TaskRef(_b_init(i, j))
98
+ if i < vchunks - 1:
99
+ prevs = []
100
+ for k in range(i + 1, vchunks):
101
+ prev_key = (name_mdot, i, k, k, j)
102
+ dsk[prev_key] = Task(
103
+ prev_key,
104
+ np.dot,
105
+ TaskRef((self.a._name, i, k)),
106
+ TaskRef(_key(k, j)),
107
+ )
108
+ prevs.append(TaskRef(prev_key))
109
+ target = Task(
110
+ None,
111
+ operator.sub,
112
+ target,
113
+ Task(None, sum, List(*prevs)),
114
+ )
115
+ dsk[_key(i, j)] = Task(
116
+ _key(i, j),
117
+ _solve_triangular_upper,
118
+ TaskRef((self.a._name, i, i)),
119
+ target,
120
+ )
121
+
122
+ return dsk
123
+
124
+
125
+ def solve_triangular(a, b, lower=False):
126
+ """Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
127
+
128
+ Parameters
129
+ ----------
130
+ a : (M, M) array_like
131
+ A triangular matrix
132
+ b : (M,) or (M, N) array_like
133
+ Right-hand side matrix in `a x = b`
134
+ lower : bool, optional
135
+ Use only data contained in the lower triangle of `a`.
136
+ Default is to use upper triangle.
137
+
138
+ Returns
139
+ -------
140
+ x : (M,) or (M, N) array
141
+ Solution to the system `a x = b`. Shape of return matches `b`.
142
+ """
143
+ from dask_array.core import asanyarray
144
+
145
+ a = asanyarray(a)
146
+ b = asanyarray(b)
147
+
148
+ if a.ndim != 2:
149
+ raise ValueError("a must be 2 dimensional")
150
+ if b.ndim <= 2:
151
+ if a.shape[1] != b.shape[0]:
152
+ raise ValueError("a.shape[1] and b.shape[0] must be equal")
153
+ if a.chunks[1] != b.chunks[0]:
154
+ msg = "a.chunks[1] and b.chunks[0] must be equal. Use .rechunk method to change the size of chunks."
155
+ raise ValueError(msg)
156
+ else:
157
+ raise ValueError("b must be 1 or 2 dimensional")
158
+
159
+ expr = SolveTriangular(a.expr, b.expr, lower)
160
+ return new_collection(expr)
161
+
162
+
163
+ def solve(a, b, sym_pos=None, assume_a="gen"):
164
+ """Solve the equation ``a x = b`` for ``x``.
165
+
166
+ By default, use LU decomposition and forward / backward substitutions.
167
+ When ``assume_a = "pos"`` use Cholesky decomposition.
168
+
169
+ Parameters
170
+ ----------
171
+ a : (M, M) array_like
172
+ A square matrix.
173
+ b : (M,) or (M, N) array_like
174
+ Right-hand side matrix in ``a x = b``.
175
+ sym_pos : bool, optional
176
+ Assume a is symmetric and positive definite. If ``True``, use Cholesky
177
+ decomposition.
178
+
179
+ .. note::
180
+ ``sym_pos`` is deprecated and will be removed in a future version.
181
+ Use ``assume_a = 'pos'`` instead.
182
+
183
+ assume_a : {'gen', 'pos'}, optional
184
+ Type of data matrix. It is used to choose the dedicated solver.
185
+ Note that Dask does not support 'her' and 'sym' types.
186
+
187
+ Returns
188
+ -------
189
+ x : (M,) or (M, N) Array
190
+ Solution to the system ``a x = b``. Shape of the return matches the
191
+ shape of `b`.
192
+
193
+ See Also
194
+ --------
195
+ scipy.linalg.solve
196
+ """
197
+ from dask_array.core import asanyarray
198
+ from dask_array.linalg._cholesky import _cholesky
199
+ from dask_array.linalg._lu import lu
200
+
201
+ if sym_pos is not None:
202
+ warnings.warn(
203
+ "The sym_pos keyword is deprecated and should be replaced by using "
204
+ "``assume_a = 'pos'``. ``sym_pos`` will be removed in a future version.",
205
+ category=FutureWarning,
206
+ )
207
+ if sym_pos:
208
+ assume_a = "pos"
209
+
210
+ if assume_a == "pos":
211
+ l, u = _cholesky(a)
212
+ elif assume_a == "gen":
213
+ p, l, u = lu(a)
214
+ b = asanyarray(b)
215
+ b = p.T.dot(b)
216
+ else:
217
+ raise ValueError(
218
+ f"{assume_a = } is not a recognized matrix structure, valid structures in Dask are 'pos' and 'gen'."
219
+ )
220
+
221
+ uy = solve_triangular(l, b, lower=True)
222
+ return solve_triangular(u, uy)
223
+
224
+
225
+ def inv(a):
226
+ """Compute the inverse of a matrix with LU decomposition.
227
+
228
+ Parameters
229
+ ----------
230
+ a : array_like
231
+ Square matrix to be inverted.
232
+
233
+ Returns
234
+ -------
235
+ ainv : Array
236
+ Inverse of the matrix `a`.
237
+ """
238
+ from dask_array.core import asanyarray
239
+ from dask_array.creation import eye
240
+
241
+ a = asanyarray(a)
242
+ return solve(a, eye(a.shape[0], chunks=a.chunks[0][0]))
243
+
244
+
245
+ def _lstsq_singular(rt, r):
246
+ """Compute singular values from R'R eigenvalues."""
247
+ return np.sqrt(np.linalg.eigvalsh(np.dot(rt, r)))[::-1]
248
+
249
+
250
+ class LstsqRank(ArrayExpr):
251
+ """Compute matrix rank from R factor."""
252
+
253
+ _parameters = ["r"]
254
+
255
+ @functools.cached_property
256
+ def _meta(self):
257
+ return np.array(0, dtype=int)
258
+
259
+ @functools.cached_property
260
+ def chunks(self):
261
+ return ()
262
+
263
+ @functools.cached_property
264
+ def _name(self):
265
+ return f"lstsq-rank-{self.deterministic_token}"
266
+
267
+ def _layer(self):
268
+ r_key = (self.r._name, 0, 0)
269
+ out_key = (self._name,)
270
+ return {out_key: Task(out_key, np.linalg.matrix_rank, TaskRef(r_key))}
271
+
272
+
273
+ class LstsqSingular(ArrayExpr):
274
+ """Compute singular values from R factor."""
275
+
276
+ _parameters = ["r"]
277
+
278
+ @functools.cached_property
279
+ def _meta(self):
280
+ input_dtype = self.r.dtype
281
+ if np.issubdtype(input_dtype, np.complexfloating):
282
+ dtype = np.finfo(input_dtype).dtype
283
+ else:
284
+ dtype = input_dtype
285
+ return np.empty((0,), dtype=dtype)
286
+
287
+ @functools.cached_property
288
+ def chunks(self):
289
+ return ((self.r.shape[0],),)
290
+
291
+ @functools.cached_property
292
+ def _name(self):
293
+ return f"lstsq-singular-{self.deterministic_token}"
294
+
295
+ def _layer(self):
296
+ r_key = (self.r._name, 0, 0)
297
+ out_key = (self._name, 0)
298
+ rt_key = (f"lstsq-rt-{self.deterministic_token}", 0, 0)
299
+ return {
300
+ rt_key: Task(rt_key, lambda x: x.T.conj(), TaskRef(r_key)),
301
+ out_key: Task(out_key, _lstsq_singular, TaskRef(rt_key), TaskRef(r_key)),
302
+ }
303
+
304
+
305
+ def lstsq(a, b):
306
+ """Return the least-squares solution to a linear matrix equation using QR.
307
+
308
+ Solves the equation `a x = b` by computing a vector `x` that
309
+ minimizes the Euclidean 2-norm `|| b - a x ||^2`. The equation may
310
+ be under-, well-, or over- determined (i.e., the number of
311
+ linearly independent rows of `a` can be less than, equal to, or
312
+ greater than its number of linearly independent columns). If `a`
313
+ is square and of full rank, then `x` (but for round-off error) is
314
+ the "exact" solution of the equation.
315
+
316
+ Parameters
317
+ ----------
318
+ a : (M, N) array_like
319
+ "Coefficient" matrix.
320
+ b : {(M,), (M, K)} array_like
321
+ Ordinate or "dependent variable" values.
322
+
323
+ Returns
324
+ -------
325
+ x : {(N,), (N, K)} Array
326
+ Least-squares solution.
327
+ residuals : {(1,), (K,)} Array
328
+ Sums of residuals; squared Euclidean 2-norm for each column in
329
+ ``b - a*x``.
330
+ rank : Array
331
+ Rank of matrix `a`.
332
+ s : (min(M, N),) Array
333
+ Singular values of `a`.
334
+ """
335
+ from dask_array.core import asanyarray
336
+ from dask_array.linalg._qr import qr
337
+
338
+ a = asanyarray(a)
339
+ b = asanyarray(b)
340
+
341
+ q, r = qr(a)
342
+ x = solve_triangular(r, q.T.conj().dot(b))
343
+ residuals = b - a.dot(x)
344
+ residuals = abs(residuals**2).sum(axis=0, keepdims=b.ndim == 1)
345
+
346
+ rank_expr = LstsqRank(r.expr)
347
+ s_expr = LstsqSingular(r.expr)
348
+
349
+ return x, residuals, new_collection(rank_expr), new_collection(s_expr)