dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,1365 @@
1
+ """
2
+ Core utility functions extracted from dask.array.core.
3
+
4
+ This module provides helper functions used throughout dask_array that were
5
+ previously imported from dask.array.core.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import functools
11
+ import math
12
+ import sys
13
+ import traceback
14
+ import warnings
15
+ from collections.abc import Iterable, Iterator
16
+ from itertools import product, zip_longest
17
+ from numbers import Integral, Number
18
+ from typing import TYPE_CHECKING
19
+
20
+ import numpy as np
21
+ from tlz import first
22
+ from toolz import frequencies
23
+
24
+ from dask import config
25
+ from dask.base import is_dask_collection, tokenize
26
+ from dask.core import flatten
27
+ from dask.delayed import delayed
28
+ from dask.sizeof import sizeof
29
+ from dask.utils import (
30
+ Dispatch,
31
+ cached_cumsum,
32
+ cached_max,
33
+ concrete,
34
+ funcname,
35
+ has_keyword,
36
+ is_arraylike,
37
+ is_integer,
38
+ ndimlist,
39
+ parse_bytes,
40
+ )
41
+
42
+ if TYPE_CHECKING:
43
+ pass
44
+
45
+ # Type definition
46
+ T_IntOrNaN = int | float # Should be int | Literal[np.nan]
47
+
48
+
49
+ # Error message constant
50
+ unknown_chunk_message = (
51
+ "\n\n"
52
+ "A possible solution: "
53
+ "https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
54
+ "Summary: to compute chunks sizes, use\n\n"
55
+ " x.compute_chunk_sizes() # for Dask Array `x`\n"
56
+ " ddf.to_dask_array(lengths=True) # for Dask DataFrame `ddf`"
57
+ )
58
+
59
+
60
+ class PerformanceWarning(Warning):
61
+ """A warning given when bad chunking may cause poor performance"""
62
+
63
+
64
+ # Dispatch registries for array operations
65
+ concatenate_lookup = Dispatch("concatenate")
66
+ tensordot_lookup = Dispatch("tensordot")
67
+
68
+
69
+ def getter(a, b, asarray=True, lock=None):
70
+ if isinstance(b, tuple) and any(x is None for x in b):
71
+ b2 = tuple(x for x in b if x is not None)
72
+ b3 = tuple(None if x is None else slice(None, None) for x in b if not isinstance(x, Integral))
73
+ return getter(a, b2, asarray=asarray, lock=lock)[b3]
74
+
75
+ if lock:
76
+ lock.acquire()
77
+ try:
78
+ c = a[b]
79
+ # Below we special-case `np.matrix` to force a conversion to
80
+ # `np.ndarray` and preserve original Dask behavior for `getter`,
81
+ # as for all purposes `np.matrix` is array-like and thus
82
+ # `is_arraylike` evaluates to `True` in that case.
83
+ if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
84
+ c = np.asarray(c)
85
+ finally:
86
+ if lock:
87
+ lock.release()
88
+ return c
89
+
90
+
91
+ def getter_nofancy(a, b, asarray=True, lock=None):
92
+ """A simple wrapper around ``getter``.
93
+
94
+ Used to indicate to the optimization passes that the backend doesn't
95
+ support fancy indexing.
96
+ """
97
+ return getter(a, b, asarray=asarray, lock=lock)
98
+
99
+
100
+ def getter_inline(a, b, asarray=True, lock=None):
101
+ """A getter function that optimizations feel comfortable inlining
102
+
103
+ Slicing operations with this function may be inlined into a graph, such as
104
+ in the following rewrite
105
+
106
+ **Before**
107
+
108
+ >>> a = x[:10] # doctest: +SKIP
109
+ >>> b = a + 1 # doctest: +SKIP
110
+ >>> c = a * 2 # doctest: +SKIP
111
+
112
+ **After**
113
+
114
+ >>> b = x[:10] + 1 # doctest: +SKIP
115
+ >>> c = x[:10] * 2 # doctest: +SKIP
116
+
117
+ This inlining can be relevant to operations when running off of disk.
118
+ """
119
+ return getter(a, b, asarray=asarray, lock=lock)
120
+
121
+
122
+ def slices_from_chunks(chunks):
123
+ """Translate chunks tuple to a set of slices in product order
124
+
125
+ >>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE
126
+ [(slice(0, 2, None), slice(0, 3, None)),
127
+ (slice(0, 2, None), slice(3, 6, None)),
128
+ (slice(0, 2, None), slice(6, 9, None)),
129
+ (slice(2, 4, None), slice(0, 3, None)),
130
+ (slice(2, 4, None), slice(3, 6, None)),
131
+ (slice(2, 4, None), slice(6, 9, None))]
132
+ """
133
+ cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
134
+ slices = [[slice(s, s + dim) for s, dim in zip(starts, shapes)] for starts, shapes in zip(cumdims, chunks)]
135
+ return list(product(*slices))
136
+
137
+
138
+ def graph_from_arraylike(
139
+ arr, # Any array-like which supports slicing
140
+ chunks,
141
+ shape,
142
+ name,
143
+ getitem=None,
144
+ lock=False,
145
+ asarray=True,
146
+ dtype=None,
147
+ inline_array=False,
148
+ ):
149
+ """
150
+ Generate a graph for slicing chunks from an array-like.
151
+
152
+ Returns a dict-based graph (not HighLevelGraph) for use with expression system.
153
+ """
154
+ from dask._task_spec import TaskRef
155
+
156
+ if getitem is None:
157
+ getitem = getter
158
+
159
+ chunks = normalize_chunks(chunks, shape, dtype=dtype)
160
+
161
+ if has_keyword(getitem, "asarray") and has_keyword(getitem, "lock") and (not asarray or lock):
162
+ kwargs = {"asarray": asarray, "lock": lock}
163
+ else:
164
+ # Common case, drop extra parameters
165
+ kwargs = {}
166
+
167
+ if inline_array:
168
+ # Embed the array directly in each task
169
+ graph = {}
170
+ for idx, slc in zip(product(*[range(len(c)) for c in chunks]), slices_from_chunks(chunks)):
171
+ key = (name,) + idx
172
+ if kwargs:
173
+ graph[key] = (getitem, arr, slc, kwargs.get("asarray", True), kwargs.get("lock", None))
174
+ else:
175
+ graph[key] = (getitem, arr, slc)
176
+ return graph
177
+ else:
178
+ # Store array separately and reference it
179
+ original_name = f"original-{name}"
180
+ graph = {original_name: arr}
181
+ for idx, slc in zip(product(*[range(len(c)) for c in chunks]), slices_from_chunks(chunks)):
182
+ key = (name,) + idx
183
+ if kwargs:
184
+ graph[key] = (
185
+ getitem,
186
+ TaskRef(original_name),
187
+ slc,
188
+ kwargs.get("asarray", True),
189
+ kwargs.get("lock", None),
190
+ )
191
+ else:
192
+ graph[key] = (getitem, TaskRef(original_name), slc)
193
+ return graph
194
+
195
+
196
+ def _concatenate2(arrays, axes=None):
197
+ """Recursively concatenate nested lists of arrays along axes
198
+
199
+ Each entry in axes corresponds to each level of the nested list. The
200
+ length of axes should correspond to the level of nesting of arrays.
201
+ If axes is an empty list or tuple, return arrays, or arrays[0] if
202
+ arrays is a list.
203
+
204
+ >>> x = np.array([[1, 2], [3, 4]])
205
+ >>> _concatenate2([x, x], axes=[0])
206
+ array([[1, 2],
207
+ [3, 4],
208
+ [1, 2],
209
+ [3, 4]])
210
+
211
+ >>> _concatenate2([x, x], axes=[1])
212
+ array([[1, 2, 1, 2],
213
+ [3, 4, 3, 4]])
214
+
215
+ >>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
216
+ array([[1, 2, 1, 2],
217
+ [3, 4, 3, 4],
218
+ [1, 2, 1, 2],
219
+ [3, 4, 3, 4]])
220
+
221
+ Supports Iterators
222
+ >>> _concatenate2(iter([x, x]), axes=[1])
223
+ array([[1, 2, 1, 2],
224
+ [3, 4, 3, 4]])
225
+
226
+ Special Case
227
+ >>> _concatenate2([x, x], axes=())
228
+ array([[1, 2],
229
+ [3, 4]])
230
+ """
231
+ if axes is None:
232
+ axes = []
233
+
234
+ if axes == ():
235
+ if isinstance(arrays, list):
236
+ return arrays[0]
237
+ else:
238
+ return arrays
239
+
240
+ if isinstance(arrays, Iterator):
241
+ arrays = list(arrays)
242
+ if not isinstance(arrays, (list, tuple)):
243
+ return arrays
244
+ if len(axes) > 1:
245
+ arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
246
+ concatenate = concatenate_lookup.dispatch(type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0))))
247
+ if isinstance(arrays[0], dict):
248
+ # Handle concatenation of `dict`s, used as a replacement for structured
249
+ # arrays when that's not supported by the array library (e.g., CuPy).
250
+ keys = list(arrays[0].keys())
251
+ assert all(list(a.keys()) == keys for a in arrays)
252
+ ret = dict()
253
+ for k in keys:
254
+ ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
255
+ return ret
256
+ else:
257
+ return concatenate(arrays, axis=axes[0])
258
+
259
+
260
+ def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
261
+ """
262
+ Tries to infer output dtype of ``func`` for a small set of input arguments.
263
+
264
+ Parameters
265
+ ----------
266
+ func: Callable
267
+ Function for which output dtype is to be determined
268
+
269
+ args: List of array like
270
+ Arguments to the function, which would usually be used. Only attributes
271
+ ``ndim`` and ``dtype`` are used.
272
+
273
+ kwargs: dict
274
+ Additional ``kwargs`` to the ``func``
275
+
276
+ funcname: String
277
+ Name of calling function to improve potential error messages
278
+
279
+ suggest_dtype: None/False or String
280
+ If not ``None`` adds suggestion to potential error message to specify a dtype
281
+ via the specified kwarg. Defaults to ``'dtype'``.
282
+
283
+ nout: None or Int
284
+ ``None`` if function returns single output, integer if many.
285
+ Defaults to ``None``.
286
+
287
+ Returns
288
+ -------
289
+ : dtype or List of dtype
290
+ One or many dtypes (depending on ``nout``)
291
+ """
292
+ from dask_array._utils import meta_from_array
293
+
294
+ # make sure that every arg is an evaluated array
295
+ args = [
296
+ (np.zeros_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype) if is_arraylike(x) else x)
297
+ for x in args
298
+ ]
299
+ try:
300
+ with np.errstate(all="ignore"):
301
+ o = func(*args, **kwargs)
302
+ except Exception as e:
303
+ exc_type, exc_value, exc_traceback = sys.exc_info()
304
+ tb = "".join(traceback.format_tb(exc_traceback))
305
+ suggest = (
306
+ (f"Please specify the dtype explicitly using the `{suggest_dtype}` kwarg.\n\n") if suggest_dtype else ""
307
+ )
308
+ msg = (
309
+ f"`dtype` inference failed in `{funcname}`.\n\n"
310
+ f"{suggest}"
311
+ "Original error is below:\n"
312
+ "------------------------\n"
313
+ f"{e!r}\n\n"
314
+ "Traceback:\n"
315
+ "---------\n"
316
+ f"{tb}"
317
+ )
318
+ else:
319
+ msg = None
320
+ if msg is not None:
321
+ raise ValueError(msg)
322
+ return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
323
+
324
+
325
+ def normalize_arg(x):
326
+ """Normalize user provided arguments to blockwise or map_blocks
327
+
328
+ We do a few things:
329
+
330
+ 1. If they are string literals that might collide with blockwise_token then we
331
+ quote them
332
+ 2. IF they are large (as defined by sizeof) then we put them into the
333
+ graph on their own by using dask.delayed
334
+ """
335
+ import re
336
+
337
+ if is_dask_collection(x):
338
+ return x
339
+ elif isinstance(x, str) and re.match(r"_\d+", x):
340
+ return delayed(x)
341
+ elif isinstance(x, list) and len(x) >= 10:
342
+ return delayed(x)
343
+ elif sizeof(x) > 1e6:
344
+ return delayed(x)
345
+ else:
346
+ return x
347
+
348
+
349
+ def _pass_extra_kwargs(func, keys, *args, **kwargs):
350
+ """Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
351
+
352
+ For each element of `keys`, a corresponding element of args is changed
353
+ to a keyword argument with that key, before all arguments re passed on
354
+ to `func`.
355
+ """
356
+ kwargs.update(zip(keys, args))
357
+ return func(*args[len(keys) :], **kwargs)
358
+
359
+
360
+ def apply_and_enforce(*args, **kwargs):
361
+ """Apply a function, and enforce the output.ndim to match expected_ndim
362
+
363
+ Ensures the output has the expected dimensionality."""
364
+ func = kwargs.pop("_func")
365
+ expected_ndim = kwargs.pop("expected_ndim")
366
+ out = func(*args, **kwargs)
367
+ if getattr(out, "ndim", 0) != expected_ndim:
368
+ out_ndim = getattr(out, "ndim", 0)
369
+ raise ValueError(
370
+ f"Dimension mismatch: expected output of {func} to have dims = {expected_ndim}. Got {out_ndim} instead."
371
+ )
372
+ return out
373
+
374
+
375
+ def broadcast_chunks(*chunkss):
376
+ """Construct a chunks tuple that broadcasts many chunks tuples
377
+
378
+ >>> a = ((5, 5),)
379
+ >>> b = ((5, 5),)
380
+ >>> broadcast_chunks(a, b)
381
+ ((5, 5),)
382
+
383
+ >>> a = ((10, 10, 10), (5, 5),)
384
+ >>> b = ((5, 5),)
385
+ >>> broadcast_chunks(a, b)
386
+ ((10, 10, 10), (5, 5))
387
+
388
+ >>> a = ((10, 10, 10), (5, 5),)
389
+ >>> b = ((1,), (5, 5),)
390
+ >>> broadcast_chunks(a, b)
391
+ ((10, 10, 10), (5, 5))
392
+
393
+ >>> a = ((10, 10, 10), (5, 5),)
394
+ >>> b = ((3, 3,), (5, 5),)
395
+ >>> broadcast_chunks(a, b)
396
+ Traceback (most recent call last):
397
+ ...
398
+ ValueError: Chunks do not align: [(10, 10, 10), (3, 3)]
399
+ """
400
+ if not chunkss:
401
+ return ()
402
+ elif len(chunkss) == 1:
403
+ return chunkss[0]
404
+ n = max(map(len, chunkss))
405
+ chunkss2 = [((1,),) * (n - len(c)) + c for c in chunkss]
406
+ result = []
407
+ for i in range(n):
408
+ step1 = [c[i] for c in chunkss2]
409
+ if all(c == (1,) for c in step1):
410
+ step2 = step1
411
+ else:
412
+ step2 = [c for c in step1 if c != (1,)]
413
+ if len(set(step2)) != 1:
414
+ raise ValueError(f"Chunks do not align: {step2}")
415
+ result.append(step2[0])
416
+ return tuple(result)
417
+
418
+
419
+ CHUNKS_NONE_ERROR_MESSAGE = """
420
+ You must specify a chunks= keyword argument.
421
+ This specifies the chunksize of your array blocks.
422
+
423
+ See the following documentation page for details:
424
+ https://docs.dask.org/en/latest/array-creation.html#chunks
425
+ """.strip()
426
+
427
+
428
+ def blockdims_from_blockshape(shape, chunks):
429
+ """
430
+ >>> blockdims_from_blockshape((10, 10), (4, 3))
431
+ ((4, 4, 2), (3, 3, 3, 1))
432
+ >>> blockdims_from_blockshape((10, 0), (4, 0))
433
+ ((4, 4, 2), (0,))
434
+ """
435
+ if chunks is None:
436
+ raise TypeError("Must supply chunks= keyword argument")
437
+ if shape is None:
438
+ raise TypeError("Must supply shape= keyword argument")
439
+ if np.isnan(sum(shape)) or np.isnan(sum(chunks)):
440
+ raise ValueError(f"Array chunk sizes are unknown. shape: {shape}, chunks: {chunks}{unknown_chunk_message}")
441
+ if not all(map(is_integer, chunks)):
442
+ raise ValueError("chunks can only contain integers.")
443
+ if not all(map(is_integer, shape)):
444
+ raise ValueError("shape can only contain integers.")
445
+ shape = tuple(map(int, shape))
446
+ chunks = tuple(map(int, chunks))
447
+ return tuple(((bd,) * (d // bd) + ((d % bd,) if d % bd else ()) if d else (0,)) for d, bd in zip(shape, chunks))
448
+
449
+
450
+ def _convert_int_chunk_to_tuple(shape, chunks):
451
+ return sum(
452
+ (
453
+ (blockdims_from_blockshape((s,), (c,)) if not isinstance(c, (tuple, list)) else (c,))
454
+ for s, c in zip(shape, chunks)
455
+ ),
456
+ (),
457
+ )
458
+
459
+
460
+ def _compute_multiplier(limit: int, dtype, largest_block: int, result):
461
+ """
462
+ Utility function for auto_chunk, to find how much larger or smaller the ideal
463
+ chunk size is relative to what we have now.
464
+ """
465
+ return (
466
+ limit
467
+ / dtype.itemsize
468
+ / largest_block
469
+ / math.prod(max(r) if isinstance(r, tuple) else r for r in result.values() if r)
470
+ )
471
+
472
+
473
+ def round_to(c, s):
474
+ """Return a chunk dimension that is close to an even multiple or factor
475
+
476
+ We want values for c that are nicely aligned with s.
477
+
478
+ If c is smaller than s we use the original chunk size and accept an
479
+ uneven chunk at the end.
480
+
481
+ If c is larger than s then we want the largest multiple of s that is still
482
+ smaller than c.
483
+ """
484
+ if c <= s:
485
+ return max(1, int(c))
486
+ else:
487
+ return c // s * s
488
+
489
+
490
+ def auto_chunks(chunks, shape, limit, dtype, previous_chunks=None):
491
+ """Determine automatic chunks
492
+
493
+ This takes in a chunks value that contains ``"auto"`` values in certain
494
+ dimensions and replaces those values with concrete dimension sizes that try
495
+ to get chunks to be of a certain size in bytes, provided by the ``limit=``
496
+ keyword. If multiple dimensions are marked as ``"auto"`` then they will
497
+ all respond to meet the desired byte limit, trying to respect the aspect
498
+ ratio of their dimensions in ``previous_chunks=``, if given.
499
+
500
+ Parameters
501
+ ----------
502
+ chunks: Tuple
503
+ A tuple of either dimensions or tuples of explicit chunk dimensions
504
+ Some entries should be "auto"
505
+ shape: Tuple[int]
506
+ limit: int, str
507
+ The maximum allowable size of a chunk in bytes
508
+ previous_chunks: Tuple[Tuple[int]]
509
+
510
+ See also
511
+ --------
512
+ normalize_chunks: for full docstring and parameters
513
+ """
514
+ if previous_chunks is not None:
515
+ # rioxarray is passing ((1, ), (x,)) for shapes like (100, 5x),
516
+ # so add this compat code for now
517
+ # https://github.com/corteva/rioxarray/pull/820
518
+ previous_chunks = (c[0] if isinstance(c, tuple) and len(c) == 1 else c for c in previous_chunks)
519
+ previous_chunks = _convert_int_chunk_to_tuple(shape, previous_chunks)
520
+ chunks = list(chunks)
521
+
522
+ autos = {i for i, c in enumerate(chunks) if c == "auto"}
523
+ if not autos:
524
+ return tuple(chunks)
525
+
526
+ if limit is None:
527
+ limit = config.get("array.chunk-size")
528
+ if isinstance(limit, str):
529
+ limit = parse_bytes(limit)
530
+
531
+ if dtype is None:
532
+ raise TypeError("dtype must be known for auto-chunking")
533
+
534
+ if dtype.hasobject:
535
+ raise NotImplementedError(
536
+ "Can not use auto rechunking with object dtype. We are unable to estimate the size in bytes of object data"
537
+ )
538
+
539
+ for x in tuple(chunks) + tuple(shape):
540
+ if isinstance(x, Number) and np.isnan(x) or isinstance(x, tuple) and np.isnan(x).any():
541
+ raise ValueError(
542
+ f"Can not perform automatic rechunking with unknown (nan) chunk sizes.{unknown_chunk_message}"
543
+ )
544
+
545
+ limit = max(1, limit)
546
+ chunksize_tolerance = config.get("array.chunk-size-tolerance")
547
+
548
+ largest_block = math.prod(cs if isinstance(cs, Number) else max(cs) for cs in chunks if cs != "auto")
549
+
550
+ if previous_chunks:
551
+ # Base ideal ratio on the median chunk size of the previous chunks
552
+ median_chunks = {a: np.median(previous_chunks[a]) for a in autos}
553
+ result = {}
554
+
555
+ # How much larger or smaller the ideal chunk size is relative to what we have now
556
+ multiplier = _compute_multiplier(limit, dtype, largest_block, median_chunks)
557
+ if multiplier < 1:
558
+ # we want to update inplace, algorithm relies on it in this case
559
+ result = median_chunks
560
+
561
+ ideal_shape = []
562
+ for i, s in enumerate(shape):
563
+ chunk_frequencies = frequencies(previous_chunks[i])
564
+ mode, count = max(chunk_frequencies.items(), key=lambda kv: kv[1])
565
+ if mode > 1 and count >= len(previous_chunks[i]) / 2:
566
+ ideal_shape.append(mode)
567
+ else:
568
+ ideal_shape.append(s)
569
+
570
+ def _trivial_aggregate(a):
571
+ autos.remove(a)
572
+ del median_chunks[a]
573
+ return True
574
+
575
+ multiplier_remaining = True
576
+ reduce_case = multiplier < 1
577
+ while multiplier_remaining: # while things change
578
+ last_autos = set(autos) # record previous values
579
+ multiplier_remaining = False
580
+
581
+ # Expand or contract each of the dimensions appropriately
582
+ for a in sorted(autos):
583
+ this_multiplier = multiplier ** (1 / len(last_autos))
584
+
585
+ proposed = median_chunks[a] * this_multiplier
586
+ this_chunksize_tolerance = chunksize_tolerance ** (1 / len(last_autos))
587
+ max_chunk_size = proposed * this_chunksize_tolerance
588
+
589
+ if proposed > shape[a]: # we've hit the shape boundary
590
+ chunks[a] = shape[a]
591
+ multiplier_remaining = _trivial_aggregate(a)
592
+ largest_block *= shape[a]
593
+ result[a] = (shape[a],)
594
+ continue
595
+ elif reduce_case or max(previous_chunks[a]) > max_chunk_size:
596
+ result[a] = round_to(proposed, ideal_shape[a])
597
+ if proposed < 1:
598
+ multiplier_remaining = True
599
+ autos.discard(a)
600
+ continue
601
+ else:
602
+ dimension_result, new_chunk = [], 0
603
+ for c in previous_chunks[a]:
604
+ if c + new_chunk <= proposed:
605
+ # keep increasing the chunk
606
+ new_chunk += c
607
+ else:
608
+ # We reach the boundary so start a new chunk
609
+ if new_chunk > 0:
610
+ dimension_result.append(new_chunk)
611
+ new_chunk = c
612
+ if new_chunk > 0:
613
+ dimension_result.append(new_chunk)
614
+
615
+ result[a] = tuple(dimension_result)
616
+
617
+ # recompute how much multiplier we have left, repeat
618
+ if multiplier_remaining or reduce_case:
619
+ last_multiplier = multiplier
620
+ multiplier = _compute_multiplier(limit, dtype, largest_block, median_chunks)
621
+ if multiplier != last_multiplier:
622
+ multiplier_remaining = True
623
+
624
+ for k, v in result.items():
625
+ chunks[k] = v if v else 0
626
+ return tuple(chunks)
627
+
628
+ else:
629
+ # Check if dtype.itemsize is greater than 0
630
+ if dtype.itemsize == 0:
631
+ raise ValueError(
632
+ "auto-chunking with dtype.itemsize == 0 is not supported, please pass in `chunks` explicitly"
633
+ )
634
+ size = (limit / dtype.itemsize / largest_block) ** (1 / len(autos))
635
+ small = [i for i in autos if shape[i] < size]
636
+ if small:
637
+ for i in small:
638
+ chunks[i] = (shape[i],)
639
+ return auto_chunks(chunks, shape, limit, dtype)
640
+
641
+ for i in autos:
642
+ chunks[i] = round_to(size, shape[i])
643
+
644
+ return tuple(chunks)
645
+
646
+
647
+ @functools.lru_cache
648
+ def normalize_chunks_cached(chunks, shape=None, limit=None, dtype=None, previous_chunks=None):
649
+ """Cached version of normalize_chunks.
650
+
651
+ .. note::
652
+
653
+ chunks and previous_chunks are expected to be hashable. Dicts and lists aren't
654
+ allowed for this function.
655
+
656
+ See :func:`normalize_chunks` for further documentation.
657
+ """
658
+ return normalize_chunks(chunks, shape=shape, limit=limit, dtype=dtype, previous_chunks=previous_chunks)
659
+
660
+
661
+ def normalize_chunks(chunks, shape=None, limit=None, dtype=None, previous_chunks=None):
662
+ """Normalize chunks to tuple of tuples
663
+
664
+ This takes in a variety of input types and information and produces a full
665
+ tuple-of-tuples result for chunks, suitable to be passed to Array or
666
+ rechunk or any other operation that creates a Dask array.
667
+
668
+ Parameters
669
+ ----------
670
+ chunks: tuple, int, dict, or string
671
+ The chunks to be normalized. See examples below for more details
672
+ shape: Tuple[int]
673
+ The shape of the array
674
+ limit: int (optional)
675
+ The maximum block size to target in bytes,
676
+ if freedom is given to choose
677
+ dtype: np.dtype
678
+ previous_chunks: Tuple[Tuple[int]] optional
679
+ Chunks from a previous array that we should use for inspiration when
680
+ rechunking auto dimensions. If not provided but auto-chunking exists
681
+ then auto-dimensions will prefer square-like chunk shapes.
682
+
683
+ Examples
684
+ --------
685
+ Fully explicit tuple-of-tuples
686
+
687
+ >>> normalize_chunks(((2, 2, 1), (2, 2, 2)), shape=(5, 6))
688
+ ((2, 2, 1), (2, 2, 2))
689
+
690
+ Specify uniform chunk sizes
691
+
692
+ >>> normalize_chunks((2, 2), shape=(5, 6))
693
+ ((2, 2, 1), (2, 2, 2))
694
+
695
+ Cleans up missing outer tuple
696
+
697
+ >>> normalize_chunks((3, 2), (5,))
698
+ ((3, 2),)
699
+
700
+ Cleans up lists to tuples
701
+
702
+ >>> normalize_chunks([[2, 2], [3, 3]])
703
+ ((2, 2), (3, 3))
704
+
705
+ Expands integer inputs 10 -> (10, 10)
706
+
707
+ >>> normalize_chunks(10, shape=(30, 5))
708
+ ((10, 10, 10), (5,))
709
+
710
+ Expands dict inputs
711
+
712
+ >>> normalize_chunks({0: 2, 1: 3}, shape=(6, 6))
713
+ ((2, 2, 2), (3, 3))
714
+
715
+ The values -1 and None get mapped to full size
716
+
717
+ >>> normalize_chunks((5, -1), shape=(10, 10))
718
+ ((5, 5), (10,))
719
+ >>> normalize_chunks((5, None), shape=(10, 10))
720
+ ((5, 5), (10,))
721
+
722
+ Use the value "auto" to automatically determine chunk sizes along certain
723
+ dimensions. This uses the ``limit=`` and ``dtype=`` keywords to
724
+ determine how large to make the chunks. The term "auto" can be used
725
+ anywhere an integer can be used. See array chunking documentation for more
726
+ information.
727
+
728
+ >>> normalize_chunks(("auto",), shape=(20,), limit=5, dtype='uint8')
729
+ ((5, 5, 5, 5),)
730
+ >>> normalize_chunks("auto", (2, 3), dtype=np.int32)
731
+ ((2,), (3,))
732
+
733
+ You can also use byte sizes (see :func:`dask.utils.parse_bytes`) in place of
734
+ "auto" to ask for a particular size
735
+
736
+ >>> normalize_chunks("1kiB", shape=(2000,), dtype='float32')
737
+ ((256, 256, 256, 256, 256, 256, 256, 208),)
738
+
739
+ Respects null dimensions
740
+
741
+ >>> normalize_chunks(())
742
+ ()
743
+ >>> normalize_chunks((), ())
744
+ ()
745
+ >>> normalize_chunks((1,), ())
746
+ ()
747
+ >>> normalize_chunks((), shape=(0, 0))
748
+ ((0,), (0,))
749
+
750
+ Handles NaNs
751
+
752
+ >>> normalize_chunks((1, (np.nan,)), (1, np.nan))
753
+ ((1,), (nan,))
754
+ """
755
+ if dtype and not isinstance(dtype, np.dtype):
756
+ dtype = np.dtype(dtype)
757
+ if chunks is None:
758
+ raise ValueError(CHUNKS_NONE_ERROR_MESSAGE)
759
+ if isinstance(chunks, list):
760
+ chunks = tuple(chunks)
761
+ if isinstance(chunks, (Number, str)):
762
+ chunks = (chunks,) * len(shape)
763
+ if isinstance(chunks, dict):
764
+ chunks = tuple(chunks.get(i, None) for i in range(len(shape)))
765
+ if isinstance(chunks, np.ndarray):
766
+ chunks = chunks.tolist()
767
+ if not chunks and shape and all(s == 0 for s in shape):
768
+ chunks = ((0,),) * len(shape)
769
+
770
+ if shape and len(shape) == 1 and len(chunks) > 1 and all(isinstance(c, (Number, str)) for c in chunks):
771
+ chunks = (chunks,)
772
+
773
+ if shape and len(chunks) != len(shape):
774
+ raise ValueError(f"Chunks and shape must be of the same length/dimension. Got chunks={chunks}, shape={shape}")
775
+ if -1 in chunks or None in chunks:
776
+ chunks = tuple(s if c == -1 or c is None else c for c, s in zip(chunks, shape))
777
+
778
+ # If specifying chunk size in bytes, use that value to set the limit.
779
+ # Verify there is only one consistent value of limit or chunk-bytes used.
780
+ for c in chunks:
781
+ if isinstance(c, str) and c != "auto":
782
+ parsed = parse_bytes(c)
783
+ if limit is None:
784
+ limit = parsed
785
+ elif parsed != limit:
786
+ raise ValueError(f"Only one consistent value of limit or chunk is allowed.Used {parsed} != {limit}")
787
+ # Substitute byte limits with 'auto' now that limit is set.
788
+ chunks = tuple("auto" if isinstance(c, str) and c != "auto" else c for c in chunks)
789
+
790
+ if any(c == "auto" for c in chunks):
791
+ chunks = auto_chunks(chunks, shape, limit, dtype, previous_chunks)
792
+
793
+ allints = None
794
+ if chunks and shape is not None:
795
+ # allints: did we start with chunks as a simple tuple of ints?
796
+ allints = all(isinstance(c, int) for c in chunks)
797
+ chunks = _convert_int_chunk_to_tuple(shape, chunks)
798
+ for c in chunks:
799
+ if not c:
800
+ raise ValueError(
801
+ "Empty tuples are not allowed in chunks. Express zero length dimensions with 0(s) in chunks"
802
+ )
803
+
804
+ if not allints and shape is not None:
805
+ if not all(c == s or (math.isnan(c) or math.isnan(s)) for c, s in zip(map(sum, chunks), shape)):
806
+ raise ValueError(f"Chunks do not add up to shape. Got chunks={chunks}, shape={shape}")
807
+ if allints or isinstance(sum(sum(_) for _ in chunks), int):
808
+ # Fastpath for when we already know chunks contains only integers
809
+ return tuple(tuple(ch) for ch in chunks)
810
+ return tuple(tuple(int(x) if not math.isnan(x) else np.nan for x in c) for c in chunks)
811
+
812
+
813
+ def common_blockdim(blockdims):
814
+ """Find the common block dimensions from the list of block dimensions
815
+
816
+ Currently only implements the simplest possible heuristic: the common
817
+ block-dimension is the only one that does not span fully span a dimension.
818
+ This is a conservative choice that allows us to avoid potentially very
819
+ expensive rechunking.
820
+
821
+ Assumes that each element of the input block dimensions has all the same
822
+ sum (i.e., that they correspond to dimensions of the same size).
823
+
824
+ Examples
825
+ --------
826
+ >>> common_blockdim([(3,), (2, 1)])
827
+ (2, 1)
828
+ >>> common_blockdim([(1, 2), (2, 1)])
829
+ (1, 1, 1)
830
+ >>> common_blockdim([(2, 2), (3, 1)]) # doctest: +SKIP
831
+ Traceback (most recent call last):
832
+ ...
833
+ ValueError: Chunks do not align
834
+ """
835
+ if not any(blockdims):
836
+ return ()
837
+ non_trivial_dims = {d for d in blockdims if len(d) > 1}
838
+ if len(non_trivial_dims) == 1:
839
+ return first(non_trivial_dims)
840
+ if len(non_trivial_dims) == 0:
841
+ return max(blockdims, key=first)
842
+
843
+ if np.isnan(sum(map(sum, blockdims))):
844
+ raise ValueError(
845
+ f"Arrays' chunk sizes ({blockdims}) are unknown.\n\nA possible solution:\n x.compute_chunk_sizes()"
846
+ )
847
+
848
+ if len(set(map(sum, non_trivial_dims))) > 1:
849
+ raise ValueError("Chunks do not add up to same value", blockdims)
850
+
851
+ # We have multiple non-trivial chunks on this axis
852
+ # e.g. (5, 2) and (4, 3)
853
+
854
+ # We create a single chunk tuple with the same total length
855
+ # that evenly divides both, e.g. (4, 1, 2)
856
+
857
+ # To accomplish this we walk down all chunk tuples together, finding the
858
+ # smallest element, adding it to the output, and subtracting it from all
859
+ # other elements and remove the element itself. We stop once we have
860
+ # burned through all of the chunk tuples.
861
+ # For efficiency's sake we reverse the lists so that we can pop off the end
862
+ rchunks = [list(ntd)[::-1] for ntd in non_trivial_dims]
863
+ total = sum(first(non_trivial_dims))
864
+ i = 0
865
+
866
+ out = []
867
+ while i < total:
868
+ m = min(c[-1] for c in rchunks)
869
+ out.append(m)
870
+ for c in rchunks:
871
+ c[-1] -= m
872
+ if c[-1] == 0:
873
+ c.pop()
874
+ i += m
875
+
876
+ return tuple(out)
877
+
878
+
879
+ def is_scalar_for_elemwise(arg):
880
+ """
881
+ >>> is_scalar_for_elemwise(42)
882
+ True
883
+ >>> is_scalar_for_elemwise('foo')
884
+ True
885
+ >>> is_scalar_for_elemwise(True)
886
+ True
887
+ >>> is_scalar_for_elemwise(np.array(42))
888
+ True
889
+ >>> is_scalar_for_elemwise([1, 2, 3])
890
+ True
891
+ >>> is_scalar_for_elemwise(np.array([1, 2, 3]))
892
+ False
893
+ """
894
+ # the second half of shape_condition is essentially just to ensure that
895
+ # dask series / frame are treated as scalars in elemwise.
896
+ maybe_shape = getattr(arg, "shape", None)
897
+ shape_condition = not isinstance(maybe_shape, Iterable) or any(is_dask_collection(x) for x in maybe_shape)
898
+
899
+ return (
900
+ np.isscalar(arg)
901
+ or shape_condition
902
+ or isinstance(arg, np.dtype)
903
+ or (isinstance(arg, np.ndarray) and arg.ndim == 0)
904
+ )
905
+
906
+
907
+ def broadcast_shapes(*shapes):
908
+ """
909
+ Determines output shape from broadcasting arrays.
910
+
911
+ Parameters
912
+ ----------
913
+ shapes : tuples
914
+ The shapes of the arguments.
915
+
916
+ Returns
917
+ -------
918
+ output_shape : tuple
919
+
920
+ Raises
921
+ ------
922
+ ValueError
923
+ If the input shapes cannot be successfully broadcast together.
924
+ """
925
+ if len(shapes) == 1:
926
+ return shapes[0]
927
+ out = []
928
+ for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
929
+ has_nan = np.isnan(sizes).any()
930
+ # Filter out -1 (missing dims), 0 and 1 (broadcastable), and nan
931
+ non_trivial = [s for s in sizes if s not in (-1, 0, 1) and not np.isnan(s)]
932
+
933
+ if has_nan:
934
+ # If any nan, output is nan but we still validate non-nan values
935
+ dim = np.nan
936
+ # All non-trivial sizes must match each other
937
+ if len(set(non_trivial)) > 1:
938
+ raise ValueError(
939
+ "operands could not be broadcast together with shapes {}".format(" ".join(map(str, shapes)))
940
+ )
941
+ else:
942
+ dim = 0 if 0 in sizes else np.max(sizes).item()
943
+ if any(i not in [-1, 0, 1, dim] for i in sizes):
944
+ raise ValueError(
945
+ "operands could not be broadcast together with shapes {}".format(" ".join(map(str, shapes)))
946
+ )
947
+ out.append(dim)
948
+ return tuple(reversed(out))
949
+
950
+
951
+ def _elemwise_handle_where(*args, **kwargs):
952
+ function = kwargs.pop("elemwise_where_function")
953
+ *args, where, out = args
954
+ if hasattr(out, "copy"):
955
+ out = out.copy()
956
+ return function(*args, where=where, out=out, **kwargs)
957
+
958
+
959
+ def handle_out(out, result):
960
+ """Handle out parameters
961
+
962
+ If out is a dask.array then this overwrites the contents of that array with
963
+ the result
964
+ """
965
+ from dask_array._collection import Array
966
+
967
+ if isinstance(out, tuple):
968
+ if len(out) == 1:
969
+ out = out[0]
970
+ elif len(out) > 1:
971
+ raise NotImplementedError("The out parameter is not fully supported")
972
+ else:
973
+ out = None
974
+ if not (out is None or isinstance(out, Array)):
975
+ raise NotImplementedError(
976
+ f"The out parameter is not fully supported. Received type {type(out).__name__}, expected Dask Array"
977
+ )
978
+ if isinstance(out, Array):
979
+ if out.shape != result.shape:
980
+ raise ValueError(
981
+ f"Mismatched shapes between result and out parameter. out={out.shape}, result={result.shape}"
982
+ )
983
+ # For expression-based arrays, we need to update the expression
984
+ out._expr = result._expr
985
+ return out
986
+ else:
987
+ return result
988
+
989
+
990
+ def _enforce_dtype(*args, **kwargs):
991
+ """Calls a function and converts its result to the given dtype.
992
+
993
+ The parameters have deliberately been given unwieldy names to avoid
994
+ clashes with keyword arguments consumed by blockwise
995
+
996
+ A dtype of `object` is treated as a special case and not enforced,
997
+ because it is used as a dummy value in some places when the result will
998
+ not be a block in an Array.
999
+
1000
+ Parameters
1001
+ ----------
1002
+ enforce_dtype : dtype
1003
+ Result dtype
1004
+ enforce_dtype_function : callable
1005
+ The wrapped function, which will be passed the remaining arguments
1006
+ """
1007
+ dtype = kwargs.pop("enforce_dtype")
1008
+ function = kwargs.pop("enforce_dtype_function")
1009
+
1010
+ result = function(*args, **kwargs)
1011
+ if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
1012
+ if not np.can_cast(result, dtype, casting="same_kind"):
1013
+ raise ValueError(
1014
+ f"Inferred dtype from function {funcname(function)!r} was {str(dtype)!r} "
1015
+ f"but got {str(result.dtype)!r}, which can't be cast using "
1016
+ "casting='same_kind'"
1017
+ )
1018
+ if np.isscalar(result):
1019
+ # scalar astype method doesn't take the keyword arguments, so
1020
+ # have to convert via 0-dimensional array and back.
1021
+ result = result.astype(dtype)
1022
+ else:
1023
+ try:
1024
+ result = result.astype(dtype, copy=False)
1025
+ except TypeError:
1026
+ # Missing copy kwarg
1027
+ result = result.astype(dtype)
1028
+ return result
1029
+
1030
+
1031
+ def unpack_singleton(x):
1032
+ """
1033
+ >>> unpack_singleton([[[[1]]]])
1034
+ 1
1035
+ >>> unpack_singleton(np.array(np.datetime64('2000-01-01')))
1036
+ array('2000-01-01', dtype='datetime64[D]')
1037
+ """
1038
+ while isinstance(x, (list, tuple)):
1039
+ try:
1040
+ x = x[0]
1041
+ except (IndexError, TypeError, KeyError):
1042
+ break
1043
+ return x
1044
+
1045
+
1046
+ def deepfirst(seq):
1047
+ """First element in a nested list
1048
+
1049
+ >>> deepfirst([[[1, 2], [3, 4]], [5, 6], [7, 8]])
1050
+ 1
1051
+ """
1052
+ if not isinstance(seq, (list, tuple)):
1053
+ return seq
1054
+ else:
1055
+ return deepfirst(seq[0])
1056
+
1057
+
1058
+ def chunks_from_arrays(arrays):
1059
+ """Chunks tuple from nested list of arrays
1060
+
1061
+ >>> x = np.array([1, 2])
1062
+ >>> chunks_from_arrays([x, x])
1063
+ ((2, 2),)
1064
+
1065
+ >>> x = np.array([[1, 2]])
1066
+ >>> chunks_from_arrays([[x], [x]])
1067
+ ((1, 1), (2,))
1068
+
1069
+ >>> x = np.array([[1, 2]])
1070
+ >>> chunks_from_arrays([[x, x]])
1071
+ ((1,), (2, 2))
1072
+
1073
+ >>> chunks_from_arrays([1, 1])
1074
+ ((1, 1),)
1075
+ """
1076
+ if not arrays:
1077
+ return ()
1078
+ result = []
1079
+ dim = 0
1080
+
1081
+ def shape(x):
1082
+ try:
1083
+ return x.shape if x.shape else (1,)
1084
+ except AttributeError:
1085
+ return (1,)
1086
+
1087
+ while isinstance(arrays, (list, tuple)):
1088
+ result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
1089
+ arrays = arrays[0]
1090
+ dim += 1
1091
+ return tuple(result)
1092
+
1093
+
1094
+ def concatenate3(arrays):
1095
+ """Recursive np.concatenate
1096
+
1097
+ Input should be a nested list of numpy arrays arranged in the order they
1098
+ should appear in the array itself. Each array should have the same number
1099
+ of dimensions as the desired output and the nesting of the lists.
1100
+
1101
+ >>> x = np.array([[1, 2]])
1102
+ >>> concatenate3([[x, x, x], [x, x, x]])
1103
+ array([[1, 2, 1, 2, 1, 2],
1104
+ [1, 2, 1, 2, 1, 2]])
1105
+
1106
+ >>> concatenate3([[x, x], [x, x], [x, x]])
1107
+ array([[1, 2, 1, 2],
1108
+ [1, 2, 1, 2],
1109
+ [1, 2, 1, 2]])
1110
+ """
1111
+ from dask import core as dask_core
1112
+
1113
+ # We need this as __array_function__ may not exist on older NumPy versions.
1114
+ # And to reduce verbosity.
1115
+ NDARRAY_ARRAY_FUNCTION = getattr(np.ndarray, "__array_function__", None)
1116
+
1117
+ arrays = concrete(arrays)
1118
+ if not arrays or all(el is None for el in flatten(arrays)):
1119
+ return np.empty(0)
1120
+
1121
+ advanced = max(
1122
+ dask_core.flatten(arrays, container=(list, tuple)),
1123
+ key=lambda x: getattr(x, "__array_priority__", 0),
1124
+ )
1125
+
1126
+ if not all(
1127
+ NDARRAY_ARRAY_FUNCTION is getattr(type(arr), "__array_function__", NDARRAY_ARRAY_FUNCTION)
1128
+ for arr in dask_core.flatten(arrays, container=(list, tuple))
1129
+ ):
1130
+ try:
1131
+ x = unpack_singleton(arrays)
1132
+ return _concatenate2(arrays, axes=tuple(range(x.ndim)))
1133
+ except TypeError:
1134
+ pass
1135
+
1136
+ if concatenate_lookup.dispatch(type(advanced)) is not np.concatenate:
1137
+ x = unpack_singleton(arrays)
1138
+ return _concatenate2(arrays, axes=list(range(x.ndim)))
1139
+
1140
+ ndim = ndimlist(arrays)
1141
+ if not ndim:
1142
+ return arrays
1143
+ chunks = chunks_from_arrays(arrays)
1144
+ shape = tuple(map(sum, chunks))
1145
+
1146
+ def dtype(x):
1147
+ try:
1148
+ return x.dtype
1149
+ except AttributeError:
1150
+ return type(x)
1151
+
1152
+ result = np.empty(shape=shape, dtype=dtype(deepfirst(arrays)))
1153
+
1154
+ for idx, arr in zip(slices_from_chunks(chunks), dask_core.flatten(arrays, container=(list, tuple))):
1155
+ if hasattr(arr, "ndim"):
1156
+ while arr.ndim < ndim:
1157
+ arr = arr[None, ...]
1158
+ result[idx] = arr
1159
+
1160
+ return result
1161
+
1162
+
1163
+ # Register numpy concatenate as default
1164
+ concatenate_lookup.register(np.ndarray, np.concatenate)
1165
+ concatenate_lookup.register(object, np.concatenate)
1166
+
1167
+ # Register numpy tensordot as default
1168
+ tensordot_lookup.register(np.ndarray, np.tensordot)
1169
+ tensordot_lookup.register(object, np.tensordot)
1170
+
1171
+
1172
+ # Vindex helper functions
1173
+
1174
+
1175
+ def _get_axis(indexes):
1176
+ """Get axis along which point-wise slicing results lie
1177
+
1178
+ This is mostly a hack because I can't figure out NumPy's rule on this and
1179
+ can't be bothered to go reading.
1180
+
1181
+ >>> _get_axis([[1, 2], None, [1, 2], None])
1182
+ 0
1183
+ >>> _get_axis([None, [1, 2], [1, 2], None])
1184
+ 1
1185
+ >>> _get_axis([None, None, [1, 2], [1, 2]])
1186
+ 2
1187
+ """
1188
+ ndim = len(indexes)
1189
+ indexes = [slice(None, None) if i is None else [0] for i in indexes]
1190
+ x = np.empty((2,) * ndim)
1191
+ x2 = x[tuple(indexes)]
1192
+ return x2.shape.index(1)
1193
+
1194
+
1195
+ def _vindex_merge(locations, values):
1196
+ """
1197
+
1198
+ >>> locations = [0], [2, 1]
1199
+ >>> values = [np.array([[1, 2, 3]]),
1200
+ ... np.array([[10, 20, 30], [40, 50, 60]])]
1201
+
1202
+ >>> _vindex_merge(locations, values)
1203
+ array([[ 1, 2, 3],
1204
+ [40, 50, 60],
1205
+ [10, 20, 30]])
1206
+ """
1207
+ locations = list(map(list, locations))
1208
+ values = list(values)
1209
+
1210
+ n = sum(map(len, locations))
1211
+
1212
+ shape = list(values[0].shape)
1213
+ shape[0] = n
1214
+ shape = tuple(shape)
1215
+
1216
+ dtype = values[0].dtype
1217
+
1218
+ x = np.empty_like(values[0], dtype=dtype, shape=shape)
1219
+
1220
+ ind = [slice(None, None) for i in range(x.ndim)]
1221
+ for loc, val in zip(locations, values):
1222
+ ind[0] = loc
1223
+ x[tuple(ind)] = val
1224
+
1225
+ return x
1226
+
1227
+
1228
+ def _vindex_slice_and_transpose(block, points, axis):
1229
+ """Pull out point-wise slices from block and rotate block so that
1230
+ points are on the first dimension"""
1231
+ points = [p if isinstance(p, slice) else list(p) for p in points]
1232
+ block = block[tuple(points)]
1233
+ axes = [axis] + list(range(axis)) + list(range(axis + 1, block.ndim))
1234
+ return block.transpose(axes)
1235
+
1236
+
1237
+ def interleave_none(a, b):
1238
+ """
1239
+
1240
+ >>> interleave_none([0, None, 2, None], [1, 3])
1241
+ (0, 1, 2, 3)
1242
+ """
1243
+ result = []
1244
+ i = j = 0
1245
+ n = len(a) + len(b)
1246
+ while i + j < n:
1247
+ if a[i] is not None:
1248
+ result.append(a[i])
1249
+ i += 1
1250
+ else:
1251
+ result.append(b[j])
1252
+ i += 1
1253
+ j += 1
1254
+ return tuple(result)
1255
+
1256
+
1257
+ def keyname(name, i, okey):
1258
+ """
1259
+
1260
+ >>> keyname('x', 3, [None, None, 0, 2])
1261
+ ('x', 3, 0, 2)
1262
+ """
1263
+ return (name, i) + tuple(k for k in okey if k is not None)
1264
+
1265
+
1266
+ # __array_function__ dict for mapping aliases and mismatching names
1267
+ _HANDLED_FUNCTIONS = {}
1268
+
1269
+
1270
+ def implements(*numpy_functions):
1271
+ """Register an __array_function__ implementation for dask.array.Array
1272
+
1273
+ Register that a function implements the API of a NumPy function (or several
1274
+ NumPy functions in case of aliases) which is handled with
1275
+ ``__array_function__``.
1276
+
1277
+ Parameters
1278
+ ----------
1279
+ \\*numpy_functions : callables
1280
+ One or more NumPy functions that are handled by ``__array_function__``
1281
+ and will be mapped by `implements` to a `dask.array` function.
1282
+ """
1283
+
1284
+ def decorator(dask_func):
1285
+ for numpy_function in numpy_functions:
1286
+ _HANDLED_FUNCTIONS[numpy_function] = dask_func
1287
+
1288
+ return dask_func
1289
+
1290
+ return decorator
1291
+
1292
+
1293
+ def _should_delegate(self, other) -> bool:
1294
+ """Check whether Dask should delegate to the other.
1295
+ This implementation follows NEP-13:
1296
+ https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
1297
+ """
1298
+ from dask_array._chunk_types import is_valid_array_chunk
1299
+
1300
+ if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
1301
+ return True
1302
+ elif (
1303
+ hasattr(other, "__array_ufunc__")
1304
+ and not is_valid_array_chunk(other)
1305
+ # don't delegate to our own parent classes
1306
+ and not isinstance(self, type(other))
1307
+ and type(self) is not type(other)
1308
+ ):
1309
+ return True
1310
+ elif (
1311
+ not hasattr(other, "__array_ufunc__")
1312
+ and hasattr(other, "__array_priority__")
1313
+ and other.__array_priority__ > self.__array_priority__
1314
+ ):
1315
+ return True
1316
+ return False
1317
+
1318
+
1319
+ def check_if_handled_given_other(f):
1320
+ """Check if method is handled by Dask given type of other
1321
+
1322
+ Ensures proper deferral to upcast types in dunder operations without
1323
+ assuming unknown types are automatically downcast types.
1324
+ """
1325
+ from functools import wraps
1326
+
1327
+ @wraps(f)
1328
+ def wrapper(self, other):
1329
+ if _should_delegate(self, other):
1330
+ return NotImplemented
1331
+ else:
1332
+ return f(self, other)
1333
+
1334
+ return wrapper
1335
+
1336
+
1337
+ def finalize(results):
1338
+ """Finalize results from a dask array computation.
1339
+
1340
+ Concatenates results if multiple chunks, otherwise returns a copy.
1341
+ """
1342
+ if not results:
1343
+ return concatenate3(results)
1344
+ results2 = results
1345
+ while isinstance(results2, (tuple, list)):
1346
+ if len(results2) > 1:
1347
+ return concatenate3(results)
1348
+ else:
1349
+ results2 = results2[0]
1350
+
1351
+ results = unpack_singleton(results)
1352
+ # Single chunk. There is a risk that the result holds a buffer stored in the
1353
+ # graph or on a process-local Worker. Deep copy to make sure that nothing can
1354
+ # accidentally write back to it.
1355
+ try:
1356
+ return results.copy() # numpy, sparse, scipy.sparse (any version)
1357
+ except AttributeError:
1358
+ # Not an Array API object
1359
+ return results
1360
+
1361
+
1362
+ def _get_chunk_shape(a):
1363
+ """Get chunk shape as an array suitable for stacking."""
1364
+ s = np.asarray(a.shape, dtype=int)
1365
+ return s[len(s) * (None,) + (slice(None),)]