dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,158 @@
1
+ """Gradient implementation for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from numbers import Integral, Real
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._collection import asarray
11
+ from dask_array._overlap import map_overlap
12
+ from dask_array._utils import validate_axis
13
+ from dask.base import is_dask_collection
14
+
15
+
16
+ def _gradient_kernel(x, block_id, coord, axis, array_locs, grad_kwargs):
17
+ """
18
+ x: nd-array
19
+ array of one block
20
+ coord: 1d-array or scalar
21
+ coordinate along which the gradient is computed.
22
+ axis: int
23
+ axis along which the gradient is computed
24
+ array_locs:
25
+ actual location along axis. None if coordinate is scalar
26
+ grad_kwargs:
27
+ keyword to be passed to np.gradient
28
+ """
29
+ block_loc = block_id[axis]
30
+ if array_locs is not None:
31
+ coord = coord[array_locs[0][block_loc] : array_locs[1][block_loc]]
32
+ grad = np.gradient(x, coord, axis=axis, **grad_kwargs)
33
+ return grad
34
+
35
+
36
+ def gradient(f, *varargs, axis=None, **kwargs):
37
+ """
38
+ Return the gradient of an N-dimensional array.
39
+
40
+ This docstring was copied from numpy.gradient.
41
+
42
+ Some inconsistencies with the Dask version may exist.
43
+
44
+ The gradient is computed using second order accurate central differences
45
+ in the interior points and either first or second order accurate one-sides
46
+ (forward or backwards) differences at the boundaries.
47
+ The returned gradient hence has the same shape as the input array.
48
+
49
+ Parameters
50
+ ----------
51
+ f : array_like
52
+ An N-dimensional array containing samples of a scalar function.
53
+ varargs : list of scalar or array, optional
54
+ Spacing between f values. Default unitary spacing for all dimensions.
55
+ Spacing can be specified using:
56
+
57
+ 1. single scalar to specify a sample distance for all dimensions.
58
+ 2. N scalars to specify a constant sample distance for each dimension.
59
+ i.e. `dx`, `dy`, `dz`, ...
60
+ 3. N arrays to specify the coordinates of the values along each
61
+ dimension of F. The length of the array must match the size of
62
+ the corresponding dimension
63
+ 4. Any combination of N scalars/arrays with the meaning of 2. and 3.
64
+
65
+ If `axis` is given, the number of varargs must equal the number of axes.
66
+ Default: 1.
67
+ axis : None or int or tuple of ints, optional
68
+ Gradient is calculated only along the given axis or axes.
69
+ The default (axis = None) is to calculate the gradient for all the axes
70
+ of the input array. axis may be negative, in which case it counts from
71
+ the last to the first axis.
72
+
73
+ Returns
74
+ -------
75
+ gradient : ndarray or list of ndarray
76
+ A list of ndarrays (or a single ndarray if there is only one dimension)
77
+ corresponding to the derivatives of f with respect to each dimension.
78
+ Each derivative has the same shape as f.
79
+
80
+ Other Parameters
81
+ ----------------
82
+ edge_order : {1, 2}, optional
83
+ Gradient is calculated using N-th order accurate differences
84
+ at the boundaries. Default: 1.
85
+ """
86
+ f = asarray(f)
87
+
88
+ kwargs["edge_order"] = math.ceil(kwargs.get("edge_order", 1))
89
+ if kwargs["edge_order"] > 2:
90
+ raise ValueError("edge_order must be less than or equal to 2.")
91
+
92
+ drop_result_list = False
93
+ if axis is None:
94
+ axis = tuple(range(f.ndim))
95
+ elif isinstance(axis, Integral):
96
+ drop_result_list = True
97
+ axis = (axis,)
98
+
99
+ axis = validate_axis(axis, f.ndim)
100
+
101
+ if len(axis) != len(set(axis)):
102
+ raise ValueError("duplicate axes not allowed")
103
+
104
+ axis = tuple(ax % f.ndim for ax in axis)
105
+
106
+ if varargs == ():
107
+ varargs = (1,)
108
+ if len(varargs) == 1:
109
+ varargs = len(axis) * varargs
110
+ if len(varargs) != len(axis):
111
+ raise TypeError("Spacing must either be a single scalar, or a scalar / 1d-array per axis")
112
+
113
+ if issubclass(f.dtype.type, (np.bool_, Integral)):
114
+ f = f.astype(float)
115
+ elif issubclass(f.dtype.type, Real) and f.dtype.itemsize < 4:
116
+ f = f.astype(float)
117
+
118
+ results = []
119
+ for i, ax in enumerate(axis):
120
+ for c in f.chunks[ax]:
121
+ if np.min(c) < kwargs["edge_order"] + 1:
122
+ raise ValueError(
123
+ "Chunk size must be larger than edge_order + 1. "
124
+ f"Minimum chunk for axis {ax} is {np.min(c)}. Rechunk to "
125
+ "proceed."
126
+ )
127
+
128
+ if np.isscalar(varargs[i]):
129
+ array_locs = None
130
+ else:
131
+ if is_dask_collection(varargs[i]):
132
+ raise NotImplementedError("dask array coordinated is not supported.")
133
+ # coordinate position for each block taking overlap into account
134
+ chunk = np.array(f.chunks[ax])
135
+ array_loc_stop = np.cumsum(chunk) + 1
136
+ array_loc_start = array_loc_stop - chunk - 2
137
+ array_loc_stop[-1] -= 1
138
+ array_loc_start[0] = 0
139
+ array_locs = (array_loc_start, array_loc_stop)
140
+
141
+ results.append(
142
+ map_overlap(
143
+ _gradient_kernel,
144
+ f,
145
+ dtype=f.dtype,
146
+ depth={j: 1 if j == ax else 0 for j in range(f.ndim)},
147
+ boundary="none",
148
+ coord=varargs[i],
149
+ axis=ax,
150
+ array_locs=array_locs,
151
+ grad_kwargs=kwargs,
152
+ )
153
+ )
154
+
155
+ if drop_result_list:
156
+ results = results[0]
157
+
158
+ return results
@@ -0,0 +1,65 @@
1
+ """Index manipulation functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import asarray, stack
8
+ from dask.base import is_dask_collection
9
+ from dask.utils import derived_from
10
+
11
+
12
+ def _unravel_index_kernel(indices, func_kwargs):
13
+ return np.stack(np.unravel_index(indices, **func_kwargs))
14
+
15
+
16
+ @derived_from(np)
17
+ def unravel_index(indices, shape, order="C"):
18
+ from dask_array.creation import empty
19
+
20
+ indices = asarray(indices)
21
+ if shape and indices.size:
22
+ unraveled_indices = tuple(
23
+ indices.map_blocks(
24
+ _unravel_index_kernel,
25
+ dtype=np.intp,
26
+ chunks=(((len(shape),),) + indices.chunks),
27
+ new_axis=0,
28
+ func_kwargs={"shape": shape, "order": order},
29
+ )
30
+ )
31
+ else:
32
+ unraveled_indices = tuple(empty((0,), dtype=np.intp, chunks=1) for i in shape)
33
+
34
+ return unraveled_indices
35
+
36
+
37
+ @derived_from(np)
38
+ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
39
+ from dask_array.routines._broadcast import broadcast_arrays
40
+
41
+ if np.isscalar(dims):
42
+ dims = (dims,)
43
+ if is_dask_collection(dims) or any(is_dask_collection(d) for d in dims):
44
+ raise NotImplementedError(f"Dask types are not supported in the `dims` argument: {dims!r}")
45
+
46
+ if hasattr(multi_index, "ndim") and multi_index.ndim > 0:
47
+ # It's an array-like
48
+ index_stack = asarray(multi_index)
49
+ else:
50
+ multi_index_arrs = broadcast_arrays(*multi_index)
51
+ index_stack = stack(multi_index_arrs)
52
+
53
+ if not np.isnan(index_stack.shape).any() and len(index_stack) != len(dims):
54
+ raise ValueError(f"parameter multi_index must be a sequence of length {len(dims)}")
55
+ if not np.issubdtype(index_stack.dtype, np.signedinteger):
56
+ raise TypeError("only int indices permitted")
57
+ return index_stack.map_blocks(
58
+ np.ravel_multi_index,
59
+ dtype=np.intp,
60
+ chunks=index_stack.chunks[1:],
61
+ drop_axis=0,
62
+ dims=dims,
63
+ mode=mode,
64
+ order=order,
65
+ )
@@ -0,0 +1,132 @@
1
+ """Insert/delete operations for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import (
8
+ asanyarray,
9
+ asarray,
10
+ broadcast_to,
11
+ concatenate,
12
+ ravel,
13
+ )
14
+ from dask_array._utils import validate_axis
15
+ from dask.utils import derived_from
16
+
17
+
18
+ @derived_from(np)
19
+ def append(arr, values, axis=None):
20
+ """Append values to the end of an array."""
21
+ arr = asanyarray(arr)
22
+ if axis is None:
23
+ if arr.ndim != 1:
24
+ arr = ravel(arr)
25
+ values = ravel(asanyarray(values))
26
+ axis = arr.ndim - 1
27
+ return concatenate((arr, values), axis=axis)
28
+
29
+
30
+ @derived_from(np)
31
+ def ediff1d(ary, to_end=None, to_begin=None):
32
+ """Compute the differences between consecutive elements of an array."""
33
+ ary = asarray(ary)
34
+
35
+ aryf = ary.flatten()
36
+ r = aryf[1:] - aryf[:-1]
37
+
38
+ r = [r]
39
+ if to_begin is not None:
40
+ r = [asarray(to_begin).flatten()] + r
41
+ if to_end is not None:
42
+ r = r + [asarray(to_end).flatten()]
43
+ r = concatenate(r)
44
+
45
+ return r
46
+
47
+
48
+ def _split_at_breaks(array, breaks, axis=0):
49
+ """Split an array into a list of arrays (using slices) at the given breaks
50
+
51
+ >>> _split_at_breaks(np.arange(6), [3, 5])
52
+ [array([0, 1, 2]), array([3, 4]), array([5])]
53
+ """
54
+ from tlz import concat, sliding_window
55
+
56
+ padded_breaks = list(concat([[None], breaks, [None]]))
57
+ slices = [slice(i, j) for i, j in sliding_window(2, padded_breaks)]
58
+ preslice = (slice(None),) * axis
59
+ split_array = [array[preslice + (s,)] for s in slices]
60
+ return split_array
61
+
62
+
63
+ @derived_from(np)
64
+ def insert(arr, obj, values, axis):
65
+ """Insert values along the given axis before the given indices."""
66
+ from tlz import interleave
67
+
68
+ # axis is a required argument here to avoid needing to deal with the numpy
69
+ # default case (which reshapes the array to make it flat)
70
+ arr = asarray(arr)
71
+ axis = validate_axis(axis, arr.ndim)
72
+
73
+ if isinstance(obj, slice):
74
+ obj = np.arange(*obj.indices(arr.shape[axis]))
75
+ obj = np.asarray(obj)
76
+ scalar_obj = obj.ndim == 0
77
+ if scalar_obj:
78
+ obj = np.atleast_1d(obj)
79
+
80
+ obj = np.where(obj < 0, obj + arr.shape[axis], obj)
81
+ if (np.diff(obj) < 0).any():
82
+ raise NotImplementedError("da.insert only implemented for monotonic ``obj`` argument")
83
+
84
+ split_arr = _split_at_breaks(arr, np.unique(obj), axis)
85
+
86
+ if getattr(values, "ndim", 0) == 0:
87
+ # we need to turn values into a dask array
88
+ values = asarray(values)
89
+
90
+ values_shape = tuple(len(obj) if axis == n else s for n, s in enumerate(arr.shape))
91
+ values = broadcast_to(values, values_shape)
92
+ elif scalar_obj:
93
+ values = values[(slice(None),) * axis + (None,)]
94
+
95
+ values = asarray(values)
96
+ values_chunks = tuple(
97
+ values_bd if axis == n else arr_bd for n, (arr_bd, values_bd) in enumerate(zip(arr.chunks, values.chunks))
98
+ )
99
+ values = values.rechunk(values_chunks)
100
+
101
+ counts = np.bincount(obj)[:-1]
102
+ values_breaks = np.cumsum(counts[counts > 0])
103
+ split_values = _split_at_breaks(values, values_breaks, axis)
104
+
105
+ interleaved = list(interleave([split_arr, split_values]))
106
+ interleaved = [i for i in interleaved if i.size]
107
+ return concatenate(interleaved, axis=axis)
108
+
109
+
110
+ @derived_from(np)
111
+ def delete(arr, obj, axis):
112
+ """Remove elements from an array along an axis."""
113
+ # axis is a required argument here to avoid needing to deal with the numpy
114
+ # default case (which reshapes the array to make it flat)
115
+ arr = asarray(arr)
116
+ axis = validate_axis(axis, arr.ndim)
117
+
118
+ if isinstance(obj, slice):
119
+ tmp = np.arange(*obj.indices(arr.shape[axis]))
120
+ obj = tmp[::-1] if obj.step and obj.step < 0 else tmp
121
+ else:
122
+ obj = np.asarray(obj)
123
+ obj = np.where(obj < 0, obj + arr.shape[axis], obj)
124
+ obj = np.unique(obj)
125
+
126
+ target_arr = _split_at_breaks(arr, obj, axis)
127
+
128
+ target_arr = [
129
+ (arr[tuple(slice(1, None) if axis == n else slice(None) for n in range(arr.ndim))] if i != 0 else arr)
130
+ for i, arr in enumerate(target_arr)
131
+ ]
132
+ return concatenate(target_arr, axis=axis)
@@ -0,0 +1,122 @@
1
+ """Miscellaneous utility functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import Array, asarray, elemwise
8
+ from dask_array._utils import validate_axis
9
+ from dask.utils import derived_from
10
+
11
+
12
+ @derived_from(np)
13
+ def ndim(a):
14
+ """Return the number of dimensions of an array."""
15
+ a = asarray(a)
16
+ return a.ndim
17
+
18
+
19
+ @derived_from(np)
20
+ def shape(a):
21
+ """Return the shape of an array."""
22
+ a = asarray(a)
23
+ return a.shape
24
+
25
+
26
+ def result_type(*arrays_and_dtypes):
27
+ """Returns the type from NumPy type promotion rules."""
28
+ args = [a.dtype if isinstance(a, Array) else a for a in arrays_and_dtypes]
29
+ return np.result_type(*args)
30
+
31
+
32
+ def compress(condition, a, axis=None):
33
+ """
34
+ Return selected slices of an array along given axis.
35
+
36
+ This docstring was copied from numpy.compress.
37
+
38
+ Some inconsistencies with the Dask version may exist.
39
+
40
+ Parameters
41
+ ----------
42
+ condition : 1-D array of bools
43
+ Array that selects which entries to return. If len(condition)
44
+ is less than the size of a along the given axis, then output is
45
+ truncated to the length of the condition array.
46
+ a : array_like
47
+ Array from which to extract a part.
48
+ axis : int, optional
49
+ Axis along which to take slices. If None (default), work on the
50
+ flattened array.
51
+
52
+ Returns
53
+ -------
54
+ compressed_array : ndarray
55
+ A copy of a without the slices along axis for which condition
56
+ is false.
57
+ """
58
+ from dask_array._utils import is_arraylike
59
+
60
+ if not is_arraylike(condition):
61
+ condition = np.asarray(condition)
62
+ condition = condition.astype(bool)
63
+ a = asarray(a)
64
+
65
+ if condition.ndim != 1:
66
+ raise ValueError("Condition must be one dimensional")
67
+
68
+ if axis is None:
69
+ a = a.ravel()
70
+ axis = 0
71
+ axis = validate_axis(axis, a.ndim)
72
+
73
+ # Treat `condition` as filled with `False` (if it is too short)
74
+ a = a[tuple(slice(None, len(condition)) if i == axis else slice(None) for i in range(a.ndim))]
75
+
76
+ # Use `condition` to select along 1 dimension
77
+ a = a[tuple(condition if i == axis else slice(None) for i in range(a.ndim))]
78
+
79
+ return a
80
+
81
+
82
+ def _take_constant(indices, a, axis):
83
+ """Take from a constant array using indices."""
84
+ return np.take(a, indices, axis)
85
+
86
+
87
+ def _take_dask_array_from_numpy(a, indices, axis):
88
+ """Take from a numpy array using a dask array of indices."""
89
+ assert isinstance(a, np.ndarray)
90
+ assert isinstance(indices, Array)
91
+
92
+ return elemwise(_take_constant, indices, dtype=a.dtype, a=a, axis=axis)
93
+
94
+
95
+ @derived_from(np)
96
+ def take(a, indices, axis=0):
97
+ """
98
+ Take elements from an array along an axis.
99
+
100
+ This docstring was copied from numpy.take.
101
+
102
+ Parameters
103
+ ----------
104
+ a : dask array
105
+ The source array.
106
+ indices : array_like
107
+ The indices of the values to extract.
108
+ axis : int, optional
109
+ The axis over which to select values.
110
+
111
+ Returns
112
+ -------
113
+ out : dask array
114
+ The returned array has the same type as a.
115
+ """
116
+ a = asarray(a)
117
+ axis = validate_axis(axis, a.ndim)
118
+
119
+ if isinstance(a, np.ndarray) and isinstance(indices, Array):
120
+ return _take_dask_array_from_numpy(a, indices, axis)
121
+ else:
122
+ return a[(slice(None),) * axis + (indices,)]
@@ -0,0 +1,72 @@
1
+ """Nonzero-related functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import asarray, elemwise, stack
8
+ from dask.utils import derived_from
9
+
10
+
11
+ def _isnonzero_vec(v):
12
+ return bool(np.count_nonzero(v))
13
+
14
+
15
+ _isnonzero_vec = np.vectorize(_isnonzero_vec, otypes=[bool])
16
+
17
+
18
+ def _isnonzero(a):
19
+ # Output of np.vectorize can't be pickled
20
+ return _isnonzero_vec(a)
21
+
22
+
23
+ def isnonzero(a):
24
+ """Handle special cases where conversion to bool does not work correctly.
25
+ xref: https://github.com/numpy/numpy/issues/9479
26
+ """
27
+ try:
28
+ np.zeros([], dtype=a.dtype).astype(bool)
29
+ except ValueError:
30
+ return elemwise(_isnonzero, a, dtype=bool)
31
+ else:
32
+ return a.astype(bool)
33
+
34
+
35
+ @derived_from(np)
36
+ def count_nonzero(a, axis=None):
37
+ """Counts the number of non-zero values in the array."""
38
+ return isnonzero(asarray(a)).astype(np.intp).sum(axis=axis)
39
+
40
+
41
+ @derived_from(np)
42
+ def argwhere(a):
43
+ """Find the indices of array elements that are non-zero."""
44
+ from dask_array._routines import compress
45
+ from dask_array.creation import indices
46
+
47
+ a = asarray(a)
48
+
49
+ nz = isnonzero(a).flatten()
50
+
51
+ ind = indices(a.shape, dtype=np.intp, chunks=a.chunks)
52
+ if ind.ndim > 1:
53
+ ind = stack([ind[i].ravel() for i in range(len(ind))], axis=1)
54
+ ind = compress(nz, ind, axis=0)
55
+
56
+ return ind
57
+
58
+
59
+ @derived_from(np)
60
+ def flatnonzero(a):
61
+ """Return indices that are non-zero in the flattened array."""
62
+ return argwhere(asarray(a).ravel())[:, 0]
63
+
64
+
65
+ @derived_from(np)
66
+ def nonzero(a):
67
+ """Return the indices of the elements that are non-zero."""
68
+ ind = argwhere(a)
69
+ if ind.ndim > 1:
70
+ return tuple(ind[:, i] for i in range(ind.shape[1]))
71
+ else:
72
+ return (ind,)
@@ -0,0 +1,123 @@
1
+ """Search functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import asarray
8
+ from dask.utils import derived_from
9
+
10
+
11
+ def _searchsorted_block(x, y, side):
12
+ res = np.searchsorted(x, y, side=side)
13
+ # 0 is only correct for the first block of a, but blockwise doesn't have a way
14
+ # of telling which block is being operated on (unlike map_blocks),
15
+ # so set all 0 values to a special value and set back at the end of searchsorted
16
+ res[res == 0] = -1
17
+ return res[np.newaxis, :]
18
+
19
+
20
+ def searchsorted(a, v, side="left", sorter=None):
21
+ """
22
+ Find indices where elements should be inserted to maintain order.
23
+
24
+ This docstring was copied from numpy.searchsorted.
25
+
26
+ Some inconsistencies with the Dask version may exist.
27
+
28
+ Find the indices into a sorted array `a` such that, if the
29
+ corresponding elements in `v` were inserted before the indices, the
30
+ order of `a` would be preserved.
31
+
32
+ Parameters
33
+ ----------
34
+ a : 1-D array_like
35
+ Input array. If `sorter` is None, then it must be sorted in
36
+ ascending order, otherwise `sorter` must be an array of indices
37
+ that sort it.
38
+ v : array_like
39
+ Values to insert into `a`.
40
+ side : {'left', 'right'}, optional
41
+ If 'left', the index of the first suitable location found is given.
42
+ If 'right', return the last such index. If there is no suitable
43
+ index, return either 0 or N (where N is the length of `a`).
44
+ sorter : 1-D array_like, optional
45
+ Optional array of integer indices that sort array a into ascending
46
+ order. They are typically the result of argsort.
47
+
48
+ Returns
49
+ -------
50
+ indices : int or array of ints
51
+ Array of insertion points with the same shape as `v`,
52
+ or an integer if `v` is a scalar.
53
+ """
54
+ from dask_array._collection import blockwise
55
+ from dask_array.routines._where import where
56
+ from dask_array._utils import array_safe, meta_from_array
57
+
58
+ if a.ndim != 1:
59
+ raise ValueError("Input array a must be one dimensional")
60
+
61
+ if sorter is not None:
62
+ raise NotImplementedError("da.searchsorted with a sorter argument is not supported")
63
+
64
+ # call np.searchsorted for each pair of blocks in a and v
65
+ meta = np.searchsorted(a._meta, v._meta)
66
+ out = blockwise(
67
+ _searchsorted_block,
68
+ list(range(v.ndim + 1)),
69
+ a,
70
+ [0],
71
+ v,
72
+ list(range(1, v.ndim + 1)),
73
+ side,
74
+ None,
75
+ meta=meta,
76
+ adjust_chunks={0: 1}, # one row for each block in a
77
+ )
78
+
79
+ # add offsets to take account of the position of each block within the array a
80
+ a_chunk_sizes = array_safe((0, *a.chunks[0]), like=meta_from_array(a))
81
+ a_chunk_offsets = np.cumsum(a_chunk_sizes)[:-1]
82
+ a_chunk_offsets = a_chunk_offsets[(Ellipsis,) + v.ndim * (np.newaxis,)]
83
+ a_offsets = asarray(a_chunk_offsets, chunks=1)
84
+ out = where(out < 0, out, out + a_offsets)
85
+
86
+ # combine the results from each block (of a)
87
+ out = out.max(axis=0)
88
+
89
+ # fix up any -1 values
90
+ out[out == -1] = 0
91
+
92
+ return out
93
+
94
+
95
+ def _isin_kernel(element, test_elements, assume_unique=False):
96
+ values = np.isin(element.ravel(), test_elements, assume_unique=assume_unique)
97
+ return values.reshape(element.shape + (1,) * test_elements.ndim)
98
+
99
+
100
+ @derived_from(np)
101
+ def isin(element, test_elements, assume_unique=False, invert=False):
102
+ from dask_array._collection import blockwise
103
+
104
+ element = asarray(element)
105
+ test_elements = asarray(test_elements)
106
+ element_axes = tuple(range(element.ndim))
107
+ test_axes = tuple(i + element.ndim for i in range(test_elements.ndim))
108
+ mapped = blockwise(
109
+ _isin_kernel,
110
+ element_axes + test_axes,
111
+ element,
112
+ element_axes,
113
+ test_elements,
114
+ test_axes,
115
+ adjust_chunks={axis: lambda _: 1 for axis in test_axes},
116
+ dtype=bool,
117
+ assume_unique=assume_unique,
118
+ )
119
+
120
+ result = mapped.any(axis=test_axes)
121
+ if invert:
122
+ result = ~result
123
+ return result