dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,394 @@
1
+ """SVD decomposition for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import operator
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._new_collection import new_collection
11
+ from dask._task_spec import Task, TaskRef
12
+ from dask_array._expr import ArrayExpr
13
+ from dask_array.random import RandomState, default_rng
14
+ from dask_array._utils import meta_from_array, svd_flip
15
+ from dask.base import wait
16
+ from dask.utils import derived_from
17
+
18
+
19
+ class InCoreSVD(ArrayExpr):
20
+ """In-core SVD decomposition."""
21
+
22
+ _parameters = ["r"]
23
+
24
+ @functools.cached_property
25
+ def _meta(self):
26
+ uu, ss, vvh = np.linalg.svd(np.ones(shape=(1, 1), dtype=self.r.dtype))
27
+ return (
28
+ meta_from_array(self.r._meta, ndim=2, dtype=uu.dtype),
29
+ meta_from_array(self.r._meta, ndim=1, dtype=ss.dtype),
30
+ meta_from_array(self.r._meta, ndim=2, dtype=vvh.dtype),
31
+ )
32
+
33
+ @functools.cached_property
34
+ def chunks(self):
35
+ return self.r.chunks
36
+
37
+ @functools.cached_property
38
+ def _name(self):
39
+ return f"svd-core-{self.deterministic_token}"
40
+
41
+ def _layer(self):
42
+ out_key = (self._name, 0, 0)
43
+ in_key = (self.r._name, 0, 0)
44
+ dsk = {out_key: Task(out_key, np.linalg.svd, TaskRef(in_key))}
45
+ return dsk
46
+
47
+
48
+ class InCoreSVDU(ArrayExpr):
49
+ """Extract U from in-core SVD."""
50
+
51
+ _parameters = ["incore_svd"]
52
+
53
+ @functools.cached_property
54
+ def _meta(self):
55
+ return self.incore_svd._meta[0]
56
+
57
+ @functools.cached_property
58
+ def chunks(self):
59
+ m = self.incore_svd.r.shape[0]
60
+ return ((m,), (m,))
61
+
62
+ @functools.cached_property
63
+ def _name(self):
64
+ return f"svd-u-{self.deterministic_token}"
65
+
66
+ def _layer(self):
67
+ out_key = (self._name, 0, 0)
68
+ in_key = (self.incore_svd._name, 0, 0)
69
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 0)}
70
+ return dsk
71
+
72
+
73
+ class InCoreSVDS(ArrayExpr):
74
+ """Extract S from in-core SVD."""
75
+
76
+ _parameters = ["incore_svd"]
77
+
78
+ @functools.cached_property
79
+ def _meta(self):
80
+ return self.incore_svd._meta[1]
81
+
82
+ @functools.cached_property
83
+ def chunks(self):
84
+ m = self.incore_svd.r.shape[0]
85
+ n = self.incore_svd.r.shape[1]
86
+ k = min(m, n)
87
+ return ((k,),)
88
+
89
+ @functools.cached_property
90
+ def _name(self):
91
+ return f"svd-s-{self.deterministic_token}"
92
+
93
+ def _layer(self):
94
+ out_key = (self._name, 0)
95
+ in_key = (self.incore_svd._name, 0, 0)
96
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 1)}
97
+ return dsk
98
+
99
+
100
+ class InCoreSVDVh(ArrayExpr):
101
+ """Extract Vh from in-core SVD."""
102
+
103
+ _parameters = ["incore_svd"]
104
+
105
+ @functools.cached_property
106
+ def _meta(self):
107
+ return self.incore_svd._meta[2]
108
+
109
+ @functools.cached_property
110
+ def chunks(self):
111
+ n = self.incore_svd.r.shape[1]
112
+ return ((n,), (n,))
113
+
114
+ @functools.cached_property
115
+ def _name(self):
116
+ return f"svd-vh-{self.deterministic_token}"
117
+
118
+ def _layer(self):
119
+ out_key = (self._name, 0, 0)
120
+ in_key = (self.incore_svd._name, 0, 0)
121
+ dsk = {out_key: Task(out_key, operator.getitem, TaskRef(in_key), 2)}
122
+ return dsk
123
+
124
+
125
+ class BlockMatMul(ArrayExpr):
126
+ """Block-wise matrix multiplication: Q @ U_r.
127
+
128
+ Q has multiple row blocks, U_r is a single block.
129
+ Result has same row blocks as Q.
130
+ """
131
+
132
+ _parameters = ["q", "u"]
133
+
134
+ @functools.cached_property
135
+ def _meta(self):
136
+ return self.q._meta
137
+
138
+ @functools.cached_property
139
+ def chunks(self):
140
+ return (self.q.chunks[0], self.u.chunks[1])
141
+
142
+ @functools.cached_property
143
+ def _name(self):
144
+ return f"block-matmul-{self.deterministic_token}"
145
+
146
+ def _layer(self):
147
+ dsk = {}
148
+ numblocks = len(self.q.chunks[0])
149
+ u_key = (self.u._name, 0, 0)
150
+ for i in range(numblocks):
151
+ out_key = (self._name, i, 0)
152
+ q_key = (self.q._name, i, 0)
153
+ dsk[out_key] = Task(out_key, np.dot, TaskRef(q_key), TaskRef(u_key))
154
+ return dsk
155
+
156
+
157
+ def _tsqr_svd(q_expr, r_expr, data_expr):
158
+ """Compute SVD from TSQR factorization."""
159
+ svd_r = InCoreSVD(r_expr)
160
+ u_r = InCoreSVDU(svd_r)
161
+ s = InCoreSVDS(svd_r)
162
+ vh = InCoreSVDVh(svd_r)
163
+
164
+ u_final = BlockMatMul(q_expr, u_r)
165
+
166
+ return new_collection(u_final), new_collection(s), new_collection(vh)
167
+
168
+
169
+ @derived_from(np.linalg)
170
+ def svd(a, full_matrices=False, compute_uv=True):
171
+ """Singular Value Decomposition.
172
+
173
+ Parameters
174
+ ----------
175
+ a : Array
176
+ Input array
177
+ full_matrices : bool
178
+ Full matrices are not supported when ``compute_uv=True``. The default
179
+ returns reduced factors with ``k = min(a.shape)``.
180
+ compute_uv : bool
181
+ If True, compute U and Vh in addition to S
182
+
183
+ Returns
184
+ -------
185
+ u, s, vh : Array, Array, Array
186
+ SVD factors when ``compute_uv=True``
187
+ s : Array
188
+ Singular values when ``compute_uv=False``
189
+ """
190
+ from dask_array.core import asanyarray
191
+ from dask_array.linalg._qr import tsqr
192
+
193
+ if full_matrices and compute_uv:
194
+ raise NotImplementedError("full_matrices=True is not supported")
195
+
196
+ a = asanyarray(a)
197
+
198
+ if a.ndim != 2:
199
+ raise ValueError("Array must be 2D")
200
+
201
+ nr, nc = len(a.chunks[0]), len(a.chunks[1])
202
+
203
+ if nr > 1 and nc > 1:
204
+ raise NotImplementedError(
205
+ "Array must be chunked in one dimension only. "
206
+ "This function (svd) only supports tall-and-skinny or short-and-fat "
207
+ "matrices (see da.linalg.svd_compressed for SVD on fully chunked arrays).\n"
208
+ f"Input shape: {a.shape}\n"
209
+ f"Input numblocks: {(nr, nc)}\n"
210
+ )
211
+
212
+ if nr >= nc:
213
+ u, s, v = tsqr(a, compute_svd=True)
214
+ if a.shape[0] < a.shape[1]:
215
+ k = min(a.shape)
216
+ u, v = u[:, :k], v[:k, :]
217
+ else:
218
+ vt, s, ut = tsqr(a.T, compute_svd=True)
219
+ u, v = ut.T, vt.T
220
+ if a.shape[0] > a.shape[1]:
221
+ k = min(a.shape)
222
+ u, v = u[:, :k], v[:k, :]
223
+
224
+ if not compute_uv:
225
+ return s
226
+ return u, s, v
227
+
228
+
229
+ def compression_level(n, q, n_oversamples=10, min_subspace_size=20):
230
+ """Compression level to use in svd_compressed.
231
+
232
+ Given the size ``n`` of a space, compress that to one of size
233
+ ``q`` plus n_oversamples.
234
+
235
+ Parameters
236
+ ----------
237
+ n: int
238
+ Column/row dimension of original matrix
239
+ q: int
240
+ Size of the desired subspace
241
+ n_oversamples: int, default=10
242
+ Number of oversamples used for generating the sampling matrix.
243
+ min_subspace_size : int, default=20
244
+ Minimum subspace size.
245
+
246
+ Returns
247
+ -------
248
+ int
249
+ Compression level
250
+ """
251
+ return min(max(min_subspace_size, q + n_oversamples), n)
252
+
253
+
254
+ def compression_matrix(
255
+ data,
256
+ q,
257
+ iterator="power",
258
+ n_power_iter=0,
259
+ n_oversamples=10,
260
+ seed=None,
261
+ compute=False,
262
+ ):
263
+ """Randomly sample matrix to find most active subspace.
264
+
265
+ Parameters
266
+ ----------
267
+ data: Array
268
+ q: int
269
+ Size of the desired subspace
270
+ iterator: {'power', 'QR'}, default='power'
271
+ Define the technique used for iterations
272
+ n_power_iter: int
273
+ Number of power iterations
274
+ n_oversamples: int, default=10
275
+ Number of oversamples
276
+ compute : bool
277
+ Whether or not to compute data at each use
278
+
279
+ Returns
280
+ -------
281
+ Array
282
+ Compression matrix
283
+ """
284
+ from dask_array.core import asanyarray
285
+ from dask_array.linalg._qr import tsqr
286
+
287
+ data = asanyarray(data)
288
+
289
+ if iterator not in ["power", "QR"]:
290
+ raise ValueError(f"Iterator '{iterator}' not valid, must be one of ['power', 'QR']")
291
+ m, n = data.shape
292
+ comp_level = compression_level(min(m, n), q, n_oversamples=n_oversamples)
293
+ if isinstance(seed, RandomState):
294
+ state = seed
295
+ else:
296
+ state = default_rng(seed)
297
+ datatype = np.float64
298
+ if (data.dtype).type in {np.float32, np.complex64}:
299
+ datatype = np.float32
300
+ omega = state.standard_normal(size=(n, comp_level), chunks=(data.chunks[1], (comp_level,))).astype(
301
+ datatype, copy=False
302
+ )
303
+ mat_h = data.dot(omega)
304
+ if iterator == "power":
305
+ for _ in range(n_power_iter):
306
+ if compute:
307
+ mat_h = mat_h.persist()
308
+ wait(mat_h)
309
+ tmp = data.T.dot(mat_h)
310
+ if compute:
311
+ tmp = tmp.persist()
312
+ wait(tmp)
313
+ mat_h = data.dot(tmp)
314
+ q_mat, _ = tsqr(mat_h)
315
+ else:
316
+ q_mat, _ = tsqr(mat_h)
317
+ for _ in range(n_power_iter):
318
+ if compute:
319
+ q_mat = q_mat.persist()
320
+ wait(q_mat)
321
+ q_mat, _ = tsqr(data.T.dot(q_mat))
322
+ if compute:
323
+ q_mat = q_mat.persist()
324
+ wait(q_mat)
325
+ q_mat, _ = tsqr(data.dot(q_mat))
326
+ return q_mat.T
327
+
328
+
329
+ def svd_compressed(
330
+ a,
331
+ k,
332
+ iterator="power",
333
+ n_power_iter=0,
334
+ n_oversamples=10,
335
+ seed=None,
336
+ compute=False,
337
+ coerce_signs=True,
338
+ ):
339
+ """Randomly compressed rank-k thin Singular Value Decomposition.
340
+
341
+ This computes the approximate singular value decomposition of a large
342
+ array. This algorithm is generally faster than the normal algorithm
343
+ but does not provide exact results.
344
+
345
+ Parameters
346
+ ----------
347
+ a: Array
348
+ Input array
349
+ k: int
350
+ Rank of the desired thin SVD decomposition.
351
+ iterator: {'power', 'QR'}, default='power'
352
+ Define the technique used for iterations
353
+ n_power_iter: int, default=0
354
+ Number of power iterations
355
+ n_oversamples: int, default=10
356
+ Number of oversamples
357
+ compute : bool
358
+ Whether or not to compute data at each use
359
+ coerce_signs : bool
360
+ Whether or not to apply sign coercion to singular vectors
361
+
362
+ Returns
363
+ -------
364
+ u: Array, unitary / orthogonal
365
+ s: Array, singular values in decreasing order
366
+ v: Array, unitary / orthogonal
367
+ """
368
+ from dask_array.core import asanyarray
369
+ from dask_array.linalg._qr import tsqr
370
+
371
+ a = asanyarray(a)
372
+
373
+ comp = compression_matrix(
374
+ a,
375
+ k,
376
+ iterator=iterator,
377
+ n_power_iter=n_power_iter,
378
+ n_oversamples=n_oversamples,
379
+ seed=seed,
380
+ compute=compute,
381
+ )
382
+ if compute:
383
+ comp = comp.persist()
384
+ wait(comp)
385
+ a_compressed = comp.dot(a)
386
+ v, s, ut = tsqr(a_compressed.T, compute_svd=True)
387
+ u = comp.T.dot(ut.T)
388
+ v = v.T
389
+ u = u[:, :k]
390
+ s = s[:k]
391
+ v = v[:k, :]
392
+ if coerce_signs:
393
+ u, v = svd_flip(u, v)
394
+ return u, s, v
@@ -0,0 +1,334 @@
1
+ """Tensor operations for array-expr (tensordot, dot, vdot, matmul)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+ from numbers import Integral
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._core_utils import is_scalar_for_elemwise, tensordot_lookup
11
+ from dask.utils import derived_from
12
+
13
+
14
+ def _result_type(*args):
15
+ """Compute result dtype for operation."""
16
+ args = [a if is_scalar_for_elemwise(a) else a.dtype for a in args]
17
+ return np.result_type(*args)
18
+
19
+
20
+ def _tensordot(a, b, axes, is_sparse):
21
+ """Helper function for tensordot that handles the actual numpy computation."""
22
+ x = max([a, b], key=lambda x: x.__array_priority__)
23
+ tensordot = tensordot_lookup.dispatch(type(x))
24
+ x = tensordot(a, b, axes=axes)
25
+ if is_sparse and len(axes[0]) == 1:
26
+ return x
27
+ else:
28
+ ind = [slice(None, None)] * x.ndim
29
+ for a in sorted(axes[0]):
30
+ ind.insert(a, None)
31
+ x = x[tuple(ind)]
32
+ return x
33
+
34
+
35
+ def _tensordot_is_sparse(x):
36
+ """Check if array is sparse (scipy sparse, not pydata sparse)."""
37
+ is_sparse = "sparse" in str(type(x._meta))
38
+ if is_sparse:
39
+ # exclude pydata sparse arrays, no workaround required for these in tensordot
40
+ is_sparse = "sparse._coo.core.COO" not in str(type(x._meta))
41
+ return is_sparse
42
+
43
+
44
+ @derived_from(np)
45
+ def tensordot(lhs, rhs, axes=2):
46
+ """Compute tensor dot product along specified axes.
47
+
48
+ Parameters
49
+ ----------
50
+ lhs : array_like
51
+ Left argument
52
+ rhs : array_like
53
+ Right argument
54
+ axes : int or tuple of (int, int) or tuple of (sequence[int], sequence[int])
55
+ If integer, sum over the last `axes` axes of `lhs` and first `axes`
56
+ axes of `rhs`. If tuple, specifies axes to contract.
57
+
58
+ Returns
59
+ -------
60
+ output : dask array
61
+
62
+ See Also
63
+ --------
64
+ numpy.tensordot
65
+ """
66
+ from dask_array._collection import Array, asarray, blockwise
67
+
68
+ if not isinstance(lhs, Array):
69
+ lhs = asarray(lhs)
70
+ if not isinstance(rhs, Array):
71
+ rhs = asarray(rhs)
72
+
73
+ if isinstance(axes, Iterable):
74
+ left_axes, right_axes = axes
75
+ else:
76
+ left_axes = tuple(range(lhs.ndim - axes, lhs.ndim))
77
+ right_axes = tuple(range(0, axes))
78
+ if isinstance(left_axes, Integral):
79
+ left_axes = (left_axes,)
80
+ if isinstance(right_axes, Integral):
81
+ right_axes = (right_axes,)
82
+ if isinstance(left_axes, list):
83
+ left_axes = tuple(left_axes)
84
+ if isinstance(right_axes, list):
85
+ right_axes = tuple(right_axes)
86
+
87
+ is_sparse = _tensordot_is_sparse(lhs) or _tensordot_is_sparse(rhs)
88
+ if is_sparse and len(left_axes) == 1:
89
+ concatenate = True
90
+ else:
91
+ concatenate = False
92
+
93
+ dt = np.promote_types(lhs.dtype, rhs.dtype)
94
+
95
+ left_index = list(range(lhs.ndim))
96
+ right_index = list(range(lhs.ndim, lhs.ndim + rhs.ndim))
97
+ out_index = left_index + right_index
98
+ adjust_chunks = {}
99
+
100
+ for l, r in zip(left_axes, right_axes):
101
+ out_index.remove(right_index[r])
102
+ right_index[r] = left_index[l]
103
+ if concatenate:
104
+ out_index.remove(left_index[l])
105
+ else:
106
+ adjust_chunks[left_index[l]] = lambda c: 1
107
+
108
+ # Compute explicit meta to preserve masked array type
109
+ # (compute_meta fails for masked arrays due to reshape issues with 0-dim arrays)
110
+ meta = None
111
+ for arr in (lhs, rhs):
112
+ if hasattr(arr._meta, "mask"): # MaskedArray check
113
+ out_ndim = len(out_index)
114
+ meta = np.ma.empty((0,) * out_ndim, dtype=dt)
115
+ break
116
+
117
+ intermediate = blockwise(
118
+ _tensordot,
119
+ out_index,
120
+ lhs,
121
+ left_index,
122
+ rhs,
123
+ right_index,
124
+ dtype=dt,
125
+ concatenate=concatenate,
126
+ adjust_chunks=adjust_chunks,
127
+ axes=(left_axes, right_axes),
128
+ is_sparse=is_sparse,
129
+ meta=meta,
130
+ )
131
+
132
+ if concatenate:
133
+ return intermediate
134
+ else:
135
+ left_axes = [ax if ax >= 0 else lhs.ndim + ax for ax in left_axes]
136
+ return intermediate.sum(axis=left_axes)
137
+
138
+
139
+ @derived_from(np, ua_args=["out"])
140
+ def dot(a, b):
141
+ """Dot product of two arrays.
142
+
143
+ For 2-D arrays it is equivalent to matrix multiplication,
144
+ for 1-D arrays to inner product of vectors.
145
+
146
+ Parameters
147
+ ----------
148
+ a : array_like
149
+ First argument
150
+ b : array_like
151
+ Second argument
152
+
153
+ Returns
154
+ -------
155
+ output : dask array
156
+
157
+ See Also
158
+ --------
159
+ numpy.dot
160
+ tensordot
161
+ """
162
+ return tensordot(a, b, axes=((a.ndim - 1,), (b.ndim - 2,)))
163
+
164
+
165
+ @derived_from(np)
166
+ def vdot(a, b):
167
+ """Return the dot product of two vectors.
168
+
169
+ The vdot function handles complex numbers differently than dot:
170
+ if the first argument is complex the complex conjugate of the
171
+ first argument is used for the calculation of the dot product.
172
+
173
+ Parameters
174
+ ----------
175
+ a : array_like
176
+ First argument
177
+ b : array_like
178
+ Second argument
179
+
180
+ Returns
181
+ -------
182
+ output : dask array
183
+
184
+ See Also
185
+ --------
186
+ numpy.vdot
187
+ dot
188
+ """
189
+ from dask_array._collection import ravel
190
+
191
+ return dot(ravel(a).conj(), ravel(b))
192
+
193
+
194
+ def _matmul(a, b):
195
+ """Helper function for matmul that handles the actual numpy computation."""
196
+ xp = np
197
+
198
+ # Check for cupy
199
+ try:
200
+ import cupy
201
+
202
+ if hasattr(a, "__cuda_array_interface__") or hasattr(b, "__cuda_array_interface__"):
203
+ xp = cupy
204
+ except ImportError:
205
+ pass
206
+
207
+ chunk = xp.matmul(a, b)
208
+ # Since we have performed the contraction via xp.matmul
209
+ # but blockwise expects all dimensions back (including
210
+ # the contraction-axis in the 2nd-to-last position of
211
+ # the output), we must then put it back in the expected
212
+ # the position ourselves:
213
+ return chunk[..., xp.newaxis, :]
214
+
215
+
216
+ def _sum_wo_cat(a, axis=None, dtype=None):
217
+ """Sum without concatenation - used for matmul reduction."""
218
+ from functools import partial, reduce
219
+
220
+ from dask_array.reductions import reduction
221
+
222
+ def _chunk_sum(a, axis=None, dtype=None, keepdims=None):
223
+ # Caution: this is not your conventional array-sum: due
224
+ # to the special nature of the preceding blockwise con-
225
+ # traction, each chunk is expected to have exactly the
226
+ # same shape, with a size of 1 for the dimension given
227
+ # by `axis` (the reduction axis). This makes mere ele-
228
+ # ment-wise addition of the arrays possible. Besides,
229
+ # the output can be merely squeezed to lose the `axis`-
230
+ # dimension when keepdims = False
231
+ if type(a) is list:
232
+ out = reduce(partial(np.add, dtype=dtype), a)
233
+ else:
234
+ out = a
235
+
236
+ if keepdims:
237
+ return out
238
+ else:
239
+ return out.squeeze(axis[0])
240
+
241
+ if dtype is None:
242
+ dtype = getattr(np.zeros(1, dtype=a.dtype).sum(), "dtype", object)
243
+
244
+ if a.shape[axis] == 1:
245
+ from dask_array._collection import squeeze
246
+
247
+ return squeeze(a, axis=axis)
248
+
249
+ return reduction(a, _chunk_sum, _chunk_sum, axis=axis, dtype=dtype, concatenate=False)
250
+
251
+
252
+ @derived_from(np)
253
+ def matmul(a, b):
254
+ """Matrix product of two arrays.
255
+
256
+ Parameters
257
+ ----------
258
+ a : array_like
259
+ First argument
260
+ b : array_like
261
+ Second argument
262
+
263
+ Returns
264
+ -------
265
+ output : dask array
266
+
267
+ See Also
268
+ --------
269
+ numpy.matmul
270
+ """
271
+ from dask_array._collection import asanyarray, blockwise
272
+
273
+ a = asanyarray(a)
274
+ b = asanyarray(b)
275
+
276
+ if a.ndim == 0 or b.ndim == 0:
277
+ raise ValueError("`matmul` does not support scalars.")
278
+
279
+ a_is_1d = False
280
+ if a.ndim == 1:
281
+ a_is_1d = True
282
+ a = a[np.newaxis, :]
283
+
284
+ b_is_1d = False
285
+ if b.ndim == 1:
286
+ b_is_1d = True
287
+ b = b[:, np.newaxis]
288
+
289
+ if a.ndim < b.ndim:
290
+ a = a[(b.ndim - a.ndim) * (np.newaxis,)]
291
+ elif a.ndim > b.ndim:
292
+ b = b[(a.ndim - b.ndim) * (np.newaxis,)]
293
+
294
+ # out_ind includes all dimensions to prevent contraction
295
+ # in the blockwise below. We set the last two dimensions
296
+ # of the output to the contraction axis and the 2nd
297
+ # (last) dimension of b in that order
298
+ out_ind = tuple(range(a.ndim + 1))
299
+ # lhs_ind includes `a`/LHS dimensions
300
+ lhs_ind = tuple(range(a.ndim))
301
+ # on `b`/RHS everything above 2nd dimension, is the same
302
+ # as `a`, -2 dimension is "contracted" with the last dimension
303
+ # of `a`, last dimension of `b` is `b` specific
304
+ rhs_ind = tuple(range(a.ndim - 2)) + (lhs_ind[-1], a.ndim)
305
+
306
+ out = blockwise(
307
+ _matmul,
308
+ out_ind,
309
+ a,
310
+ lhs_ind,
311
+ b,
312
+ rhs_ind,
313
+ adjust_chunks={lhs_ind[-1]: 1},
314
+ dtype=_result_type(a, b),
315
+ concatenate=False,
316
+ )
317
+
318
+ # Because contraction + concatenate in blockwise leads to high
319
+ # memory footprints, we want to avoid them. Instead we will perform
320
+ # blockwise (without contraction) followed by reduction. More about
321
+ # this issue: https://github.com/dask/dask/issues/6874
322
+
323
+ # We will also perform the reduction without concatenation
324
+ out = _sum_wo_cat(out, axis=-2)
325
+
326
+ if a_is_1d or b_is_1d:
327
+ from dask_array._collection import squeeze
328
+
329
+ if a_is_1d:
330
+ out = squeeze(out, axis=-2)
331
+ if b_is_1d:
332
+ out = squeeze(out, axis=-1)
333
+
334
+ return out