dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,331 @@
1
+ """Concatenate operation - expression and collection function."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ from bisect import bisect
7
+ from itertools import product
8
+ from operator import add
9
+
10
+ import numpy as np
11
+ from tlz import accumulate
12
+ from toolz import concat
13
+
14
+ from dask._task_spec import Alias, List, Task, TaskRef
15
+ from dask_array._expr import ArrayExpr, unify_chunks_expr
16
+ from dask_array._core_utils import concatenate3
17
+ from dask_array._dispatch import concatenate_lookup
18
+ from dask_array._utils import meta_from_array
19
+
20
+
21
+ class Concatenate(ArrayExpr):
22
+ _parameters = ["array", "axis", "meta"]
23
+
24
+ @functools.cached_property
25
+ def args(self):
26
+ return [self.array] + self.operands[len(self._parameters) :]
27
+
28
+ @functools.cached_property
29
+ def _meta(self):
30
+ return self.operand("meta")
31
+
32
+ @functools.cached_property
33
+ def chunks(self):
34
+ bds = [a.chunks for a in self.args]
35
+ chunks = bds[0][: self.axis] + (sum((bd[self.axis] for bd in bds), ()),) + bds[0][self.axis + 1 :]
36
+ return chunks
37
+
38
+ @functools.cached_property
39
+ def _name(self):
40
+ return "stack-" + self.deterministic_token
41
+
42
+ def _layer(self) -> dict:
43
+ axis = self.axis
44
+ cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in self.args]))
45
+ keys = list(product([self._name], *[range(len(bd)) for bd in self.chunks]))
46
+ names = [a.name for a in self.args]
47
+
48
+ dsk = {}
49
+ for key in keys:
50
+ source_name = names[bisect(cum_dims, key[axis + 1]) - 1]
51
+ source_key = (
52
+ (source_name,)
53
+ + key[1 : axis + 1]
54
+ + (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
55
+ + key[axis + 2 :]
56
+ )
57
+ dsk[key] = Alias(key, source_key)
58
+
59
+ return dsk
60
+
61
+ def _simplify_up(self, parent, dependents):
62
+ """Allow slice and shuffle operations to push through Concatenate."""
63
+ from dask_array._shuffle import Shuffle
64
+ from dask_array.slicing import SliceSlicesIntegers
65
+
66
+ if isinstance(parent, SliceSlicesIntegers):
67
+ return self._accept_slice(parent)
68
+ if isinstance(parent, Shuffle):
69
+ return self._accept_shuffle(parent)
70
+ return None
71
+
72
+ def _accept_shuffle(self, shuffle_expr):
73
+ """Accept a shuffle being pushed through Concatenate.
74
+
75
+ Can only push through if not shuffling on the concat axis.
76
+ """
77
+ from dask_array._shuffle import Shuffle
78
+
79
+ concat_axis = self.axis
80
+ shuffle_axis = shuffle_expr.axis
81
+
82
+ # Can't shuffle on concat axis (would split indices across arrays)
83
+ if shuffle_axis == concat_axis:
84
+ return None
85
+
86
+ # Shuffle each input
87
+ arrays = self.args
88
+ shuffled_arrays = [Shuffle(a, shuffle_expr.indexer, shuffle_axis, shuffle_expr.operand("name")) for a in arrays]
89
+
90
+ return type(self)(
91
+ shuffled_arrays[0],
92
+ concat_axis,
93
+ self._meta,
94
+ *shuffled_arrays[1:],
95
+ )
96
+
97
+ def _accept_slice(self, slice_expr):
98
+ """Accept a slice being pushed through Concatenate.
99
+
100
+ Cases:
101
+ 1. Slice on concat axis: select/trim relevant arrays
102
+ 2. Slice on other axis: push to all inputs
103
+ """
104
+ from numbers import Integral
105
+
106
+ from dask_array._new_collection import new_collection
107
+
108
+ axis = self.axis
109
+ arrays = self.args
110
+ index = slice_expr.index
111
+
112
+ # Pad index to full length
113
+ full_index = index + (slice(None),) * (self.ndim - len(index))
114
+
115
+ # For now, only handle simple slices (no integers that reduce dims)
116
+ if any(isinstance(idx, Integral) for idx in full_index):
117
+ return None
118
+ if any(idx is None for idx in full_index):
119
+ return None
120
+
121
+ axis_slice = full_index[axis]
122
+
123
+ # Normalize the axis slice
124
+ concat_dim_size = sum(a.shape[axis] for a in arrays)
125
+ if isinstance(axis_slice, slice):
126
+ start, stop, step = axis_slice.indices(concat_dim_size)
127
+ if step != 1:
128
+ return None # Don't handle non-unit steps
129
+ else:
130
+ return None
131
+
132
+ # Build sliced arrays
133
+ sliced_arrays = []
134
+ cumsum = 0
135
+ for arr in arrays:
136
+ arr_size = arr.shape[axis]
137
+ arr_start = cumsum
138
+ arr_end = cumsum + arr_size
139
+
140
+ # Check if this array overlaps with the slice
141
+ overlap_start = max(start, arr_start)
142
+ overlap_end = min(stop, arr_end)
143
+
144
+ if overlap_end > overlap_start:
145
+ # Build slice for this array
146
+ local_start = overlap_start - arr_start
147
+ local_stop = overlap_end - arr_start
148
+
149
+ # Build full slice tuple for this array
150
+ arr_slices = list(full_index)
151
+ arr_slices[axis] = slice(local_start, local_stop)
152
+
153
+ sliced_arr = new_collection(arr)[tuple(arr_slices)]
154
+ sliced_arrays.append(sliced_arr.expr)
155
+
156
+ cumsum = arr_end
157
+
158
+ if not sliced_arrays:
159
+ # Empty result - shouldn't happen with valid slice
160
+ return None
161
+
162
+ if len(sliced_arrays) == 1:
163
+ # Only one array needed - just return it (already sliced)
164
+ return sliced_arrays[0]
165
+
166
+ # Multiple arrays - create new Concatenate
167
+ return type(self)(
168
+ sliced_arrays[0],
169
+ axis,
170
+ self._meta,
171
+ *sliced_arrays[1:],
172
+ )
173
+
174
+
175
+ class ConcatenateFinalize(ArrayExpr):
176
+ """Finalize array computation by concatenating all blocks.
177
+
178
+ This is used for arrays with unknown chunk sizes where rechunking
179
+ is not possible.
180
+ """
181
+
182
+ _parameters = ["arr"]
183
+
184
+ @functools.cached_property
185
+ def _name(self):
186
+ return f"concatenate-finalize-{self.deterministic_token}"
187
+
188
+ @functools.cached_property
189
+ def _meta(self):
190
+ return self.arr._meta
191
+
192
+ @functools.cached_property
193
+ def chunks(self):
194
+ # Output is a single chunk with unknown size
195
+ return tuple((np.nan,) for _ in range(self.arr.ndim))
196
+
197
+ @functools.cached_property
198
+ def numblocks(self):
199
+ return tuple(1 for _ in range(self.arr.ndim))
200
+
201
+ @functools.cached_property
202
+ def _cached_keys(self):
203
+ return List(TaskRef((self._name,) + (0,) * self.arr.ndim))
204
+
205
+ def _layer(self) -> dict:
206
+ # Get all keys from the input array in nested list structure
207
+ arr_keys = self.arr.__dask_keys__()
208
+
209
+ # Convert nested key structure to TaskRefs
210
+ def convert_keys(keys):
211
+ if isinstance(keys, list):
212
+ return List(*[convert_keys(k) for k in keys])
213
+ return TaskRef(keys)
214
+
215
+ keys_list = convert_keys(arr_keys)
216
+
217
+ out_key = (self._name,) + (0,) * self.arr.ndim
218
+ return {out_key: Task(out_key, concatenate3, keys_list)}
219
+
220
+
221
+ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
222
+ """
223
+ Concatenate arrays along an existing axis
224
+
225
+ Given a sequence of dask Arrays form a new dask Array by stacking them
226
+ along an existing dimension (axis=0 by default)
227
+
228
+ Parameters
229
+ ----------
230
+ seq: list of dask.arrays
231
+ axis: int
232
+ Dimension along which to align all of the arrays. If axis is None,
233
+ arrays are flattened before use.
234
+ allow_unknown_chunksizes: bool
235
+ Allow unknown chunksizes, such as come from converting from dask
236
+ dataframes. Dask.array is unable to verify that chunks line up. If
237
+ data comes from differently aligned sources then this can cause
238
+ unexpected results.
239
+
240
+ Examples
241
+ --------
242
+
243
+ Create slices
244
+
245
+ >>> import dask_array as da
246
+ >>> import numpy as np
247
+
248
+ >>> data = [da.from_array(np.ones((4, 4)), chunks=(2, 2))
249
+ ... for i in range(3)]
250
+
251
+ >>> x = da.concatenate(data, axis=0)
252
+ >>> x.shape
253
+ (12, 4)
254
+
255
+ >>> da.concatenate(data, axis=1).shape
256
+ (4, 12)
257
+
258
+ Result is a new dask Array
259
+
260
+ See Also
261
+ --------
262
+ stack
263
+ """
264
+ from dask_array.creation import empty, empty_like
265
+
266
+ # Lazy import to avoid circular dependency
267
+ from dask_array._new_collection import new_collection
268
+ from dask_array.core import asarray
269
+
270
+ seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
271
+
272
+ if not seq:
273
+ raise ValueError("Need array(s) to concatenate")
274
+
275
+ if axis is None:
276
+ seq = [a.flatten() for a in seq]
277
+ axis = 0
278
+
279
+ seq_metas = [meta_from_array(s) for s in seq]
280
+ _concatenate = concatenate_lookup.dispatch(type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0))))
281
+ meta = _concatenate(seq_metas, axis=axis)
282
+
283
+ # Promote types to match meta
284
+ seq = [a.astype(meta.dtype) for a in seq]
285
+
286
+ # Find output array shape
287
+ ndim = len(seq[0].shape)
288
+ shape = tuple(sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i] for i in range(ndim))
289
+
290
+ # Drop empty arrays
291
+ seq2 = [a for a in seq if a.size]
292
+ if not seq2:
293
+ seq2 = seq
294
+
295
+ if axis < 0:
296
+ axis = ndim + axis
297
+ if axis >= ndim:
298
+ msg = "Axis must be less than than number of dimensions\nData has %d dimensions, but got axis=%d"
299
+ raise ValueError(msg % (ndim, axis))
300
+
301
+ n = len(seq2)
302
+ if n == 0:
303
+ try:
304
+ return empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
305
+ except TypeError:
306
+ return empty(shape, chunks=shape, dtype=meta.dtype)
307
+ elif n == 1:
308
+ return seq2[0]
309
+
310
+ if not allow_unknown_chunksizes and not all(
311
+ i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2) for i in range(ndim)
312
+ ):
313
+ if any(map(np.isnan, seq2[0].shape)):
314
+ raise ValueError(
315
+ "Tried to concatenate arrays with unknown"
316
+ f" shape {seq2[0].shape}.\n\nTwo solutions:\n"
317
+ " 1. Force concatenation pass"
318
+ " allow_unknown_chunksizes=True.\n"
319
+ " 2. Compute shapes with "
320
+ "[x.compute_chunk_sizes() for x in seq]"
321
+ )
322
+ raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
323
+
324
+ inds = [list(range(ndim)) for i in range(n)]
325
+ for i, ind in enumerate(inds):
326
+ ind[axis] = -(i + 1)
327
+
328
+ seq_tmp = [s.expr for s in seq2]
329
+ uc_args = list(concat((s, i) for s, i in zip(seq_tmp, inds)))
330
+ _, seq2, _ = unify_chunks_expr(*uc_args, warn=False)
331
+ return new_collection(Concatenate(seq2[0], axis, meta, *seq2[1:]))