dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_gufunc.py ADDED
@@ -0,0 +1,805 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import re
5
+
6
+ import numpy as np
7
+ from tlz import concat, merge, unique
8
+
9
+ from dask_array._new_collection import new_collection
10
+ from dask_array._collection import asarray, blockwise
11
+ from dask_array._expr import ArrayExpr
12
+ from dask_array._chunk import getitem
13
+ from dask_array._core_utils import apply_infer_dtype, normalize_chunks
14
+ from dask_array._utils import meta_from_array
15
+ from dask.core import flatten
16
+
17
+ # Modified version of `numpy.lib.function_base._parse_gufunc_signature`
18
+ # Modifications:
19
+ # - Allow for zero input arguments
20
+ # See https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html
21
+ _DIMENSION_NAME = r"\w+"
22
+ _CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*,?)?"
23
+ _ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
24
+ _INPUT_ARGUMENTS = f"(?:{_ARGUMENT}(?:,{_ARGUMENT})*,?)?"
25
+ _OUTPUT_ARGUMENTS = f"{_ARGUMENT}(?:,{_ARGUMENT})*" # Use `'{0:}(?:,{0:})*,?'` if gufunc-
26
+ # signature should be allowed for length 1 tuple returns
27
+ _SIGNATURE = f"^{_INPUT_ARGUMENTS}->{_OUTPUT_ARGUMENTS}$"
28
+
29
+
30
+ def _parse_gufunc_signature(signature):
31
+ """
32
+ Parse string signatures for a generalized universal function.
33
+
34
+ Arguments
35
+ ---------
36
+ signature : string
37
+ Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
38
+ for ``np.matmul``.
39
+
40
+ Returns
41
+ -------
42
+ Tuple of input and output core dimensions parsed from the signature, each
43
+ of the form List[Tuple[str, ...]], except for one output. For one output
44
+ core dimension is not a list, but of the form Tuple[str, ...]
45
+ """
46
+ signature = re.sub(r"\s+", "", signature)
47
+ if not re.match(_SIGNATURE, signature):
48
+ raise ValueError(f"Not a valid gufunc signature: {signature}")
49
+ in_txt, out_txt = signature.split("->")
50
+ ins = [tuple(re.findall(_DIMENSION_NAME, arg)) for arg in re.findall(_ARGUMENT, in_txt)]
51
+ outs = [tuple(re.findall(_DIMENSION_NAME, arg)) for arg in re.findall(_ARGUMENT, out_txt)]
52
+ outs = outs[0] if ((len(outs) == 1) and (out_txt[-1] != ",")) else outs
53
+ return ins, outs
54
+
55
+
56
+ def _validate_normalize_axes(axes, axis, keepdims, input_coredimss, output_coredimss):
57
+ """
58
+ Validates logic of `axes`/`axis`/`keepdims` arguments and normalize them.
59
+ Refer to [1]_ for details
60
+
61
+ Arguments
62
+ ---------
63
+ axes: List of tuples
64
+ axis: int
65
+ keepdims: bool
66
+ input_coredimss: List of Tuple of dims
67
+ output_coredimss: List of Tuple of dims
68
+
69
+ Returns
70
+ -------
71
+ input_axes: List of tuple of int
72
+ output_axes: List of tuple of int
73
+
74
+ References
75
+ ----------
76
+ .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html#optional-keyword-arguments
77
+ """
78
+ nin = len(input_coredimss)
79
+ nout = 1 if not isinstance(output_coredimss, list) else len(output_coredimss)
80
+
81
+ if axes is not None and axis is not None:
82
+ raise ValueError("Only one of `axis` or `axes` keyword arguments should be given")
83
+ if axes and not isinstance(axes, list):
84
+ raise ValueError("`axes` has to be of type list")
85
+
86
+ output_coredimss = output_coredimss if nout > 1 else [output_coredimss]
87
+ filtered_core_dims = list(filter(len, input_coredimss))
88
+ nr_outputs_with_coredims = len([True for x in output_coredimss if len(x) > 0])
89
+
90
+ if keepdims:
91
+ if nr_outputs_with_coredims > 0:
92
+ raise ValueError("`keepdims` can only be used for scalar outputs")
93
+ output_coredimss = len(output_coredimss) * [filtered_core_dims[0]]
94
+
95
+ core_dims = input_coredimss + output_coredimss
96
+ if axis is not None:
97
+ if not isinstance(axis, int):
98
+ raise ValueError("`axis` argument has to be an integer value")
99
+ if filtered_core_dims:
100
+ cd0 = filtered_core_dims[0]
101
+ if len(cd0) != 1:
102
+ raise ValueError("`axis` can be used only, if one core dimension is present")
103
+ for cd in filtered_core_dims:
104
+ if cd0 != cd:
105
+ raise ValueError("To use `axis`, all core dimensions have to be equal")
106
+
107
+ # Expand defaults or axis
108
+ if axes is None:
109
+ if axis is not None:
110
+ axes = [(axis,) if cd else tuple() for cd in core_dims]
111
+ else:
112
+ axes = [tuple(range(-len(icd), 0)) for icd in core_dims]
113
+ elif not isinstance(axes, list):
114
+ raise ValueError("`axes` argument has to be a list")
115
+ axes = [(a,) if isinstance(a, int) else a for a in axes]
116
+
117
+ if ((nr_outputs_with_coredims == 0) and (nin != len(axes)) and (nin + nout != len(axes))) or (
118
+ (nr_outputs_with_coredims > 0) and (nin + nout != len(axes))
119
+ ):
120
+ raise ValueError("The number of `axes` entries is not equal the number of input and output arguments")
121
+
122
+ # Treat outputs
123
+ output_axes = axes[nin:]
124
+ output_axes = output_axes if output_axes else [tuple(range(-len(ocd), 0)) for ocd in output_coredimss]
125
+ input_axes = axes[:nin]
126
+
127
+ # Assert we have as many axes as output core dimensions
128
+ for idx, (iax, icd) in enumerate(zip(input_axes, input_coredimss)):
129
+ if len(iax) != len(icd):
130
+ raise ValueError(
131
+ f"The number of `axes` entries for argument #{idx} is not equal "
132
+ "the number of respective input core dimensions in signature"
133
+ )
134
+ if not keepdims:
135
+ for idx, (oax, ocd) in enumerate(zip(output_axes, output_coredimss)):
136
+ if len(oax) != len(ocd):
137
+ raise ValueError(
138
+ f"The number of `axes` entries for argument #{idx} is not equal "
139
+ "the number of respective output core dimensions in signature"
140
+ )
141
+ elif input_coredimss:
142
+ icd0 = input_coredimss[0]
143
+ for icd in input_coredimss:
144
+ if icd0 != icd:
145
+ raise ValueError("To use `keepdims`, all core dimensions have to be equal")
146
+ iax0 = input_axes[0]
147
+ output_axes = [iax0 for _ in output_coredimss]
148
+
149
+ return input_axes, output_axes
150
+
151
+
152
+ def apply_gufunc(
153
+ func,
154
+ signature,
155
+ *args,
156
+ axes=None,
157
+ axis=None,
158
+ keepdims=False,
159
+ output_dtypes=None,
160
+ output_sizes=None,
161
+ vectorize=None,
162
+ allow_rechunk=False,
163
+ meta=None,
164
+ **kwargs,
165
+ ):
166
+ """
167
+ Apply a generalized ufunc or similar python function to arrays.
168
+
169
+ ``signature`` determines if the function consumes or produces core
170
+ dimensions. The remaining dimensions in given input arrays (``*args``)
171
+ are considered loop dimensions and are required to broadcast
172
+ naturally against each other.
173
+
174
+ In other terms, this function is like ``np.vectorize``, but for
175
+ the blocks of dask arrays. If the function itself shall also
176
+ be vectorized use ``vectorize=True`` for convenience.
177
+
178
+ Parameters
179
+ ----------
180
+ func : callable
181
+ Function to call like ``func(*args, **kwargs)`` on input arrays
182
+ (``*args``) that returns an array or tuple of arrays. If multiple
183
+ arguments with non-matching dimensions are supplied, this function is
184
+ expected to vectorize (broadcast) over axes of positional arguments in
185
+ the style of NumPy universal functions [1]_ (if this is not the case,
186
+ set ``vectorize=True``). If this function returns multiple outputs,
187
+ ``output_core_dims`` has to be set as well.
188
+ signature: string
189
+ Specifies what core dimensions are consumed and produced by ``func``.
190
+ According to the specification of numpy.gufunc signature [2]_
191
+ *args : numeric
192
+ Input arrays or scalars to the callable function.
193
+ axes: List of tuples, optional, keyword only
194
+ A list of tuples with indices of axes a generalized ufunc should operate on.
195
+ For instance, for a signature of ``"(i,j),(j,k)->(i,k)"`` appropriate for
196
+ matrix multiplication, the base elements are two-dimensional matrices
197
+ and these are taken to be stored in the two last axes of each argument. The
198
+ corresponding axes keyword would be ``[(-2, -1), (-2, -1), (-2, -1)]``.
199
+ For simplicity, for generalized ufuncs that operate on 1-dimensional arrays
200
+ (vectors), a single integer is accepted instead of a single-element tuple,
201
+ and for generalized ufuncs for which all outputs are scalars, the output
202
+ tuples can be omitted.
203
+ axis: int, optional, keyword only
204
+ A single axis over which a generalized ufunc should operate. This is a short-cut
205
+ for ufuncs that operate over a single, shared core dimension, equivalent to passing
206
+ in axes with entries of (axis,) for each single-core-dimension argument and ``()`` for
207
+ all others. For instance, for a signature ``"(i),(i)->()"``, it is equivalent to passing
208
+ in ``axes=[(axis,), (axis,), ()]``.
209
+ keepdims: bool, optional, keyword only
210
+ If this is set to True, axes which are reduced over will be left in the result as
211
+ a dimension with size one, so that the result will broadcast correctly against the
212
+ inputs. This option can only be used for generalized ufuncs that operate on inputs
213
+ that all have the same number of core dimensions and with outputs that have no core
214
+ dimensions , i.e., with signatures like ``"(i),(i)->()"`` or ``"(m,m)->()"``.
215
+ If used, the location of the dimensions in the output can be controlled with axes
216
+ and axis.
217
+ output_dtypes : Optional, dtype or list of dtypes, keyword only
218
+ Valid numpy dtype specification or list thereof.
219
+ If not given, a call of ``func`` with a small set of data
220
+ is performed in order to try to automatically determine the
221
+ output dtypes.
222
+ output_sizes : dict, optional, keyword only
223
+ Optional mapping from dimension names to sizes for outputs. Only used if
224
+ new core dimensions (not found on inputs) appear on outputs.
225
+ vectorize: bool, keyword only
226
+ If set to ``True``, ``np.vectorize`` is applied to ``func`` for
227
+ convenience. Defaults to ``False``.
228
+ allow_rechunk: Optional, bool, keyword only
229
+ Allows rechunking, otherwise chunk sizes need to match and core
230
+ dimensions are to consist only of one chunk.
231
+ Warning: enabling this can increase memory usage significantly.
232
+ Defaults to ``False``.
233
+ meta: Optional, tuple, keyword only
234
+ tuple of empty ndarrays describing the shape and dtype of the output of the gufunc.
235
+ Defaults to ``None``.
236
+ **kwargs : dict
237
+ Extra keyword arguments to pass to `func`
238
+
239
+ Returns
240
+ -------
241
+ Single dask.array.Array or tuple of dask.array.Array
242
+
243
+ Examples
244
+ --------
245
+ >>> import dask_array as da
246
+ >>> import numpy as np
247
+ >>> def stats(x):
248
+ ... return np.mean(x, axis=-1), np.std(x, axis=-1)
249
+ >>> a = da.random.normal(size=(10,20,30), chunks=(5, 10, 30))
250
+ >>> mean, std = da.apply_gufunc(stats, "(i)->(),()", a)
251
+ >>> mean.compute().shape
252
+ (10, 20)
253
+
254
+
255
+ >>> def outer_product(x, y):
256
+ ... return np.einsum("i,j->ij", x, y)
257
+ >>> a = da.random.normal(size=( 20,30), chunks=(10, 30))
258
+ >>> b = da.random.normal(size=(10, 1,40), chunks=(5, 1, 40))
259
+ >>> c = da.apply_gufunc(outer_product, "(i),(j)->(i,j)", a, b, vectorize=True)
260
+ >>> c.compute().shape
261
+ (10, 20, 30, 40)
262
+
263
+ References
264
+ ----------
265
+ .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html
266
+ .. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html
267
+ """
268
+ # Input processing:
269
+ ## Signature
270
+ if not isinstance(signature, str):
271
+ raise TypeError("`signature` has to be of type string")
272
+ # NumPy versions before https://github.com/numpy/numpy/pull/19627
273
+ # would not ignore whitespace characters in `signature` like they
274
+ # are supposed to. We remove the whitespace here as a workaround.
275
+ signature = re.sub(r"\s+", "", signature)
276
+ input_coredimss, output_coredimss = _parse_gufunc_signature(signature)
277
+
278
+ ## Determine nout: nout = None for functions of one direct return; nout = int for return tuples
279
+ nout = None if not isinstance(output_coredimss, list) else len(output_coredimss)
280
+
281
+ ## Consolidate onto `meta`
282
+ if meta is not None and output_dtypes is not None:
283
+ raise ValueError("Only one of `meta` and `output_dtypes` should be given (`meta` is preferred).")
284
+ if meta is None:
285
+ if output_dtypes is None:
286
+ ## Infer `output_dtypes`
287
+ if vectorize:
288
+ tempfunc = np.vectorize(func, signature=signature)
289
+ else:
290
+ tempfunc = func
291
+ output_dtypes = apply_infer_dtype(tempfunc, args, kwargs, "apply_gufunc", "output_dtypes", nout)
292
+
293
+ ## Turn `output_dtypes` into `meta`
294
+ if nout is None and isinstance(output_dtypes, (tuple, list)) and len(output_dtypes) == 1:
295
+ output_dtypes = output_dtypes[0]
296
+ sample = args[0] if args else None
297
+ if nout is None:
298
+ meta = meta_from_array(sample, dtype=output_dtypes)
299
+ else:
300
+ meta = tuple(meta_from_array(sample, dtype=odt) for odt in output_dtypes)
301
+
302
+ ## Normalize `meta` format
303
+ meta = meta_from_array(meta)
304
+ if isinstance(meta, list):
305
+ meta = tuple(meta)
306
+
307
+ ## Validate `meta`
308
+ if nout is None:
309
+ if isinstance(meta, tuple):
310
+ if len(meta) == 1:
311
+ meta = meta[0]
312
+ else:
313
+ raise ValueError(
314
+ "For a function with one output, must give a single item for `output_dtypes`/`meta`, "
315
+ "not a tuple or list."
316
+ )
317
+ else:
318
+ if not isinstance(meta, tuple):
319
+ raise ValueError(
320
+ f"For a function with {nout} outputs, must give a tuple or list for `output_dtypes`/`meta`, "
321
+ "not a single item."
322
+ )
323
+ if len(meta) != nout:
324
+ raise ValueError(
325
+ f"For a function with {nout} outputs, must give a tuple or list of {nout} items for "
326
+ f"`output_dtypes`/`meta`, not {len(meta)}."
327
+ )
328
+
329
+ ## Vectorize function, if required
330
+ if vectorize:
331
+ otypes = [x.dtype for x in meta] if isinstance(meta, tuple) else [meta.dtype]
332
+ func = np.vectorize(func, signature=signature, otypes=otypes)
333
+
334
+ ## Miscellaneous
335
+ if output_sizes is None:
336
+ output_sizes = {}
337
+
338
+ ## Axes
339
+ input_axes, output_axes = _validate_normalize_axes(axes, axis, keepdims, input_coredimss, output_coredimss)
340
+
341
+ # Main code:
342
+ ## Cast all input arrays to dask
343
+ args = [asarray(a) for a in args]
344
+
345
+ if len(input_coredimss) != len(args):
346
+ raise ValueError(
347
+ f"According to `signature`, `func` requires {len(input_coredimss)} arguments, but {len(args)} given"
348
+ )
349
+
350
+ ## Axes: transpose input arguments
351
+ transposed_args = []
352
+ for arg, iax in zip(args, input_axes):
353
+ shape = arg.shape
354
+ iax = tuple(a if a < 0 else a - len(shape) for a in iax)
355
+ tidc = tuple(i for i in range(-len(shape) + 0, 0) if i not in iax) + iax
356
+ transposed_arg = arg.transpose(tidc)
357
+ transposed_args.append(transposed_arg)
358
+ args = transposed_args
359
+
360
+ ## Assess input args for loop dims
361
+ input_shapes = [a.shape for a in args]
362
+ input_chunkss = [a.chunks for a in args]
363
+ num_loopdims = [len(s) - len(cd) for s, cd in zip(input_shapes, input_coredimss)]
364
+ max_loopdims = max(num_loopdims) if num_loopdims else None
365
+ core_input_shapes = [dict(zip(icd, s[n:])) for s, n, icd in zip(input_shapes, num_loopdims, input_coredimss)]
366
+ core_shapes = merge(*core_input_shapes)
367
+ core_shapes.update(output_sizes)
368
+
369
+ loop_input_dimss = [tuple(f"__loopdim{d}__" for d in range(max_loopdims - n, max_loopdims)) for n in num_loopdims]
370
+ input_dimss = [l + c for l, c in zip(loop_input_dimss, input_coredimss)]
371
+
372
+ loop_output_dims = max(loop_input_dimss, key=len) if loop_input_dimss else tuple()
373
+
374
+ ## Assess input args for same size and chunk sizes
375
+ ### Collect sizes and chunksizes of all dims in all arrays
376
+ dimsizess = {}
377
+ chunksizess = {}
378
+ for dims, shape, chunksizes in zip(input_dimss, input_shapes, input_chunkss):
379
+ for dim, size, chunksize in zip(dims, shape, chunksizes):
380
+ dimsizes = dimsizess.get(dim, [])
381
+ dimsizes.append(size)
382
+ dimsizess[dim] = dimsizes
383
+ chunksizes_ = chunksizess.get(dim, [])
384
+ chunksizes_.append(chunksize)
385
+ chunksizess[dim] = chunksizes_
386
+ ### Assert correct partitioning, for case:
387
+ for dim, sizes in dimsizess.items():
388
+ #### Check that the arrays have same length for same dimensions or dimension `1`
389
+ if set(sizes) | {1} != {1, max(sizes)}:
390
+ raise ValueError(f"Dimension `'{dim}'` with different lengths in arrays")
391
+ if not allow_rechunk:
392
+ chunksizes = chunksizess[dim]
393
+ #### Check if core dimensions consist of only one chunk
394
+ if (dim in core_shapes) and (chunksizes[0][0] < core_shapes[dim]):
395
+ raise ValueError(
396
+ f"Core dimension `'{dim}'` consists of multiple chunks. To fix, rechunk into "
397
+ "a single chunk along this dimension or set `allow_rechunk=True`, but beware "
398
+ "that this may increase memory usage significantly."
399
+ )
400
+ #### Check if loop dimensions consist of same chunksizes, when they have sizes > 1
401
+ relevant_chunksizes = list(unique(c for s, c in zip(sizes, chunksizes) if s > 1))
402
+ if len(relevant_chunksizes) > 1:
403
+ raise ValueError(f"Dimension `'{dim}'` with different chunksize present")
404
+
405
+ ## Apply function - use blockwise here
406
+ arginds = list(concat(zip(args, input_dimss)))
407
+
408
+ ### Use existing `blockwise` but only with loopdims to enforce
409
+ ### concatenation for coredims that appear also at the output
410
+ ### Modifying `blockwise` could improve things here.
411
+ tmp = blockwise(func, loop_output_dims, *arginds, concatenate=True, meta=meta, **kwargs)
412
+
413
+ # NOTE: we likely could just use `meta` instead of `tmp._meta`,
414
+ # but we use it and validate it anyway just to be sure nothing odd has happened.
415
+ metas = tmp._meta
416
+ if nout is None:
417
+ assert not isinstance(metas, (list, tuple)), (
418
+ f"meta changed from single output to multiple output during blockwise: {meta} -> {metas}"
419
+ )
420
+ metas = (metas,)
421
+ else:
422
+ assert isinstance(metas, (list, tuple)), (
423
+ f"meta changed from multiple output to single output during blockwise: {meta} -> {metas}"
424
+ )
425
+ assert len(metas) == nout, f"Number of outputs changed from {nout} to {len(metas)} during blockwise"
426
+
427
+ ## Prepare output shapes
428
+ loop_output_shape = tmp.shape
429
+ loop_output_chunks = tmp.chunks
430
+ keys = list(flatten(tmp.__dask_keys__()))
431
+ name, token = keys[0][0].split("-")
432
+
433
+ ### *) Treat direct output
434
+ if nout is None:
435
+ output_coredimss = [output_coredimss]
436
+
437
+ ## Split output
438
+ leaf_arrs = []
439
+ for i, (ocd, oax, meta) in enumerate(zip(output_coredimss, output_axes, metas)):
440
+ leaf_arr = new_collection(
441
+ GUfuncLeafExpr(
442
+ tmp,
443
+ i,
444
+ name,
445
+ loop_output_chunks,
446
+ core_shapes,
447
+ ocd,
448
+ nout,
449
+ meta,
450
+ loop_output_shape,
451
+ )
452
+ )
453
+ # core_output_shape = tuple(core_shapes[d] for d in ocd)
454
+ # core_chunkinds = len(ocd) * (0,)
455
+ # output_shape = loop_output_shape + core_output_shape
456
+ # output_chunks = loop_output_chunks + core_output_shape
457
+ # leaf_name = "%s_%d-%s" % (name, i, token)
458
+ # leaf_dsk = {
459
+ # (leaf_name,)
460
+ # + key[1:]
461
+ # + core_chunkinds: ((getitem, key, i) if nout else key)
462
+ # for key in keys
463
+ # }
464
+ # graph = HighLevelGraph.from_collections(leaf_name, leaf_dsk, dependencies=[tmp])
465
+ # meta = meta_from_array(meta, len(output_shape))
466
+ # leaf_arr = Array(
467
+ # graph, leaf_name, chunks=output_chunks, shape=output_shape, meta=meta
468
+ # )
469
+
470
+ ### Axes:
471
+ if keepdims:
472
+ slices = len(leaf_arr.shape) * (slice(None),) + len(oax) * (np.newaxis,)
473
+ leaf_arr = leaf_arr[slices]
474
+
475
+ tidcs = [None] * len(leaf_arr.shape)
476
+ for ii, oa in zip(range(-len(oax), 0), oax):
477
+ tidcs[oa] = ii
478
+ j = 0
479
+ for ii in range(len(tidcs)):
480
+ if tidcs[ii] is None:
481
+ tidcs[ii] = j
482
+ j += 1
483
+ leaf_arr = leaf_arr.transpose(tidcs)
484
+ leaf_arrs.append(leaf_arr)
485
+
486
+ return (*leaf_arrs,) if nout else leaf_arrs[0] # Undo *) from above
487
+
488
+
489
+ class GUfuncLeafExpr(ArrayExpr):
490
+ _parameters = [
491
+ "array",
492
+ "i",
493
+ "name_prefix",
494
+ "loop_output_chunks",
495
+ "core_shapes",
496
+ "ocd",
497
+ "nout",
498
+ "input_meta",
499
+ "loop_output_shape",
500
+ ]
501
+
502
+ @functools.cached_property
503
+ def _meta(self):
504
+ return meta_from_array(self.input_meta, len(self._shape))
505
+
506
+ @functools.cached_property
507
+ def _name(self):
508
+ last_name = self.array._name.split("-")[-1]
509
+ return f"{self.name_prefix}_{self.i}-{last_name}"
510
+
511
+ @functools.cached_property
512
+ def _shape(self):
513
+ core_output_shape = tuple(self.core_shapes[d] for d in self.ocd)
514
+ return self.loop_output_shape + core_output_shape
515
+
516
+ @functools.cached_property
517
+ def chunks(self):
518
+ output_chunks = self.loop_output_chunks + tuple(self.core_shapes[d] for d in self.ocd)
519
+ return normalize_chunks(output_chunks, self._shape, dtype=self._meta.dtype)
520
+
521
+ def _layer(self):
522
+ core_chunkinds = len(self.ocd) * (0,)
523
+ leaf_dsk = {
524
+ (self._name,) + key[1:] + core_chunkinds: ((getitem, key, self.i) if self.nout else key)
525
+ for key in list(flatten(self.array.__dask_keys__()))
526
+ }
527
+ return leaf_dsk
528
+
529
+
530
+ class gufunc:
531
+ """
532
+ Binds `pyfunc` into ``dask.array.apply_gufunc`` when called.
533
+
534
+ Parameters
535
+ ----------
536
+ pyfunc : callable
537
+ Function to call like ``func(*args, **kwargs)`` on input arrays
538
+ (``*args``) that returns an array or tuple of arrays. If multiple
539
+ arguments with non-matching dimensions are supplied, this function is
540
+ expected to vectorize (broadcast) over axes of positional arguments in
541
+ the style of NumPy universal functions [1]_ (if this is not the case,
542
+ set ``vectorize=True``). If this function returns multiple outputs,
543
+ ``output_core_dims`` has to be set as well.
544
+ signature : String, keyword only
545
+ Specifies what core dimensions are consumed and produced by ``func``.
546
+ According to the specification of numpy.gufunc signature [2]_
547
+ axes: List of tuples, optional, keyword only
548
+ A list of tuples with indices of axes a generalized ufunc should operate on.
549
+ For instance, for a signature of ``"(i,j),(j,k)->(i,k)"`` appropriate for
550
+ matrix multiplication, the base elements are two-dimensional matrices
551
+ and these are taken to be stored in the two last axes of each argument. The
552
+ corresponding axes keyword would be ``[(-2, -1), (-2, -1), (-2, -1)]``.
553
+ For simplicity, for generalized ufuncs that operate on 1-dimensional arrays
554
+ (vectors), a single integer is accepted instead of a single-element tuple,
555
+ and for generalized ufuncs for which all outputs are scalars, the output
556
+ tuples can be omitted.
557
+ axis: int, optional, keyword only
558
+ A single axis over which a generalized ufunc should operate. This is a short-cut
559
+ for ufuncs that operate over a single, shared core dimension, equivalent to passing
560
+ in axes with entries of (axis,) for each single-core-dimension argument and ``()`` for
561
+ all others. For instance, for a signature ``"(i),(i)->()"``, it is equivalent to passing
562
+ in ``axes=[(axis,), (axis,), ()]``.
563
+ keepdims: bool, optional, keyword only
564
+ If this is set to True, axes which are reduced over will be left in the result as
565
+ a dimension with size one, so that the result will broadcast correctly against the
566
+ inputs. This option can only be used for generalized ufuncs that operate on inputs
567
+ that all have the same number of core dimensions and with outputs that have no core
568
+ dimensions , i.e., with signatures like ``"(i),(i)->()"`` or ``"(m,m)->()"``.
569
+ If used, the location of the dimensions in the output can be controlled with axes
570
+ and axis.
571
+ output_dtypes : Optional, dtype or list of dtypes, keyword only
572
+ Valid numpy dtype specification or list thereof.
573
+ If not given, a call of ``func`` with a small set of data
574
+ is performed in order to try to automatically determine the
575
+ output dtypes.
576
+ output_sizes : dict, optional, keyword only
577
+ Optional mapping from dimension names to sizes for outputs. Only used if
578
+ new core dimensions (not found on inputs) appear on outputs.
579
+ vectorize: bool, keyword only
580
+ If set to ``True``, ``np.vectorize`` is applied to ``func`` for
581
+ convenience. Defaults to ``False``.
582
+ allow_rechunk: Optional, bool, keyword only
583
+ Allows rechunking, otherwise chunk sizes need to match and core
584
+ dimensions are to consist only of one chunk.
585
+ Warning: enabling this can increase memory usage significantly.
586
+ Defaults to ``False``.
587
+ meta: Optional, tuple, keyword only
588
+ tuple of empty ndarrays describing the shape and dtype of the output of the gufunc.
589
+ Defaults to ``None``.
590
+
591
+ Returns
592
+ -------
593
+ Wrapped function
594
+
595
+ Examples
596
+ --------
597
+ >>> import dask_array as da
598
+ >>> import numpy as np
599
+ >>> a = da.random.normal(size=(10,20,30), chunks=(5, 10, 30))
600
+ >>> def stats(x):
601
+ ... return np.mean(x, axis=-1), np.std(x, axis=-1)
602
+ >>> gustats = da.gufunc(stats, signature="(i)->(),()", output_dtypes=(float, float))
603
+ >>> mean, std = gustats(a)
604
+ >>> mean.compute().shape
605
+ (10, 20)
606
+
607
+ >>> a = da.random.normal(size=( 20,30), chunks=(10, 30))
608
+ >>> b = da.random.normal(size=(10, 1,40), chunks=(5, 1, 40))
609
+ >>> def outer_product(x, y):
610
+ ... return np.einsum("i,j->ij", x, y)
611
+ >>> guouter_product = da.gufunc(outer_product, signature="(i),(j)->(i,j)", output_dtypes=float, vectorize=True)
612
+ >>> c = guouter_product(a, b)
613
+ >>> c.compute().shape
614
+ (10, 20, 30, 40)
615
+
616
+ >>> a = da.ones((1, 5, 10), chunks=(-1, -1, -1))
617
+ >>> def stats(x):
618
+ ... return np.atleast_1d(x.mean()), np.atleast_1d(x.max())
619
+ >>> meta = (np.array((), dtype=np.float64), np.array((), dtype=np.float64))
620
+ >>> gustats = da.gufunc(stats, signature="(i,j)->(),()", meta=meta)
621
+ >>> result = gustats(a)
622
+ >>> result[0].compute().shape
623
+ (1,)
624
+ >>> result[1].compute().shape
625
+ (1,)
626
+
627
+ References
628
+ ----------
629
+ .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html
630
+ .. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html
631
+ """
632
+
633
+ def __init__(
634
+ self,
635
+ pyfunc,
636
+ *,
637
+ signature=None,
638
+ vectorize=False,
639
+ axes=None,
640
+ axis=None,
641
+ keepdims=False,
642
+ output_sizes=None,
643
+ output_dtypes=None,
644
+ allow_rechunk=False,
645
+ meta=None,
646
+ ):
647
+ self.pyfunc = pyfunc
648
+ self.signature = signature
649
+ self.vectorize = vectorize
650
+ self.axes = axes
651
+ self.axis = axis
652
+ self.keepdims = keepdims
653
+ self.output_sizes = output_sizes
654
+ self.output_dtypes = output_dtypes
655
+ self.allow_rechunk = allow_rechunk
656
+ self.meta = meta
657
+
658
+ self.__doc__ = f"""
659
+ Bound ``dask.array.gufunc``
660
+ func: ``{self.pyfunc}``
661
+ signature: ``'{self.signature}'``
662
+
663
+ Parameters
664
+ ----------
665
+ *args : numpy/dask arrays or scalars
666
+ Arrays to which to apply to ``func``. Core dimensions as specified in
667
+ ``signature`` need to come last.
668
+ **kwargs : dict
669
+ Extra keyword arguments to pass to ``func``
670
+
671
+ Returns
672
+ -------
673
+ Single dask.array.Array or tuple of dask.array.Array
674
+ """
675
+
676
+ def __call__(self, *args, allow_rechunk=False, **kwargs):
677
+ return apply_gufunc(
678
+ self.pyfunc,
679
+ self.signature,
680
+ *args,
681
+ vectorize=self.vectorize,
682
+ axes=self.axes,
683
+ axis=self.axis,
684
+ keepdims=self.keepdims,
685
+ output_sizes=self.output_sizes,
686
+ output_dtypes=self.output_dtypes,
687
+ allow_rechunk=self.allow_rechunk or allow_rechunk,
688
+ meta=self.meta,
689
+ **kwargs,
690
+ )
691
+
692
+
693
+ def as_gufunc(signature=None, **kwargs):
694
+ """
695
+ Decorator for ``dask.array.gufunc``.
696
+
697
+ Parameters
698
+ ----------
699
+ signature : String
700
+ Specifies what core dimensions are consumed and produced by ``func``.
701
+ According to the specification of numpy.gufunc signature [2]_
702
+ axes: List of tuples, optional, keyword only
703
+ A list of tuples with indices of axes a generalized ufunc should operate on.
704
+ For instance, for a signature of ``"(i,j),(j,k)->(i,k)"`` appropriate for
705
+ matrix multiplication, the base elements are two-dimensional matrices
706
+ and these are taken to be stored in the two last axes of each argument. The
707
+ corresponding axes keyword would be ``[(-2, -1), (-2, -1), (-2, -1)]``.
708
+ For simplicity, for generalized ufuncs that operate on 1-dimensional arrays
709
+ (vectors), a single integer is accepted instead of a single-element tuple,
710
+ and for generalized ufuncs for which all outputs are scalars, the output
711
+ tuples can be omitted.
712
+ axis: int, optional, keyword only
713
+ A single axis over which a generalized ufunc should operate. This is a short-cut
714
+ for ufuncs that operate over a single, shared core dimension, equivalent to passing
715
+ in axes with entries of (axis,) for each single-core-dimension argument and ``()`` for
716
+ all others. For instance, for a signature ``"(i),(i)->()"``, it is equivalent to passing
717
+ in ``axes=[(axis,), (axis,), ()]``.
718
+ keepdims: bool, optional, keyword only
719
+ If this is set to True, axes which are reduced over will be left in the result as
720
+ a dimension with size one, so that the result will broadcast correctly against the
721
+ inputs. This option can only be used for generalized ufuncs that operate on inputs
722
+ that all have the same number of core dimensions and with outputs that have no core
723
+ dimensions , i.e., with signatures like ``"(i),(i)->()"`` or ``"(m,m)->()"``.
724
+ If used, the location of the dimensions in the output can be controlled with axes
725
+ and axis.
726
+ output_dtypes : Optional, dtype or list of dtypes, keyword only
727
+ Valid numpy dtype specification or list thereof.
728
+ If not given, a call of ``func`` with a small set of data
729
+ is performed in order to try to automatically determine the
730
+ output dtypes.
731
+ output_sizes : dict, optional, keyword only
732
+ Optional mapping from dimension names to sizes for outputs. Only used if
733
+ new core dimensions (not found on inputs) appear on outputs.
734
+ vectorize: bool, keyword only
735
+ If set to ``True``, ``np.vectorize`` is applied to ``func`` for
736
+ convenience. Defaults to ``False``.
737
+ allow_rechunk: Optional, bool, keyword only
738
+ Allows rechunking, otherwise chunk sizes need to match and core
739
+ dimensions are to consist only of one chunk.
740
+ Warning: enabling this can increase memory usage significantly.
741
+ Defaults to ``False``.
742
+ meta: Optional, tuple, keyword only
743
+ tuple of empty ndarrays describing the shape and dtype of the output of the gufunc.
744
+ Defaults to ``None``.
745
+
746
+ Returns
747
+ -------
748
+ Decorator for `pyfunc` that itself returns a `gufunc`.
749
+
750
+ Examples
751
+ --------
752
+ >>> import dask_array as da
753
+ >>> import numpy as np
754
+ >>> a = da.random.normal(size=(10,20,30), chunks=(5, 10, 30))
755
+ >>> @da.as_gufunc("(i)->(),()", output_dtypes=(float, float))
756
+ ... def stats(x):
757
+ ... return np.mean(x, axis=-1), np.std(x, axis=-1)
758
+ >>> mean, std = stats(a)
759
+ >>> mean.compute().shape
760
+ (10, 20)
761
+
762
+ >>> a = da.random.normal(size=( 20,30), chunks=(10, 30))
763
+ >>> b = da.random.normal(size=(10, 1,40), chunks=(5, 1, 40))
764
+ >>> @da.as_gufunc("(i),(j)->(i,j)", output_dtypes=float, vectorize=True)
765
+ ... def outer_product(x, y):
766
+ ... return np.einsum("i,j->ij", x, y)
767
+ >>> c = outer_product(a, b)
768
+ >>> c.compute().shape
769
+ (10, 20, 30, 40)
770
+
771
+ References
772
+ ----------
773
+ .. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html
774
+ .. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html
775
+ """
776
+ _allowedkeys = {
777
+ "vectorize",
778
+ "axes",
779
+ "axis",
780
+ "keepdims",
781
+ "output_sizes",
782
+ "output_dtypes",
783
+ "allow_rechunk",
784
+ "meta",
785
+ }
786
+ if kwargs.keys() - _allowedkeys:
787
+ raise TypeError("Unsupported keyword argument(s) provided")
788
+
789
+ def _as_gufunc(pyfunc):
790
+ return gufunc(pyfunc, signature=signature, **kwargs)
791
+
792
+ _as_gufunc.__doc__ = f"""
793
+ Decorator to make ``dask.array.gufunc``
794
+ signature: ``'{signature}'``
795
+
796
+ Parameters
797
+ ----------
798
+ pyfunc : callable
799
+ Function matching signature ``'{signature}'``.
800
+
801
+ Returns
802
+ -------
803
+ ``dask.array.gufunc``
804
+ """
805
+ return _as_gufunc