dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,74 @@
1
+ """Shared utilities for linear algebra operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ def _cumsum_blocks(it):
9
+ """Yield cumulative (start, end) pairs for block sizes."""
10
+ total = 0
11
+ for x in it:
12
+ total_previous = total
13
+ total += x
14
+ yield (total_previous, total)
15
+
16
+
17
+ def _nanmin(m, n):
18
+ """Return min(m, n), handling NaN values.
19
+
20
+ If either value is NaN, return the other value if it's not NaN,
21
+ otherwise return NaN.
22
+ """
23
+ k_0 = min([m, n])
24
+ k_1 = m if np.isnan(n) else n
25
+ return k_1 if np.isnan(k_0) else k_0
26
+
27
+
28
+ def _has_uncertain_chunks(chunks):
29
+ """Check if any chunk sizes are uncertain (NaN)."""
30
+ return any(np.isnan(c) for cs in chunks for c in cs)
31
+
32
+
33
+ def _cumsum_part(last, new):
34
+ """Compute (start, end) from previous cumsum and new value."""
35
+ return (last[1], last[1] + new)
36
+
37
+
38
+ def _get_block_size(block):
39
+ """Get min dimension from a block (used for R block sizes)."""
40
+ return min(block.shape)
41
+
42
+
43
+ def _make_slice(cumsum_pair, n):
44
+ """Create slice tuple from cumsum pair and column count."""
45
+ return (slice(cumsum_pair[0], cumsum_pair[1]), slice(0, n))
46
+
47
+
48
+ def _getitem_with_slice(arr, slc):
49
+ """Apply slice to array."""
50
+ return arr[slc]
51
+
52
+
53
+ def _get_n(arr):
54
+ """Get the number of columns from an array."""
55
+ return arr.shape[1]
56
+
57
+
58
+ def _solve_triangular_lower(a, b):
59
+ """Solve triangular system with lower triangular matrix."""
60
+ from dask_array._utils import solve_triangular_safe
61
+
62
+ return solve_triangular_safe(a, b, lower=True)
63
+
64
+
65
+ def _solve_triangular_upper(a, b):
66
+ """Solve triangular system with upper triangular matrix."""
67
+ from dask_array._utils import solve_triangular_safe
68
+
69
+ return solve_triangular_safe(a, b, lower=False)
70
+
71
+
72
+ def _transpose(x):
73
+ """Transpose a matrix (used for P_inv and U.T)."""
74
+ return x.T
@@ -0,0 +1,45 @@
1
+ """Array manipulation functions: flip, transpose, reshape, expand_dims, etc."""
2
+
3
+ # Import from module files
4
+ from dask_array.manipulation._expand import (
5
+ atleast_1d,
6
+ atleast_2d,
7
+ atleast_3d,
8
+ expand_dims,
9
+ )
10
+ from dask_array.manipulation._flip import flip, fliplr, flipud, rot90
11
+ from dask_array.manipulation._roll import roll
12
+ from dask_array.manipulation._transpose import (
13
+ moveaxis,
14
+ rollaxis,
15
+ swapaxes,
16
+ transpose,
17
+ )
18
+
19
+
20
+ def __getattr__(name):
21
+ """Lazy import of reshape and ravel to avoid circular imports."""
22
+ if name in ("reshape", "ravel"):
23
+ from dask_array import _collection
24
+
25
+ return getattr(_collection, name)
26
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
27
+
28
+
29
+ __all__ = [
30
+ "flip",
31
+ "flipud",
32
+ "fliplr",
33
+ "rot90",
34
+ "swapaxes",
35
+ "moveaxis",
36
+ "rollaxis",
37
+ "transpose",
38
+ "expand_dims",
39
+ "atleast_1d",
40
+ "atleast_2d",
41
+ "atleast_3d",
42
+ "roll",
43
+ "reshape",
44
+ "ravel",
45
+ ]
@@ -0,0 +1,321 @@
1
+ """Expand operations: expand_dims, atleast_1d, atleast_2d, atleast_3d."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ import numpy as np
8
+
9
+ from dask_array._new_collection import new_collection
10
+ from dask._task_spec import Task, TaskRef
11
+ from dask_array._expr import ArrayExpr
12
+ from dask_array._chunk import getitem
13
+
14
+
15
+ class ExpandDims(ArrayExpr):
16
+ """Dimension expansion expression.
17
+
18
+ Adds new axes of size 1 at specified positions using numpy indexing with None.
19
+ This is more efficient than reshape for dimension expansion and integrates
20
+ better with slice pushdown optimizations.
21
+
22
+ Parameters
23
+ ----------
24
+ array : ArrayExpr
25
+ The input array expression.
26
+ axes : tuple of int
27
+ Positions where new axes should be inserted (in output coordinates).
28
+ """
29
+
30
+ _parameters = ["array", "axes"]
31
+
32
+ @functools.cached_property
33
+ def _meta(self):
34
+ meta = self.array._meta
35
+ for ax in sorted(self.axes):
36
+ meta = np.expand_dims(meta, axis=ax)
37
+ return meta
38
+
39
+ @functools.cached_property
40
+ def chunks(self):
41
+ chunks = list(self.array.chunks)
42
+ for ax in sorted(self.axes):
43
+ chunks.insert(ax, (1,))
44
+ return tuple(chunks)
45
+
46
+ @functools.cached_property
47
+ def _name(self):
48
+ return f"expand-dims-{self.deterministic_token}"
49
+
50
+ @functools.cached_property
51
+ def _indexer(self):
52
+ """Build indexer tuple with None at expansion axes."""
53
+ out_ndim = self.array.ndim + len(self.axes)
54
+ return tuple(None if i in self.axes else slice(None) for i in range(out_ndim))
55
+
56
+ def _layer(self) -> dict:
57
+ indexer = self._indexer
58
+ axes = sorted(self.axes)
59
+ input_name = self.array._name
60
+
61
+ dsk = {}
62
+ for block_id in np.ndindex(self.array.numblocks):
63
+ in_key = (input_name,) + block_id
64
+ # Insert 0 at expansion axes: ('x', 2, 3) -> ('expand', 0, 2, 0, 3)
65
+ out_block_id = list(block_id)
66
+ for ax in axes:
67
+ out_block_id.insert(ax, 0)
68
+ out_key = (self._name,) + tuple(out_block_id)
69
+ dsk[out_key] = Task(out_key, getitem, TaskRef(in_key), indexer)
70
+
71
+ return dsk
72
+
73
+ def _simplify_up(self, parent, dependents):
74
+ """Allow slice and shuffle operations to push through ExpandDims."""
75
+ from dask_array._shuffle import Shuffle
76
+ from dask_array.slicing import SliceSlicesIntegers
77
+
78
+ if isinstance(parent, SliceSlicesIntegers):
79
+ return self._accept_slice(parent)
80
+ if isinstance(parent, Shuffle):
81
+ return self._accept_shuffle(parent)
82
+ return None
83
+
84
+ def _accept_slice(self, slice_expr):
85
+ """Accept a slice being pushed through ExpandDims.
86
+
87
+ Maps output slice indices to input, removing expanded axes.
88
+ """
89
+ from numbers import Integral
90
+
91
+ axes = set(self.axes)
92
+ index = slice_expr.index
93
+ out_ndim = self.ndim
94
+
95
+ # Pad index to full output length
96
+ full_index = index + (slice(None),) * (out_ndim - len(index))
97
+
98
+ # Build input slice by removing expanded axes
99
+ input_index = []
100
+ dims_removed = 0
101
+ for i, idx in enumerate(full_index):
102
+ if i in axes:
103
+ # This is an expanded axis (size 1)
104
+ if isinstance(idx, Integral):
105
+ # Integer index on size-1 axis must be 0 or -1
106
+ if idx not in (0, -1):
107
+ return None # Out of bounds
108
+ dims_removed += 1
109
+ elif idx == slice(None):
110
+ pass # Keep this expansion axis
111
+ elif isinstance(idx, slice):
112
+ # Slicing a size-1 axis - normalize
113
+ start, stop, step = idx.indices(1)
114
+ if stop <= start:
115
+ return None # Empty result
116
+ else:
117
+ return None # Can't handle this
118
+ else:
119
+ input_index.append(idx)
120
+
121
+ # Slice the input
122
+ if all(idx == slice(None) for idx in input_index):
123
+ sliced_input = self.array
124
+ else:
125
+ sliced_input = new_collection(self.array)[tuple(input_index)].expr
126
+
127
+ # Compute new axes positions after slicing
128
+ # Axes that had integer indexing are removed
129
+ new_axes = []
130
+ removed_count = 0
131
+ for i, idx in enumerate(full_index):
132
+ if i in axes:
133
+ if isinstance(idx, Integral):
134
+ removed_count += 1
135
+ else:
136
+ # Adjust for removed dimensions before this position
137
+ new_axes.append(i - removed_count)
138
+
139
+ if not new_axes:
140
+ # All expansion axes were removed by integer indexing
141
+ return sliced_input
142
+
143
+ return ExpandDims(sliced_input, tuple(new_axes))
144
+
145
+ def _accept_shuffle(self, shuffle_expr):
146
+ """Accept a shuffle being pushed through ExpandDims.
147
+
148
+ Maps shuffle axis through expansion to input axis.
149
+ """
150
+ from dask_array._shuffle import Shuffle
151
+
152
+ axes = set(self.axes)
153
+ shuffle_axis = shuffle_expr.axis
154
+
155
+ # Can't shuffle on an expanded axis (size 1, nothing to shuffle)
156
+ if shuffle_axis in axes:
157
+ return None
158
+
159
+ # Map output axis to input axis (subtract number of expansion axes before it)
160
+ input_axis = shuffle_axis - sum(1 for ax in axes if ax < shuffle_axis)
161
+
162
+ shuffled_input = Shuffle(
163
+ self.array,
164
+ shuffle_expr.indexer,
165
+ input_axis,
166
+ shuffle_expr.operand("name"),
167
+ )
168
+ return ExpandDims(shuffled_input, self.axes)
169
+
170
+
171
+ def expand_dims(a, axis):
172
+ """Expand the shape of an array.
173
+
174
+ Insert a new axis that will appear at the axis position in the expanded
175
+ array shape.
176
+
177
+ Parameters
178
+ ----------
179
+ a : array_like
180
+ Input array.
181
+ axis : int or tuple of ints
182
+ Position in the expanded axes where the new axis (or axes) is placed.
183
+
184
+ Returns
185
+ -------
186
+ result : Array
187
+ Array with the number of dimensions increased.
188
+
189
+ See Also
190
+ --------
191
+ numpy.expand_dims
192
+ """
193
+ from dask_array._utils import validate_axis
194
+
195
+ if axis is None:
196
+ raise TypeError("axis must be an integer, not None")
197
+
198
+ if type(axis) not in (tuple, list):
199
+ axis = (axis,)
200
+
201
+ out_ndim = len(axis) + a.ndim
202
+ axis = validate_axis(axis, out_ndim)
203
+
204
+ return new_collection(ExpandDims(a.expr, tuple(sorted(axis))))
205
+
206
+
207
+ def atleast_1d(*arys):
208
+ """Convert inputs to arrays with at least one dimension.
209
+
210
+ Parameters
211
+ ----------
212
+ arys : array_like
213
+ One or more array-like sequences. Non-array inputs are converted
214
+ to arrays. Arrays that already have one or more dimensions are
215
+ preserved.
216
+
217
+ Returns
218
+ -------
219
+ ret : Array or tuple of Arrays
220
+ An array, or tuple of arrays, each with a.ndim >= 1.
221
+
222
+ See Also
223
+ --------
224
+ numpy.atleast_1d
225
+ """
226
+ from dask_array.core import asanyarray
227
+ from dask_array._numpy_compat import NUMPY_GE_200
228
+
229
+ new_arys = []
230
+ for x in arys:
231
+ x = asanyarray(x)
232
+ if x.ndim == 0:
233
+ x = x[None]
234
+ new_arys.append(x)
235
+
236
+ if len(new_arys) == 1:
237
+ return new_arys[0]
238
+ else:
239
+ if NUMPY_GE_200:
240
+ new_arys = tuple(new_arys)
241
+ return new_arys
242
+
243
+
244
+ def atleast_2d(*arys):
245
+ """View inputs as arrays with at least two dimensions.
246
+
247
+ Parameters
248
+ ----------
249
+ arys : array_like
250
+ One or more array-like sequences. Non-array inputs are converted
251
+ to arrays. Arrays that already have two or more dimensions are
252
+ preserved.
253
+
254
+ Returns
255
+ -------
256
+ ret : Array or tuple of Arrays
257
+ An array, or tuple of arrays, each with a.ndim >= 2.
258
+
259
+ See Also
260
+ --------
261
+ numpy.atleast_2d
262
+ """
263
+ from dask_array.core import asanyarray
264
+ from dask_array._numpy_compat import NUMPY_GE_200
265
+
266
+ new_arys = []
267
+ for x in arys:
268
+ x = asanyarray(x)
269
+ if x.ndim == 0:
270
+ x = x[None, None]
271
+ elif x.ndim == 1:
272
+ x = x[None, :]
273
+ new_arys.append(x)
274
+
275
+ if len(new_arys) == 1:
276
+ return new_arys[0]
277
+ else:
278
+ if NUMPY_GE_200:
279
+ new_arys = tuple(new_arys)
280
+ return new_arys
281
+
282
+
283
+ def atleast_3d(*arys):
284
+ """View inputs as arrays with at least three dimensions.
285
+
286
+ Parameters
287
+ ----------
288
+ arys : array_like
289
+ One or more array-like sequences. Non-array inputs are converted
290
+ to arrays. Arrays that already have three or more dimensions are
291
+ preserved.
292
+
293
+ Returns
294
+ -------
295
+ ret : Array or tuple of Arrays
296
+ An array, or tuple of arrays, each with a.ndim >= 3.
297
+
298
+ See Also
299
+ --------
300
+ numpy.atleast_3d
301
+ """
302
+ from dask_array.core import asanyarray
303
+ from dask_array._numpy_compat import NUMPY_GE_200
304
+
305
+ new_arys = []
306
+ for x in arys:
307
+ x = asanyarray(x)
308
+ if x.ndim == 0:
309
+ x = x[None, None, None]
310
+ elif x.ndim == 1:
311
+ x = x[None, :, None]
312
+ elif x.ndim == 2:
313
+ x = x[:, :, None]
314
+ new_arys.append(x)
315
+
316
+ if len(new_arys) == 1:
317
+ return new_arys[0]
318
+ else:
319
+ if NUMPY_GE_200:
320
+ new_arys = tuple(new_arys)
321
+ return new_arys
@@ -0,0 +1,92 @@
1
+ """Flip operations: flip, flipud, fliplr, rot90."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+
7
+ import numpy as np
8
+
9
+
10
+ def flip(m, axis=None):
11
+ """Reverse element order along axis.
12
+
13
+ See Also
14
+ --------
15
+ numpy.flip
16
+ """
17
+ from dask_array.core import asanyarray
18
+
19
+ m = asanyarray(m)
20
+
21
+ sl = m.ndim * [slice(None)]
22
+ if axis is None:
23
+ axis = range(m.ndim)
24
+ if not isinstance(axis, Iterable):
25
+ axis = (axis,)
26
+ try:
27
+ for ax in axis:
28
+ sl[ax] = slice(None, None, -1)
29
+ except IndexError as e:
30
+ raise ValueError(f"`axis` of {axis} invalid for {m.ndim}-D array") from e
31
+ sl = tuple(sl)
32
+
33
+ return m[sl]
34
+
35
+
36
+ def flipud(m):
37
+ """Flip array in the up/down direction.
38
+
39
+ See Also
40
+ --------
41
+ numpy.flipud
42
+ """
43
+ return flip(m, 0)
44
+
45
+
46
+ def fliplr(m):
47
+ """Flip array in the left/right direction.
48
+
49
+ See Also
50
+ --------
51
+ numpy.fliplr
52
+ """
53
+ return flip(m, 1)
54
+
55
+
56
+ def rot90(m, k=1, axes=(0, 1)):
57
+ """Rotate an array by 90 degrees in the plane specified by axes.
58
+
59
+ See Also
60
+ --------
61
+ numpy.rot90
62
+ """
63
+ from dask_array.core import asanyarray
64
+ from dask_array.manipulation._transpose import transpose
65
+
66
+ axes = tuple(axes)
67
+ if len(axes) != 2:
68
+ raise ValueError("len(axes) must be 2.")
69
+
70
+ m = asanyarray(m)
71
+
72
+ if axes[0] == axes[1] or np.absolute(axes[0] - axes[1]) == m.ndim:
73
+ raise ValueError("Axes must be different.")
74
+
75
+ if axes[0] >= m.ndim or axes[0] < -m.ndim or axes[1] >= m.ndim or axes[1] < -m.ndim:
76
+ raise ValueError(f"Axes={axes} out of range for array of ndim={m.ndim}.")
77
+
78
+ k %= 4
79
+
80
+ if k == 0:
81
+ return m[:]
82
+ if k == 2:
83
+ return flip(flip(m, axes[0]), axes[1])
84
+
85
+ axes_list = list(range(0, m.ndim))
86
+ (axes_list[axes[0]], axes_list[axes[1]]) = (axes_list[axes[1]], axes_list[axes[0]])
87
+
88
+ if k == 1:
89
+ return transpose(flip(m, axes[1]), axes_list)
90
+ else:
91
+ # k == 3
92
+ return flip(transpose(m, axes_list), axes[1])
@@ -0,0 +1,78 @@
1
+ """Roll operation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numbers import Integral
6
+
7
+
8
+ def roll(array, shift, axis=None):
9
+ """Roll array elements along a given axis.
10
+
11
+ Elements that roll beyond the last position are re-introduced at the first.
12
+
13
+ Parameters
14
+ ----------
15
+ array : array_like
16
+ Input array.
17
+ shift : int or tuple of ints
18
+ The number of places by which elements are shifted.
19
+ axis : int or tuple of ints, optional
20
+ Axis or axes along which elements are shifted. By default, the
21
+ array is flattened before shifting, after which the original shape
22
+ is restored.
23
+
24
+ Returns
25
+ -------
26
+ result : Array
27
+ Array with the same shape as array.
28
+
29
+ See Also
30
+ --------
31
+ numpy.roll
32
+ """
33
+ # Import here to avoid circular imports
34
+ from dask_array._collection import concatenate, ravel
35
+
36
+ result = array
37
+
38
+ if axis is None:
39
+ result = ravel(result)
40
+
41
+ if not isinstance(shift, Integral):
42
+ raise TypeError("Expect `shift` to be an instance of Integral when `axis` is None.")
43
+
44
+ shift = (shift,)
45
+ axis = (0,)
46
+ else:
47
+ try:
48
+ len(shift)
49
+ except TypeError:
50
+ shift = (shift,)
51
+ try:
52
+ len(axis)
53
+ except TypeError:
54
+ axis = (axis,)
55
+
56
+ if len(shift) != len(axis):
57
+ raise ValueError("Must have the same number of shifts as axes.")
58
+
59
+ for i, s in zip(axis, shift):
60
+ shape = result.shape[i]
61
+ s = 0 if shape == 0 else -s % shape
62
+
63
+ sl1 = result.ndim * [slice(None)]
64
+ sl2 = result.ndim * [slice(None)]
65
+
66
+ sl1[i] = slice(s, None)
67
+ sl2[i] = slice(None, s)
68
+
69
+ sl1 = tuple(sl1)
70
+ sl2 = tuple(sl2)
71
+
72
+ result = concatenate([result[sl1], result[sl2]], axis=i)
73
+
74
+ result = result.reshape(array.shape)
75
+ # Ensure that the output is always a new array object
76
+ result = result.copy() if result is array else result
77
+
78
+ return result