dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""Transpose operations: transpose, swapaxes, moveaxis, rollaxis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from dask._task_spec import Task, TaskRef
|
|
10
|
+
from dask_array._blockwise import Blockwise
|
|
11
|
+
from dask_array._utils import meta_from_array
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Transpose(Blockwise):
|
|
15
|
+
_parameters = ["array", "axes"]
|
|
16
|
+
func = staticmethod(np.transpose)
|
|
17
|
+
align_arrays = False
|
|
18
|
+
adjust_chunks = None
|
|
19
|
+
concatenate = None
|
|
20
|
+
token = "transpose"
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def new_axes(self):
|
|
24
|
+
return {}
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def name(self):
|
|
28
|
+
return self._name
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def _meta_provided(self):
|
|
32
|
+
return self.array._meta
|
|
33
|
+
|
|
34
|
+
@functools.cached_property
|
|
35
|
+
def _meta(self):
|
|
36
|
+
meta = self.array._meta
|
|
37
|
+
if meta is None or getattr(meta, "ndim", None) != len(self.axes):
|
|
38
|
+
return meta_from_array(None, ndim=len(self.axes), dtype=self.array.dtype)
|
|
39
|
+
return np.transpose(meta, self.axes)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def dtype(self):
|
|
43
|
+
return self.array.dtype
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def out_ind(self):
|
|
47
|
+
return self.axes
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def kwargs(self):
|
|
51
|
+
return {"axes": self.axes}
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def args(self):
|
|
55
|
+
return (self.array, tuple(range(self.array.ndim)))
|
|
56
|
+
|
|
57
|
+
@functools.cached_property
|
|
58
|
+
def _inverse_axes(self):
|
|
59
|
+
"""Inverse permutation of axes."""
|
|
60
|
+
inv = [0] * len(self.axes)
|
|
61
|
+
for i, a in enumerate(self.axes):
|
|
62
|
+
inv[a] = i
|
|
63
|
+
return tuple(inv)
|
|
64
|
+
|
|
65
|
+
def _task(self, key, block_id: tuple[int, ...]) -> Task:
|
|
66
|
+
"""Generate task for a specific output block."""
|
|
67
|
+
# Map output block_id to input block_id using inverse permutation
|
|
68
|
+
# For axes=(1,0), output block (i,j) needs input block (j,i)
|
|
69
|
+
input_block_id = self._input_block_id(self.array, block_id)
|
|
70
|
+
return Task(key, self.func, TaskRef((self.array._name, *input_block_id)), **self.kwargs)
|
|
71
|
+
|
|
72
|
+
def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
|
|
73
|
+
"""Map output block_id to input block_id using inverse permutation."""
|
|
74
|
+
return tuple(block_id[self._inverse_axes[d]] for d in range(len(block_id)))
|
|
75
|
+
|
|
76
|
+
def _simplify_down(self):
|
|
77
|
+
# Transpose(Transpose(x)) -> single Transpose with composed axes
|
|
78
|
+
if isinstance(self.array, Transpose):
|
|
79
|
+
axes = tuple(self.array.axes[i] for i in self.axes)
|
|
80
|
+
return Transpose(self.array.array, axes)
|
|
81
|
+
# Identity transpose -> return the array
|
|
82
|
+
if self.axes == tuple(range(self.array.ndim)):
|
|
83
|
+
return self.array
|
|
84
|
+
# Transpose(Elemwise(x, y)) -> Elemwise(Transpose(x), Transpose(y))
|
|
85
|
+
from dask_array._blockwise import Elemwise
|
|
86
|
+
|
|
87
|
+
if isinstance(self.array, Elemwise):
|
|
88
|
+
return self._pushdown_through_elemwise()
|
|
89
|
+
|
|
90
|
+
def _pushdown_through_elemwise(self):
|
|
91
|
+
"""Push transpose through elemwise by transposing each input."""
|
|
92
|
+
from dask_array._blockwise import Elemwise
|
|
93
|
+
from dask_array._core_utils import is_scalar_for_elemwise
|
|
94
|
+
|
|
95
|
+
elemwise = self.array
|
|
96
|
+
axes = self.axes
|
|
97
|
+
out_ndim = len(axes)
|
|
98
|
+
|
|
99
|
+
# Only push through if all array inputs have the same ndim as output
|
|
100
|
+
# Broadcasting cases require index transformations we don't handle
|
|
101
|
+
for arg in elemwise.elemwise_args:
|
|
102
|
+
if is_scalar_for_elemwise(arg):
|
|
103
|
+
continue
|
|
104
|
+
if arg.ndim != out_ndim:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
# Check where/out as well
|
|
108
|
+
if hasattr(elemwise.where, "ndim") and elemwise.where.ndim != out_ndim:
|
|
109
|
+
return None
|
|
110
|
+
if hasattr(elemwise.out, "ndim") and elemwise.out.ndim != out_ndim:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
# Transpose each array input
|
|
114
|
+
new_args = [arg if is_scalar_for_elemwise(arg) else Transpose(arg, axes) for arg in elemwise.elemwise_args]
|
|
115
|
+
|
|
116
|
+
# Transpose where/out if they are arrays
|
|
117
|
+
new_where = elemwise.where
|
|
118
|
+
if hasattr(new_where, "ndim"):
|
|
119
|
+
new_where = Transpose(new_where, axes)
|
|
120
|
+
|
|
121
|
+
new_out = elemwise.out
|
|
122
|
+
if hasattr(new_out, "ndim"):
|
|
123
|
+
new_out = Transpose(new_out, axes)
|
|
124
|
+
|
|
125
|
+
return Elemwise(
|
|
126
|
+
elemwise.op,
|
|
127
|
+
elemwise.operand("dtype"),
|
|
128
|
+
elemwise.operand("name"),
|
|
129
|
+
new_where,
|
|
130
|
+
new_out,
|
|
131
|
+
elemwise.operand("_user_kwargs"),
|
|
132
|
+
*new_args,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _simplify_up(self, parent, dependents):
|
|
136
|
+
"""Allow slice and shuffle operations to push through Transpose."""
|
|
137
|
+
from dask_array._shuffle import Shuffle
|
|
138
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
139
|
+
|
|
140
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
141
|
+
return self._accept_slice(parent)
|
|
142
|
+
if isinstance(parent, Shuffle):
|
|
143
|
+
return self._accept_shuffle(parent)
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def _accept_shuffle(self, shuffle_expr):
|
|
147
|
+
"""Accept a shuffle being pushed through Transpose.
|
|
148
|
+
|
|
149
|
+
Maps shuffle axis through transpose to get input axis.
|
|
150
|
+
"""
|
|
151
|
+
axes = self.axes
|
|
152
|
+
shuffle_axis = shuffle_expr.axis
|
|
153
|
+
|
|
154
|
+
# Map shuffle axis through transpose: axes[i] tells us which input axis
|
|
155
|
+
# becomes output axis i. So to shuffle output axis `shuffle_axis`, we need
|
|
156
|
+
# to shuffle input axis `axes[shuffle_axis]`.
|
|
157
|
+
input_axis = axes[shuffle_axis]
|
|
158
|
+
|
|
159
|
+
from dask_array._shuffle import Shuffle
|
|
160
|
+
|
|
161
|
+
shuffled_input = Shuffle(self.array, shuffle_expr.indexer, input_axis, shuffle_expr.operand("name"))
|
|
162
|
+
return Transpose(shuffled_input, axes)
|
|
163
|
+
|
|
164
|
+
def _accept_slice(self, slice_expr):
|
|
165
|
+
"""Accept a slice being pushed through Transpose.
|
|
166
|
+
|
|
167
|
+
Maps output slice indices through transpose axes to get input slice.
|
|
168
|
+
"""
|
|
169
|
+
from numbers import Integral
|
|
170
|
+
|
|
171
|
+
from dask_array._new_collection import new_collection
|
|
172
|
+
|
|
173
|
+
axes = self.axes
|
|
174
|
+
index = slice_expr.index
|
|
175
|
+
|
|
176
|
+
# Don't handle None/newaxis (adds dimensions)
|
|
177
|
+
if any(idx is None for idx in index):
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
# Pad index to full length
|
|
181
|
+
full_index = index + (slice(None),) * (self.ndim - len(index))
|
|
182
|
+
|
|
183
|
+
# Map output slice through transpose axes to get input slice
|
|
184
|
+
# axes[i] tells us which input axis becomes output axis i
|
|
185
|
+
# So output axis i gets slice full_index[i], which should go to input axis axes[i]
|
|
186
|
+
input_index = [slice(None)] * len(axes)
|
|
187
|
+
for out_axis, in_axis in enumerate(axes):
|
|
188
|
+
input_index[in_axis] = full_index[out_axis]
|
|
189
|
+
|
|
190
|
+
sliced_input = new_collection(self.array)[tuple(input_index)]
|
|
191
|
+
|
|
192
|
+
# Check if any dimensions were removed by integer indexing
|
|
193
|
+
has_integers = any(isinstance(idx, Integral) for idx in full_index)
|
|
194
|
+
|
|
195
|
+
if not has_integers:
|
|
196
|
+
# No dimension changes - just apply original transpose
|
|
197
|
+
return Transpose(sliced_input.expr, axes)
|
|
198
|
+
|
|
199
|
+
# Integer indices remove dimensions - compute new axes for remaining dims
|
|
200
|
+
# Track which input dimensions remain (those not indexed by integers)
|
|
201
|
+
remaining_input_dims = [
|
|
202
|
+
in_axis for out_axis, in_axis in enumerate(axes) if not isinstance(full_index[out_axis], Integral)
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
if len(remaining_input_dims) <= 1:
|
|
206
|
+
# 0 or 1 dimension left - no transpose needed
|
|
207
|
+
return sliced_input.expr
|
|
208
|
+
|
|
209
|
+
# Map old input dim indices to new (post-slice) indices
|
|
210
|
+
# After slicing, input dims are renumbered 0, 1, 2, ...
|
|
211
|
+
sorted_remaining = sorted(remaining_input_dims)
|
|
212
|
+
dim_map = {old: new for new, old in enumerate(sorted_remaining)}
|
|
213
|
+
|
|
214
|
+
# Build new axes: for each remaining output dim, what's the new input dim?
|
|
215
|
+
new_axes = tuple(dim_map[in_dim] for in_dim in remaining_input_dims)
|
|
216
|
+
|
|
217
|
+
# Check if it's an identity transpose
|
|
218
|
+
if new_axes == tuple(range(len(new_axes))):
|
|
219
|
+
return sliced_input.expr
|
|
220
|
+
|
|
221
|
+
return Transpose(sliced_input.expr, new_axes)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def transpose(a, axes=None):
|
|
225
|
+
"""Reverse or permute the axes of an array.
|
|
226
|
+
|
|
227
|
+
See Also
|
|
228
|
+
--------
|
|
229
|
+
numpy.transpose
|
|
230
|
+
"""
|
|
231
|
+
from dask_array.core import asanyarray
|
|
232
|
+
|
|
233
|
+
a = asanyarray(a)
|
|
234
|
+
if axes is not None:
|
|
235
|
+
return a.transpose(axes)
|
|
236
|
+
return a.transpose()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def swapaxes(a, axis1, axis2):
|
|
240
|
+
"""Interchange two axes of an array.
|
|
241
|
+
|
|
242
|
+
See Also
|
|
243
|
+
--------
|
|
244
|
+
numpy.swapaxes
|
|
245
|
+
"""
|
|
246
|
+
from dask_array.core import asanyarray
|
|
247
|
+
|
|
248
|
+
a = asanyarray(a)
|
|
249
|
+
if axis1 == axis2:
|
|
250
|
+
return a
|
|
251
|
+
if axis1 < 0:
|
|
252
|
+
axis1 = axis1 + a.ndim
|
|
253
|
+
if axis2 < 0:
|
|
254
|
+
axis2 = axis2 + a.ndim
|
|
255
|
+
ind = list(range(a.ndim))
|
|
256
|
+
ind[axis1], ind[axis2] = ind[axis2], ind[axis1]
|
|
257
|
+
return transpose(a, ind)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def moveaxis(a, source, destination):
|
|
261
|
+
"""Move axes of an array to new positions.
|
|
262
|
+
|
|
263
|
+
See Also
|
|
264
|
+
--------
|
|
265
|
+
numpy.moveaxis
|
|
266
|
+
"""
|
|
267
|
+
from dask_array.core import asanyarray
|
|
268
|
+
from dask_array._numpy_compat import normalize_axis_tuple
|
|
269
|
+
|
|
270
|
+
a = asanyarray(a)
|
|
271
|
+
source = normalize_axis_tuple(source, a.ndim, "source")
|
|
272
|
+
destination = normalize_axis_tuple(destination, a.ndim, "destination")
|
|
273
|
+
if len(source) != len(destination):
|
|
274
|
+
raise ValueError("`source` and `destination` arguments must have the same number of elements")
|
|
275
|
+
|
|
276
|
+
order = [n for n in range(a.ndim) if n not in source]
|
|
277
|
+
|
|
278
|
+
for dest, src in sorted(zip(destination, source)):
|
|
279
|
+
order.insert(dest, src)
|
|
280
|
+
|
|
281
|
+
return transpose(a, order)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def rollaxis(a, axis, start=0):
|
|
285
|
+
"""Roll the specified axis backwards, until it lies in a given position.
|
|
286
|
+
|
|
287
|
+
See Also
|
|
288
|
+
--------
|
|
289
|
+
numpy.rollaxis
|
|
290
|
+
"""
|
|
291
|
+
from dask_array.core import asanyarray
|
|
292
|
+
from dask_array._numpy_compat import normalize_axis_index
|
|
293
|
+
|
|
294
|
+
a = asanyarray(a)
|
|
295
|
+
n = a.ndim
|
|
296
|
+
axis = normalize_axis_index(axis, n)
|
|
297
|
+
if start < 0:
|
|
298
|
+
start += n
|
|
299
|
+
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
|
|
300
|
+
if not (0 <= start < n + 1):
|
|
301
|
+
raise ValueError(msg % ("start", -n, "start", n + 1, start))
|
|
302
|
+
if axis < start:
|
|
303
|
+
start -= 1
|
|
304
|
+
if axis == start:
|
|
305
|
+
return a[...]
|
|
306
|
+
axes = list(range(0, n))
|
|
307
|
+
axes.remove(axis)
|
|
308
|
+
axes.insert(start, axis)
|
|
309
|
+
return transpose(a, axes)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from threading import Lock
|
|
4
|
+
|
|
5
|
+
from ._generator import Generator, default_rng
|
|
6
|
+
from ._random_state import RandomState
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Lazy RNG-state machinery
|
|
10
|
+
#
|
|
11
|
+
# Many of the RandomState methods are exported as functions in da.random for
|
|
12
|
+
# backward compatibility reasons. Their usage is discouraged.
|
|
13
|
+
# Use da.random.default_rng() to get a Generator based rng and use its
|
|
14
|
+
# methods instead.
|
|
15
|
+
|
|
16
|
+
_cached_states: dict[str, RandomState] = {}
|
|
17
|
+
_cached_states_lock = Lock()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _make_api(attr):
|
|
21
|
+
def wrapper(*args, **kwargs):
|
|
22
|
+
from dask_array._backends_array import array_creation_dispatch
|
|
23
|
+
|
|
24
|
+
key = array_creation_dispatch.backend
|
|
25
|
+
with _cached_states_lock:
|
|
26
|
+
try:
|
|
27
|
+
state = _cached_states[key]
|
|
28
|
+
except KeyError:
|
|
29
|
+
_cached_states[key] = state = RandomState()
|
|
30
|
+
return getattr(state, attr)(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
wrapper.__name__ = getattr(RandomState, attr).__name__
|
|
33
|
+
wrapper.__doc__ = getattr(RandomState, attr).__doc__
|
|
34
|
+
return wrapper
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# RandomState only
|
|
38
|
+
|
|
39
|
+
seed = _make_api("seed")
|
|
40
|
+
|
|
41
|
+
beta = _make_api("beta")
|
|
42
|
+
binomial = _make_api("binomial")
|
|
43
|
+
chisquare = _make_api("chisquare")
|
|
44
|
+
choice = _make_api("choice")
|
|
45
|
+
exponential = _make_api("exponential")
|
|
46
|
+
f = _make_api("f")
|
|
47
|
+
gamma = _make_api("gamma")
|
|
48
|
+
geometric = _make_api("geometric")
|
|
49
|
+
gumbel = _make_api("gumbel")
|
|
50
|
+
hypergeometric = _make_api("hypergeometric")
|
|
51
|
+
laplace = _make_api("laplace")
|
|
52
|
+
logistic = _make_api("logistic")
|
|
53
|
+
lognormal = _make_api("lognormal")
|
|
54
|
+
logseries = _make_api("logseries")
|
|
55
|
+
multinomial = _make_api("multinomial")
|
|
56
|
+
negative_binomial = _make_api("negative_binomial")
|
|
57
|
+
noncentral_chisquare = _make_api("noncentral_chisquare")
|
|
58
|
+
noncentral_f = _make_api("noncentral_f")
|
|
59
|
+
normal = _make_api("normal")
|
|
60
|
+
pareto = _make_api("pareto")
|
|
61
|
+
permutation = _make_api("permutation")
|
|
62
|
+
poisson = _make_api("poisson")
|
|
63
|
+
power = _make_api("power")
|
|
64
|
+
random_sample = _make_api("random_sample")
|
|
65
|
+
random = _make_api("random_sample")
|
|
66
|
+
randint = _make_api("randint")
|
|
67
|
+
random_integers = _make_api("random_integers")
|
|
68
|
+
rayleigh = _make_api("rayleigh")
|
|
69
|
+
standard_cauchy = _make_api("standard_cauchy")
|
|
70
|
+
standard_exponential = _make_api("standard_exponential")
|
|
71
|
+
standard_gamma = _make_api("standard_gamma")
|
|
72
|
+
standard_normal = _make_api("standard_normal")
|
|
73
|
+
standard_t = _make_api("standard_t")
|
|
74
|
+
triangular = _make_api("triangular")
|
|
75
|
+
uniform = _make_api("uniform")
|
|
76
|
+
vonmises = _make_api("vonmises")
|
|
77
|
+
wald = _make_api("wald")
|
|
78
|
+
weibull = _make_api("weibull")
|
|
79
|
+
zipf = _make_api("zipf")
|
|
80
|
+
|
|
81
|
+
__all__ = [
|
|
82
|
+
"Generator",
|
|
83
|
+
"RandomState",
|
|
84
|
+
"default_rng",
|
|
85
|
+
"seed",
|
|
86
|
+
"beta",
|
|
87
|
+
"binomial",
|
|
88
|
+
"chisquare",
|
|
89
|
+
"choice",
|
|
90
|
+
"exponential",
|
|
91
|
+
"f",
|
|
92
|
+
"gamma",
|
|
93
|
+
"geometric",
|
|
94
|
+
"gumbel",
|
|
95
|
+
"hypergeometric",
|
|
96
|
+
"laplace",
|
|
97
|
+
"logistic",
|
|
98
|
+
"lognormal",
|
|
99
|
+
"logseries",
|
|
100
|
+
"multinomial",
|
|
101
|
+
"negative_binomial",
|
|
102
|
+
"noncentral_chisquare",
|
|
103
|
+
"noncentral_f",
|
|
104
|
+
"normal",
|
|
105
|
+
"pareto",
|
|
106
|
+
"permutation",
|
|
107
|
+
"poisson",
|
|
108
|
+
"power",
|
|
109
|
+
"random_sample",
|
|
110
|
+
"random",
|
|
111
|
+
"randint",
|
|
112
|
+
"random_integers",
|
|
113
|
+
"rayleigh",
|
|
114
|
+
"standard_cauchy",
|
|
115
|
+
"standard_exponential",
|
|
116
|
+
"standard_gamma",
|
|
117
|
+
"standard_normal",
|
|
118
|
+
"standard_t",
|
|
119
|
+
"triangular",
|
|
120
|
+
"uniform",
|
|
121
|
+
"vonmises",
|
|
122
|
+
"wald",
|
|
123
|
+
"weibull",
|
|
124
|
+
"zipf",
|
|
125
|
+
]
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from itertools import product
|
|
4
|
+
from numbers import Integral
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from dask._task_spec import TaskRef
|
|
9
|
+
from dask_array._collection import Array
|
|
10
|
+
from dask_array.core._conversion import asarray
|
|
11
|
+
from dask_array.io import IO
|
|
12
|
+
from dask_array._core_utils import normalize_chunks
|
|
13
|
+
from dask_array._utils import asarray_safe
|
|
14
|
+
from dask_array._backends_array import array_creation_dispatch
|
|
15
|
+
from dask.utils import cached_property, random_state_data
|
|
16
|
+
|
|
17
|
+
from ._expr import _spawn_bitgens
|
|
18
|
+
from ._generator import Generator
|
|
19
|
+
from ._random_state import RandomState
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _choice_rng(state_data, a, size, replace, p, axis, shuffle):
|
|
23
|
+
from ._expr import _rng_from_bitgen
|
|
24
|
+
|
|
25
|
+
state = _rng_from_bitgen(state_data)
|
|
26
|
+
return state.choice(a, size=size, replace=replace, p=p, axis=axis, shuffle=shuffle)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _choice_rs(state_data, a, size, replace, p):
|
|
30
|
+
state = array_creation_dispatch.RandomState(state_data)
|
|
31
|
+
return state.choice(a, size=size, replace=replace, p=p)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _choice_validate_params(state, a, size, replace, p, axis, chunks):
|
|
35
|
+
"""Validate and normalize parameters for choice.
|
|
36
|
+
|
|
37
|
+
Returns expressions for array/p (or int/None) so they participate in lowering.
|
|
38
|
+
"""
|
|
39
|
+
# Normalize and validate `a`
|
|
40
|
+
if isinstance(a, Integral):
|
|
41
|
+
if isinstance(state, Generator):
|
|
42
|
+
if state._backend_name == "cupy":
|
|
43
|
+
raise NotImplementedError("`choice` not supported for cupy-backed `Generator`.")
|
|
44
|
+
meta = state._backend.random.default_rng().choice(1, size=(), p=None)
|
|
45
|
+
elif isinstance(state, RandomState):
|
|
46
|
+
# On windows the output dtype differs if p is provided or
|
|
47
|
+
# # absent, see https://github.com/numpy/numpy/issues/9867
|
|
48
|
+
dummy_p = state._backend.array([1]) if p is not None else p
|
|
49
|
+
meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p)
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError("Unknown generator class")
|
|
52
|
+
len_a = a
|
|
53
|
+
if a < 0:
|
|
54
|
+
raise ValueError("a must be greater than 0")
|
|
55
|
+
a_expr = None # No expression for int a
|
|
56
|
+
else:
|
|
57
|
+
a = asarray(a)
|
|
58
|
+
a = a.rechunk(a.shape)
|
|
59
|
+
meta = a._meta
|
|
60
|
+
if a.ndim != 1:
|
|
61
|
+
raise ValueError("a must be one dimensional")
|
|
62
|
+
len_a = len(a)
|
|
63
|
+
a_expr = a.expr # Store expression so it gets lowered
|
|
64
|
+
|
|
65
|
+
# Normalize and validate `p`
|
|
66
|
+
p_expr = None
|
|
67
|
+
if p is not None:
|
|
68
|
+
if not isinstance(p, Array):
|
|
69
|
+
# If p is not a dask array, first check the sum is close
|
|
70
|
+
# to 1 before converting.
|
|
71
|
+
p = asarray_safe(p, like=p)
|
|
72
|
+
if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0):
|
|
73
|
+
raise ValueError("probabilities do not sum to 1")
|
|
74
|
+
p = asarray(p)
|
|
75
|
+
else:
|
|
76
|
+
p = p.rechunk(p.shape)
|
|
77
|
+
|
|
78
|
+
if p.ndim != 1:
|
|
79
|
+
raise ValueError("p must be one dimensional")
|
|
80
|
+
if len(p) != len_a:
|
|
81
|
+
raise ValueError("a and p must have the same size")
|
|
82
|
+
|
|
83
|
+
p_expr = p.expr # Store expression so it gets lowered
|
|
84
|
+
|
|
85
|
+
if size is None:
|
|
86
|
+
size = ()
|
|
87
|
+
elif not isinstance(size, (tuple, list)):
|
|
88
|
+
size = (size,)
|
|
89
|
+
|
|
90
|
+
if axis != 0:
|
|
91
|
+
raise ValueError("axis must be 0 since a is one dimensional")
|
|
92
|
+
|
|
93
|
+
chunks = normalize_chunks(chunks, size, dtype=np.float64)
|
|
94
|
+
if not replace and len(chunks[0]) > 1:
|
|
95
|
+
err_msg = "replace=False is not currently supported for dask.array.choice with multi-chunk output arrays"
|
|
96
|
+
raise NotImplementedError(err_msg)
|
|
97
|
+
|
|
98
|
+
# For int a, return the int value; for array a, return None (use a_expr)
|
|
99
|
+
a_val = a if isinstance(a, Integral) else None
|
|
100
|
+
return a_val, a_expr, size, replace, p_expr, axis, chunks, meta
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class RandomChoice(IO):
|
|
104
|
+
_parameters = [
|
|
105
|
+
"a_val", # int value of a (or None if a is an array)
|
|
106
|
+
"a_expr", # expression for a (or None if a is an int)
|
|
107
|
+
"chunks",
|
|
108
|
+
"_meta",
|
|
109
|
+
"_state",
|
|
110
|
+
"replace",
|
|
111
|
+
"p_expr", # expression for p (or None)
|
|
112
|
+
"axis",
|
|
113
|
+
"shuffle",
|
|
114
|
+
]
|
|
115
|
+
_defaults = {"axis": None, "shuffle": None}
|
|
116
|
+
_funcname = "da.random.choice-"
|
|
117
|
+
|
|
118
|
+
@cached_property
|
|
119
|
+
def chunks(self):
|
|
120
|
+
return self.operand("chunks")
|
|
121
|
+
|
|
122
|
+
@cached_property
|
|
123
|
+
def sizes(self):
|
|
124
|
+
return list(product(*self.chunks))
|
|
125
|
+
|
|
126
|
+
@cached_property
|
|
127
|
+
def state_data(self):
|
|
128
|
+
return random_state_data(len(self.sizes), self._state)
|
|
129
|
+
|
|
130
|
+
@cached_property
|
|
131
|
+
def _meta(self):
|
|
132
|
+
return self.operand("_meta")
|
|
133
|
+
|
|
134
|
+
# No custom dependencies() needed - base class finds Expr operands automatically
|
|
135
|
+
# (a_expr and p_expr are included when they're expressions, excluded when None)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def _a_arg(self):
|
|
139
|
+
"""Value to pass to choice: int or TaskRef to single-chunk array."""
|
|
140
|
+
if self.a_val is not None:
|
|
141
|
+
return self.a_val
|
|
142
|
+
return TaskRef((self.a_expr._name, 0))
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def _p_arg(self):
|
|
146
|
+
"""Value to pass to choice: None or TaskRef to single-chunk array."""
|
|
147
|
+
if self.p_expr is None:
|
|
148
|
+
return None
|
|
149
|
+
return TaskRef((self.p_expr._name, 0))
|
|
150
|
+
|
|
151
|
+
def _layer(self) -> dict:
|
|
152
|
+
keys = product([self._name], *[range(len(bd)) for bd in self.chunks])
|
|
153
|
+
return {
|
|
154
|
+
k: (_choice_rs, state, self._a_arg, size, self.replace, self._p_arg)
|
|
155
|
+
for k, state, size in zip(keys, self.state_data, self.sizes)
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class RandomChoiceGenerator(RandomChoice):
|
|
160
|
+
# Keep axis and shuffle as required parameters (no defaults)
|
|
161
|
+
_defaults = {}
|
|
162
|
+
|
|
163
|
+
@cached_property
|
|
164
|
+
def state_data(self):
|
|
165
|
+
return _spawn_bitgens(self._state, len(self.sizes))
|
|
166
|
+
|
|
167
|
+
def _layer(self) -> dict:
|
|
168
|
+
keys = product([self._name], *[range(len(bd)) for bd in self.chunks])
|
|
169
|
+
return {
|
|
170
|
+
k: (
|
|
171
|
+
_choice_rng,
|
|
172
|
+
bitgen,
|
|
173
|
+
self._a_arg,
|
|
174
|
+
size,
|
|
175
|
+
self.replace,
|
|
176
|
+
self._p_arg,
|
|
177
|
+
self.axis,
|
|
178
|
+
self.shuffle,
|
|
179
|
+
)
|
|
180
|
+
for k, bitgen, size in zip(keys, self.state_data, self.sizes)
|
|
181
|
+
}
|