dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,234 @@
1
+ """Cholesky decomposition for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import operator
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._new_collection import new_collection
11
+ from dask._task_spec import Alias, List, Task, TaskRef
12
+ from dask_array._expr import ArrayExpr
13
+ from dask_array.linalg._utils import _solve_triangular_lower
14
+ from dask_array._utils import meta_from_array
15
+
16
+
17
+ def _cholesky_lower(a):
18
+ """Compute Cholesky decomposition (lower triangular)."""
19
+ return np.linalg.cholesky(a)
20
+
21
+
22
+ def _zeros_like_shape(arr_meta, shape):
23
+ """Create zeros with specific shape, dtype from arr_meta."""
24
+ return np.zeros(shape, dtype=arr_meta.dtype)
25
+
26
+
27
+ class Cholesky(ArrayExpr):
28
+ """Block Cholesky decomposition.
29
+
30
+ Computes both lower and upper triangular factors.
31
+ """
32
+
33
+ _parameters = ["array"]
34
+
35
+ @functools.cached_property
36
+ def _meta(self):
37
+ from dask_array._utils import array_safe
38
+
39
+ arr_meta = meta_from_array(self.array._meta)
40
+ cho = np.linalg.cholesky(array_safe([[1, 2], [2, 5]], dtype=self.array.dtype, like=arr_meta))
41
+ return meta_from_array(self.array._meta, dtype=cho.dtype)
42
+
43
+ @functools.cached_property
44
+ def chunks(self):
45
+ return self.array.chunks
46
+
47
+ @functools.cached_property
48
+ def _name(self):
49
+ return f"cholesky-{self.deterministic_token}"
50
+
51
+ def _layer(self):
52
+ vdim = len(self.array.chunks[0])
53
+ hdim = len(self.array.chunks[1])
54
+
55
+ name = self._name
56
+ name_upper = f"cholesky-upper-{self.deterministic_token}"
57
+ name_lt_dot = f"cholesky-lt-dot-{self.deterministic_token}"
58
+
59
+ dsk = {}
60
+
61
+ for j in range(hdim):
62
+ for i in range(vdim):
63
+ if i < j:
64
+ chunk_shape = (
65
+ self.array.chunks[0][i],
66
+ self.array.chunks[1][j],
67
+ )
68
+ dsk[(name, i, j)] = Task(
69
+ (name, i, j),
70
+ _zeros_like_shape,
71
+ self.array._meta,
72
+ chunk_shape,
73
+ )
74
+ dsk[(name_upper, j, i)] = TaskRef((name, i, j))
75
+ elif i == j:
76
+ target = TaskRef((self.array._name, i, j))
77
+ if i > 0:
78
+ prevs = []
79
+ for p in range(i):
80
+ prev = (name_lt_dot, i, p, i, p)
81
+ dsk[prev] = Task(
82
+ prev,
83
+ np.dot,
84
+ TaskRef((name, i, p)),
85
+ TaskRef((name_upper, p, i)),
86
+ )
87
+ prevs.append(TaskRef(prev))
88
+ target = Task(None, operator.sub, target, Task(None, sum, List(*prevs)))
89
+ dsk[(name, i, i)] = Task((name, i, i), _cholesky_lower, target)
90
+ dsk[(name_upper, i, i)] = Task((name_upper, i, i), np.transpose, TaskRef((name, i, i)))
91
+ else:
92
+ target = TaskRef((self.array._name, j, i))
93
+ if j > 0:
94
+ prevs = []
95
+ for p in range(j):
96
+ prev = (name_lt_dot, j, p, i, p)
97
+ dsk[prev] = Task(
98
+ prev,
99
+ np.dot,
100
+ TaskRef((name, j, p)),
101
+ TaskRef((name_upper, p, i)),
102
+ )
103
+ prevs.append(TaskRef(prev))
104
+ target = Task(None, operator.sub, target, Task(None, sum, List(*prevs)))
105
+ dsk[(name_upper, j, i)] = Task(
106
+ (name_upper, j, i),
107
+ _solve_triangular_lower,
108
+ TaskRef((name, j, j)),
109
+ target,
110
+ )
111
+ dsk[(name, i, j)] = Task((name, i, j), np.transpose, TaskRef((name_upper, j, i)))
112
+
113
+ return dsk
114
+
115
+
116
+ class CholeskyLower(ArrayExpr):
117
+ """Extract lower triangular from Cholesky.
118
+
119
+ This is a view into the lower triangular portion of the Cholesky
120
+ factorization. It uses Alias tasks to reference the parent Cholesky's
121
+ tasks.
122
+ """
123
+
124
+ _parameters = ["chol"]
125
+
126
+ @functools.cached_property
127
+ def _meta(self):
128
+ return self.chol._meta
129
+
130
+ @functools.cached_property
131
+ def chunks(self):
132
+ return self.chol.chunks
133
+
134
+ @functools.cached_property
135
+ def _name(self):
136
+ return f"cholesky-lower-{self.chol.deterministic_token}"
137
+
138
+ def _layer(self):
139
+ vdim = len(self.chol.chunks[0])
140
+ hdim = len(self.chol.chunks[1])
141
+ parent_name = self.chol._name
142
+
143
+ dsk = {}
144
+ for i in range(vdim):
145
+ for j in range(hdim):
146
+ out_key = (self._name, i, j)
147
+ in_key = (parent_name, i, j)
148
+ dsk[out_key] = Alias(out_key, in_key)
149
+ return dsk
150
+
151
+
152
+ class CholeskyUpper(ArrayExpr):
153
+ """Extract upper triangular from Cholesky.
154
+
155
+ This is a view into the upper triangular portion of the Cholesky
156
+ factorization. It uses Alias tasks to reference the parent Cholesky's
157
+ tasks.
158
+ """
159
+
160
+ _parameters = ["chol"]
161
+
162
+ @functools.cached_property
163
+ def _meta(self):
164
+ return self.chol._meta
165
+
166
+ @functools.cached_property
167
+ def chunks(self):
168
+ return self.chol.chunks
169
+
170
+ @functools.cached_property
171
+ def _name(self):
172
+ return f"cholesky-upper-view-{self.chol.deterministic_token}"
173
+
174
+ def _layer(self):
175
+ vdim = len(self.chol.chunks[0])
176
+ hdim = len(self.chol.chunks[1])
177
+ parent_name = f"cholesky-upper-{self.chol.deterministic_token}"
178
+
179
+ dsk = {}
180
+ for i in range(vdim):
181
+ for j in range(hdim):
182
+ out_key = (self._name, i, j)
183
+ in_key = (parent_name, i, j)
184
+ dsk[out_key] = Alias(out_key, in_key)
185
+ return dsk
186
+
187
+
188
+ def _cholesky(a):
189
+ """Private function to compute both L and U Cholesky factors."""
190
+ from dask_array.core import asanyarray
191
+
192
+ a = asanyarray(a)
193
+
194
+ if a.ndim != 2:
195
+ raise ValueError("Dimension must be 2 to perform cholesky decomposition")
196
+
197
+ xdim, ydim = a.shape
198
+ if xdim != ydim:
199
+ raise ValueError("Input must be a square matrix to perform cholesky decomposition")
200
+ if len(set(a.chunks[0] + a.chunks[1])) != 1:
201
+ msg = (
202
+ "All chunks must be a square matrix to perform cholesky decomposition. "
203
+ "Use .rechunk method to change the size of chunks."
204
+ )
205
+ raise ValueError(msg)
206
+
207
+ chol_expr = Cholesky(a.expr)
208
+ lower_expr = CholeskyLower(chol_expr)
209
+ upper_expr = CholeskyUpper(chol_expr)
210
+
211
+ return new_collection(lower_expr), new_collection(upper_expr)
212
+
213
+
214
+ def cholesky(a, lower=False):
215
+ """Returns the Cholesky decomposition of a Hermitian positive-definite matrix.
216
+
217
+ Parameters
218
+ ----------
219
+ a : (M, M) array_like
220
+ Matrix to be decomposed
221
+ lower : bool, optional
222
+ Whether to compute the upper or lower triangular Cholesky
223
+ factorization. Default is upper-triangular.
224
+
225
+ Returns
226
+ -------
227
+ c : (M, M) Array
228
+ Upper- or lower-triangular Cholesky factor of `a`.
229
+ """
230
+ l, u = _cholesky(a)
231
+ if lower:
232
+ return l
233
+ else:
234
+ return u
@@ -0,0 +1,300 @@
1
+ """LU decomposition for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import operator
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._new_collection import new_collection
11
+ from dask._task_spec import List, Task, TaskRef
12
+ from dask_array._expr import ArrayExpr
13
+ from dask_array.linalg._utils import (
14
+ _solve_triangular_lower,
15
+ _transpose,
16
+ )
17
+ from dask_array._utils import meta_from_array
18
+
19
+
20
+ def _scipy_lu(a):
21
+ """Compute LU decomposition using scipy."""
22
+ import scipy.linalg
23
+
24
+ return scipy.linalg.lu(a)
25
+
26
+
27
+ class LU(ArrayExpr):
28
+ """Block LU decomposition returning (P, L, U) tuple.
29
+
30
+ Uses scipy.linalg.lu for diagonal blocks and propagates through
31
+ off-diagonal blocks with forward/backward substitution.
32
+ """
33
+
34
+ _parameters = ["array"]
35
+
36
+ @functools.cached_property
37
+ def _meta(self):
38
+ import scipy.linalg
39
+
40
+ pp, ll, uu = scipy.linalg.lu(np.ones(shape=(1, 1), dtype=self.array.dtype))
41
+ arr_meta = self.array._meta
42
+ return (
43
+ meta_from_array(arr_meta, ndim=2, dtype=pp.dtype),
44
+ meta_from_array(arr_meta, ndim=2, dtype=ll.dtype),
45
+ meta_from_array(arr_meta, ndim=2, dtype=uu.dtype),
46
+ )
47
+
48
+ @functools.cached_property
49
+ def chunks(self):
50
+ return self.array.chunks
51
+
52
+ @functools.cached_property
53
+ def _name(self):
54
+ return f"lu-{self.deterministic_token}"
55
+
56
+ def _layer(self):
57
+ vdim = len(self.array.chunks[0])
58
+ hdim = len(self.array.chunks[1])
59
+
60
+ name_lu = f"lu-lu-{self.deterministic_token}"
61
+ name_p = f"lu-p-{self.deterministic_token}"
62
+ name_l = f"lu-l-{self.deterministic_token}"
63
+ name_u = f"lu-u-{self.deterministic_token}"
64
+ name_transpose = f"lu-p-inv-{self.deterministic_token}"
65
+ name_l_permuted = f"lu-l-permute-{self.deterministic_token}"
66
+ name_transposed = f"lu-u-transpose-{self.deterministic_token}"
67
+ name_plu_dot = f"lu-plu-dot-{self.deterministic_token}"
68
+ name_lu_dot = f"lu-lu-dot-{self.deterministic_token}"
69
+
70
+ dsk = {}
71
+
72
+ for i in range(min(vdim, hdim)):
73
+ target = TaskRef((self.array._name, i, i))
74
+ if i > 0:
75
+ prevs = []
76
+ for p in range(i):
77
+ prev = (name_plu_dot, i, p, p, i)
78
+ dsk[prev] = Task(
79
+ prev,
80
+ np.dot,
81
+ TaskRef((name_l_permuted, i, p)),
82
+ TaskRef((name_u, p, i)),
83
+ )
84
+ prevs.append(TaskRef(prev))
85
+ target = Task(None, operator.sub, target, Task(None, sum, List(*prevs)))
86
+ dsk[(name_lu, i, i)] = Task((name_lu, i, i), _scipy_lu, target)
87
+
88
+ for j in range(i + 1, hdim):
89
+ target_h = Task(
90
+ None,
91
+ np.dot,
92
+ TaskRef((name_transpose, i, i)),
93
+ TaskRef((self.array._name, i, j)),
94
+ )
95
+ if i > 0:
96
+ prevs = []
97
+ for p in range(i):
98
+ prev = (name_lu_dot, i, p, p, j)
99
+ dsk[prev] = Task(
100
+ prev,
101
+ np.dot,
102
+ TaskRef((name_l, i, p)),
103
+ TaskRef((name_u, p, j)),
104
+ )
105
+ prevs.append(TaskRef(prev))
106
+ target_h = Task(None, operator.sub, target_h, Task(None, sum, List(*prevs)))
107
+ dsk[(name_lu, i, j)] = Task(
108
+ (name_lu, i, j),
109
+ _solve_triangular_lower,
110
+ TaskRef((name_l, i, i)),
111
+ target_h,
112
+ )
113
+
114
+ for k in range(i + 1, vdim):
115
+ target_v = TaskRef((self.array._name, k, i))
116
+ if i > 0:
117
+ prevs = []
118
+ for p in range(i):
119
+ prev = (name_plu_dot, k, p, p, i)
120
+ dsk[prev] = Task(
121
+ prev,
122
+ np.dot,
123
+ TaskRef((name_l_permuted, k, p)),
124
+ TaskRef((name_u, p, i)),
125
+ )
126
+ prevs.append(TaskRef(prev))
127
+ target_v = Task(None, operator.sub, target_v, Task(None, sum, List(*prevs)))
128
+ dsk[(name_lu, k, i)] = Task(
129
+ (name_lu, k, i),
130
+ np.transpose,
131
+ Task(
132
+ None,
133
+ _solve_triangular_lower,
134
+ TaskRef((name_transposed, i, i)),
135
+ Task(None, np.transpose, target_v),
136
+ ),
137
+ )
138
+
139
+ for i in range(min(vdim, hdim)):
140
+ for j in range(min(vdim, hdim)):
141
+ if i == j:
142
+ dsk[(name_p, i, j)] = Task(
143
+ (name_p, i, j),
144
+ operator.getitem,
145
+ TaskRef((name_lu, i, j)),
146
+ 0,
147
+ )
148
+ dsk[(name_l, i, j)] = Task(
149
+ (name_l, i, j),
150
+ operator.getitem,
151
+ TaskRef((name_lu, i, j)),
152
+ 1,
153
+ )
154
+ dsk[(name_u, i, j)] = Task(
155
+ (name_u, i, j),
156
+ operator.getitem,
157
+ TaskRef((name_lu, i, j)),
158
+ 2,
159
+ )
160
+ dsk[(name_l_permuted, i, j)] = Task(
161
+ (name_l_permuted, i, j),
162
+ np.dot,
163
+ TaskRef((name_p, i, j)),
164
+ TaskRef((name_l, i, j)),
165
+ )
166
+ dsk[(name_transposed, i, j)] = Task(
167
+ (name_transposed, i, j),
168
+ _transpose,
169
+ TaskRef((name_u, i, j)),
170
+ )
171
+ dsk[(name_transpose, i, j)] = Task(
172
+ (name_transpose, i, j),
173
+ _transpose,
174
+ TaskRef((name_p, i, j)),
175
+ )
176
+ elif i > j:
177
+ chunk_shape = (
178
+ self.array.chunks[0][i],
179
+ self.array.chunks[1][j],
180
+ )
181
+ dsk[(name_p, i, j)] = Task((name_p, i, j), np.zeros, chunk_shape)
182
+ dsk[(name_l, i, j)] = Task(
183
+ (name_l, i, j),
184
+ np.dot,
185
+ TaskRef((name_transpose, i, i)),
186
+ TaskRef((name_lu, i, j)),
187
+ )
188
+ dsk[(name_u, i, j)] = Task((name_u, i, j), np.zeros, chunk_shape)
189
+ dsk[(name_l_permuted, i, j)] = TaskRef((name_lu, i, j))
190
+ else:
191
+ chunk_shape = (
192
+ self.array.chunks[0][i],
193
+ self.array.chunks[1][j],
194
+ )
195
+ dsk[(name_p, i, j)] = Task((name_p, i, j), np.zeros, chunk_shape)
196
+ dsk[(name_l, i, j)] = Task((name_l, i, j), np.zeros, chunk_shape)
197
+ dsk[(name_u, i, j)] = TaskRef((name_lu, i, j))
198
+
199
+ return dsk
200
+
201
+
202
+ class LUGetP(ArrayExpr):
203
+ """Extract P from LU decomposition."""
204
+
205
+ _parameters = ["lu"]
206
+
207
+ @functools.cached_property
208
+ def _meta(self):
209
+ return self.lu._meta[0]
210
+
211
+ @functools.cached_property
212
+ def chunks(self):
213
+ return self.lu.chunks
214
+
215
+ @functools.cached_property
216
+ def _name(self):
217
+ return f"lu-p-{self.lu.deterministic_token}"
218
+
219
+ def _layer(self):
220
+ return {}
221
+
222
+
223
+ class LUGetL(ArrayExpr):
224
+ """Extract L from LU decomposition."""
225
+
226
+ _parameters = ["lu"]
227
+
228
+ @functools.cached_property
229
+ def _meta(self):
230
+ return self.lu._meta[1]
231
+
232
+ @functools.cached_property
233
+ def chunks(self):
234
+ return self.lu.chunks
235
+
236
+ @functools.cached_property
237
+ def _name(self):
238
+ return f"lu-l-{self.lu.deterministic_token}"
239
+
240
+ def _layer(self):
241
+ return {}
242
+
243
+
244
+ class LUGetU(ArrayExpr):
245
+ """Extract U from LU decomposition."""
246
+
247
+ _parameters = ["lu"]
248
+
249
+ @functools.cached_property
250
+ def _meta(self):
251
+ return self.lu._meta[2]
252
+
253
+ @functools.cached_property
254
+ def chunks(self):
255
+ return self.lu.chunks
256
+
257
+ @functools.cached_property
258
+ def _name(self):
259
+ return f"lu-u-{self.lu.deterministic_token}"
260
+
261
+ def _layer(self):
262
+ return {}
263
+
264
+
265
+ def lu(a):
266
+ """Compute the LU decomposition of a matrix.
267
+
268
+ Examples
269
+ --------
270
+ >>> p, l, u = da.linalg.lu(x) # doctest: +SKIP
271
+
272
+ Returns
273
+ -------
274
+ p : Array, permutation matrix
275
+ l : Array, lower triangular matrix with unit diagonal.
276
+ u : Array, upper triangular matrix
277
+ """
278
+ from dask_array.core import asanyarray
279
+
280
+ a = asanyarray(a)
281
+
282
+ if a.ndim != 2:
283
+ raise ValueError("Dimension must be 2 to perform lu decomposition")
284
+
285
+ xdim, ydim = a.shape
286
+ if xdim != ydim:
287
+ raise ValueError("Input must be a square matrix to perform lu decomposition")
288
+ if len(set(a.chunks[0] + a.chunks[1])) != 1:
289
+ msg = (
290
+ "All chunks must be a square matrix to perform lu decomposition. "
291
+ "Use .rechunk method to change the size of chunks."
292
+ )
293
+ raise ValueError(msg)
294
+
295
+ lu_expr = LU(a.expr)
296
+ p_expr = LUGetP(lu_expr)
297
+ l_expr = LUGetL(lu_expr)
298
+ u_expr = LUGetU(lu_expr)
299
+
300
+ return new_collection(p_expr), new_collection(l_expr), new_collection(u_expr)
@@ -0,0 +1,94 @@
1
+ """Matrix and vector norms for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numbers import Number
6
+
7
+ import numpy as np
8
+
9
+ from dask.utils import derived_from
10
+
11
+
12
+ @derived_from(np.linalg)
13
+ def norm(x, ord=None, axis=None, keepdims=False):
14
+ """Matrix or vector norm.
15
+
16
+ This function uses array operations (abs, sum, max, min) which are
17
+ already implemented in array-expr.
18
+ """
19
+ from dask_array.core import asanyarray
20
+ from dask_array.linalg._svd import svd
21
+
22
+ x = asanyarray(x)
23
+
24
+ if axis is None:
25
+ axis = tuple(range(x.ndim))
26
+ elif isinstance(axis, Number):
27
+ axis = (int(axis),)
28
+ else:
29
+ axis = tuple(axis)
30
+
31
+ if len(axis) > 2:
32
+ raise ValueError("Improper number of dimensions to norm.")
33
+
34
+ if ord == "fro":
35
+ ord = None
36
+ if len(axis) == 1:
37
+ raise ValueError("Invalid norm order for vectors.")
38
+
39
+ r = abs(x)
40
+
41
+ if ord is None:
42
+ r = (r**2).sum(axis=axis, keepdims=keepdims) ** 0.5
43
+ elif ord == "nuc":
44
+ if len(axis) == 1:
45
+ raise ValueError("Invalid norm order for vectors.")
46
+ if x.ndim > 2:
47
+ raise NotImplementedError("SVD based norm not implemented for ndim > 2")
48
+
49
+ r = svd(x)[1][None].sum(keepdims=keepdims)
50
+ elif ord == np.inf:
51
+ if len(axis) == 1:
52
+ r = r.max(axis=axis, keepdims=keepdims)
53
+ else:
54
+ r = r.sum(axis=axis[1], keepdims=True).max(axis=axis[0], keepdims=True)
55
+ if keepdims is False:
56
+ r = r.squeeze(axis=axis)
57
+ elif ord == -np.inf:
58
+ if len(axis) == 1:
59
+ r = r.min(axis=axis, keepdims=keepdims)
60
+ else:
61
+ r = r.sum(axis=axis[1], keepdims=True).min(axis=axis[0], keepdims=True)
62
+ if keepdims is False:
63
+ r = r.squeeze(axis=axis)
64
+ elif ord == 0:
65
+ if len(axis) == 2:
66
+ raise ValueError("Invalid norm order for matrices.")
67
+
68
+ r = (r != 0).astype(r.dtype).sum(axis=axis, keepdims=keepdims)
69
+ elif ord == 1:
70
+ if len(axis) == 1:
71
+ r = r.sum(axis=axis, keepdims=keepdims)
72
+ else:
73
+ r = r.sum(axis=axis[0], keepdims=True).max(axis=axis[1], keepdims=True)
74
+ if keepdims is False:
75
+ r = r.squeeze(axis=axis)
76
+ elif len(axis) == 2 and ord == -1:
77
+ r = r.sum(axis=axis[0], keepdims=True).min(axis=axis[1], keepdims=True)
78
+ if keepdims is False:
79
+ r = r.squeeze(axis=axis)
80
+ elif len(axis) == 2 and ord == 2:
81
+ if x.ndim > 2:
82
+ raise NotImplementedError("SVD based norm not implemented for ndim > 2")
83
+ r = svd(x)[1][None].max(keepdims=keepdims)
84
+ elif len(axis) == 2 and ord == -2:
85
+ if x.ndim > 2:
86
+ raise NotImplementedError("SVD based norm not implemented for ndim > 2")
87
+ r = svd(x)[1][None].min(keepdims=keepdims)
88
+ else:
89
+ if len(axis) == 2:
90
+ raise ValueError("Invalid norm order for matrices.")
91
+
92
+ r = (r**ord).sum(axis=axis, keepdims=keepdims) ** (1.0 / ord)
93
+
94
+ return r