dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_utils.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import numbers
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy.exceptions import AxisError
|
|
10
|
+
|
|
11
|
+
from dask.base import is_dask_collection
|
|
12
|
+
from dask.utils import has_keyword
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def typename(typ: Any, short: bool = False) -> str:
|
|
16
|
+
"""Return the name of a type.
|
|
17
|
+
|
|
18
|
+
Examples
|
|
19
|
+
--------
|
|
20
|
+
>>> typename(int)
|
|
21
|
+
'int'
|
|
22
|
+
"""
|
|
23
|
+
if not isinstance(typ, type):
|
|
24
|
+
return typename(type(typ))
|
|
25
|
+
try:
|
|
26
|
+
if not typ.__module__ or typ.__module__ == "builtins":
|
|
27
|
+
return typ.__name__
|
|
28
|
+
else:
|
|
29
|
+
if short:
|
|
30
|
+
module, *_ = typ.__module__.split(".")
|
|
31
|
+
else:
|
|
32
|
+
module = typ.__module__
|
|
33
|
+
return f"{module}.{typ.__name__}"
|
|
34
|
+
except AttributeError:
|
|
35
|
+
return str(typ)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def is_cupy_type(x) -> bool:
|
|
39
|
+
"""Check if x is a CuPy array type."""
|
|
40
|
+
return "cupy" in str(type(x))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_arraylike(x) -> bool:
|
|
44
|
+
"""Is this object a numpy array or something similar?
|
|
45
|
+
|
|
46
|
+
This function tests specifically for an object that already has
|
|
47
|
+
array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray,
|
|
48
|
+
sparse.COO), **NOT** for something that can be coerced into an
|
|
49
|
+
array object (e.g. Python lists and tuples).
|
|
50
|
+
|
|
51
|
+
Examples
|
|
52
|
+
--------
|
|
53
|
+
>>> import numpy as np
|
|
54
|
+
>>> is_arraylike(np.ones(5))
|
|
55
|
+
True
|
|
56
|
+
>>> is_arraylike(np.ones(()))
|
|
57
|
+
True
|
|
58
|
+
>>> is_arraylike(5)
|
|
59
|
+
False
|
|
60
|
+
>>> is_arraylike('cat')
|
|
61
|
+
False
|
|
62
|
+
"""
|
|
63
|
+
is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__")
|
|
64
|
+
|
|
65
|
+
return bool(
|
|
66
|
+
hasattr(x, "shape")
|
|
67
|
+
and isinstance(x.shape, tuple)
|
|
68
|
+
and hasattr(x, "dtype")
|
|
69
|
+
and not any(is_dask_collection(n) for n in x.shape)
|
|
70
|
+
# We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial
|
|
71
|
+
# support for them is useful in scenarios where we mostly call `map_partitions`
|
|
72
|
+
# or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes.
|
|
73
|
+
and (is_duck_array or "scipy.sparse" in typename(type(x)))
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def meta_from_array(x, ndim=None, dtype=None):
|
|
78
|
+
"""Normalize an array to appropriate meta object.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
x: array-like, callable
|
|
83
|
+
Either an object that looks sufficiently like a Numpy array,
|
|
84
|
+
or a callable that accepts shape and dtype keywords
|
|
85
|
+
ndim: int
|
|
86
|
+
Number of dimensions of the array
|
|
87
|
+
dtype: Numpy dtype
|
|
88
|
+
A valid input for ``np.dtype``
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
array-like with zero elements of the correct dtype
|
|
93
|
+
"""
|
|
94
|
+
# If using x._meta, x must be a Dask Array, some libraries (e.g. zarr)
|
|
95
|
+
# implement a _meta attribute that are incompatible with Dask Array._meta
|
|
96
|
+
if hasattr(x, "_meta") and is_dask_collection(x) and is_arraylike(x):
|
|
97
|
+
x = x._meta
|
|
98
|
+
|
|
99
|
+
if dtype is None and x is None:
|
|
100
|
+
raise ValueError("You must specify the meta or dtype of the array")
|
|
101
|
+
|
|
102
|
+
if np.isscalar(x):
|
|
103
|
+
x = np.array(x)
|
|
104
|
+
|
|
105
|
+
if x is None:
|
|
106
|
+
x = np.ndarray
|
|
107
|
+
elif dtype is None and hasattr(x, "dtype"):
|
|
108
|
+
dtype = x.dtype
|
|
109
|
+
|
|
110
|
+
if isinstance(x, type):
|
|
111
|
+
x = x(shape=(0,) * (ndim or 0), dtype=dtype)
|
|
112
|
+
|
|
113
|
+
if isinstance(x, (list, tuple)):
|
|
114
|
+
ndims = [(0 if isinstance(a, numbers.Number) else a.ndim if hasattr(a, "ndim") else len(a)) for a in x]
|
|
115
|
+
a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)]
|
|
116
|
+
return a if isinstance(x, list) else tuple(x)
|
|
117
|
+
|
|
118
|
+
if not hasattr(x, "shape") or not hasattr(x, "dtype") or not isinstance(x.shape, tuple):
|
|
119
|
+
return x
|
|
120
|
+
|
|
121
|
+
if ndim is None:
|
|
122
|
+
ndim = x.ndim
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))]
|
|
126
|
+
if meta.ndim != ndim:
|
|
127
|
+
if ndim > x.ndim:
|
|
128
|
+
meta = meta[(Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim))]
|
|
129
|
+
meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))]
|
|
130
|
+
elif ndim == 0:
|
|
131
|
+
meta = meta.sum()
|
|
132
|
+
else:
|
|
133
|
+
meta = meta.reshape((0,) * ndim)
|
|
134
|
+
if meta is np.ma.masked:
|
|
135
|
+
meta = np.ma.array(np.empty((0,) * ndim, dtype=dtype or x.dtype), mask=True)
|
|
136
|
+
except Exception:
|
|
137
|
+
meta = np.empty((0,) * ndim, dtype=dtype or x.dtype)
|
|
138
|
+
|
|
139
|
+
if np.isscalar(meta):
|
|
140
|
+
meta = np.array(meta)
|
|
141
|
+
|
|
142
|
+
if dtype and meta.dtype != dtype:
|
|
143
|
+
try:
|
|
144
|
+
meta = meta.astype(dtype)
|
|
145
|
+
except ValueError as e:
|
|
146
|
+
if (
|
|
147
|
+
any(
|
|
148
|
+
s in str(e)
|
|
149
|
+
for s in [
|
|
150
|
+
"invalid literal",
|
|
151
|
+
"could not convert string to float",
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
and meta.dtype.kind in "SU"
|
|
155
|
+
):
|
|
156
|
+
meta = np.array([]).astype(dtype)
|
|
157
|
+
else:
|
|
158
|
+
raise e
|
|
159
|
+
|
|
160
|
+
return meta
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def validate_axis(axis, ndim):
|
|
164
|
+
"""Validate an input to axis= keywords."""
|
|
165
|
+
if isinstance(axis, (tuple, list)):
|
|
166
|
+
return tuple(validate_axis(ax, ndim) for ax in axis)
|
|
167
|
+
if not isinstance(axis, numbers.Integral):
|
|
168
|
+
raise TypeError(f"Axis value must be an integer, got {axis}")
|
|
169
|
+
if axis < -ndim or axis >= ndim:
|
|
170
|
+
raise AxisError(f"Axis {axis} is out of bounds for array of dimension {ndim}")
|
|
171
|
+
if axis < 0:
|
|
172
|
+
axis += ndim
|
|
173
|
+
return axis
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def arange_safe(*args, like, **kwargs):
|
|
177
|
+
"""Use the `like=` from `np.arange` to create a new array dispatching
|
|
178
|
+
to the downstream library. If that fails, falls back to the
|
|
179
|
+
default NumPy behavior, resulting in a `numpy.ndarray`.
|
|
180
|
+
"""
|
|
181
|
+
if like is None:
|
|
182
|
+
return np.arange(*args, **kwargs)
|
|
183
|
+
else:
|
|
184
|
+
try:
|
|
185
|
+
return np.arange(*args, like=meta_from_array(like), **kwargs)
|
|
186
|
+
except TypeError:
|
|
187
|
+
return np.arange(*args, **kwargs)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _array_like_safe(np_func, da_func, a, like, **kwargs):
|
|
191
|
+
"""Helper for array_safe, asarray_safe, asanyarray_safe."""
|
|
192
|
+
from dask_array._collection import Array
|
|
193
|
+
|
|
194
|
+
if like is a and hasattr(a, "__array_function__"):
|
|
195
|
+
return a
|
|
196
|
+
|
|
197
|
+
if isinstance(like, Array):
|
|
198
|
+
return da_func(a, **kwargs)
|
|
199
|
+
|
|
200
|
+
if isinstance(a, Array) and is_cupy_type(a._meta):
|
|
201
|
+
a = a.compute(scheduler="sync")
|
|
202
|
+
|
|
203
|
+
if hasattr(like, "__array_function__"):
|
|
204
|
+
return np_func(a, like=like, **kwargs)
|
|
205
|
+
|
|
206
|
+
if type(like).__module__.startswith("scipy.sparse"):
|
|
207
|
+
# e.g. scipy.sparse.csr_matrix
|
|
208
|
+
kwargs.pop("order", None)
|
|
209
|
+
if np.isscalar(a):
|
|
210
|
+
a = np.array([[a]])
|
|
211
|
+
return type(like)(a, **kwargs)
|
|
212
|
+
|
|
213
|
+
# Unknown namespace with no __array_function__ support.
|
|
214
|
+
# Quietly disregard like= parameter.
|
|
215
|
+
return np_func(a, **kwargs)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def array_safe(a, like, **kwargs):
|
|
219
|
+
"""If `a` is `dask_array.Array`, return `dask_array.asarray(a, **kwargs)`,
|
|
220
|
+
otherwise return `np.asarray(a, like=like, **kwargs)`, dispatching
|
|
221
|
+
the call to the library that implements the like array.
|
|
222
|
+
"""
|
|
223
|
+
from dask_array.core import array
|
|
224
|
+
|
|
225
|
+
return _array_like_safe(np.array, array, a, like, **kwargs)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def asarray_safe(a, like, **kwargs):
|
|
229
|
+
"""If a is dask_array.Array, return dask_array.asarray(a, **kwargs),
|
|
230
|
+
otherwise return np.asarray(a, like=like, **kwargs), dispatching
|
|
231
|
+
the call to the library that implements the like array.
|
|
232
|
+
"""
|
|
233
|
+
from dask_array.core import asarray
|
|
234
|
+
|
|
235
|
+
return _array_like_safe(np.asarray, asarray, a, like, **kwargs)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def asanyarray_safe(a, like, **kwargs):
|
|
239
|
+
"""If a is dask_array.Array, return dask_array.asanyarray(a, **kwargs),
|
|
240
|
+
otherwise return np.asanyarray(a, like=like, **kwargs), dispatching
|
|
241
|
+
the call to the library that implements the like array.
|
|
242
|
+
"""
|
|
243
|
+
from dask_array.core import asanyarray
|
|
244
|
+
|
|
245
|
+
return _array_like_safe(np.asanyarray, asanyarray, a, like, **kwargs)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def svd_flip(u, v, u_based_decision=False):
|
|
249
|
+
"""Sign correction to ensure deterministic output from SVD.
|
|
250
|
+
|
|
251
|
+
This function is useful for orienting eigenvectors such that
|
|
252
|
+
they all lie in a shared but arbitrary half-space. This makes
|
|
253
|
+
it possible to ensure that results are equivalent across SVD
|
|
254
|
+
implementations and random number generator states.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
u : (M, K) array_like
|
|
259
|
+
Left singular vectors (in columns)
|
|
260
|
+
v : (K, N) array_like
|
|
261
|
+
Right singular vectors (in rows)
|
|
262
|
+
u_based_decision: bool
|
|
263
|
+
Whether or not to choose signs based
|
|
264
|
+
on `u` rather than `v`, by default False
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
u : (M, K) array_like
|
|
269
|
+
Left singular vectors with corrected sign
|
|
270
|
+
v: (K, N) array_like
|
|
271
|
+
Right singular vectors with corrected sign
|
|
272
|
+
"""
|
|
273
|
+
if u_based_decision:
|
|
274
|
+
dtype = u.dtype
|
|
275
|
+
signs = np.sum(u, axis=0, keepdims=True)
|
|
276
|
+
else:
|
|
277
|
+
dtype = v.dtype
|
|
278
|
+
signs = np.sum(v, axis=1, keepdims=True).T
|
|
279
|
+
signs = 2.0 * ((signs >= 0) - 0.5).astype(dtype)
|
|
280
|
+
u, v = u * signs, v * signs.T
|
|
281
|
+
return u, v
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def solve_triangular_safe(a, b, lower=False):
|
|
285
|
+
"""Solve triangular system using scipy.linalg (or cupyx for GPU)."""
|
|
286
|
+
if is_cupy_type(a):
|
|
287
|
+
import cupyx.scipy.linalg
|
|
288
|
+
|
|
289
|
+
return cupyx.scipy.linalg.solve_triangular(a, b, lower=lower)
|
|
290
|
+
else:
|
|
291
|
+
import scipy.linalg
|
|
292
|
+
|
|
293
|
+
return scipy.linalg.solve_triangular(a, b, lower=lower)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def compute_meta(func, _dtype, *args, **kwargs):
|
|
297
|
+
"""Compute metadata for an operation."""
|
|
298
|
+
from dask_array._expr import ArrayExpr
|
|
299
|
+
|
|
300
|
+
with np.errstate(all="ignore"), warnings.catch_warnings():
|
|
301
|
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
|
302
|
+
|
|
303
|
+
args_meta = [
|
|
304
|
+
(x._meta if isinstance(x, ArrayExpr) else meta_from_array(x) if is_arraylike(x) else x) for x in args
|
|
305
|
+
]
|
|
306
|
+
kwargs_meta = {
|
|
307
|
+
k: (v._meta if isinstance(v, ArrayExpr) else meta_from_array(v) if is_arraylike(v) else v)
|
|
308
|
+
for k, v in kwargs.items()
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
# todo: look for alternative to this, causes issues when using map_blocks()
|
|
312
|
+
# with np.vectorize, such as dask.array.routines._isnonzero_vec().
|
|
313
|
+
if isinstance(func, np.vectorize):
|
|
314
|
+
meta = func(*args_meta)
|
|
315
|
+
else:
|
|
316
|
+
try:
|
|
317
|
+
# some reduction functions need to know they are computing meta
|
|
318
|
+
if has_keyword(func, "computing_meta"):
|
|
319
|
+
kwargs_meta["computing_meta"] = True
|
|
320
|
+
meta = func(*args_meta, **kwargs_meta)
|
|
321
|
+
except TypeError as e:
|
|
322
|
+
if any(
|
|
323
|
+
s in str(e)
|
|
324
|
+
for s in [
|
|
325
|
+
"unexpected keyword argument",
|
|
326
|
+
"is an invalid keyword for",
|
|
327
|
+
"Did not understand the following kwargs",
|
|
328
|
+
]
|
|
329
|
+
):
|
|
330
|
+
raise
|
|
331
|
+
else:
|
|
332
|
+
return None
|
|
333
|
+
except ValueError as e:
|
|
334
|
+
# min/max functions have no identity, just use the same input type when there's only one
|
|
335
|
+
if len(args_meta) == 1 and "zero-size array to reduction operation" in str(e):
|
|
336
|
+
meta = args_meta[0]
|
|
337
|
+
else:
|
|
338
|
+
return None
|
|
339
|
+
except Exception:
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
if _dtype and getattr(meta, "dtype", None) != _dtype:
|
|
343
|
+
with contextlib.suppress(AttributeError):
|
|
344
|
+
meta = meta.astype(_dtype)
|
|
345
|
+
|
|
346
|
+
if np.isscalar(meta):
|
|
347
|
+
meta = np.array(meta)
|
|
348
|
+
|
|
349
|
+
return meta
|
dask_array/_visualize.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Rich-based visualization for array expressions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import math
|
|
7
|
+
from math import isnan, nan, prod
|
|
8
|
+
|
|
9
|
+
from dask.utils import funcname
|
|
10
|
+
|
|
11
|
+
# Color coding using Tango palette for readability
|
|
12
|
+
# Orange (warm) = sources, new data entering the computation
|
|
13
|
+
# Blue (cool) = reducers, data being reduced/consumed
|
|
14
|
+
SOURCE_COLOR = "#ce5c00" # Tango orange dark
|
|
15
|
+
REDUCER_COLOR = "#3465a4" # Tango sky blue
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def format_bytes(nbytes: float) -> str:
|
|
19
|
+
"""Format bytes with 2 significant figures."""
|
|
20
|
+
if math.isnan(nbytes):
|
|
21
|
+
return "?"
|
|
22
|
+
|
|
23
|
+
for unit, threshold in [
|
|
24
|
+
("PiB", 2**50),
|
|
25
|
+
("TiB", 2**40),
|
|
26
|
+
("GiB", 2**30),
|
|
27
|
+
("MiB", 2**20),
|
|
28
|
+
("kiB", 2**10),
|
|
29
|
+
]:
|
|
30
|
+
if nbytes >= threshold:
|
|
31
|
+
value = nbytes / threshold
|
|
32
|
+
if value >= 10:
|
|
33
|
+
return f"{value:.0f} {unit}"
|
|
34
|
+
else:
|
|
35
|
+
return f"{value:.1f} {unit}"
|
|
36
|
+
return f"{int(nbytes)} B"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ExprTable:
|
|
40
|
+
"""Wrapper for rich Table with Jupyter and terminal display support."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, table):
|
|
43
|
+
self._table = table
|
|
44
|
+
self._html_cache = None
|
|
45
|
+
self._text_cache = None
|
|
46
|
+
|
|
47
|
+
def _repr_html_(self):
|
|
48
|
+
"""Jupyter notebook display."""
|
|
49
|
+
if self._html_cache is None:
|
|
50
|
+
from rich.console import Console
|
|
51
|
+
|
|
52
|
+
console = Console(file=io.StringIO(), record=True, width=120, force_jupyter=False)
|
|
53
|
+
console.print(self._table)
|
|
54
|
+
self._html_cache = console.export_html(inline_styles=True, code_format="<pre>{code}</pre>")
|
|
55
|
+
return self._html_cache
|
|
56
|
+
|
|
57
|
+
def __repr__(self):
|
|
58
|
+
"""Terminal display."""
|
|
59
|
+
if self._text_cache is None:
|
|
60
|
+
from rich.console import Console
|
|
61
|
+
|
|
62
|
+
console = Console(file=io.StringIO(), force_terminal=True, force_jupyter=False, width=120)
|
|
63
|
+
console.print(self._table)
|
|
64
|
+
self._text_cache = console.file.getvalue().rstrip()
|
|
65
|
+
return self._text_cache
|
|
66
|
+
|
|
67
|
+
def __str__(self):
|
|
68
|
+
return self.__repr__()
|
|
69
|
+
|
|
70
|
+
def print(self):
|
|
71
|
+
"""Print to the current console."""
|
|
72
|
+
from rich.console import Console
|
|
73
|
+
|
|
74
|
+
Console().print(self._table)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _walk_expr(expr, prefix: str = "", is_last: bool = True):
|
|
78
|
+
"""Walk expression tree depth-first, yielding (expr, display_prefix) pairs."""
|
|
79
|
+
yield expr, prefix
|
|
80
|
+
|
|
81
|
+
deps = [op for op in expr.dependencies() if hasattr(op, "chunks")]
|
|
82
|
+
|
|
83
|
+
for i, child in enumerate(deps):
|
|
84
|
+
is_last_child = i == len(deps) - 1
|
|
85
|
+
if prefix == "":
|
|
86
|
+
child_prefix = ""
|
|
87
|
+
else:
|
|
88
|
+
child_prefix = prefix[:-2] + (" " if is_last else "│ ")
|
|
89
|
+
branch = "└ " if is_last_child else "├ "
|
|
90
|
+
yield from _walk_expr(child, child_prefix + branch, is_last_child)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _compute_row_emphasis(values: list[float], threshold: float = 0.5) -> list[bool]:
|
|
94
|
+
"""Compute which rows should be emphasized based on relative values."""
|
|
95
|
+
valid_values = [v for v in values if not math.isnan(v)]
|
|
96
|
+
if not valid_values:
|
|
97
|
+
return [True] * len(values)
|
|
98
|
+
|
|
99
|
+
max_value = max(valid_values)
|
|
100
|
+
if max_value <= 0:
|
|
101
|
+
return [True] * len(values)
|
|
102
|
+
|
|
103
|
+
return [not math.isnan(v) and v > threshold * max_value for v in values]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_op_display_name(node, use_label_for: frozenset) -> str:
|
|
107
|
+
"""Get the display name for an operation."""
|
|
108
|
+
class_name = funcname(type(node))
|
|
109
|
+
|
|
110
|
+
if class_name not in use_label_for:
|
|
111
|
+
return class_name
|
|
112
|
+
|
|
113
|
+
# Extract prefix from _name (everything before the hash)
|
|
114
|
+
expr_name = node._name
|
|
115
|
+
if "-" in expr_name:
|
|
116
|
+
parts = expr_name.rsplit("-", 1)
|
|
117
|
+
if len(parts) == 2 and len(parts[1]) >= 8:
|
|
118
|
+
label = parts[0]
|
|
119
|
+
label = label.replace("_", " ")
|
|
120
|
+
for suffix in ["-aggregate", "-partial"]:
|
|
121
|
+
if suffix in label:
|
|
122
|
+
label = label.replace(suffix, "")
|
|
123
|
+
label = label.replace("-", " ").strip()
|
|
124
|
+
return label.title()
|
|
125
|
+
|
|
126
|
+
return class_name
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _get_op_color(node) -> str | None:
|
|
130
|
+
"""Determine operation color based on class hierarchy and data flow."""
|
|
131
|
+
from dask_array._expr import ArrayExpr
|
|
132
|
+
from dask_array.reductions._reduction import PartialReduce
|
|
133
|
+
from dask_array.slicing._basic import Slice
|
|
134
|
+
|
|
135
|
+
# Sources: no ArrayExpr dependencies (data enters here)
|
|
136
|
+
deps = [op for op in node.operands if isinstance(op, ArrayExpr)]
|
|
137
|
+
if not deps:
|
|
138
|
+
return SOURCE_COLOR
|
|
139
|
+
|
|
140
|
+
# Reducers: PartialReduce or Slice subclasses (data shrinks here)
|
|
141
|
+
if isinstance(node, (PartialReduce, Slice)):
|
|
142
|
+
return REDUCER_COLOR
|
|
143
|
+
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _get_nbytes(node) -> float:
|
|
148
|
+
"""Get the number of bytes for an expression, or NaN if unknown."""
|
|
149
|
+
try:
|
|
150
|
+
shape = node.shape
|
|
151
|
+
if any(isnan(s) for s in shape):
|
|
152
|
+
return nan
|
|
153
|
+
return prod(shape) * node.dtype.itemsize
|
|
154
|
+
except Exception:
|
|
155
|
+
return nan
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Operations where we prefer showing the _name prefix as the primary name
|
|
159
|
+
_USE_LABEL_AS_NAME = frozenset({"Blockwise", "PartialReduce", "Elemwise", "Random", "SliceSlicesIntegers"})
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def expr_table(expr, color: bool = True) -> ExprTable:
|
|
163
|
+
"""
|
|
164
|
+
Display expression tree as a table.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
expr : ArrayExpr
|
|
169
|
+
The expression to visualize
|
|
170
|
+
color : bool
|
|
171
|
+
Whether to color-code operations by type
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
ExprTable
|
|
176
|
+
A displayable table object (works in Jupyter and terminal)
|
|
177
|
+
"""
|
|
178
|
+
from rich.table import Table
|
|
179
|
+
from rich.text import Text
|
|
180
|
+
|
|
181
|
+
table = Table(
|
|
182
|
+
show_header=True,
|
|
183
|
+
header_style="dim",
|
|
184
|
+
box=None,
|
|
185
|
+
padding=(0, 2),
|
|
186
|
+
collapse_padding=True,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
table.add_column("Operation", no_wrap=True)
|
|
190
|
+
table.add_column("Shape", justify="right", no_wrap=True)
|
|
191
|
+
table.add_column("Bytes", justify="right", no_wrap=True)
|
|
192
|
+
table.add_column("Chunks", justify="right", no_wrap=True)
|
|
193
|
+
|
|
194
|
+
# Collect nodes and compute emphasis based on bytes
|
|
195
|
+
nodes_and_prefixes = list(_walk_expr(expr))
|
|
196
|
+
node_bytes = [_get_nbytes(node) for node, _ in nodes_and_prefixes]
|
|
197
|
+
row_emphasis = _compute_row_emphasis(node_bytes)
|
|
198
|
+
|
|
199
|
+
for (node, prefix), nbytes, emphasize in zip(nodes_and_prefixes, node_bytes, row_emphasis):
|
|
200
|
+
display_name = _get_op_display_name(node, _USE_LABEL_AS_NAME)
|
|
201
|
+
data_style = None if color and emphasize else "dim"
|
|
202
|
+
|
|
203
|
+
if color:
|
|
204
|
+
op_color = _get_op_color(node)
|
|
205
|
+
op_style = f"bold {op_color}" if op_color else "bold"
|
|
206
|
+
op_text = Text()
|
|
207
|
+
op_text.append(prefix, style="dim")
|
|
208
|
+
op_text.append(display_name, style=op_style)
|
|
209
|
+
else:
|
|
210
|
+
op_text = f"{prefix}{display_name}"
|
|
211
|
+
|
|
212
|
+
# Format shape and chunks
|
|
213
|
+
shape_str = "()" if not node.shape else f"({', '.join(str(s) for s in node.shape)})"
|
|
214
|
+
chunks_str = "×".join(str(c[0] if c else 0) for c in node.chunks) if node.chunks else ""
|
|
215
|
+
|
|
216
|
+
table.add_row(
|
|
217
|
+
op_text,
|
|
218
|
+
Text(shape_str, style=data_style),
|
|
219
|
+
Text(format_bytes(nbytes), style=data_style),
|
|
220
|
+
Text(chunks_str, style=data_style),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return ExprTable(table)
|