dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,270 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from dask_array._collection import asarray, concatenate
6
+ from dask_array._utils import meta_from_array
7
+ from dask.utils import derived_from
8
+
9
+ from ._utils import (
10
+ expand_pad_value,
11
+ get_pad_shapes_chunks,
12
+ linear_ramp_chunk,
13
+ wrapped_pad_func,
14
+ )
15
+
16
+
17
+ def _pad_reuse_expr(array, pad_width, mode, **kwargs):
18
+ """
19
+ Helper function for padding boundaries with values in the array.
20
+
21
+ Handles the cases where the padding is constructed from values in
22
+ the array. Namely by reflecting them or tiling them to create periodic
23
+ boundary constraints.
24
+ """
25
+ from dask_array._collection import block
26
+ from dask_array.manipulation._flip import flip
27
+
28
+ if mode in {"reflect", "symmetric"}:
29
+ reflect_type = kwargs.get("reflect", "even")
30
+ if reflect_type == "odd":
31
+ raise NotImplementedError("`pad` does not support `reflect_type` of `odd`.")
32
+ if reflect_type != "even":
33
+ raise ValueError("unsupported value for reflect_type, must be one of (`even`, `odd`)")
34
+
35
+ result = np.empty(array.ndim * (3,), dtype=object)
36
+ for idx in np.ndindex(result.shape):
37
+ select = []
38
+ flip_axes = []
39
+ for axis, (i, s, pw) in enumerate(zip(idx, array.shape, pad_width)):
40
+ if mode == "wrap":
41
+ pw = pw[::-1]
42
+
43
+ if i < 1:
44
+ if mode == "reflect":
45
+ select.append(slice(1, pw[0] + 1, None))
46
+ else:
47
+ select.append(slice(None, pw[0], None))
48
+ elif i > 1:
49
+ if mode == "reflect":
50
+ select.append(slice(s - pw[1] - 1, s - 1, None))
51
+ else:
52
+ select.append(slice(s - pw[1], None, None))
53
+ else:
54
+ select.append(slice(None))
55
+
56
+ if i != 1 and mode in ["reflect", "symmetric"]:
57
+ flip_axes.append(axis)
58
+
59
+ select = tuple(select)
60
+
61
+ if mode == "wrap":
62
+ idx = tuple(2 - i for i in idx)
63
+
64
+ chunk = array[select]
65
+ # Apply flips for each axis that needs reversal
66
+ for axis in flip_axes:
67
+ chunk = flip(chunk, axis)
68
+ result[idx] = chunk
69
+
70
+ result = block(result.tolist())
71
+
72
+ return result
73
+
74
+
75
+ def _pad_stats_expr(array, pad_width, mode, stat_length):
76
+ """
77
+ Helper function for padding boundaries with statistics from the array.
78
+
79
+ In cases where the padding requires computations of statistics from part
80
+ or all of the array, this function helps compute those statistics as
81
+ requested and then adds those statistics onto the boundaries of the array.
82
+ """
83
+ from dask_array._collection import block, broadcast_to
84
+ from dask_array._ufunc import rint
85
+
86
+ if mode == "median":
87
+ raise NotImplementedError("`pad` does not support `mode` of `median`.")
88
+
89
+ stat_length = expand_pad_value(array, stat_length)
90
+
91
+ result = np.empty(array.ndim * (3,), dtype=object)
92
+ for idx in np.ndindex(result.shape):
93
+ axes = []
94
+ select = []
95
+ pad_shape = []
96
+ pad_chunks = []
97
+ for d, (i, s, c, w, l) in enumerate(zip(idx, array.shape, array.chunks, pad_width, stat_length)):
98
+ if i < 1:
99
+ axes.append(d)
100
+ select.append(slice(None, l[0], None))
101
+ pad_shape.append(w[0])
102
+ pad_chunks.append(w[0])
103
+ elif i > 1:
104
+ axes.append(d)
105
+ select.append(slice(s - l[1], None, None))
106
+ pad_shape.append(w[1])
107
+ pad_chunks.append(w[1])
108
+ else:
109
+ select.append(slice(None))
110
+ pad_shape.append(s)
111
+ pad_chunks.append(c)
112
+
113
+ axes = tuple(axes)
114
+ select = tuple(select)
115
+ pad_shape = tuple(pad_shape)
116
+ pad_chunks = tuple(pad_chunks)
117
+
118
+ result_idx = array[select]
119
+ if axes:
120
+ stat_funcs = {"maximum": "max", "mean": "mean", "minimum": "min"}
121
+ result_idx = getattr(result_idx, stat_funcs[mode])(axis=axes, keepdims=True)
122
+ result_idx = broadcast_to(result_idx, pad_shape, chunks=pad_chunks)
123
+
124
+ if mode == "mean":
125
+ if np.issubdtype(array.dtype, np.integer):
126
+ result_idx = rint(result_idx)
127
+ result_idx = result_idx.astype(array.dtype)
128
+
129
+ result[idx] = result_idx
130
+
131
+ result = block(result.tolist())
132
+
133
+ return result
134
+
135
+
136
+ def _pad_udf_expr(array, pad_width, mode, **kwargs):
137
+ """
138
+ Helper function for padding boundaries with a user defined function.
139
+
140
+ In cases where the padding requires a custom user defined function be
141
+ applied to the array, this function assists in the prepping and
142
+ application of this function to the Dask Array to construct the desired
143
+ boundaries.
144
+ """
145
+ result = _pad_edge_expr(array, pad_width, "constant", constant_values=0)
146
+
147
+ chunks = result.chunks
148
+ for d in range(result.ndim):
149
+ result = result.rechunk(chunks[:d] + (result.shape[d : d + 1],) + chunks[d + 1 :])
150
+
151
+ result = result.map_blocks(
152
+ wrapped_pad_func,
153
+ name="pad",
154
+ dtype=result.dtype,
155
+ pad_func=mode,
156
+ iaxis_pad_width=pad_width[d],
157
+ iaxis=d,
158
+ pad_func_kwargs=kwargs,
159
+ )
160
+
161
+ result = result.rechunk(chunks)
162
+
163
+ return result
164
+
165
+
166
+ def _pad_edge_expr(array, pad_width, mode, **kwargs):
167
+ """
168
+ Helper function for padding edges - array-expr version.
169
+
170
+ Handles the cases where the only the values on the edge are needed.
171
+ """
172
+ from dask_array._collection import broadcast_to
173
+ from dask_array._utils import asarray_safe
174
+
175
+ from ._ones_zeros import empty_like
176
+
177
+ kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
178
+
179
+ result = array
180
+ for d in range(array.ndim):
181
+ pad_shapes, pad_chunks = get_pad_shapes_chunks(result, pad_width, (d,), mode=mode)
182
+ pad_arrays = [result, result]
183
+
184
+ if mode == "constant":
185
+ constant_values = kwargs["constant_values"][d]
186
+ constant_values = [
187
+ asarray_safe(c, like=meta_from_array(array), dtype=result.dtype) for c in constant_values
188
+ ]
189
+
190
+ pad_arrays = [broadcast_to(v, s, c) for v, s, c in zip(constant_values, pad_shapes, pad_chunks)]
191
+ elif mode in ["edge", "linear_ramp"]:
192
+ pad_slices = [result.ndim * [slice(None)], result.ndim * [slice(None)]]
193
+ pad_slices[0][d] = slice(None, 1, None)
194
+ pad_slices[1][d] = slice(-1, None, None)
195
+ pad_slices = [tuple(sl) for sl in pad_slices]
196
+
197
+ pad_arrays = [result[sl] for sl in pad_slices]
198
+
199
+ if mode == "edge":
200
+ pad_arrays = [broadcast_to(a, s, c) for a, s, c in zip(pad_arrays, pad_shapes, pad_chunks)]
201
+ elif mode == "linear_ramp":
202
+ end_values = kwargs["end_values"][d]
203
+
204
+ pad_arrays = [
205
+ a.map_blocks(
206
+ linear_ramp_chunk,
207
+ ev,
208
+ pw,
209
+ chunks=c,
210
+ dtype=result.dtype,
211
+ dim=d,
212
+ step=(2 * i - 1),
213
+ )
214
+ for i, (a, ev, pw, c) in enumerate(zip(pad_arrays, end_values, pad_width[d], pad_chunks))
215
+ ]
216
+ elif mode == "empty":
217
+ pad_arrays = [
218
+ empty_like(array, shape=s, dtype=array.dtype, chunks=c) for s, c in zip(pad_shapes, pad_chunks)
219
+ ]
220
+
221
+ result = concatenate([pad_arrays[0], result, pad_arrays[1]], axis=d)
222
+
223
+ return result
224
+
225
+
226
+ @derived_from(np)
227
+ def pad(array, pad_width, mode="constant", **kwargs):
228
+ array = asarray(array)
229
+
230
+ pad_width = expand_pad_value(array, pad_width)
231
+
232
+ if callable(mode):
233
+ return _pad_udf_expr(array, pad_width, mode, **kwargs)
234
+
235
+ # Make sure that no unsupported keywords were passed for the current mode
236
+ allowed_kwargs = {
237
+ "empty": [],
238
+ "edge": [],
239
+ "wrap": [],
240
+ "constant": ["constant_values"],
241
+ "linear_ramp": ["end_values"],
242
+ "maximum": ["stat_length"],
243
+ "mean": ["stat_length"],
244
+ "median": ["stat_length"],
245
+ "minimum": ["stat_length"],
246
+ "reflect": ["reflect_type"],
247
+ "symmetric": ["reflect_type"],
248
+ }
249
+ try:
250
+ unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode])
251
+ except KeyError as e:
252
+ raise ValueError(f"mode '{mode}' is not supported") from e
253
+ if unsupported_kwargs:
254
+ raise ValueError(f"unsupported keyword arguments for mode '{mode}': {unsupported_kwargs}")
255
+
256
+ if mode in {"maximum", "mean", "median", "minimum"}:
257
+ stat_length = kwargs.get("stat_length", tuple((n, n) for n in array.shape))
258
+ return _pad_stats_expr(array, pad_width, mode, stat_length)
259
+ elif mode == "constant":
260
+ kwargs.setdefault("constant_values", 0)
261
+ return _pad_edge_expr(array, pad_width, mode, **kwargs)
262
+ elif mode == "linear_ramp":
263
+ kwargs.setdefault("end_values", 0)
264
+ return _pad_edge_expr(array, pad_width, mode, **kwargs)
265
+ elif mode in {"edge", "empty"}:
266
+ return _pad_edge_expr(array, pad_width, mode)
267
+ elif mode in ["reflect", "symmetric", "wrap"]:
268
+ return _pad_reuse_expr(array, pad_width, mode, **kwargs)
269
+
270
+ raise RuntimeError("unreachable")
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from numbers import Integral
4
+
5
+ import numpy as np
6
+ from tlz import sliding_window
7
+
8
+ from dask_array._collection import concatenate
9
+ from dask.utils import cached_cumsum, derived_from
10
+
11
+
12
+ @derived_from(np)
13
+ def repeat(a, repeats, axis=None):
14
+ if axis is None:
15
+ if a.ndim == 1:
16
+ axis = 0
17
+ else:
18
+ raise NotImplementedError("Must supply an integer axis value")
19
+
20
+ if not isinstance(repeats, Integral):
21
+ raise NotImplementedError("Only integer valued repeats supported")
22
+
23
+ if -a.ndim <= axis < 0:
24
+ axis += a.ndim
25
+ elif not 0 <= axis <= a.ndim - 1:
26
+ raise ValueError(f"axis(={axis}) out of bounds")
27
+
28
+ if repeats == 0:
29
+ return a[tuple(slice(None) if d != axis else slice(0) for d in range(a.ndim))]
30
+ elif repeats == 1:
31
+ return a
32
+
33
+ cchunks = cached_cumsum(a.chunks[axis], initial_zero=True)
34
+ slices = []
35
+ for c_start, c_stop in sliding_window(2, cchunks):
36
+ ls = np.linspace(c_start, c_stop, repeats).round(0)
37
+ for ls_start, ls_stop in sliding_window(2, ls):
38
+ if ls_start != ls_stop:
39
+ slices.append(slice(ls_start, ls_stop))
40
+
41
+ all_slice = slice(None, None, None)
42
+ slices = [(all_slice,) * axis + (s,) + (all_slice,) * (a.ndim - axis - 1) for s in slices]
43
+
44
+ slabs = [a[slc] for slc in slices]
45
+
46
+ out = []
47
+ for slab in slabs:
48
+ chunks = list(slab.chunks)
49
+ assert len(chunks[axis]) == 1
50
+ chunks[axis] = (chunks[axis][0] * repeats,)
51
+ chunks = tuple(chunks)
52
+ result = slab.map_blocks(np.repeat, repeats, axis=axis, chunks=chunks, dtype=slab.dtype)
53
+ out.append(result)
54
+
55
+ return concatenate(out, axis=axis)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from dask_array._collection import asarray
6
+ from dask.utils import derived_from
7
+
8
+
9
+ @derived_from(np)
10
+ def tile(A, reps):
11
+ from dask_array._collection import block
12
+
13
+ from ._ones_zeros import empty
14
+
15
+ try:
16
+ tup = tuple(reps)
17
+ except TypeError:
18
+ tup = (reps,)
19
+ if any(i < 0 for i in tup):
20
+ raise ValueError("Negative `reps` are not allowed.")
21
+ c = asarray(A)
22
+
23
+ if all(tup):
24
+ for nrep in tup[::-1]:
25
+ c = nrep * [c]
26
+ return block(c)
27
+
28
+ d = len(tup)
29
+ if d < c.ndim:
30
+ tup = (1,) * (c.ndim - d) + tup
31
+ if c.ndim < d:
32
+ shape = (1,) * (d - c.ndim) + c.shape
33
+ else:
34
+ shape = c.shape
35
+ shape_out = tuple(s * t for s, t in zip(shape, tup))
36
+ return empty(shape=shape_out, dtype=c.dtype)
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from dask.utils import derived_from
6
+
7
+
8
+ @derived_from(np)
9
+ def tri(N, M=None, k=0, dtype=float, chunks="auto", *, like=None):
10
+ from dask_array._ufunc import greater_equal
11
+ from dask_array._core_utils import normalize_chunks
12
+
13
+ from ._arange import arange
14
+
15
+ if M is None:
16
+ M = N
17
+
18
+ chunks = normalize_chunks(chunks, shape=(N, M), dtype=dtype)
19
+
20
+ m = greater_equal(
21
+ arange(N, chunks=chunks[0][0], like=like).reshape(1, N).T,
22
+ arange(-k, M - k, chunks=chunks[1][0], like=like),
23
+ )
24
+
25
+ # Avoid making a copy if the requested type is already bool
26
+ m = m.astype(dtype, copy=False)
27
+
28
+ return m
@@ -0,0 +1,296 @@
1
+ """Helper functions for array creation operations.
2
+
3
+ Copied/adapted from dask.array.creation and dask.array.wrap to reduce
4
+ imports from dask.array.* modules.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Sequence
10
+ from numbers import Number
11
+
12
+ import numpy as np
13
+ from tlz import curry
14
+
15
+ from dask_array._core_utils import normalize_chunks
16
+ from dask.base import tokenize
17
+ from dask.utils import funcname
18
+
19
+
20
+ def _parse_wrap_args(func, args, kwargs, shape):
21
+ """Parse arguments for wrap functions (ones, zeros, full, empty).
22
+
23
+ Parameters
24
+ ----------
25
+ func : callable
26
+ The numpy function (e.g., np.ones_like)
27
+ args : tuple
28
+ Positional arguments after shape
29
+ kwargs : dict
30
+ Keyword arguments (may include name, chunks, dtype)
31
+ shape : tuple or int
32
+ The desired shape
33
+
34
+ Returns
35
+ -------
36
+ dict with keys: shape, dtype, kwargs, chunks, name
37
+ """
38
+ if isinstance(shape, np.ndarray):
39
+ shape = shape.tolist()
40
+
41
+ if not isinstance(shape, (tuple, list)):
42
+ shape = (shape,)
43
+
44
+ name = kwargs.pop("name", None)
45
+ chunks = kwargs.pop("chunks", "auto")
46
+
47
+ dtype = kwargs.pop("dtype", None)
48
+ if dtype is None:
49
+ dtype = func(shape, *args, **kwargs).dtype
50
+ dtype = np.dtype(dtype)
51
+
52
+ chunks = normalize_chunks(chunks, shape, dtype=dtype)
53
+
54
+ name = name or funcname(func) + "-" + tokenize(func, shape, chunks, dtype, args, kwargs)
55
+
56
+ return {
57
+ "shape": shape,
58
+ "dtype": dtype,
59
+ "kwargs": kwargs,
60
+ "chunks": chunks,
61
+ "name": name,
62
+ }
63
+
64
+
65
+ @curry
66
+ def _broadcast_trick_inner(func, shape, meta=(), *args, **kwargs):
67
+ """Inner function for broadcast_trick.
68
+
69
+ cupy-specific hack. numpy is happy with hardcoded shape=().
70
+ """
71
+ null_shape = () if shape == () else 1
72
+ return np.broadcast_to(func(meta, *args, shape=null_shape, **kwargs), shape)
73
+
74
+
75
+ def broadcast_trick(func):
76
+ """Provide a decorator to wrap common numpy function with a broadcast trick.
77
+
78
+ Dask arrays are currently immutable; thus when we know an array is uniform,
79
+ we can replace the actual data by a single value and have all elements point
80
+ to it, thus reducing the size.
81
+
82
+ >>> x = np.broadcast_to(1, (100,100,100))
83
+ >>> x.base.nbytes
84
+ 8
85
+
86
+ Those array are not only more efficient locally, but dask serialisation is
87
+ aware of the _real_ size of those array and thus can send them around
88
+ efficiently and schedule accordingly.
89
+
90
+ Note that those array are read-only and numpy will refuse to assign to them,
91
+ so should be safe.
92
+ """
93
+ inner = _broadcast_trick_inner(func)
94
+ inner.__doc__ = func.__doc__
95
+ inner.__name__ = func.__name__
96
+ return inner
97
+
98
+
99
+ def _get_like_function_shapes_chunks(a, chunks, shape):
100
+ """Helper function for finding shapes and chunks for *_like() array creation functions.
101
+
102
+ Parameters
103
+ ----------
104
+ a : dask array
105
+ The input array to get shape/chunks from
106
+ chunks : tuple or None
107
+ Desired chunks (None means use a's chunks)
108
+ shape : tuple or None
109
+ Desired shape (None means use a's shape)
110
+
111
+ Returns
112
+ -------
113
+ shape, chunks : tuple, tuple
114
+ """
115
+ if shape is None:
116
+ shape = a.shape
117
+ if chunks is None:
118
+ chunks = a.chunks
119
+ elif chunks is None:
120
+ chunks = "auto"
121
+ return shape, chunks
122
+
123
+
124
+ def expand_pad_value(array, pad_value):
125
+ """Expand pad_value to a per-dimension format.
126
+
127
+ Parameters
128
+ ----------
129
+ array : dask array
130
+ The array to be padded (used to get ndim)
131
+ pad_value : various
132
+ The pad value in various formats
133
+
134
+ Returns
135
+ -------
136
+ tuple of tuples
137
+ Normalized pad_value as ((before, after), ...) for each dimension
138
+ """
139
+ if isinstance(pad_value, Number) or getattr(pad_value, "ndim", None) == 0:
140
+ pad_value = array.ndim * ((pad_value, pad_value),)
141
+ elif isinstance(pad_value, Sequence) and all(isinstance(pw, Number) for pw in pad_value) and len(pad_value) == 1:
142
+ pad_value = array.ndim * ((pad_value[0], pad_value[0]),)
143
+ elif isinstance(pad_value, Sequence) and len(pad_value) == 2 and all(isinstance(pw, Number) for pw in pad_value):
144
+ pad_value = array.ndim * (tuple(pad_value),)
145
+ elif (
146
+ isinstance(pad_value, Sequence)
147
+ and len(pad_value) == array.ndim
148
+ and all(isinstance(pw, Sequence) for pw in pad_value)
149
+ and all((len(pw) == 2) for pw in pad_value)
150
+ and all(all(isinstance(w, Number) for w in pw) for pw in pad_value)
151
+ ):
152
+ pad_value = tuple(tuple(pw) for pw in pad_value)
153
+ elif (
154
+ isinstance(pad_value, Sequence)
155
+ and len(pad_value) == 1
156
+ and isinstance(pad_value[0], Sequence)
157
+ and len(pad_value[0]) == 2
158
+ and all(isinstance(pw, Number) for pw in pad_value[0])
159
+ ):
160
+ pad_value = array.ndim * (tuple(pad_value[0]),)
161
+ else:
162
+ raise TypeError("`pad_value` must be composed of integral typed values.")
163
+
164
+ return pad_value
165
+
166
+
167
+ def get_pad_shapes_chunks(array, pad_width, axes, mode):
168
+ """Helper function for finding shapes and chunks of end pads.
169
+
170
+ Parameters
171
+ ----------
172
+ array : dask array
173
+ The array to be padded
174
+ pad_width : tuple of tuples
175
+ The pad widths as ((before, after), ...) for each dimension
176
+ axes : tuple of ints
177
+ Which axes to compute pad info for
178
+ mode : str
179
+ The padding mode
180
+
181
+ Returns
182
+ -------
183
+ pad_shapes : list of tuples
184
+ Shape for [before, after] pads
185
+ pad_chunks : list of tuples
186
+ Chunks for [before, after] pads
187
+ """
188
+ pad_shapes = [list(array.shape), list(array.shape)]
189
+ pad_chunks = [list(array.chunks), list(array.chunks)]
190
+
191
+ for d in axes:
192
+ for i in range(2):
193
+ pad_shapes[i][d] = pad_width[d][i]
194
+ if mode != "constant" or pad_width[d][i] == 0:
195
+ pad_chunks[i][d] = (pad_width[d][i],)
196
+ else:
197
+ pad_chunks[i][d] = normalize_chunks((max(pad_chunks[i][d]),), (pad_width[d][i],))[0]
198
+
199
+ pad_shapes = [tuple(s) for s in pad_shapes]
200
+ pad_chunks = [tuple(c) for c in pad_chunks]
201
+
202
+ return pad_shapes, pad_chunks
203
+
204
+
205
+ def linear_ramp_chunk(start, stop, num, dim, step):
206
+ """Helper function to find the linear ramp for a chunk.
207
+
208
+ Parameters
209
+ ----------
210
+ start : array
211
+ Starting values (shape has size 1 in dim)
212
+ stop : scalar
213
+ End value for ramp
214
+ num : int
215
+ Number of points in ramp
216
+ dim : int
217
+ Dimension along which to ramp
218
+ step : int
219
+ Direction (1 or -1)
220
+
221
+ Returns
222
+ -------
223
+ array with linear ramp values
224
+ """
225
+ num1 = num + 1
226
+
227
+ shape = list(start.shape)
228
+ shape[dim] = num
229
+ shape = tuple(shape)
230
+
231
+ dtype = np.dtype(start.dtype)
232
+
233
+ result = np.empty_like(start, shape=shape, dtype=dtype)
234
+ for i in np.ndindex(start.shape):
235
+ j = list(i)
236
+ j[dim] = slice(None)
237
+ j = tuple(j)
238
+
239
+ result[j] = np.linspace(start[i], stop, num1, dtype=dtype)[1:][::step]
240
+
241
+ return result
242
+
243
+
244
+ def wrapped_pad_func(array, pad_func, iaxis_pad_width, iaxis, pad_func_kwargs):
245
+ """Wrapper to apply a user-defined pad function along an axis.
246
+
247
+ Parameters
248
+ ----------
249
+ array : ndarray
250
+ The input array chunk
251
+ pad_func : callable
252
+ User-defined padding function
253
+ iaxis_pad_width : tuple
254
+ (before, after) pad widths for this axis
255
+ iaxis : int
256
+ The axis index
257
+ pad_func_kwargs : dict
258
+ Keyword arguments to pass to pad_func
259
+
260
+ Returns
261
+ -------
262
+ array with padding applied
263
+ """
264
+ result = np.empty_like(array)
265
+ for i in np.ndindex(array.shape[:iaxis] + array.shape[iaxis + 1 :]):
266
+ i = i[:iaxis] + (slice(None),) + i[iaxis:]
267
+ result[i] = pad_func(array[i], iaxis_pad_width, iaxis, pad_func_kwargs)
268
+
269
+ return result
270
+
271
+
272
+ def to_backend(x, backend: str | None = None, **kwargs):
273
+ """Move an Array collection to a new backend.
274
+
275
+ Parameters
276
+ ----------
277
+ x : Array
278
+ The input Array collection.
279
+ backend : str, Optional
280
+ The name of the new backend to move to. The default
281
+ is the current "array.backend" configuration.
282
+
283
+ Returns
284
+ -------
285
+ Array
286
+ A new Array collection with the backend specified
287
+ by ``backend``.
288
+ """
289
+ from dask_array._backends_array import array_creation_dispatch
290
+
291
+ # Get desired backend
292
+ backend = backend or array_creation_dispatch.backend
293
+ # Check that "backend" has a registered entrypoint
294
+ backend_entrypoint = array_creation_dispatch.dispatch(backend)
295
+ # Call `ArrayBackendEntrypoint.to_backend`
296
+ return backend_entrypoint.to_backend(x, **kwargs)