dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,259 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import importlib
5
+ import numbers
6
+
7
+ import numpy as np
8
+
9
+ from dask_array._new_collection import new_collection
10
+ from dask_array.creation import arange
11
+ from dask_array._backends_array import array_creation_dispatch
12
+ from dask.utils import derived_from, typename
13
+
14
+ from ._utils import _wrap_func
15
+
16
+
17
+ class RandomState:
18
+ """
19
+ Mersenne Twister pseudo-random number generator
20
+
21
+ This object contains state to deterministically generate pseudo-random
22
+ numbers from a variety of probability distributions. It is identical to
23
+ ``np.random.RandomState`` except that all functions also take a ``chunks=``
24
+ keyword argument.
25
+
26
+ Parameters
27
+ ----------
28
+ seed: Number
29
+ Object to pass to RandomState to serve as deterministic seed
30
+ RandomState: Callable[seed] -> RandomState
31
+ A callable that, when provided with a ``seed`` keyword provides an
32
+ object that operates identically to ``np.random.RandomState`` (the
33
+ default). This might also be a function that returns a
34
+ ``mkl_random``, or ``cupy.random.RandomState`` object.
35
+
36
+ Examples
37
+ --------
38
+ >>> import dask_array as da
39
+ >>> state = da.random.RandomState(1234) # a seed
40
+ >>> x = state.normal(10, 0.1, size=3, chunks=(2,))
41
+ >>> x.compute()
42
+ array([10.01867852, 10.04812289, 9.89649746])
43
+
44
+ See Also
45
+ --------
46
+ np.random.RandomState
47
+ """
48
+
49
+ def __init__(self, seed=None, RandomState=None):
50
+ self._numpy_state = np.random.RandomState(seed)
51
+ self._RandomState = array_creation_dispatch.RandomState if RandomState is None else RandomState
52
+
53
+ @property
54
+ def _backend(self):
55
+ # Assumes typename(self._RandomState) starts with
56
+ # an importable array-library name (e.g. "numpy" or "cupy")
57
+ _backend_name = typename(self._RandomState).split(".")[0]
58
+ return importlib.import_module(_backend_name)
59
+
60
+ def seed(self, seed=None):
61
+ self._numpy_state.seed(seed)
62
+
63
+ @derived_from(np.random.RandomState, skipblocks=1)
64
+ def beta(self, a, b, size=None, chunks="auto", **kwargs):
65
+ return _wrap_func(self, "beta", a, b, size=size, chunks=chunks, **kwargs)
66
+
67
+ @derived_from(np.random.RandomState, skipblocks=1)
68
+ def binomial(self, n, p, size=None, chunks="auto", **kwargs):
69
+ return _wrap_func(self, "binomial", n, p, size=size, chunks=chunks, **kwargs)
70
+
71
+ @derived_from(np.random.RandomState, skipblocks=1)
72
+ def chisquare(self, df, size=None, chunks="auto", **kwargs):
73
+ return _wrap_func(self, "chisquare", df, size=size, chunks=chunks, **kwargs)
74
+
75
+ with contextlib.suppress(AttributeError):
76
+
77
+ @derived_from(np.random.RandomState, skipblocks=1)
78
+ def choice(self, a, size=None, replace=True, p=None, chunks="auto"):
79
+ from ._choice import RandomChoice, _choice_validate_params
80
+
81
+ (
82
+ a_val,
83
+ a_expr,
84
+ size,
85
+ replace,
86
+ p_expr,
87
+ axis, # np.random.RandomState.choice does not use axis
88
+ chunks,
89
+ meta,
90
+ ) = _choice_validate_params(self, a, size, replace, p, 0, chunks)
91
+
92
+ return new_collection(RandomChoice(a_val, a_expr, chunks, meta, self._numpy_state, replace, p_expr))
93
+
94
+ @derived_from(np.random.RandomState, skipblocks=1)
95
+ def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs):
96
+ return _wrap_func(self, "exponential", scale, size=size, chunks=chunks, **kwargs)
97
+
98
+ @derived_from(np.random.RandomState, skipblocks=1)
99
+ def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs):
100
+ return _wrap_func(self, "f", dfnum, dfden, size=size, chunks=chunks, **kwargs)
101
+
102
+ @derived_from(np.random.RandomState, skipblocks=1)
103
+ def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs):
104
+ return _wrap_func(self, "gamma", shape, scale, size=size, chunks=chunks, **kwargs)
105
+
106
+ @derived_from(np.random.RandomState, skipblocks=1)
107
+ def geometric(self, p, size=None, chunks="auto", **kwargs):
108
+ return _wrap_func(self, "geometric", p, size=size, chunks=chunks, **kwargs)
109
+
110
+ @derived_from(np.random.RandomState, skipblocks=1)
111
+ def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
112
+ return _wrap_func(self, "gumbel", loc, scale, size=size, chunks=chunks, **kwargs)
113
+
114
+ @derived_from(np.random.RandomState, skipblocks=1)
115
+ def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs):
116
+ return _wrap_func(
117
+ self,
118
+ "hypergeometric",
119
+ ngood,
120
+ nbad,
121
+ nsample,
122
+ size=size,
123
+ chunks=chunks,
124
+ **kwargs,
125
+ )
126
+
127
+ @derived_from(np.random.RandomState, skipblocks=1)
128
+ def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
129
+ return _wrap_func(self, "laplace", loc, scale, size=size, chunks=chunks, **kwargs)
130
+
131
+ @derived_from(np.random.RandomState, skipblocks=1)
132
+ def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
133
+ return _wrap_func(self, "logistic", loc, scale, size=size, chunks=chunks, **kwargs)
134
+
135
+ @derived_from(np.random.RandomState, skipblocks=1)
136
+ def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs):
137
+ return _wrap_func(self, "lognormal", mean, sigma, size=size, chunks=chunks, **kwargs)
138
+
139
+ @derived_from(np.random.RandomState, skipblocks=1)
140
+ def logseries(self, p, size=None, chunks="auto", **kwargs):
141
+ return _wrap_func(self, "logseries", p, size=size, chunks=chunks, **kwargs)
142
+
143
+ @derived_from(np.random.RandomState, skipblocks=1)
144
+ def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs):
145
+ return _wrap_func(
146
+ self,
147
+ "multinomial",
148
+ n,
149
+ pvals,
150
+ size=size,
151
+ chunks=chunks,
152
+ extra_chunks=((len(pvals),),),
153
+ **kwargs,
154
+ )
155
+
156
+ @derived_from(np.random.RandomState, skipblocks=1)
157
+ def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs):
158
+ return _wrap_func(self, "negative_binomial", n, p, size=size, chunks=chunks, **kwargs)
159
+
160
+ @derived_from(np.random.RandomState, skipblocks=1)
161
+ def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs):
162
+ return _wrap_func(self, "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs)
163
+
164
+ @derived_from(np.random.RandomState, skipblocks=1)
165
+ def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs):
166
+ return _wrap_func(self, "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs)
167
+
168
+ @derived_from(np.random.RandomState, skipblocks=1)
169
+ def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs):
170
+ return _wrap_func(self, "normal", loc, scale, size=size, chunks=chunks, **kwargs)
171
+
172
+ @derived_from(np.random.RandomState, skipblocks=1)
173
+ def pareto(self, a, size=None, chunks="auto", **kwargs):
174
+ return _wrap_func(self, "pareto", a, size=size, chunks=chunks, **kwargs)
175
+
176
+ @derived_from(np.random.RandomState, skipblocks=1)
177
+ def permutation(self, x):
178
+ from dask_array.slicing._utils import shuffle_slice
179
+
180
+ if isinstance(x, numbers.Number):
181
+ x = arange(x, chunks="auto")
182
+
183
+ index = np.arange(len(x))
184
+ self._numpy_state.shuffle(index)
185
+ return shuffle_slice(x, index)
186
+
187
+ @derived_from(np.random.RandomState, skipblocks=1)
188
+ def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs):
189
+ return _wrap_func(self, "poisson", lam, size=size, chunks=chunks, **kwargs)
190
+
191
+ @derived_from(np.random.RandomState, skipblocks=1)
192
+ def power(self, a, size=None, chunks="auto", **kwargs):
193
+ return _wrap_func(self, "power", a, size=size, chunks=chunks, **kwargs)
194
+
195
+ @derived_from(np.random.RandomState, skipblocks=1)
196
+ def randint(self, low, high=None, size=None, chunks="auto", dtype="l", **kwargs):
197
+ return _wrap_func(self, "randint", low, high, size=size, chunks=chunks, dtype=dtype, **kwargs)
198
+
199
+ @derived_from(np.random.RandomState, skipblocks=1)
200
+ def random_integers(self, low, high=None, size=None, chunks="auto", **kwargs):
201
+ return _wrap_func(self, "random_integers", low, high, size=size, chunks=chunks, **kwargs)
202
+
203
+ @derived_from(np.random.RandomState, skipblocks=1)
204
+ def random_sample(self, size=None, chunks="auto", **kwargs):
205
+ return _wrap_func(self, "random_sample", size=size, chunks=chunks, **kwargs)
206
+
207
+ random = random_sample
208
+
209
+ @derived_from(np.random.RandomState, skipblocks=1)
210
+ def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs):
211
+ return _wrap_func(self, "rayleigh", scale, size=size, chunks=chunks, **kwargs)
212
+
213
+ @derived_from(np.random.RandomState, skipblocks=1)
214
+ def standard_cauchy(self, size=None, chunks="auto", **kwargs):
215
+ return _wrap_func(self, "standard_cauchy", size=size, chunks=chunks, **kwargs)
216
+
217
+ @derived_from(np.random.RandomState, skipblocks=1)
218
+ def standard_exponential(self, size=None, chunks="auto", **kwargs):
219
+ return _wrap_func(self, "standard_exponential", size=size, chunks=chunks, **kwargs)
220
+
221
+ @derived_from(np.random.RandomState, skipblocks=1)
222
+ def standard_gamma(self, shape, size=None, chunks="auto", **kwargs):
223
+ return _wrap_func(self, "standard_gamma", shape, size=size, chunks=chunks, **kwargs)
224
+
225
+ @derived_from(np.random.RandomState, skipblocks=1)
226
+ def standard_normal(self, size=None, chunks="auto", **kwargs):
227
+ return _wrap_func(self, "standard_normal", size=size, chunks=chunks, **kwargs)
228
+
229
+ @derived_from(np.random.RandomState, skipblocks=1)
230
+ def standard_t(self, df, size=None, chunks="auto", **kwargs):
231
+ return _wrap_func(self, "standard_t", df, size=size, chunks=chunks, **kwargs)
232
+
233
+ @derived_from(np.random.RandomState, skipblocks=1)
234
+ def tomaxint(self, size=None, chunks="auto", **kwargs):
235
+ return _wrap_func(self, "tomaxint", size=size, chunks=chunks, **kwargs)
236
+
237
+ @derived_from(np.random.RandomState, skipblocks=1)
238
+ def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs):
239
+ return _wrap_func(self, "triangular", left, mode, right, size=size, chunks=chunks, **kwargs)
240
+
241
+ @derived_from(np.random.RandomState, skipblocks=1)
242
+ def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs):
243
+ return _wrap_func(self, "uniform", low, high, size=size, chunks=chunks, **kwargs)
244
+
245
+ @derived_from(np.random.RandomState, skipblocks=1)
246
+ def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs):
247
+ return _wrap_func(self, "vonmises", mu, kappa, size=size, chunks=chunks, **kwargs)
248
+
249
+ @derived_from(np.random.RandomState, skipblocks=1)
250
+ def wald(self, mean, scale, size=None, chunks="auto", **kwargs):
251
+ return _wrap_func(self, "wald", mean, scale, size=size, chunks=chunks, **kwargs)
252
+
253
+ @derived_from(np.random.RandomState, skipblocks=1)
254
+ def weibull(self, a, size=None, chunks="auto", **kwargs):
255
+ return _wrap_func(self, "weibull", a, size=size, chunks=chunks, **kwargs)
256
+
257
+ @derived_from(np.random.RandomState, skipblocks=1)
258
+ def zipf(self, a, size=None, chunks="auto", **kwargs):
259
+ return _wrap_func(self, "zipf", a, size=size, chunks=chunks, **kwargs)
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._new_collection import new_collection
8
+ from dask_array._collection import Array
9
+ from dask_array._core_utils import normalize_chunks
10
+ from dask.utils import typename
11
+
12
+
13
+ def _rng_from_bitgen(bitgen):
14
+ # Assumes typename(bitgen) starts with importable
15
+ # library name (e.g. "numpy" or "cupy")
16
+ backend_name = typename(bitgen).split(".")[0]
17
+ backend_lib = importlib.import_module(backend_name)
18
+ return backend_lib.random.default_rng(bitgen)
19
+
20
+
21
+ def _shuffle(bit_generator, x, axis=0):
22
+ """Shuffle array in place and advance bit generator state."""
23
+ state_data = bit_generator.state
24
+ new_bitgen = type(bit_generator)()
25
+ new_bitgen.state = state_data
26
+ state = _rng_from_bitgen(new_bitgen)
27
+ state.shuffle(x, axis=axis)
28
+ # Copy advanced state back to original so subsequent calls get different results
29
+ bit_generator.state = new_bitgen.state
30
+
31
+
32
+ def _broadcast_array_arg(arg, size, target_chunks):
33
+ """Broadcast and rechunk an array argument to match output shape."""
34
+ from dask_array._broadcast import broadcast_to
35
+ from dask_array.core._conversion import from_array
36
+
37
+ if isinstance(arg, np.ndarray) and arg.shape:
38
+ arg = from_array(arg, chunks=arg.shape)
39
+ arg = broadcast_to(arg, size).rechunk(target_chunks)
40
+ elif isinstance(arg, Array):
41
+ arg = broadcast_to(arg, size).rechunk(target_chunks)
42
+ return arg
43
+
44
+
45
+ def _wrap_func(rng, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs):
46
+ from ._expr import RandomNormal, RandomPoisson
47
+
48
+ if size is not None and not isinstance(size, (tuple, list)):
49
+ size = (size,)
50
+
51
+ # Collect shapes from array arguments for broadcasting
52
+ shapes = []
53
+ for arg in args:
54
+ if isinstance(arg, (np.ndarray, Array)) and arg.shape:
55
+ shapes.append(arg.shape)
56
+ for v in kwargs.values():
57
+ if isinstance(v, (np.ndarray, Array)) and v.shape:
58
+ shapes.append(v.shape)
59
+
60
+ # Validate that all shapes can be broadcast together with size
61
+ if size is not None and shapes:
62
+ np.broadcast_shapes(*shapes, size) # Raises ValueError if incompatible
63
+ elif size is None and shapes:
64
+ size = np.broadcast_shapes(*shapes)
65
+
66
+ # Broadcast and rechunk array arguments to match output shape/chunks
67
+ if size is not None and shapes:
68
+ target_chunks = normalize_chunks(chunks, size, dtype=kwargs.get("dtype", np.float64))
69
+ args = tuple(_broadcast_array_arg(arg, size, target_chunks) for arg in args)
70
+ kwargs = {k: _broadcast_array_arg(v, size, target_chunks) for k, v in kwargs.items()}
71
+
72
+ # Dispatch to specific subclass if available
73
+ if funcname == "normal":
74
+ loc = kwargs.pop("loc", args[0] if len(args) > 0 else 0.0)
75
+ scale = kwargs.pop("scale", args[1] if len(args) > 1 else 1.0)
76
+ return new_collection(RandomNormal(rng, size, chunks, extra_chunks, loc, scale))
77
+ elif funcname == "poisson":
78
+ lam = args[0] if len(args) > 0 else kwargs.pop("lam", 1.0)
79
+ return new_collection(RandomPoisson(rng, size, chunks, extra_chunks, lam))
80
+
81
+ # Fallback: use generic Random with args/kwargs tuples
82
+ from ._expr import Random
83
+
84
+ return new_collection(Random(rng, funcname, size, chunks, extra_chunks, args, kwargs))
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from dask_array.reductions._arg_reduction import arg_reduction
4
+ from dask_array.reductions._common import (
5
+ all,
6
+ any,
7
+ argmax,
8
+ argmin,
9
+ nannumel,
10
+ nanargmax,
11
+ nanargmin,
12
+ max,
13
+ mean,
14
+ median,
15
+ min,
16
+ moment,
17
+ nanmax,
18
+ nanmean,
19
+ nanmedian,
20
+ nanmin,
21
+ nanprod,
22
+ nanquantile,
23
+ nanstd,
24
+ nansum,
25
+ nanvar,
26
+ numel,
27
+ prod,
28
+ quantile,
29
+ std,
30
+ sum,
31
+ var,
32
+ )
33
+ from dask_array.reductions._cumulative import (
34
+ cumprod,
35
+ cumreduction,
36
+ cumsum,
37
+ nancumprod,
38
+ nancumsum,
39
+ )
40
+ from dask_array.reductions._reduction import (
41
+ _tree_reduce,
42
+ reduction,
43
+ )
44
+ from dask_array.reductions._trace import trace
45
+ from dask_array.reductions._percentile import nanpercentile, percentile
46
+
47
+ __all__ = [
48
+ "all",
49
+ "any",
50
+ "arg_reduction",
51
+ "argmax",
52
+ "argmin",
53
+ "cumprod",
54
+ "cumreduction",
55
+ "cumsum",
56
+ "max",
57
+ "mean",
58
+ "median",
59
+ "min",
60
+ "moment",
61
+ "nanargmax",
62
+ "nanargmin",
63
+ "nancumprod",
64
+ "nancumsum",
65
+ "nanmax",
66
+ "nanmean",
67
+ "nanmedian",
68
+ "nanmin",
69
+ "nanpercentile",
70
+ "nanprod",
71
+ "nanquantile",
72
+ "nanstd",
73
+ "nansum",
74
+ "nanvar",
75
+ "percentile",
76
+ "prod",
77
+ "quantile",
78
+ "reduction",
79
+ "std",
80
+ "sum",
81
+ "trace",
82
+ "var",
83
+ "_tree_reduce",
84
+ ]
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ import operator
4
+ from itertools import product, repeat
5
+ from numbers import Integral
6
+
7
+ import numpy as np
8
+ from tlz import accumulate, pluck
9
+
10
+ from dask_array._expr import ArrayExpr
11
+ from dask_array._utils import is_arraylike, validate_axis
12
+ from dask.tokenize import _tokenize_deterministic
13
+ from dask.utils import cached_property
14
+
15
+
16
+ class ArgChunk(ArrayExpr):
17
+ """Expression for the initial chunk step of arg reductions (argmin/argmax).
18
+
19
+ Maps the chunk function across all blocks, tracking offsets to compute
20
+ global indices.
21
+ """
22
+
23
+ _parameters = ["array", "chunk_func", "axis", "ravel"]
24
+
25
+ @cached_property
26
+ def _name(self):
27
+ return "arg-chunk-" + _tokenize_deterministic(self.array, self.chunk_func, self.axis, self.ravel)
28
+
29
+ @cached_property
30
+ def _meta(self):
31
+ # The chunk function returns a structured array or dict with 'vals' and 'arg'
32
+ # fields. The dtype comes from argmin on the meta.
33
+ from dask_array._utils import asarray_safe, meta_from_array
34
+
35
+ dtype = np.argmin(asarray_safe([1], like=meta_from_array(self.array)))
36
+ if is_arraylike(dtype):
37
+ return dtype
38
+ # Return a small array with the correct dtype
39
+ return np.array([], dtype=np.intp)
40
+
41
+ @cached_property
42
+ def chunks(self):
43
+ # After the chunk step, each block is reduced to size 1 along the axis
44
+ return tuple((1,) * len(c) if i in self.axis else c for (i, c) in enumerate(self.array.chunks))
45
+
46
+ def _layer(self):
47
+ x = self.array
48
+ axis = self.axis
49
+ ravel = self.ravel
50
+
51
+ keys = list(product(*map(range, x.numblocks)))
52
+ offsets = list(product(*(accumulate(operator.add, bd[:-1], 0) for bd in x.chunks)))
53
+ if ravel:
54
+ offset_info = list(zip(offsets, repeat(x.shape)))
55
+ else:
56
+ offset_info = list(pluck(axis[0], offsets))
57
+
58
+ dsk = {}
59
+ for k, off in zip(keys, offset_info):
60
+ dsk[(self._name,) + tuple(k)] = (
61
+ self.chunk_func,
62
+ (x.name,) + tuple(k),
63
+ axis,
64
+ off,
65
+ )
66
+ return dsk
67
+
68
+
69
+ def arg_reduction(x, chunk, combine, agg, axis=None, keepdims=False, split_every=None, out=None):
70
+ """Generic function for arg reductions in array-expr.
71
+
72
+ Parameters
73
+ ----------
74
+ x : Array
75
+ chunk : callable
76
+ Partialed ``arg_chunk``.
77
+ combine : callable
78
+ Partialed ``arg_combine``.
79
+ agg : callable
80
+ Partialed ``arg_agg``.
81
+ axis : int, optional
82
+ split_every : int or dict, optional
83
+ """
84
+ from dask_array.core._blockwise_funcs import _handle_out
85
+ from dask_array._utils import asarray_safe, meta_from_array
86
+
87
+ if axis is None:
88
+ axis = tuple(range(x.ndim))
89
+ ravel = True
90
+ elif isinstance(axis, Integral):
91
+ axis = validate_axis(axis, x.ndim)
92
+ axis = (axis,)
93
+ ravel = x.ndim == 1
94
+ else:
95
+ raise TypeError(f"axis must be either `None` or int, got '{axis}'")
96
+
97
+ for ax in axis:
98
+ chunks = x.chunks[ax]
99
+ if len(chunks) > 1 and np.isnan(chunks).any():
100
+ raise ValueError(
101
+ "Arg-reductions do not work with arrays that have "
102
+ "unknown chunksizes. At some point in your computation "
103
+ "this array lost chunking information.\n\n"
104
+ "A possible solution is with \n"
105
+ " x.compute_chunk_sizes()"
106
+ )
107
+
108
+ # Create the ArgChunk expression for the initial chunk step
109
+ tmp = ArgChunk(x.expr, chunk, axis, ravel)
110
+
111
+ # Determine dtype
112
+ dtype = np.argmin(asarray_safe([1], like=meta_from_array(x)))
113
+ if hasattr(dtype, "dtype"):
114
+ dtype = dtype.dtype
115
+ else:
116
+ dtype = np.dtype(type(dtype))
117
+
118
+ # Import _tree_reduce from the same package
119
+ from dask_array.reductions._reduction import _tree_reduce
120
+
121
+ result = _tree_reduce(
122
+ tmp,
123
+ agg,
124
+ axis,
125
+ keepdims=keepdims,
126
+ dtype=dtype,
127
+ split_every=split_every,
128
+ combine=combine,
129
+ )
130
+ return _handle_out(out, result)