dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,113 @@
1
+ """Selection functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import asarray, elemwise
8
+ from dask.utils import derived_from
9
+
10
+
11
+ @derived_from(np)
12
+ def digitize(a, bins, right=False):
13
+ """Return the indices of the bins to which each value in input array belongs.
14
+
15
+ Parameters
16
+ ----------
17
+ a : dask array
18
+ Input array to be binned.
19
+ bins : array_like
20
+ Array of bins. Must be 1-dimensional and monotonic.
21
+ right : bool, optional
22
+ Indicating whether the intervals include the right or left bin edge.
23
+
24
+ Returns
25
+ -------
26
+ indices : dask array of ints
27
+ Output array of indices.
28
+ """
29
+ bins = np.asarray(bins)
30
+ if bins.ndim != 1:
31
+ raise ValueError("bins must be 1-dimensional")
32
+
33
+ dtype = np.digitize(np.asarray([0], like=bins), bins, right=right).dtype
34
+ return elemwise(np.digitize, a, dtype=dtype, bins=bins, right=right)
35
+
36
+
37
+ def _variadic_choose(a, *choices):
38
+ return np.choose(a, choices)
39
+
40
+
41
+ @derived_from(np)
42
+ def choose(a, choices):
43
+ a = asarray(a)
44
+ choices = [asarray(c) for c in choices]
45
+ return elemwise(_variadic_choose, a, *choices)
46
+
47
+
48
+ @derived_from(np)
49
+ def extract(condition, arr):
50
+ from dask_array.routines._misc import compress
51
+
52
+ condition = asarray(condition).astype(bool)
53
+ arr = asarray(arr)
54
+ return compress(condition.ravel(), arr.ravel())
55
+
56
+
57
+ def _int_piecewise(x, *condlist, funclist=None, func_args=(), func_kw=None):
58
+ return np.piecewise(x, list(condlist), funclist, *func_args, **(func_kw or {}))
59
+
60
+
61
+ @derived_from(np)
62
+ def piecewise(x, condlist, funclist, *args, **kw):
63
+ x = asarray(x)
64
+ return elemwise(
65
+ _int_piecewise,
66
+ x,
67
+ *condlist,
68
+ dtype=x.dtype,
69
+ name="piecewise",
70
+ funclist=funclist,
71
+ func_args=args,
72
+ func_kw=kw,
73
+ )
74
+
75
+
76
+ def _select(*args, **kwargs):
77
+ split_at = len(args) // 2
78
+ condlist = args[:split_at]
79
+ choicelist = args[split_at:]
80
+ return np.select(condlist, choicelist, **kwargs)
81
+
82
+
83
+ @derived_from(np)
84
+ def select(condlist, choicelist, default=0):
85
+ from dask_array._collection import blockwise
86
+ from dask_array.routines._misc import result_type
87
+
88
+ if len(condlist) != len(choicelist):
89
+ raise ValueError("list of cases must be same length as list of conditions")
90
+
91
+ if len(condlist) == 0:
92
+ raise ValueError("select with an empty condition list is not possible")
93
+
94
+ choicelist = [asarray(choice) for choice in choicelist]
95
+
96
+ try:
97
+ intermediate_dtype = result_type(*choicelist)
98
+ except TypeError as e:
99
+ msg = "Choicelist elements do not have a common dtype."
100
+ raise TypeError(msg) from e
101
+
102
+ blockwise_shape = tuple(range(choicelist[0].ndim))
103
+ condargs = [arg for elem in condlist for arg in (elem, blockwise_shape)]
104
+ choiceargs = [arg for elem in choicelist for arg in (elem, blockwise_shape)]
105
+
106
+ return blockwise(
107
+ _select,
108
+ blockwise_shape,
109
+ *condargs,
110
+ *choiceargs,
111
+ dtype=intermediate_dtype,
112
+ default=default,
113
+ )
@@ -0,0 +1,171 @@
1
+ """Statistical functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings
6
+
7
+ import numpy as np
8
+
9
+ from dask_array._collection import (
10
+ array,
11
+ asanyarray,
12
+ asarray,
13
+ broadcast_to,
14
+ concatenate,
15
+ )
16
+ from dask.utils import derived_from
17
+
18
+
19
+ def result_type(*arrays_and_dtypes):
20
+ """Returns the type from NumPy type promotion rules."""
21
+ from dask_array._collection import Array
22
+
23
+ args = [a.dtype if isinstance(a, Array) else a for a in arrays_and_dtypes]
24
+ return np.result_type(*args)
25
+
26
+
27
+ def average(a, axis=None, weights=None, returned=False, keepdims=False):
28
+ """Compute the weighted average along the specified axis."""
29
+ a = asanyarray(a)
30
+
31
+ if weights is None:
32
+ avg = a.mean(axis, keepdims=keepdims)
33
+ scl = avg.dtype.type(a.size / avg.size)
34
+ else:
35
+ wgt = asanyarray(weights)
36
+
37
+ if issubclass(a.dtype.type, (np.integer, np.bool_)):
38
+ result_dtype = result_type(a.dtype, wgt.dtype, "f8")
39
+ else:
40
+ result_dtype = result_type(a.dtype, wgt.dtype)
41
+
42
+ if a.shape != wgt.shape:
43
+ if axis is None:
44
+ raise TypeError("Axis must be specified when shapes of a and weights differ.")
45
+ if wgt.ndim != 1:
46
+ raise TypeError("1D weights expected when shapes of a and weights differ.")
47
+ if wgt.shape[0] != a.shape[axis]:
48
+ raise ValueError("Length of weights not compatible with specified axis.")
49
+
50
+ wgt = broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
51
+ wgt = wgt.swapaxes(-1, axis)
52
+
53
+ scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
54
+ from dask_array._ufunc import multiply
55
+
56
+ avg = multiply(a, wgt, dtype=result_dtype).sum(axis, keepdims=keepdims) / scl
57
+
58
+ if returned:
59
+ if scl.shape != avg.shape:
60
+ scl = broadcast_to(scl, avg.shape)
61
+ return avg, scl
62
+ else:
63
+ return avg
64
+
65
+
66
+ @derived_from(np)
67
+ def cov(
68
+ m,
69
+ y=None,
70
+ rowvar=True,
71
+ bias=False,
72
+ ddof=None,
73
+ fweights=None,
74
+ aweights=None,
75
+ *,
76
+ dtype=None,
77
+ ):
78
+ """Estimate a covariance matrix."""
79
+ from dask_array._ufunc import true_divide
80
+ from dask_array.linalg import dot
81
+
82
+ if ddof is not None and ddof != int(ddof):
83
+ raise ValueError("ddof must be integer")
84
+
85
+ m = asarray(m)
86
+
87
+ if y is None:
88
+ dtype = result_type(m, np.float64)
89
+ else:
90
+ y = asarray(y)
91
+ dtype = result_type(m, y, np.float64)
92
+
93
+ if m.ndim > 2:
94
+ raise ValueError("m has more than 2 dimensions")
95
+ if y is not None and y.ndim > 2:
96
+ raise ValueError("y has more than 2 dimensions")
97
+
98
+ X = array(m, ndmin=2, dtype=dtype)
99
+
100
+ if ddof is None:
101
+ ddof = 1 if bias == 0 else 0
102
+
103
+ if not rowvar and m.ndim != 1:
104
+ X = X.T
105
+ if X.shape[0] == 0:
106
+ return array([]).reshape(0, 0)
107
+ if y is not None:
108
+ y = array(y, ndmin=2, dtype=dtype)
109
+ if not rowvar and y.shape[0] != 1:
110
+ y = y.T
111
+ X = concatenate((X, y), axis=0)
112
+
113
+ w = None
114
+ if fweights is not None:
115
+ fweights = asarray(fweights, dtype=float)
116
+ if fweights.ndim > 1:
117
+ raise RuntimeError("cannot handle multidimensional fweights")
118
+ if fweights.shape[0] != X.shape[1]:
119
+ raise RuntimeError("incompatible numbers of samples and fweights")
120
+ w = fweights
121
+ if aweights is not None:
122
+ aweights = asarray(aweights, dtype=float)
123
+ if aweights.ndim > 1:
124
+ raise RuntimeError("cannot handle multidimensional aweights")
125
+ if aweights.shape[0] != X.shape[1]:
126
+ raise RuntimeError("incompatible numbers of samples and aweights")
127
+ if w is None:
128
+ w = aweights
129
+ else:
130
+ w *= aweights
131
+
132
+ avg, w_sum = average(X, axis=1, weights=w, returned=True)
133
+ w_sum = w_sum[0]
134
+
135
+ if w is None:
136
+ fact = X.shape[1] - ddof
137
+ elif ddof == 0:
138
+ fact = w_sum
139
+ elif aweights is None:
140
+ fact = w_sum - ddof
141
+ else:
142
+ fact = w_sum - ddof * (w * aweights).sum() / w_sum
143
+
144
+ if fact <= 0:
145
+ warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning)
146
+ fact = 0.0
147
+
148
+ X -= avg[:, None]
149
+ if w is None:
150
+ X_T = X.T
151
+ else:
152
+ X_T = (X * w).T
153
+ c = dot(X, X_T.conj())
154
+ c *= true_divide(1, fact)
155
+ return c.squeeze()
156
+
157
+
158
+ @derived_from(np)
159
+ def corrcoef(x, y=None, rowvar=1):
160
+ """Return Pearson product-moment correlation coefficients."""
161
+ from dask_array._ufunc import sqrt
162
+ from dask_array.creation import diag
163
+
164
+ c = cov(x, y, rowvar)
165
+
166
+ if c.shape == ():
167
+ return c / c
168
+ d = diag(c)
169
+ d = d.reshape((d.shape[0], 1))
170
+ sqr_d = sqrt(d)
171
+ return (c / sqr_d) / sqr_d.T
@@ -0,0 +1,82 @@
1
+ """Top-k selection for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import partial
6
+ from numbers import Number
7
+
8
+ import numpy as np
9
+
10
+ from dask_array._collection import asarray
11
+ from dask_array._utils import validate_axis
12
+
13
+
14
+ def topk(a, k, axis=-1, split_every=None):
15
+ """Extract the k largest elements from a on the given axis.
16
+
17
+ Returns them sorted from largest to smallest. If k is negative,
18
+ extract the -k smallest elements instead, and return them sorted
19
+ from smallest to largest.
20
+ """
21
+ from dask_array import _chunk as chunk
22
+ from dask_array.reductions import reduction
23
+
24
+ a = asarray(a)
25
+ axis = validate_axis(axis, a.ndim)
26
+
27
+ chunk_combine = partial(chunk.topk, k=k)
28
+ aggregate = partial(chunk.topk_aggregate, k=k)
29
+
30
+ return reduction(
31
+ a,
32
+ chunk=chunk_combine,
33
+ combine=chunk_combine,
34
+ aggregate=aggregate,
35
+ axis=axis,
36
+ keepdims=True,
37
+ dtype=a.dtype,
38
+ split_every=split_every,
39
+ output_size=abs(k),
40
+ )
41
+
42
+
43
+ def argtopk(a, k, axis=-1, split_every=None):
44
+ """Extract the indices of the k largest elements from a on the given axis.
45
+
46
+ Returns them sorted from largest to smallest. If k is negative,
47
+ extract the indices of the -k smallest elements instead.
48
+ """
49
+ from dask_array import _chunk as chunk
50
+ from dask_array.creation import arange
51
+ from dask_array.reductions import reduction
52
+
53
+ a = asarray(a)
54
+ axis = validate_axis(axis, a.ndim)
55
+
56
+ idx = arange(a.shape[axis], chunks=(a.chunks[axis],), dtype=np.intp)
57
+ idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(a.ndim))]
58
+ a_plus_idx = a.map_blocks(chunk.argtopk_preprocess, idx, dtype=object)
59
+
60
+ chunk_combine = partial(chunk.argtopk, k=k)
61
+ aggregate = partial(chunk.argtopk_aggregate, k=k)
62
+
63
+ if isinstance(axis, Number):
64
+ naxis = 1
65
+ else:
66
+ naxis = len(axis)
67
+
68
+ meta = a._meta.astype(np.intp).reshape((0,) * (a.ndim - naxis + 1))
69
+
70
+ return reduction(
71
+ a_plus_idx,
72
+ chunk=chunk_combine,
73
+ combine=chunk_combine,
74
+ aggregate=aggregate,
75
+ axis=axis,
76
+ keepdims=True,
77
+ dtype=np.intp,
78
+ split_every=split_every,
79
+ concatenate=False,
80
+ output_size=abs(k),
81
+ meta=meta,
82
+ )
@@ -0,0 +1,74 @@
1
+ """Triangular matrix functions for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._collection import asarray
8
+ from dask.utils import derived_from
9
+
10
+
11
+ @derived_from(np)
12
+ def tril(m, k=0):
13
+ from dask_array.creation import tri
14
+ from dask_array.routines._where import where
15
+ from dask_array._utils import meta_from_array
16
+
17
+ m = asarray(m)
18
+ mask = tri(
19
+ *m.shape[-2:],
20
+ k=k,
21
+ dtype=bool,
22
+ chunks=m.chunks[-2:],
23
+ like=meta_from_array(m),
24
+ )
25
+
26
+ return where(mask, m, np.zeros_like(m._meta, shape=(1,)))
27
+
28
+
29
+ @derived_from(np)
30
+ def triu(m, k=0):
31
+ from dask_array.creation import tri
32
+ from dask_array.routines._where import where
33
+ from dask_array._utils import meta_from_array
34
+
35
+ m = asarray(m)
36
+ mask = tri(
37
+ *m.shape[-2:],
38
+ k=k - 1,
39
+ dtype=bool,
40
+ chunks=m.chunks[-2:],
41
+ like=meta_from_array(m),
42
+ )
43
+
44
+ return where(mask, np.zeros_like(m._meta, shape=(1,)), m)
45
+
46
+
47
+ @derived_from(np)
48
+ def tril_indices(n, k=0, m=None, chunks="auto"):
49
+ from dask_array.creation import tri
50
+ from dask_array.routines._nonzero import nonzero
51
+
52
+ return nonzero(tri(n, m, k=k, dtype=bool, chunks=chunks))
53
+
54
+
55
+ @derived_from(np)
56
+ def tril_indices_from(arr, k=0):
57
+ if arr.ndim != 2:
58
+ raise ValueError("input array must be 2-d")
59
+ return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1], chunks=arr.chunks)
60
+
61
+
62
+ @derived_from(np)
63
+ def triu_indices(n, k=0, m=None, chunks="auto"):
64
+ from dask_array.creation import tri
65
+ from dask_array.routines._nonzero import nonzero
66
+
67
+ return nonzero(~tri(n, m, k=k - 1, dtype=bool, chunks=chunks))
68
+
69
+
70
+ @derived_from(np)
71
+ def triu_indices_from(arr, k=0):
72
+ if arr.ndim != 2:
73
+ raise ValueError("input array must be 2-d")
74
+ return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1], chunks=arr.chunks)
@@ -0,0 +1,232 @@
1
+ """Unique implementation for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import cached_property
6
+
7
+ import numpy as np
8
+
9
+ from dask._task_spec import List, Task, TaskRef
10
+ from dask_array._collection import (
11
+ asarray,
12
+ concatenate,
13
+ new_collection,
14
+ ravel,
15
+ )
16
+ from dask_array._expr import ArrayExpr
17
+ from dask_array._utils import meta_from_array
18
+ from dask.utils import derived_from
19
+
20
+
21
+ def _unique_internal(ar, indices, counts, return_inverse=False):
22
+ """Helper function for np.unique with structured array output."""
23
+ return_index = indices is not None
24
+ return_counts = counts is not None
25
+
26
+ u = np.unique(ar)
27
+
28
+ dt = [("values", u.dtype)]
29
+ if return_index:
30
+ dt.append(("indices", np.intp))
31
+ if return_inverse:
32
+ dt.append(("inverse", np.intp))
33
+ if return_counts:
34
+ dt.append(("counts", np.intp))
35
+
36
+ r = np.empty(u.shape, dtype=dt)
37
+ r["values"] = u
38
+ if return_inverse:
39
+ r["inverse"] = np.arange(len(r), dtype=np.intp)
40
+ if return_index or return_counts:
41
+ for i, v in enumerate(r["values"]):
42
+ m = ar == v
43
+ if return_index:
44
+ indices[m].min(keepdims=True, out=r["indices"][i : i + 1])
45
+ if return_counts:
46
+ counts[m].sum(keepdims=True, out=r["counts"][i : i + 1])
47
+
48
+ return r
49
+
50
+
51
+ class UniqueChunked(ArrayExpr):
52
+ """Expression for per-chunk unique computation."""
53
+
54
+ _parameters = ["x", "indices", "counts", "out_dtype"]
55
+ _defaults = {"indices": None, "counts": None}
56
+
57
+ @cached_property
58
+ def _meta(self):
59
+ return np.empty((0,), dtype=self.out_dtype)
60
+
61
+ @cached_property
62
+ def chunks(self):
63
+ nchunks = len(self.x.chunks[0])
64
+ return ((np.nan,) * nchunks,)
65
+
66
+ @cached_property
67
+ def _name(self):
68
+ return f"unique-chunk-{self.deterministic_token}"
69
+
70
+ def _layer(self):
71
+ dsk = {}
72
+ for i in range(len(self.x.chunks[0])):
73
+ key = (self._name, i)
74
+ x_ref = TaskRef((self.x._name, i))
75
+ idx_ref = TaskRef((self.indices._name, i)) if self.indices is not None else None
76
+ cnt_ref = TaskRef((self.counts._name, i)) if self.counts is not None else None
77
+ dsk[key] = Task(key, _unique_internal, x_ref, idx_ref, cnt_ref, False)
78
+ return dsk
79
+
80
+ @property
81
+ def _dependencies(self):
82
+ deps = [self.x]
83
+ if self.indices is not None:
84
+ deps.append(self.indices)
85
+ if self.counts is not None:
86
+ deps.append(self.counts)
87
+ return deps
88
+
89
+
90
+ def _unique_aggregate_func(chunks, return_inverse):
91
+ """Aggregate unique results from multiple chunks."""
92
+ combined = np.concatenate(chunks)
93
+ return_index = "indices" in combined.dtype.names
94
+ return_counts = "counts" in combined.dtype.names
95
+
96
+ return _unique_internal(
97
+ combined["values"],
98
+ combined["indices"] if return_index else None,
99
+ combined["counts"] if return_counts else None,
100
+ return_inverse=return_inverse,
101
+ )
102
+
103
+
104
+ class UniqueAggregate(ArrayExpr):
105
+ """Expression for aggregating unique results from all chunks."""
106
+
107
+ _parameters = ["chunked", "return_inverse", "out_dtype"]
108
+ _defaults = {"return_inverse": False}
109
+
110
+ @cached_property
111
+ def _meta(self):
112
+ return np.empty((0,), dtype=self.out_dtype)
113
+
114
+ @cached_property
115
+ def chunks(self):
116
+ return ((np.nan,),)
117
+
118
+ @cached_property
119
+ def _name(self):
120
+ return f"unique-aggregate-{self.deterministic_token}"
121
+
122
+ def _layer(self):
123
+ chunk_keys = [(self.chunked._name, i) for i in range(len(self.chunked.chunks[0]))]
124
+ key = (self._name, 0)
125
+ chunks_list = List(*[TaskRef(k) for k in chunk_keys])
126
+ dsk = {key: Task(key, _unique_aggregate_func, chunks_list, self.return_inverse)}
127
+ return dsk
128
+
129
+ @property
130
+ def _dependencies(self):
131
+ return [self.chunked]
132
+
133
+
134
+ def unique_no_structured_arr(ar, return_index=False, return_inverse=False, return_counts=False):
135
+ """Simplified version of unique for arrays that don't support structured arrays."""
136
+ from dask_array._blockwise import Blockwise
137
+ from dask_array._expr import ChunksOverride
138
+ from dask_array.reductions import _tree_reduce
139
+
140
+ if return_index or return_inverse or return_counts:
141
+ raise ValueError(
142
+ "dask.array.unique does not support `return_index`, `return_inverse` "
143
+ "or `return_counts` with array types that don't support structured arrays."
144
+ )
145
+
146
+ ar = ravel(ar)
147
+ out = Blockwise(np.unique, "i", ar, "i", dtype=ar.dtype)
148
+ chunked = new_collection(out)
149
+ chunked = new_collection(ChunksOverride(chunked.expr, ((np.nan,) * len(ar.chunks[0]),)))
150
+
151
+ def _unique_agg(arrays, axis, keepdims):
152
+ if not isinstance(arrays, list):
153
+ arrays = [arrays]
154
+ return np.unique(np.concatenate(arrays))
155
+
156
+ return _tree_reduce(
157
+ chunked.expr,
158
+ aggregate=_unique_agg,
159
+ axis=(0,),
160
+ keepdims=False,
161
+ dtype=ar.dtype,
162
+ concatenate=False,
163
+ )
164
+
165
+
166
+ @derived_from(np)
167
+ def unique(ar, return_index=False, return_inverse=False, return_counts=False):
168
+ """Find the unique elements of an array."""
169
+ from dask_array.creation import arange, ones
170
+ from dask_array._numpy_compat import NUMPY_GE_200
171
+
172
+ try:
173
+ meta = meta_from_array(ar)
174
+ np.empty_like(meta, dtype=[("a", int), ("b", float)])
175
+ except TypeError:
176
+ return unique_no_structured_arr(
177
+ ar,
178
+ return_index=return_index,
179
+ return_inverse=return_inverse,
180
+ return_counts=return_counts,
181
+ )
182
+
183
+ orig_shape = ar.shape
184
+ ar = ravel(ar)
185
+
186
+ out_dtype = [("values", ar.dtype)]
187
+ indices_arr = None
188
+ counts_arr = None
189
+
190
+ if return_index:
191
+ indices_arr = arange(ar.shape[0], dtype=np.intp, chunks=ar.chunks[0])
192
+ out_dtype.append(("indices", np.intp))
193
+ if return_counts:
194
+ counts_arr = ones((ar.shape[0],), dtype=np.intp, chunks=ar.chunks[0])
195
+ out_dtype.append(("counts", np.intp))
196
+
197
+ out_dtype = np.dtype(out_dtype)
198
+ chunked = UniqueChunked(
199
+ ar.expr,
200
+ indices_arr.expr if indices_arr is not None else None,
201
+ counts_arr.expr if counts_arr is not None else None,
202
+ out_dtype,
203
+ )
204
+
205
+ final_dtype = out_dtype if not return_inverse else np.dtype(list(out_dtype.descr) + [("inverse", np.intp)])
206
+ aggregated = new_collection(UniqueAggregate(chunked, return_inverse, final_dtype))
207
+
208
+ result = [aggregated["values"]]
209
+ if return_index:
210
+ result.append(aggregated["indices"])
211
+ if return_inverse:
212
+ matches = (ar[:, None] == aggregated["values"][None, :]).astype(np.intp)
213
+ inverse = (matches * aggregated["inverse"]).sum(axis=1)
214
+ if NUMPY_GE_200:
215
+ from dask_array._reshape import reshape
216
+
217
+ inverse = reshape(inverse, orig_shape)
218
+ result.append(inverse)
219
+ if return_counts:
220
+ result.append(aggregated["counts"])
221
+
222
+ if len(result) == 1:
223
+ return result[0]
224
+ return tuple(result)
225
+
226
+
227
+ @derived_from(np)
228
+ def union1d(ar1, ar2):
229
+ """Find the union of two arrays."""
230
+ ar1 = asarray(ar1)
231
+ ar2 = asarray(ar2)
232
+ return unique(concatenate((ravel(ar1), ravel(ar2))))