dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_utils.py ADDED
@@ -0,0 +1,349 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import numbers
5
+ import warnings
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from numpy.exceptions import AxisError
10
+
11
+ from dask.base import is_dask_collection
12
+ from dask.utils import has_keyword
13
+
14
+
15
+ def typename(typ: Any, short: bool = False) -> str:
16
+ """Return the name of a type.
17
+
18
+ Examples
19
+ --------
20
+ >>> typename(int)
21
+ 'int'
22
+ """
23
+ if not isinstance(typ, type):
24
+ return typename(type(typ))
25
+ try:
26
+ if not typ.__module__ or typ.__module__ == "builtins":
27
+ return typ.__name__
28
+ else:
29
+ if short:
30
+ module, *_ = typ.__module__.split(".")
31
+ else:
32
+ module = typ.__module__
33
+ return f"{module}.{typ.__name__}"
34
+ except AttributeError:
35
+ return str(typ)
36
+
37
+
38
+ def is_cupy_type(x) -> bool:
39
+ """Check if x is a CuPy array type."""
40
+ return "cupy" in str(type(x))
41
+
42
+
43
+ def is_arraylike(x) -> bool:
44
+ """Is this object a numpy array or something similar?
45
+
46
+ This function tests specifically for an object that already has
47
+ array attributes (e.g. np.ndarray, dask.array.Array, cupy.ndarray,
48
+ sparse.COO), **NOT** for something that can be coerced into an
49
+ array object (e.g. Python lists and tuples).
50
+
51
+ Examples
52
+ --------
53
+ >>> import numpy as np
54
+ >>> is_arraylike(np.ones(5))
55
+ True
56
+ >>> is_arraylike(np.ones(()))
57
+ True
58
+ >>> is_arraylike(5)
59
+ False
60
+ >>> is_arraylike('cat')
61
+ False
62
+ """
63
+ is_duck_array = hasattr(x, "__array_function__") or hasattr(x, "__array_ufunc__")
64
+
65
+ return bool(
66
+ hasattr(x, "shape")
67
+ and isinstance(x.shape, tuple)
68
+ and hasattr(x, "dtype")
69
+ and not any(is_dask_collection(n) for n in x.shape)
70
+ # We special case scipy.sparse and cupyx.scipy.sparse arrays as having partial
71
+ # support for them is useful in scenarios where we mostly call `map_partitions`
72
+ # or `map_blocks` with scikit-learn functions on dask arrays and dask dataframes.
73
+ and (is_duck_array or "scipy.sparse" in typename(type(x)))
74
+ )
75
+
76
+
77
+ def meta_from_array(x, ndim=None, dtype=None):
78
+ """Normalize an array to appropriate meta object.
79
+
80
+ Parameters
81
+ ----------
82
+ x: array-like, callable
83
+ Either an object that looks sufficiently like a Numpy array,
84
+ or a callable that accepts shape and dtype keywords
85
+ ndim: int
86
+ Number of dimensions of the array
87
+ dtype: Numpy dtype
88
+ A valid input for ``np.dtype``
89
+
90
+ Returns
91
+ -------
92
+ array-like with zero elements of the correct dtype
93
+ """
94
+ # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr)
95
+ # implement a _meta attribute that are incompatible with Dask Array._meta
96
+ if hasattr(x, "_meta") and is_dask_collection(x) and is_arraylike(x):
97
+ x = x._meta
98
+
99
+ if dtype is None and x is None:
100
+ raise ValueError("You must specify the meta or dtype of the array")
101
+
102
+ if np.isscalar(x):
103
+ x = np.array(x)
104
+
105
+ if x is None:
106
+ x = np.ndarray
107
+ elif dtype is None and hasattr(x, "dtype"):
108
+ dtype = x.dtype
109
+
110
+ if isinstance(x, type):
111
+ x = x(shape=(0,) * (ndim or 0), dtype=dtype)
112
+
113
+ if isinstance(x, (list, tuple)):
114
+ ndims = [(0 if isinstance(a, numbers.Number) else a.ndim if hasattr(a, "ndim") else len(a)) for a in x]
115
+ a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)]
116
+ return a if isinstance(x, list) else tuple(x)
117
+
118
+ if not hasattr(x, "shape") or not hasattr(x, "dtype") or not isinstance(x.shape, tuple):
119
+ return x
120
+
121
+ if ndim is None:
122
+ ndim = x.ndim
123
+
124
+ try:
125
+ meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))]
126
+ if meta.ndim != ndim:
127
+ if ndim > x.ndim:
128
+ meta = meta[(Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim))]
129
+ meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))]
130
+ elif ndim == 0:
131
+ meta = meta.sum()
132
+ else:
133
+ meta = meta.reshape((0,) * ndim)
134
+ if meta is np.ma.masked:
135
+ meta = np.ma.array(np.empty((0,) * ndim, dtype=dtype or x.dtype), mask=True)
136
+ except Exception:
137
+ meta = np.empty((0,) * ndim, dtype=dtype or x.dtype)
138
+
139
+ if np.isscalar(meta):
140
+ meta = np.array(meta)
141
+
142
+ if dtype and meta.dtype != dtype:
143
+ try:
144
+ meta = meta.astype(dtype)
145
+ except ValueError as e:
146
+ if (
147
+ any(
148
+ s in str(e)
149
+ for s in [
150
+ "invalid literal",
151
+ "could not convert string to float",
152
+ ]
153
+ )
154
+ and meta.dtype.kind in "SU"
155
+ ):
156
+ meta = np.array([]).astype(dtype)
157
+ else:
158
+ raise e
159
+
160
+ return meta
161
+
162
+
163
+ def validate_axis(axis, ndim):
164
+ """Validate an input to axis= keywords."""
165
+ if isinstance(axis, (tuple, list)):
166
+ return tuple(validate_axis(ax, ndim) for ax in axis)
167
+ if not isinstance(axis, numbers.Integral):
168
+ raise TypeError(f"Axis value must be an integer, got {axis}")
169
+ if axis < -ndim or axis >= ndim:
170
+ raise AxisError(f"Axis {axis} is out of bounds for array of dimension {ndim}")
171
+ if axis < 0:
172
+ axis += ndim
173
+ return axis
174
+
175
+
176
+ def arange_safe(*args, like, **kwargs):
177
+ """Use the `like=` from `np.arange` to create a new array dispatching
178
+ to the downstream library. If that fails, falls back to the
179
+ default NumPy behavior, resulting in a `numpy.ndarray`.
180
+ """
181
+ if like is None:
182
+ return np.arange(*args, **kwargs)
183
+ else:
184
+ try:
185
+ return np.arange(*args, like=meta_from_array(like), **kwargs)
186
+ except TypeError:
187
+ return np.arange(*args, **kwargs)
188
+
189
+
190
+ def _array_like_safe(np_func, da_func, a, like, **kwargs):
191
+ """Helper for array_safe, asarray_safe, asanyarray_safe."""
192
+ from dask_array._collection import Array
193
+
194
+ if like is a and hasattr(a, "__array_function__"):
195
+ return a
196
+
197
+ if isinstance(like, Array):
198
+ return da_func(a, **kwargs)
199
+
200
+ if isinstance(a, Array) and is_cupy_type(a._meta):
201
+ a = a.compute(scheduler="sync")
202
+
203
+ if hasattr(like, "__array_function__"):
204
+ return np_func(a, like=like, **kwargs)
205
+
206
+ if type(like).__module__.startswith("scipy.sparse"):
207
+ # e.g. scipy.sparse.csr_matrix
208
+ kwargs.pop("order", None)
209
+ if np.isscalar(a):
210
+ a = np.array([[a]])
211
+ return type(like)(a, **kwargs)
212
+
213
+ # Unknown namespace with no __array_function__ support.
214
+ # Quietly disregard like= parameter.
215
+ return np_func(a, **kwargs)
216
+
217
+
218
+ def array_safe(a, like, **kwargs):
219
+ """If `a` is `dask_array.Array`, return `dask_array.asarray(a, **kwargs)`,
220
+ otherwise return `np.asarray(a, like=like, **kwargs)`, dispatching
221
+ the call to the library that implements the like array.
222
+ """
223
+ from dask_array.core import array
224
+
225
+ return _array_like_safe(np.array, array, a, like, **kwargs)
226
+
227
+
228
+ def asarray_safe(a, like, **kwargs):
229
+ """If a is dask_array.Array, return dask_array.asarray(a, **kwargs),
230
+ otherwise return np.asarray(a, like=like, **kwargs), dispatching
231
+ the call to the library that implements the like array.
232
+ """
233
+ from dask_array.core import asarray
234
+
235
+ return _array_like_safe(np.asarray, asarray, a, like, **kwargs)
236
+
237
+
238
+ def asanyarray_safe(a, like, **kwargs):
239
+ """If a is dask_array.Array, return dask_array.asanyarray(a, **kwargs),
240
+ otherwise return np.asanyarray(a, like=like, **kwargs), dispatching
241
+ the call to the library that implements the like array.
242
+ """
243
+ from dask_array.core import asanyarray
244
+
245
+ return _array_like_safe(np.asanyarray, asanyarray, a, like, **kwargs)
246
+
247
+
248
+ def svd_flip(u, v, u_based_decision=False):
249
+ """Sign correction to ensure deterministic output from SVD.
250
+
251
+ This function is useful for orienting eigenvectors such that
252
+ they all lie in a shared but arbitrary half-space. This makes
253
+ it possible to ensure that results are equivalent across SVD
254
+ implementations and random number generator states.
255
+
256
+ Parameters
257
+ ----------
258
+ u : (M, K) array_like
259
+ Left singular vectors (in columns)
260
+ v : (K, N) array_like
261
+ Right singular vectors (in rows)
262
+ u_based_decision: bool
263
+ Whether or not to choose signs based
264
+ on `u` rather than `v`, by default False
265
+
266
+ Returns
267
+ -------
268
+ u : (M, K) array_like
269
+ Left singular vectors with corrected sign
270
+ v: (K, N) array_like
271
+ Right singular vectors with corrected sign
272
+ """
273
+ if u_based_decision:
274
+ dtype = u.dtype
275
+ signs = np.sum(u, axis=0, keepdims=True)
276
+ else:
277
+ dtype = v.dtype
278
+ signs = np.sum(v, axis=1, keepdims=True).T
279
+ signs = 2.0 * ((signs >= 0) - 0.5).astype(dtype)
280
+ u, v = u * signs, v * signs.T
281
+ return u, v
282
+
283
+
284
+ def solve_triangular_safe(a, b, lower=False):
285
+ """Solve triangular system using scipy.linalg (or cupyx for GPU)."""
286
+ if is_cupy_type(a):
287
+ import cupyx.scipy.linalg
288
+
289
+ return cupyx.scipy.linalg.solve_triangular(a, b, lower=lower)
290
+ else:
291
+ import scipy.linalg
292
+
293
+ return scipy.linalg.solve_triangular(a, b, lower=lower)
294
+
295
+
296
+ def compute_meta(func, _dtype, *args, **kwargs):
297
+ """Compute metadata for an operation."""
298
+ from dask_array._expr import ArrayExpr
299
+
300
+ with np.errstate(all="ignore"), warnings.catch_warnings():
301
+ warnings.simplefilter("ignore", category=RuntimeWarning)
302
+
303
+ args_meta = [
304
+ (x._meta if isinstance(x, ArrayExpr) else meta_from_array(x) if is_arraylike(x) else x) for x in args
305
+ ]
306
+ kwargs_meta = {
307
+ k: (v._meta if isinstance(v, ArrayExpr) else meta_from_array(v) if is_arraylike(v) else v)
308
+ for k, v in kwargs.items()
309
+ }
310
+
311
+ # todo: look for alternative to this, causes issues when using map_blocks()
312
+ # with np.vectorize, such as dask.array.routines._isnonzero_vec().
313
+ if isinstance(func, np.vectorize):
314
+ meta = func(*args_meta)
315
+ else:
316
+ try:
317
+ # some reduction functions need to know they are computing meta
318
+ if has_keyword(func, "computing_meta"):
319
+ kwargs_meta["computing_meta"] = True
320
+ meta = func(*args_meta, **kwargs_meta)
321
+ except TypeError as e:
322
+ if any(
323
+ s in str(e)
324
+ for s in [
325
+ "unexpected keyword argument",
326
+ "is an invalid keyword for",
327
+ "Did not understand the following kwargs",
328
+ ]
329
+ ):
330
+ raise
331
+ else:
332
+ return None
333
+ except ValueError as e:
334
+ # min/max functions have no identity, just use the same input type when there's only one
335
+ if len(args_meta) == 1 and "zero-size array to reduction operation" in str(e):
336
+ meta = args_meta[0]
337
+ else:
338
+ return None
339
+ except Exception:
340
+ return None
341
+
342
+ if _dtype and getattr(meta, "dtype", None) != _dtype:
343
+ with contextlib.suppress(AttributeError):
344
+ meta = meta.astype(_dtype)
345
+
346
+ if np.isscalar(meta):
347
+ meta = np.array(meta)
348
+
349
+ return meta
@@ -0,0 +1,223 @@
1
+ """Rich-based visualization for array expressions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ import math
7
+ from math import isnan, nan, prod
8
+
9
+ from dask.utils import funcname
10
+
11
+ # Color coding using Tango palette for readability
12
+ # Orange (warm) = sources, new data entering the computation
13
+ # Blue (cool) = reducers, data being reduced/consumed
14
+ SOURCE_COLOR = "#ce5c00" # Tango orange dark
15
+ REDUCER_COLOR = "#3465a4" # Tango sky blue
16
+
17
+
18
+ def format_bytes(nbytes: float) -> str:
19
+ """Format bytes with 2 significant figures."""
20
+ if math.isnan(nbytes):
21
+ return "?"
22
+
23
+ for unit, threshold in [
24
+ ("PiB", 2**50),
25
+ ("TiB", 2**40),
26
+ ("GiB", 2**30),
27
+ ("MiB", 2**20),
28
+ ("kiB", 2**10),
29
+ ]:
30
+ if nbytes >= threshold:
31
+ value = nbytes / threshold
32
+ if value >= 10:
33
+ return f"{value:.0f} {unit}"
34
+ else:
35
+ return f"{value:.1f} {unit}"
36
+ return f"{int(nbytes)} B"
37
+
38
+
39
+ class ExprTable:
40
+ """Wrapper for rich Table with Jupyter and terminal display support."""
41
+
42
+ def __init__(self, table):
43
+ self._table = table
44
+ self._html_cache = None
45
+ self._text_cache = None
46
+
47
+ def _repr_html_(self):
48
+ """Jupyter notebook display."""
49
+ if self._html_cache is None:
50
+ from rich.console import Console
51
+
52
+ console = Console(file=io.StringIO(), record=True, width=120, force_jupyter=False)
53
+ console.print(self._table)
54
+ self._html_cache = console.export_html(inline_styles=True, code_format="<pre>{code}</pre>")
55
+ return self._html_cache
56
+
57
+ def __repr__(self):
58
+ """Terminal display."""
59
+ if self._text_cache is None:
60
+ from rich.console import Console
61
+
62
+ console = Console(file=io.StringIO(), force_terminal=True, force_jupyter=False, width=120)
63
+ console.print(self._table)
64
+ self._text_cache = console.file.getvalue().rstrip()
65
+ return self._text_cache
66
+
67
+ def __str__(self):
68
+ return self.__repr__()
69
+
70
+ def print(self):
71
+ """Print to the current console."""
72
+ from rich.console import Console
73
+
74
+ Console().print(self._table)
75
+
76
+
77
+ def _walk_expr(expr, prefix: str = "", is_last: bool = True):
78
+ """Walk expression tree depth-first, yielding (expr, display_prefix) pairs."""
79
+ yield expr, prefix
80
+
81
+ deps = [op for op in expr.dependencies() if hasattr(op, "chunks")]
82
+
83
+ for i, child in enumerate(deps):
84
+ is_last_child = i == len(deps) - 1
85
+ if prefix == "":
86
+ child_prefix = ""
87
+ else:
88
+ child_prefix = prefix[:-2] + (" " if is_last else "│ ")
89
+ branch = "└ " if is_last_child else "├ "
90
+ yield from _walk_expr(child, child_prefix + branch, is_last_child)
91
+
92
+
93
+ def _compute_row_emphasis(values: list[float], threshold: float = 0.5) -> list[bool]:
94
+ """Compute which rows should be emphasized based on relative values."""
95
+ valid_values = [v for v in values if not math.isnan(v)]
96
+ if not valid_values:
97
+ return [True] * len(values)
98
+
99
+ max_value = max(valid_values)
100
+ if max_value <= 0:
101
+ return [True] * len(values)
102
+
103
+ return [not math.isnan(v) and v > threshold * max_value for v in values]
104
+
105
+
106
+ def _get_op_display_name(node, use_label_for: frozenset) -> str:
107
+ """Get the display name for an operation."""
108
+ class_name = funcname(type(node))
109
+
110
+ if class_name not in use_label_for:
111
+ return class_name
112
+
113
+ # Extract prefix from _name (everything before the hash)
114
+ expr_name = node._name
115
+ if "-" in expr_name:
116
+ parts = expr_name.rsplit("-", 1)
117
+ if len(parts) == 2 and len(parts[1]) >= 8:
118
+ label = parts[0]
119
+ label = label.replace("_", " ")
120
+ for suffix in ["-aggregate", "-partial"]:
121
+ if suffix in label:
122
+ label = label.replace(suffix, "")
123
+ label = label.replace("-", " ").strip()
124
+ return label.title()
125
+
126
+ return class_name
127
+
128
+
129
+ def _get_op_color(node) -> str | None:
130
+ """Determine operation color based on class hierarchy and data flow."""
131
+ from dask_array._expr import ArrayExpr
132
+ from dask_array.reductions._reduction import PartialReduce
133
+ from dask_array.slicing._basic import Slice
134
+
135
+ # Sources: no ArrayExpr dependencies (data enters here)
136
+ deps = [op for op in node.operands if isinstance(op, ArrayExpr)]
137
+ if not deps:
138
+ return SOURCE_COLOR
139
+
140
+ # Reducers: PartialReduce or Slice subclasses (data shrinks here)
141
+ if isinstance(node, (PartialReduce, Slice)):
142
+ return REDUCER_COLOR
143
+
144
+ return None
145
+
146
+
147
+ def _get_nbytes(node) -> float:
148
+ """Get the number of bytes for an expression, or NaN if unknown."""
149
+ try:
150
+ shape = node.shape
151
+ if any(isnan(s) for s in shape):
152
+ return nan
153
+ return prod(shape) * node.dtype.itemsize
154
+ except Exception:
155
+ return nan
156
+
157
+
158
+ # Operations where we prefer showing the _name prefix as the primary name
159
+ _USE_LABEL_AS_NAME = frozenset({"Blockwise", "PartialReduce", "Elemwise", "Random", "SliceSlicesIntegers"})
160
+
161
+
162
+ def expr_table(expr, color: bool = True) -> ExprTable:
163
+ """
164
+ Display expression tree as a table.
165
+
166
+ Parameters
167
+ ----------
168
+ expr : ArrayExpr
169
+ The expression to visualize
170
+ color : bool
171
+ Whether to color-code operations by type
172
+
173
+ Returns
174
+ -------
175
+ ExprTable
176
+ A displayable table object (works in Jupyter and terminal)
177
+ """
178
+ from rich.table import Table
179
+ from rich.text import Text
180
+
181
+ table = Table(
182
+ show_header=True,
183
+ header_style="dim",
184
+ box=None,
185
+ padding=(0, 2),
186
+ collapse_padding=True,
187
+ )
188
+
189
+ table.add_column("Operation", no_wrap=True)
190
+ table.add_column("Shape", justify="right", no_wrap=True)
191
+ table.add_column("Bytes", justify="right", no_wrap=True)
192
+ table.add_column("Chunks", justify="right", no_wrap=True)
193
+
194
+ # Collect nodes and compute emphasis based on bytes
195
+ nodes_and_prefixes = list(_walk_expr(expr))
196
+ node_bytes = [_get_nbytes(node) for node, _ in nodes_and_prefixes]
197
+ row_emphasis = _compute_row_emphasis(node_bytes)
198
+
199
+ for (node, prefix), nbytes, emphasize in zip(nodes_and_prefixes, node_bytes, row_emphasis):
200
+ display_name = _get_op_display_name(node, _USE_LABEL_AS_NAME)
201
+ data_style = None if color and emphasize else "dim"
202
+
203
+ if color:
204
+ op_color = _get_op_color(node)
205
+ op_style = f"bold {op_color}" if op_color else "bold"
206
+ op_text = Text()
207
+ op_text.append(prefix, style="dim")
208
+ op_text.append(display_name, style=op_style)
209
+ else:
210
+ op_text = f"{prefix}{display_name}"
211
+
212
+ # Format shape and chunks
213
+ shape_str = "()" if not node.shape else f"({', '.join(str(s) for s in node.shape)})"
214
+ chunks_str = "×".join(str(c[0] if c else 0) for c in node.chunks) if node.chunks else ""
215
+
216
+ table.add_row(
217
+ op_text,
218
+ Text(shape_str, style=data_style),
219
+ Text(format_bytes(nbytes), style=data_style),
220
+ Text(chunks_str, style=data_style),
221
+ )
222
+
223
+ return ExprTable(table)