dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,309 @@
1
+ """Transpose operations: transpose, swapaxes, moveaxis, rollaxis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ import numpy as np
8
+
9
+ from dask._task_spec import Task, TaskRef
10
+ from dask_array._blockwise import Blockwise
11
+ from dask_array._utils import meta_from_array
12
+
13
+
14
+ class Transpose(Blockwise):
15
+ _parameters = ["array", "axes"]
16
+ func = staticmethod(np.transpose)
17
+ align_arrays = False
18
+ adjust_chunks = None
19
+ concatenate = None
20
+ token = "transpose"
21
+
22
+ @property
23
+ def new_axes(self):
24
+ return {}
25
+
26
+ @property
27
+ def name(self):
28
+ return self._name
29
+
30
+ @property
31
+ def _meta_provided(self):
32
+ return self.array._meta
33
+
34
+ @functools.cached_property
35
+ def _meta(self):
36
+ meta = self.array._meta
37
+ if meta is None or getattr(meta, "ndim", None) != len(self.axes):
38
+ return meta_from_array(None, ndim=len(self.axes), dtype=self.array.dtype)
39
+ return np.transpose(meta, self.axes)
40
+
41
+ @property
42
+ def dtype(self):
43
+ return self.array.dtype
44
+
45
+ @property
46
+ def out_ind(self):
47
+ return self.axes
48
+
49
+ @property
50
+ def kwargs(self):
51
+ return {"axes": self.axes}
52
+
53
+ @property
54
+ def args(self):
55
+ return (self.array, tuple(range(self.array.ndim)))
56
+
57
+ @functools.cached_property
58
+ def _inverse_axes(self):
59
+ """Inverse permutation of axes."""
60
+ inv = [0] * len(self.axes)
61
+ for i, a in enumerate(self.axes):
62
+ inv[a] = i
63
+ return tuple(inv)
64
+
65
+ def _task(self, key, block_id: tuple[int, ...]) -> Task:
66
+ """Generate task for a specific output block."""
67
+ # Map output block_id to input block_id using inverse permutation
68
+ # For axes=(1,0), output block (i,j) needs input block (j,i)
69
+ input_block_id = self._input_block_id(self.array, block_id)
70
+ return Task(key, self.func, TaskRef((self.array._name, *input_block_id)), **self.kwargs)
71
+
72
+ def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
73
+ """Map output block_id to input block_id using inverse permutation."""
74
+ return tuple(block_id[self._inverse_axes[d]] for d in range(len(block_id)))
75
+
76
+ def _simplify_down(self):
77
+ # Transpose(Transpose(x)) -> single Transpose with composed axes
78
+ if isinstance(self.array, Transpose):
79
+ axes = tuple(self.array.axes[i] for i in self.axes)
80
+ return Transpose(self.array.array, axes)
81
+ # Identity transpose -> return the array
82
+ if self.axes == tuple(range(self.array.ndim)):
83
+ return self.array
84
+ # Transpose(Elemwise(x, y)) -> Elemwise(Transpose(x), Transpose(y))
85
+ from dask_array._blockwise import Elemwise
86
+
87
+ if isinstance(self.array, Elemwise):
88
+ return self._pushdown_through_elemwise()
89
+
90
+ def _pushdown_through_elemwise(self):
91
+ """Push transpose through elemwise by transposing each input."""
92
+ from dask_array._blockwise import Elemwise
93
+ from dask_array._core_utils import is_scalar_for_elemwise
94
+
95
+ elemwise = self.array
96
+ axes = self.axes
97
+ out_ndim = len(axes)
98
+
99
+ # Only push through if all array inputs have the same ndim as output
100
+ # Broadcasting cases require index transformations we don't handle
101
+ for arg in elemwise.elemwise_args:
102
+ if is_scalar_for_elemwise(arg):
103
+ continue
104
+ if arg.ndim != out_ndim:
105
+ return None
106
+
107
+ # Check where/out as well
108
+ if hasattr(elemwise.where, "ndim") and elemwise.where.ndim != out_ndim:
109
+ return None
110
+ if hasattr(elemwise.out, "ndim") and elemwise.out.ndim != out_ndim:
111
+ return None
112
+
113
+ # Transpose each array input
114
+ new_args = [arg if is_scalar_for_elemwise(arg) else Transpose(arg, axes) for arg in elemwise.elemwise_args]
115
+
116
+ # Transpose where/out if they are arrays
117
+ new_where = elemwise.where
118
+ if hasattr(new_where, "ndim"):
119
+ new_where = Transpose(new_where, axes)
120
+
121
+ new_out = elemwise.out
122
+ if hasattr(new_out, "ndim"):
123
+ new_out = Transpose(new_out, axes)
124
+
125
+ return Elemwise(
126
+ elemwise.op,
127
+ elemwise.operand("dtype"),
128
+ elemwise.operand("name"),
129
+ new_where,
130
+ new_out,
131
+ elemwise.operand("_user_kwargs"),
132
+ *new_args,
133
+ )
134
+
135
+ def _simplify_up(self, parent, dependents):
136
+ """Allow slice and shuffle operations to push through Transpose."""
137
+ from dask_array._shuffle import Shuffle
138
+ from dask_array.slicing import SliceSlicesIntegers
139
+
140
+ if isinstance(parent, SliceSlicesIntegers):
141
+ return self._accept_slice(parent)
142
+ if isinstance(parent, Shuffle):
143
+ return self._accept_shuffle(parent)
144
+ return None
145
+
146
+ def _accept_shuffle(self, shuffle_expr):
147
+ """Accept a shuffle being pushed through Transpose.
148
+
149
+ Maps shuffle axis through transpose to get input axis.
150
+ """
151
+ axes = self.axes
152
+ shuffle_axis = shuffle_expr.axis
153
+
154
+ # Map shuffle axis through transpose: axes[i] tells us which input axis
155
+ # becomes output axis i. So to shuffle output axis `shuffle_axis`, we need
156
+ # to shuffle input axis `axes[shuffle_axis]`.
157
+ input_axis = axes[shuffle_axis]
158
+
159
+ from dask_array._shuffle import Shuffle
160
+
161
+ shuffled_input = Shuffle(self.array, shuffle_expr.indexer, input_axis, shuffle_expr.operand("name"))
162
+ return Transpose(shuffled_input, axes)
163
+
164
+ def _accept_slice(self, slice_expr):
165
+ """Accept a slice being pushed through Transpose.
166
+
167
+ Maps output slice indices through transpose axes to get input slice.
168
+ """
169
+ from numbers import Integral
170
+
171
+ from dask_array._new_collection import new_collection
172
+
173
+ axes = self.axes
174
+ index = slice_expr.index
175
+
176
+ # Don't handle None/newaxis (adds dimensions)
177
+ if any(idx is None for idx in index):
178
+ return None
179
+
180
+ # Pad index to full length
181
+ full_index = index + (slice(None),) * (self.ndim - len(index))
182
+
183
+ # Map output slice through transpose axes to get input slice
184
+ # axes[i] tells us which input axis becomes output axis i
185
+ # So output axis i gets slice full_index[i], which should go to input axis axes[i]
186
+ input_index = [slice(None)] * len(axes)
187
+ for out_axis, in_axis in enumerate(axes):
188
+ input_index[in_axis] = full_index[out_axis]
189
+
190
+ sliced_input = new_collection(self.array)[tuple(input_index)]
191
+
192
+ # Check if any dimensions were removed by integer indexing
193
+ has_integers = any(isinstance(idx, Integral) for idx in full_index)
194
+
195
+ if not has_integers:
196
+ # No dimension changes - just apply original transpose
197
+ return Transpose(sliced_input.expr, axes)
198
+
199
+ # Integer indices remove dimensions - compute new axes for remaining dims
200
+ # Track which input dimensions remain (those not indexed by integers)
201
+ remaining_input_dims = [
202
+ in_axis for out_axis, in_axis in enumerate(axes) if not isinstance(full_index[out_axis], Integral)
203
+ ]
204
+
205
+ if len(remaining_input_dims) <= 1:
206
+ # 0 or 1 dimension left - no transpose needed
207
+ return sliced_input.expr
208
+
209
+ # Map old input dim indices to new (post-slice) indices
210
+ # After slicing, input dims are renumbered 0, 1, 2, ...
211
+ sorted_remaining = sorted(remaining_input_dims)
212
+ dim_map = {old: new for new, old in enumerate(sorted_remaining)}
213
+
214
+ # Build new axes: for each remaining output dim, what's the new input dim?
215
+ new_axes = tuple(dim_map[in_dim] for in_dim in remaining_input_dims)
216
+
217
+ # Check if it's an identity transpose
218
+ if new_axes == tuple(range(len(new_axes))):
219
+ return sliced_input.expr
220
+
221
+ return Transpose(sliced_input.expr, new_axes)
222
+
223
+
224
+ def transpose(a, axes=None):
225
+ """Reverse or permute the axes of an array.
226
+
227
+ See Also
228
+ --------
229
+ numpy.transpose
230
+ """
231
+ from dask_array.core import asanyarray
232
+
233
+ a = asanyarray(a)
234
+ if axes is not None:
235
+ return a.transpose(axes)
236
+ return a.transpose()
237
+
238
+
239
+ def swapaxes(a, axis1, axis2):
240
+ """Interchange two axes of an array.
241
+
242
+ See Also
243
+ --------
244
+ numpy.swapaxes
245
+ """
246
+ from dask_array.core import asanyarray
247
+
248
+ a = asanyarray(a)
249
+ if axis1 == axis2:
250
+ return a
251
+ if axis1 < 0:
252
+ axis1 = axis1 + a.ndim
253
+ if axis2 < 0:
254
+ axis2 = axis2 + a.ndim
255
+ ind = list(range(a.ndim))
256
+ ind[axis1], ind[axis2] = ind[axis2], ind[axis1]
257
+ return transpose(a, ind)
258
+
259
+
260
+ def moveaxis(a, source, destination):
261
+ """Move axes of an array to new positions.
262
+
263
+ See Also
264
+ --------
265
+ numpy.moveaxis
266
+ """
267
+ from dask_array.core import asanyarray
268
+ from dask_array._numpy_compat import normalize_axis_tuple
269
+
270
+ a = asanyarray(a)
271
+ source = normalize_axis_tuple(source, a.ndim, "source")
272
+ destination = normalize_axis_tuple(destination, a.ndim, "destination")
273
+ if len(source) != len(destination):
274
+ raise ValueError("`source` and `destination` arguments must have the same number of elements")
275
+
276
+ order = [n for n in range(a.ndim) if n not in source]
277
+
278
+ for dest, src in sorted(zip(destination, source)):
279
+ order.insert(dest, src)
280
+
281
+ return transpose(a, order)
282
+
283
+
284
+ def rollaxis(a, axis, start=0):
285
+ """Roll the specified axis backwards, until it lies in a given position.
286
+
287
+ See Also
288
+ --------
289
+ numpy.rollaxis
290
+ """
291
+ from dask_array.core import asanyarray
292
+ from dask_array._numpy_compat import normalize_axis_index
293
+
294
+ a = asanyarray(a)
295
+ n = a.ndim
296
+ axis = normalize_axis_index(axis, n)
297
+ if start < 0:
298
+ start += n
299
+ msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
300
+ if not (0 <= start < n + 1):
301
+ raise ValueError(msg % ("start", -n, "start", n + 1, start))
302
+ if axis < start:
303
+ start -= 1
304
+ if axis == start:
305
+ return a[...]
306
+ axes = list(range(0, n))
307
+ axes.remove(axis)
308
+ axes.insert(start, axis)
309
+ return transpose(a, axes)
@@ -0,0 +1,125 @@
1
+ from __future__ import annotations
2
+
3
+ from threading import Lock
4
+
5
+ from ._generator import Generator, default_rng
6
+ from ._random_state import RandomState
7
+
8
+
9
+ # Lazy RNG-state machinery
10
+ #
11
+ # Many of the RandomState methods are exported as functions in da.random for
12
+ # backward compatibility reasons. Their usage is discouraged.
13
+ # Use da.random.default_rng() to get a Generator based rng and use its
14
+ # methods instead.
15
+
16
+ _cached_states: dict[str, RandomState] = {}
17
+ _cached_states_lock = Lock()
18
+
19
+
20
+ def _make_api(attr):
21
+ def wrapper(*args, **kwargs):
22
+ from dask_array._backends_array import array_creation_dispatch
23
+
24
+ key = array_creation_dispatch.backend
25
+ with _cached_states_lock:
26
+ try:
27
+ state = _cached_states[key]
28
+ except KeyError:
29
+ _cached_states[key] = state = RandomState()
30
+ return getattr(state, attr)(*args, **kwargs)
31
+
32
+ wrapper.__name__ = getattr(RandomState, attr).__name__
33
+ wrapper.__doc__ = getattr(RandomState, attr).__doc__
34
+ return wrapper
35
+
36
+
37
+ # RandomState only
38
+
39
+ seed = _make_api("seed")
40
+
41
+ beta = _make_api("beta")
42
+ binomial = _make_api("binomial")
43
+ chisquare = _make_api("chisquare")
44
+ choice = _make_api("choice")
45
+ exponential = _make_api("exponential")
46
+ f = _make_api("f")
47
+ gamma = _make_api("gamma")
48
+ geometric = _make_api("geometric")
49
+ gumbel = _make_api("gumbel")
50
+ hypergeometric = _make_api("hypergeometric")
51
+ laplace = _make_api("laplace")
52
+ logistic = _make_api("logistic")
53
+ lognormal = _make_api("lognormal")
54
+ logseries = _make_api("logseries")
55
+ multinomial = _make_api("multinomial")
56
+ negative_binomial = _make_api("negative_binomial")
57
+ noncentral_chisquare = _make_api("noncentral_chisquare")
58
+ noncentral_f = _make_api("noncentral_f")
59
+ normal = _make_api("normal")
60
+ pareto = _make_api("pareto")
61
+ permutation = _make_api("permutation")
62
+ poisson = _make_api("poisson")
63
+ power = _make_api("power")
64
+ random_sample = _make_api("random_sample")
65
+ random = _make_api("random_sample")
66
+ randint = _make_api("randint")
67
+ random_integers = _make_api("random_integers")
68
+ rayleigh = _make_api("rayleigh")
69
+ standard_cauchy = _make_api("standard_cauchy")
70
+ standard_exponential = _make_api("standard_exponential")
71
+ standard_gamma = _make_api("standard_gamma")
72
+ standard_normal = _make_api("standard_normal")
73
+ standard_t = _make_api("standard_t")
74
+ triangular = _make_api("triangular")
75
+ uniform = _make_api("uniform")
76
+ vonmises = _make_api("vonmises")
77
+ wald = _make_api("wald")
78
+ weibull = _make_api("weibull")
79
+ zipf = _make_api("zipf")
80
+
81
+ __all__ = [
82
+ "Generator",
83
+ "RandomState",
84
+ "default_rng",
85
+ "seed",
86
+ "beta",
87
+ "binomial",
88
+ "chisquare",
89
+ "choice",
90
+ "exponential",
91
+ "f",
92
+ "gamma",
93
+ "geometric",
94
+ "gumbel",
95
+ "hypergeometric",
96
+ "laplace",
97
+ "logistic",
98
+ "lognormal",
99
+ "logseries",
100
+ "multinomial",
101
+ "negative_binomial",
102
+ "noncentral_chisquare",
103
+ "noncentral_f",
104
+ "normal",
105
+ "pareto",
106
+ "permutation",
107
+ "poisson",
108
+ "power",
109
+ "random_sample",
110
+ "random",
111
+ "randint",
112
+ "random_integers",
113
+ "rayleigh",
114
+ "standard_cauchy",
115
+ "standard_exponential",
116
+ "standard_gamma",
117
+ "standard_normal",
118
+ "standard_t",
119
+ "triangular",
120
+ "uniform",
121
+ "vonmises",
122
+ "wald",
123
+ "weibull",
124
+ "zipf",
125
+ ]
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+
3
+ from itertools import product
4
+ from numbers import Integral
5
+
6
+ import numpy as np
7
+
8
+ from dask._task_spec import TaskRef
9
+ from dask_array._collection import Array
10
+ from dask_array.core._conversion import asarray
11
+ from dask_array.io import IO
12
+ from dask_array._core_utils import normalize_chunks
13
+ from dask_array._utils import asarray_safe
14
+ from dask_array._backends_array import array_creation_dispatch
15
+ from dask.utils import cached_property, random_state_data
16
+
17
+ from ._expr import _spawn_bitgens
18
+ from ._generator import Generator
19
+ from ._random_state import RandomState
20
+
21
+
22
+ def _choice_rng(state_data, a, size, replace, p, axis, shuffle):
23
+ from ._expr import _rng_from_bitgen
24
+
25
+ state = _rng_from_bitgen(state_data)
26
+ return state.choice(a, size=size, replace=replace, p=p, axis=axis, shuffle=shuffle)
27
+
28
+
29
+ def _choice_rs(state_data, a, size, replace, p):
30
+ state = array_creation_dispatch.RandomState(state_data)
31
+ return state.choice(a, size=size, replace=replace, p=p)
32
+
33
+
34
+ def _choice_validate_params(state, a, size, replace, p, axis, chunks):
35
+ """Validate and normalize parameters for choice.
36
+
37
+ Returns expressions for array/p (or int/None) so they participate in lowering.
38
+ """
39
+ # Normalize and validate `a`
40
+ if isinstance(a, Integral):
41
+ if isinstance(state, Generator):
42
+ if state._backend_name == "cupy":
43
+ raise NotImplementedError("`choice` not supported for cupy-backed `Generator`.")
44
+ meta = state._backend.random.default_rng().choice(1, size=(), p=None)
45
+ elif isinstance(state, RandomState):
46
+ # On windows the output dtype differs if p is provided or
47
+ # # absent, see https://github.com/numpy/numpy/issues/9867
48
+ dummy_p = state._backend.array([1]) if p is not None else p
49
+ meta = state._backend.random.RandomState().choice(1, size=(), p=dummy_p)
50
+ else:
51
+ raise ValueError("Unknown generator class")
52
+ len_a = a
53
+ if a < 0:
54
+ raise ValueError("a must be greater than 0")
55
+ a_expr = None # No expression for int a
56
+ else:
57
+ a = asarray(a)
58
+ a = a.rechunk(a.shape)
59
+ meta = a._meta
60
+ if a.ndim != 1:
61
+ raise ValueError("a must be one dimensional")
62
+ len_a = len(a)
63
+ a_expr = a.expr # Store expression so it gets lowered
64
+
65
+ # Normalize and validate `p`
66
+ p_expr = None
67
+ if p is not None:
68
+ if not isinstance(p, Array):
69
+ # If p is not a dask array, first check the sum is close
70
+ # to 1 before converting.
71
+ p = asarray_safe(p, like=p)
72
+ if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0):
73
+ raise ValueError("probabilities do not sum to 1")
74
+ p = asarray(p)
75
+ else:
76
+ p = p.rechunk(p.shape)
77
+
78
+ if p.ndim != 1:
79
+ raise ValueError("p must be one dimensional")
80
+ if len(p) != len_a:
81
+ raise ValueError("a and p must have the same size")
82
+
83
+ p_expr = p.expr # Store expression so it gets lowered
84
+
85
+ if size is None:
86
+ size = ()
87
+ elif not isinstance(size, (tuple, list)):
88
+ size = (size,)
89
+
90
+ if axis != 0:
91
+ raise ValueError("axis must be 0 since a is one dimensional")
92
+
93
+ chunks = normalize_chunks(chunks, size, dtype=np.float64)
94
+ if not replace and len(chunks[0]) > 1:
95
+ err_msg = "replace=False is not currently supported for dask.array.choice with multi-chunk output arrays"
96
+ raise NotImplementedError(err_msg)
97
+
98
+ # For int a, return the int value; for array a, return None (use a_expr)
99
+ a_val = a if isinstance(a, Integral) else None
100
+ return a_val, a_expr, size, replace, p_expr, axis, chunks, meta
101
+
102
+
103
+ class RandomChoice(IO):
104
+ _parameters = [
105
+ "a_val", # int value of a (or None if a is an array)
106
+ "a_expr", # expression for a (or None if a is an int)
107
+ "chunks",
108
+ "_meta",
109
+ "_state",
110
+ "replace",
111
+ "p_expr", # expression for p (or None)
112
+ "axis",
113
+ "shuffle",
114
+ ]
115
+ _defaults = {"axis": None, "shuffle": None}
116
+ _funcname = "da.random.choice-"
117
+
118
+ @cached_property
119
+ def chunks(self):
120
+ return self.operand("chunks")
121
+
122
+ @cached_property
123
+ def sizes(self):
124
+ return list(product(*self.chunks))
125
+
126
+ @cached_property
127
+ def state_data(self):
128
+ return random_state_data(len(self.sizes), self._state)
129
+
130
+ @cached_property
131
+ def _meta(self):
132
+ return self.operand("_meta")
133
+
134
+ # No custom dependencies() needed - base class finds Expr operands automatically
135
+ # (a_expr and p_expr are included when they're expressions, excluded when None)
136
+
137
+ @property
138
+ def _a_arg(self):
139
+ """Value to pass to choice: int or TaskRef to single-chunk array."""
140
+ if self.a_val is not None:
141
+ return self.a_val
142
+ return TaskRef((self.a_expr._name, 0))
143
+
144
+ @property
145
+ def _p_arg(self):
146
+ """Value to pass to choice: None or TaskRef to single-chunk array."""
147
+ if self.p_expr is None:
148
+ return None
149
+ return TaskRef((self.p_expr._name, 0))
150
+
151
+ def _layer(self) -> dict:
152
+ keys = product([self._name], *[range(len(bd)) for bd in self.chunks])
153
+ return {
154
+ k: (_choice_rs, state, self._a_arg, size, self.replace, self._p_arg)
155
+ for k, state, size in zip(keys, self.state_data, self.sizes)
156
+ }
157
+
158
+
159
+ class RandomChoiceGenerator(RandomChoice):
160
+ # Keep axis and shuffle as required parameters (no defaults)
161
+ _defaults = {}
162
+
163
+ @cached_property
164
+ def state_data(self):
165
+ return _spawn_bitgens(self._state, len(self.sizes))
166
+
167
+ def _layer(self) -> dict:
168
+ keys = product([self._name], *[range(len(bd)) for bd in self.chunks])
169
+ return {
170
+ k: (
171
+ _choice_rng,
172
+ bitgen,
173
+ self._a_arg,
174
+ size,
175
+ self.replace,
176
+ self._p_arg,
177
+ self.axis,
178
+ self.shuffle,
179
+ )
180
+ for k, bitgen, size in zip(keys, self.state_data, self.sizes)
181
+ }