dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_expr_flow.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
"""Expression flow visualization - shows data transformation pipeline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from math import prod
|
|
7
|
+
|
|
8
|
+
from dask.utils import funcname
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class FlowNode:
|
|
13
|
+
"""A node in the flow visualization representing a unique shape."""
|
|
14
|
+
|
|
15
|
+
shape: tuple
|
|
16
|
+
chunks: tuple
|
|
17
|
+
operations: list[str] = field(default_factory=list)
|
|
18
|
+
expressions: list = field(default_factory=list) # Original expr objects
|
|
19
|
+
nbytes: int = 0
|
|
20
|
+
row: int = 0
|
|
21
|
+
col: int = 0
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def ndim(self):
|
|
25
|
+
return len(self.shape)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class FlowEdge:
|
|
30
|
+
"""Connection between flow nodes."""
|
|
31
|
+
|
|
32
|
+
source: FlowNode
|
|
33
|
+
target: FlowNode
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_operation_name(expr) -> str:
|
|
37
|
+
"""Get a user-friendly operation name from an expression."""
|
|
38
|
+
class_name = funcname(type(expr))
|
|
39
|
+
|
|
40
|
+
# Special cases for nicer names
|
|
41
|
+
if class_name == "FromArray":
|
|
42
|
+
return "Load"
|
|
43
|
+
|
|
44
|
+
# Try to extract meaningful name from _name attribute
|
|
45
|
+
if hasattr(expr, "_name"):
|
|
46
|
+
name = expr._name
|
|
47
|
+
if "-" in name:
|
|
48
|
+
prefix = name.rsplit("-", 1)[0]
|
|
49
|
+
# Clean up common patterns
|
|
50
|
+
prefix = prefix.replace("_", " ")
|
|
51
|
+
prefix = prefix.replace("-aggregate", "")
|
|
52
|
+
prefix = prefix.replace("-partial", "")
|
|
53
|
+
prefix = prefix.strip().replace("-", " ")
|
|
54
|
+
if prefix:
|
|
55
|
+
# Capitalize nicely
|
|
56
|
+
return prefix.title()
|
|
57
|
+
|
|
58
|
+
return class_name
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_reduction_intermediate(expr) -> bool:
|
|
62
|
+
"""Check if this expression is an intermediate reduction shape.
|
|
63
|
+
|
|
64
|
+
Conservative filter - only catches the most obvious tree_reduce intermediates.
|
|
65
|
+
"""
|
|
66
|
+
from dask_array.reductions._reduction import PartialReduce
|
|
67
|
+
|
|
68
|
+
if not isinstance(expr, PartialReduce):
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
shape = expr.shape
|
|
72
|
+
if not shape:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
# Only filter if ALL dimensions are small (clearly chunk counts, not user data)
|
|
76
|
+
CHUNK_COUNT_MAX = 16
|
|
77
|
+
return all(0 < d <= CHUNK_COUNT_MAX for d in shape)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _walk_expr_tree(expr, visited=None):
|
|
81
|
+
"""Walk expression tree depth-first, yielding expressions from leaves to root."""
|
|
82
|
+
if visited is None:
|
|
83
|
+
visited = set()
|
|
84
|
+
|
|
85
|
+
expr_id = id(expr)
|
|
86
|
+
if expr_id in visited:
|
|
87
|
+
return
|
|
88
|
+
visited.add(expr_id)
|
|
89
|
+
|
|
90
|
+
# Get array dependencies
|
|
91
|
+
deps = [op for op in expr.dependencies() if hasattr(op, "chunks")]
|
|
92
|
+
|
|
93
|
+
# Visit children first (leaves to root)
|
|
94
|
+
for dep in deps:
|
|
95
|
+
yield from _walk_expr_tree(dep, visited)
|
|
96
|
+
|
|
97
|
+
yield expr
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _get_expr_inputs(expr):
|
|
101
|
+
"""Get the direct array expression inputs to this expression."""
|
|
102
|
+
# Use dependencies() instead of operands - handles fused nodes correctly
|
|
103
|
+
return [dep for dep in expr.dependencies() if hasattr(dep, "chunks")]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def build_flow_graph(expr):
|
|
107
|
+
"""Build a flow graph from an expression tree.
|
|
108
|
+
|
|
109
|
+
Returns a tuple of (nodes, edges) where:
|
|
110
|
+
- nodes: list of FlowNode objects
|
|
111
|
+
- edges: list of FlowEdge objects
|
|
112
|
+
|
|
113
|
+
Nodes are grouped by shape, with consecutive same-shape operations
|
|
114
|
+
collapsed into a single node (only when they form a linear chain).
|
|
115
|
+
"""
|
|
116
|
+
# Collect all expressions
|
|
117
|
+
all_exprs = list(_walk_expr_tree(expr))
|
|
118
|
+
|
|
119
|
+
# Filter out intermediate reduction shapes (but never filter the root expression)
|
|
120
|
+
root_id = id(expr)
|
|
121
|
+
filtered_exprs = [e for e in all_exprs if id(e) == root_id or not _is_reduction_intermediate(e)]
|
|
122
|
+
|
|
123
|
+
if not filtered_exprs:
|
|
124
|
+
return [], []
|
|
125
|
+
|
|
126
|
+
# Build expression -> node mapping
|
|
127
|
+
# Key insight: we only merge operations into the same node if:
|
|
128
|
+
# 1. Same shape
|
|
129
|
+
# 2. Single input that's already in a node
|
|
130
|
+
# 3. That input node is the one we'd extend (linear chain)
|
|
131
|
+
expr_to_node = {}
|
|
132
|
+
nodes = []
|
|
133
|
+
|
|
134
|
+
for e in filtered_exprs:
|
|
135
|
+
shape = e.shape
|
|
136
|
+
inputs = _get_expr_inputs(e)
|
|
137
|
+
|
|
138
|
+
# Find which node(s) our inputs belong to
|
|
139
|
+
input_nodes = [expr_to_node.get(inp) for inp in inputs if inp in expr_to_node]
|
|
140
|
+
input_nodes = [n for n in input_nodes if n is not None]
|
|
141
|
+
|
|
142
|
+
# Can only extend if: single input, same shape, and that input's node
|
|
143
|
+
# has the same shape (indicating a linear chain)
|
|
144
|
+
can_extend = len(input_nodes) == 1 and len(inputs) == 1 and shape == input_nodes[0].shape
|
|
145
|
+
|
|
146
|
+
if can_extend:
|
|
147
|
+
# Extend the existing node
|
|
148
|
+
node = input_nodes[0]
|
|
149
|
+
node.operations.append(_get_operation_name(e))
|
|
150
|
+
node.expressions.append(e)
|
|
151
|
+
expr_to_node[e] = node
|
|
152
|
+
# Update nbytes to reflect latest expression
|
|
153
|
+
try:
|
|
154
|
+
node.nbytes = prod(shape) * e.dtype.itemsize
|
|
155
|
+
except Exception:
|
|
156
|
+
pass
|
|
157
|
+
else:
|
|
158
|
+
# Create a new node
|
|
159
|
+
try:
|
|
160
|
+
nbytes = prod(shape) * e.dtype.itemsize
|
|
161
|
+
except Exception:
|
|
162
|
+
nbytes = 0
|
|
163
|
+
|
|
164
|
+
node = FlowNode(
|
|
165
|
+
shape=shape,
|
|
166
|
+
chunks=e.chunks,
|
|
167
|
+
operations=[_get_operation_name(e)],
|
|
168
|
+
expressions=[e],
|
|
169
|
+
nbytes=nbytes,
|
|
170
|
+
)
|
|
171
|
+
nodes.append(node)
|
|
172
|
+
expr_to_node[e] = node
|
|
173
|
+
|
|
174
|
+
# Build edges based on expression dependencies
|
|
175
|
+
# We need to trace through filtered intermediates to find actual sources
|
|
176
|
+
edges = []
|
|
177
|
+
seen_edges = set()
|
|
178
|
+
|
|
179
|
+
def find_source_nodes(expr, visited=None):
|
|
180
|
+
"""Trace back through filtered expressions to find source nodes."""
|
|
181
|
+
if visited is None:
|
|
182
|
+
visited = set()
|
|
183
|
+
if id(expr) in visited:
|
|
184
|
+
return []
|
|
185
|
+
visited.add(id(expr))
|
|
186
|
+
|
|
187
|
+
node = expr_to_node.get(expr)
|
|
188
|
+
if node is not None:
|
|
189
|
+
return [node]
|
|
190
|
+
|
|
191
|
+
# This expression was filtered - look at its inputs
|
|
192
|
+
results = []
|
|
193
|
+
for inp in _get_expr_inputs(expr):
|
|
194
|
+
results.extend(find_source_nodes(inp, visited))
|
|
195
|
+
return results
|
|
196
|
+
|
|
197
|
+
for e in filtered_exprs:
|
|
198
|
+
target_node = expr_to_node.get(e)
|
|
199
|
+
if target_node is None:
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
for inp in _get_expr_inputs(e):
|
|
203
|
+
# Trace back through filtered intermediates
|
|
204
|
+
for source_node in find_source_nodes(inp):
|
|
205
|
+
if source_node != target_node:
|
|
206
|
+
edge_key = (id(source_node), id(target_node))
|
|
207
|
+
if edge_key not in seen_edges:
|
|
208
|
+
edges.append(FlowEdge(source=source_node, target=target_node))
|
|
209
|
+
seen_edges.add(edge_key)
|
|
210
|
+
|
|
211
|
+
# Assign row/column positions for layout
|
|
212
|
+
_assign_layout(nodes, edges)
|
|
213
|
+
|
|
214
|
+
return nodes, edges
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _assign_layout(nodes, edges):
|
|
218
|
+
"""Assign row and column positions to nodes for rendering.
|
|
219
|
+
|
|
220
|
+
Uses a simple algorithm:
|
|
221
|
+
- Nodes with no incoming edges start at column 0
|
|
222
|
+
- Each node's column = max(input columns) + 1
|
|
223
|
+
- Nodes at the same column are stacked in rows
|
|
224
|
+
"""
|
|
225
|
+
if not nodes:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
# Build adjacency info
|
|
229
|
+
node_inputs = {id(n): [] for n in nodes}
|
|
230
|
+
for edge in edges:
|
|
231
|
+
node_inputs[id(edge.target)].append(edge.source)
|
|
232
|
+
|
|
233
|
+
# Assign columns (topological order)
|
|
234
|
+
node_col = {}
|
|
235
|
+
for node in nodes:
|
|
236
|
+
inputs = node_inputs[id(node)]
|
|
237
|
+
if not inputs:
|
|
238
|
+
node_col[id(node)] = 0
|
|
239
|
+
else:
|
|
240
|
+
max_input_col = max(node_col.get(id(inp), 0) for inp in inputs)
|
|
241
|
+
node_col[id(node)] = max_input_col + 1
|
|
242
|
+
node.col = node_col[id(node)]
|
|
243
|
+
|
|
244
|
+
# Assign rows within each column
|
|
245
|
+
col_counts = {}
|
|
246
|
+
for node in nodes:
|
|
247
|
+
col = node.col
|
|
248
|
+
row = col_counts.get(col, 0)
|
|
249
|
+
node.row = row
|
|
250
|
+
col_counts[col] = row + 1
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def count_operations(expr) -> int:
|
|
254
|
+
"""Count total operations in an expression tree."""
|
|
255
|
+
return len(list(_walk_expr_tree(expr)))
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _format_bytes(nbytes: int) -> str:
|
|
259
|
+
"""Format bytes with 2 significant figures."""
|
|
260
|
+
for unit, threshold in [
|
|
261
|
+
("PiB", 2**50),
|
|
262
|
+
("TiB", 2**40),
|
|
263
|
+
("GiB", 2**30),
|
|
264
|
+
("MiB", 2**20),
|
|
265
|
+
("kiB", 2**10),
|
|
266
|
+
]:
|
|
267
|
+
if nbytes >= threshold:
|
|
268
|
+
value = nbytes / threshold
|
|
269
|
+
if value >= 10:
|
|
270
|
+
return f"{value:.0f} {unit}"
|
|
271
|
+
else:
|
|
272
|
+
return f"{value:.1f} {unit}"
|
|
273
|
+
return f"{nbytes} B"
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _format_shape(shape: tuple) -> str:
|
|
277
|
+
"""Format shape tuple for display."""
|
|
278
|
+
if not shape:
|
|
279
|
+
return "scalar"
|
|
280
|
+
return f"({', '.join(str(s) for s in shape)})"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# Card dimensions (fixed for consistency)
|
|
284
|
+
CARD_WIDTH = 130
|
|
285
|
+
CARD_HEIGHT = 150
|
|
286
|
+
CARD_SVG_REGION = 55 # Available space for SVG in card
|
|
287
|
+
CARD_GAP = 20
|
|
288
|
+
ARROW_WIDTH = 50
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _compute_emphasis(nodes, threshold: float = 0.5) -> dict:
|
|
292
|
+
"""Compute which nodes should be emphasized based on array size.
|
|
293
|
+
|
|
294
|
+
Returns a dict mapping node id to bool (True = emphasize).
|
|
295
|
+
Nodes with nbytes > threshold * max_nbytes are emphasized.
|
|
296
|
+
"""
|
|
297
|
+
valid_bytes = [n.nbytes for n in nodes if n.nbytes > 0]
|
|
298
|
+
if not valid_bytes:
|
|
299
|
+
return {id(n): True for n in nodes}
|
|
300
|
+
|
|
301
|
+
max_bytes = max(valid_bytes)
|
|
302
|
+
if max_bytes <= 0:
|
|
303
|
+
return {id(n): True for n in nodes}
|
|
304
|
+
|
|
305
|
+
return {id(n): n.nbytes > threshold * max_bytes for n in nodes}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def render_flow_svg(expr) -> str:
|
|
309
|
+
"""Render expression flow as an SVG diagram with card-based layout.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
expr : ArrayExpr
|
|
314
|
+
The expression to visualize
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
str
|
|
319
|
+
HTML with embedded SVG showing the data flow
|
|
320
|
+
"""
|
|
321
|
+
nodes, edges = build_flow_graph(expr)
|
|
322
|
+
if not nodes:
|
|
323
|
+
return "<div>Empty expression</div>"
|
|
324
|
+
|
|
325
|
+
max_col = max(n.col for n in nodes) + 1
|
|
326
|
+
max_row = max(n.row for n in nodes) + 1
|
|
327
|
+
|
|
328
|
+
# Compute which nodes to emphasize
|
|
329
|
+
emphasis = _compute_emphasis(nodes)
|
|
330
|
+
|
|
331
|
+
# Compute global max dimension for consistent scaling across all SVGs
|
|
332
|
+
all_shapes = [n.shape for n in nodes if n.shape]
|
|
333
|
+
global_max_dim = max(max(s) for s in all_shapes) if all_shapes else 1
|
|
334
|
+
|
|
335
|
+
# Group nodes by column
|
|
336
|
+
cols = {}
|
|
337
|
+
for node in nodes:
|
|
338
|
+
cols.setdefault(node.col, []).append(node)
|
|
339
|
+
|
|
340
|
+
# Calculate SVG dimensions
|
|
341
|
+
padding = 24
|
|
342
|
+
col_width = CARD_WIDTH + ARROW_WIDTH
|
|
343
|
+
row_height = CARD_HEIGHT + CARD_GAP
|
|
344
|
+
svg_width = max_col * col_width - ARROW_WIDTH + 2 * padding
|
|
345
|
+
svg_height = max_row * row_height - CARD_GAP + 2 * padding
|
|
346
|
+
|
|
347
|
+
# Build node position map (center of each card)
|
|
348
|
+
node_positions = {}
|
|
349
|
+
for node in nodes:
|
|
350
|
+
x = padding + node.col * col_width + CARD_WIDTH / 2
|
|
351
|
+
y = padding + node.row * row_height + CARD_HEIGHT / 2
|
|
352
|
+
node_positions[id(node)] = (x, y)
|
|
353
|
+
|
|
354
|
+
# Start SVG
|
|
355
|
+
svg_parts = [
|
|
356
|
+
f'<svg width="{svg_width}" height="{svg_height}" '
|
|
357
|
+
f'style="font-family: system-ui;" xmlns="http://www.w3.org/2000/svg">'
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
# Add styles using JupyterLab CSS variables with light-mode fallbacks
|
|
361
|
+
svg_parts.append("""<style>
|
|
362
|
+
.flow-card { fill: var(--jp-layout-color1, #fafaf9); }
|
|
363
|
+
.flow-card-emphasized { fill: var(--jp-layout-color2, #fff7ed); }
|
|
364
|
+
.flow-card-border { stroke: var(--jp-border-color1, #d6d3d1); }
|
|
365
|
+
.flow-card-border-emphasized { stroke: #fb923c; }
|
|
366
|
+
.flow-text-title { fill: var(--jp-ui-font-color1, #44403c); }
|
|
367
|
+
.flow-text-info { fill: var(--jp-ui-font-color2, #57534e); }
|
|
368
|
+
.flow-text-secondary { fill: var(--jp-ui-font-color3, #a8a29e); }
|
|
369
|
+
.flow-divider { stroke: var(--jp-border-color2, #e7e5e4); }
|
|
370
|
+
.flow-arrow-line { stroke: var(--jp-ui-font-color3, #a8a29e); }
|
|
371
|
+
.flow-arrow-head { fill: var(--jp-ui-font-color3, #a8a29e); }
|
|
372
|
+
.flow-arrow-path { stroke: var(--jp-ui-font-color3, #a8a29e); fill: none; }
|
|
373
|
+
</style>""")
|
|
374
|
+
|
|
375
|
+
# Draw arrows first (so they appear behind cards)
|
|
376
|
+
for edge in edges:
|
|
377
|
+
src_x, src_y = node_positions[id(edge.source)]
|
|
378
|
+
tgt_x, tgt_y = node_positions[id(edge.target)]
|
|
379
|
+
|
|
380
|
+
# Arrow from right edge of source to left edge of target
|
|
381
|
+
x1 = src_x + CARD_WIDTH / 2 + 4
|
|
382
|
+
y1 = src_y
|
|
383
|
+
x2 = tgt_x - CARD_WIDTH / 2 - 4
|
|
384
|
+
y2 = tgt_y
|
|
385
|
+
|
|
386
|
+
# Calculate column span - long arrows get vertical offset
|
|
387
|
+
col_span = edge.target.col - edge.source.col
|
|
388
|
+
|
|
389
|
+
if abs(y1 - y2) < 5:
|
|
390
|
+
if col_span <= 1:
|
|
391
|
+
# Simple horizontal arrow - straight line with arrowhead
|
|
392
|
+
svg_parts.append(
|
|
393
|
+
f'<line class="flow-arrow-line" x1="{x1}" y1="{y1}" x2="{x2 - 8}" y2="{y2}" stroke="#a8a29e" stroke-width="2"/>'
|
|
394
|
+
)
|
|
395
|
+
svg_parts.append(
|
|
396
|
+
f'<polygon class="flow-arrow-head" points="{x2},{y2} {x2 - 8},{y2 - 4} {x2 - 8},{y2 + 4}" fill="#a8a29e"/>'
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
# Long-span arrow - horizontal line offset below, cards cover the middle
|
|
400
|
+
y_offset = 12 + (col_span - 2) * 6
|
|
401
|
+
y_line = y1 + y_offset
|
|
402
|
+
svg_parts.append(
|
|
403
|
+
f'<line class="flow-arrow-line" x1="{x1}" y1="{y_line}" x2="{x2 - 8}" y2="{y_line}" stroke="#a8a29e" stroke-width="2"/>'
|
|
404
|
+
)
|
|
405
|
+
svg_parts.append(
|
|
406
|
+
f'<polygon class="flow-arrow-head" points="{x2},{y_line} {x2 - 8},{y_line - 4} {x2 - 8},{y_line + 4}" fill="#a8a29e"/>'
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
# Curved arrow for cross-row connections - use dot instead of arrowhead
|
|
410
|
+
mid_x = (x1 + x2) / 2
|
|
411
|
+
svg_parts.append(
|
|
412
|
+
f'<path class="flow-arrow-path" d="M {x1} {y1} C {mid_x} {y1}, {mid_x} {y2}, {x2} {y2}" '
|
|
413
|
+
f'stroke="#a8a29e" stroke-width="2" fill="none"/>'
|
|
414
|
+
)
|
|
415
|
+
svg_parts.append(f'<circle class="flow-arrow-head" cx="{x2}" cy="{y2}" r="4" fill="#a8a29e"/>')
|
|
416
|
+
|
|
417
|
+
# Draw cards - use consistent SVG size across all cards
|
|
418
|
+
for node in nodes:
|
|
419
|
+
cx, cy = node_positions[id(node)]
|
|
420
|
+
card_x = cx - CARD_WIDTH / 2
|
|
421
|
+
card_y = cy - CARD_HEIGHT / 2
|
|
422
|
+
emphasized = emphasis.get(id(node), False)
|
|
423
|
+
svg_parts.append(_render_card(node, card_x, card_y, emphasized, global_max_dim))
|
|
424
|
+
|
|
425
|
+
svg_parts.append("</svg>")
|
|
426
|
+
return "\n".join(svg_parts)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _render_card(node: FlowNode, x: float, y: float, emphasized: bool, global_max_dim: int) -> str:
|
|
430
|
+
"""Render a single flow node as an SVG card."""
|
|
431
|
+
from dask_array._svg import svg, ratio_response
|
|
432
|
+
|
|
433
|
+
parts = []
|
|
434
|
+
|
|
435
|
+
# Card styling based on emphasis
|
|
436
|
+
if emphasized:
|
|
437
|
+
# Emphasized: visible border, warm orange tint for large arrays
|
|
438
|
+
fill = "#fff7ed" # orange-50 - noticeable warm tint
|
|
439
|
+
stroke = "#fb923c" # orange-400 - matches array color
|
|
440
|
+
stroke_width = "2"
|
|
441
|
+
card_class = "flow-card-emphasized flow-card-border-emphasized"
|
|
442
|
+
title_class = "flow-text-title"
|
|
443
|
+
info_class = "flow-text-info"
|
|
444
|
+
secondary_class = "flow-text-secondary"
|
|
445
|
+
else:
|
|
446
|
+
# Normal: subtle gray background, visible border
|
|
447
|
+
fill = "#fafaf9" # stone-50 - subtle off-white
|
|
448
|
+
stroke = "#d6d3d1" # stone-300 - visible but not harsh
|
|
449
|
+
stroke_width = "1"
|
|
450
|
+
card_class = "flow-card flow-card-border"
|
|
451
|
+
title_class = "flow-text-title"
|
|
452
|
+
info_class = "flow-text-info"
|
|
453
|
+
secondary_class = "flow-text-secondary"
|
|
454
|
+
|
|
455
|
+
parts.append(
|
|
456
|
+
f'<rect class="{card_class}" x="{x}" y="{y}" width="{CARD_WIDTH}" height="{CARD_HEIGHT}" '
|
|
457
|
+
f'rx="6" fill="{fill}" stroke="{stroke}" stroke-width="{stroke_width}"/>'
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Text colors based on emphasis
|
|
461
|
+
title_color = "#44403c" if emphasized else "#78716c"
|
|
462
|
+
info_color = "#57534e" if emphasized else "#a8a29e"
|
|
463
|
+
secondary_color = "#a8a29e" if emphasized else "#d6d3d1"
|
|
464
|
+
|
|
465
|
+
# Operation name at top (title)
|
|
466
|
+
ops = node.operations
|
|
467
|
+
if len(ops) > 2:
|
|
468
|
+
ops_str = f"{ops[0]} → {ops[-1]}"
|
|
469
|
+
elif len(ops) == 2:
|
|
470
|
+
ops_str = f"{ops[0]} → {ops[1]}"
|
|
471
|
+
else:
|
|
472
|
+
ops_str = ops[0] if ops else ""
|
|
473
|
+
|
|
474
|
+
# Truncate if too long
|
|
475
|
+
if len(ops_str) > 18:
|
|
476
|
+
ops_str = ops_str[:16] + "…"
|
|
477
|
+
|
|
478
|
+
parts.append(
|
|
479
|
+
f'<text class="{title_class}" x="{x + CARD_WIDTH / 2}" y="{y + 20}" '
|
|
480
|
+
f'text-anchor="middle" font-size="11" font-weight="600" fill="{title_color}">'
|
|
481
|
+
f"{ops_str}</text>"
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Divider line
|
|
485
|
+
parts.append(
|
|
486
|
+
f'<line class="flow-divider" x1="{x + 10}" y1="{y + 30}" x2="{x + CARD_WIDTH - 10}" y2="{y + 30}" '
|
|
487
|
+
f'stroke="#e7e5e4" stroke-width="1"/>'
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# SVG visualization (centered in middle region)
|
|
491
|
+
svg_y = y + 35
|
|
492
|
+
svg_region_height = 70
|
|
493
|
+
svg_region_width = CARD_WIDTH - 20
|
|
494
|
+
try:
|
|
495
|
+
if node.chunks and all(node.chunks):
|
|
496
|
+
# Compute sizes using global reference so dimensions are comparable across cards
|
|
497
|
+
shape = node.shape
|
|
498
|
+
# Ratio of global max to each dimension, with logarithmic compression
|
|
499
|
+
ratios = [global_max_dim / max(0.1, d) for d in shape]
|
|
500
|
+
ratios = [ratio_response(r) for r in ratios]
|
|
501
|
+
sizes = tuple(CARD_SVG_REGION / r for r in ratios)
|
|
502
|
+
|
|
503
|
+
node_svg = svg(node.chunks, size=CARD_SVG_REGION, sizes=sizes, labels=False)
|
|
504
|
+
parts.append(
|
|
505
|
+
f'<foreignObject x="{x + 10}" y="{svg_y}" width="{svg_region_width}" height="{svg_region_height}">'
|
|
506
|
+
f'<div xmlns="http://www.w3.org/1999/xhtml" style="display:flex;justify-content:center;align-items:center;height:100%;overflow:hidden;">'
|
|
507
|
+
f"{node_svg}"
|
|
508
|
+
f"</div></foreignObject>"
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
# Scalar - show small circle
|
|
512
|
+
cx = x + CARD_WIDTH / 2
|
|
513
|
+
cy = svg_y + svg_region_height / 2
|
|
514
|
+
parts.append(f'<circle cx="{cx}" cy="{cy}" r="8" fill="#fb923c" fill-opacity="0.7"/>')
|
|
515
|
+
except (NotImplementedError, ValueError):
|
|
516
|
+
# Fallback - empty area
|
|
517
|
+
pass
|
|
518
|
+
|
|
519
|
+
# Divider line before info section
|
|
520
|
+
parts.append(
|
|
521
|
+
f'<line class="flow-divider" x1="{x + 10}" y1="{y + 110}" x2="{x + CARD_WIDTH - 10}" y2="{y + 110}" '
|
|
522
|
+
f'stroke="#e7e5e4" stroke-width="1"/>'
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Bottom section: shape and bytes, left-aligned
|
|
526
|
+
shape_str = _format_shape(node.shape)
|
|
527
|
+
bytes_str = _format_bytes(node.nbytes) if node.nbytes > 0 else ""
|
|
528
|
+
left_margin = x + 12
|
|
529
|
+
|
|
530
|
+
parts.append(
|
|
531
|
+
f'<text class="{info_class}" x="{left_margin}" y="{y + 128}" '
|
|
532
|
+
f'text-anchor="start" font-size="10" fill="{info_color}">'
|
|
533
|
+
f"{shape_str}</text>"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
if bytes_str:
|
|
537
|
+
# Same font size as shape, but bold for emphasized (large) arrays
|
|
538
|
+
bytes_weight = 'font-weight="600"' if emphasized else ""
|
|
539
|
+
parts.append(
|
|
540
|
+
f'<text class="{secondary_class}" x="{left_margin}" y="{y + 142}" '
|
|
541
|
+
f'text-anchor="start" font-size="10" {bytes_weight} fill="{secondary_color}">'
|
|
542
|
+
f"{bytes_str}</text>"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
return "\n".join(parts)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class FlowDiagram:
|
|
549
|
+
"""Wrapper for flow diagram with Jupyter and terminal display support."""
|
|
550
|
+
|
|
551
|
+
def __init__(self, expr):
|
|
552
|
+
self._expr = expr
|
|
553
|
+
self._html_cache = None
|
|
554
|
+
|
|
555
|
+
def _repr_html_(self) -> str:
|
|
556
|
+
"""Jupyter notebook display."""
|
|
557
|
+
if self._html_cache is None:
|
|
558
|
+
self._html_cache = render_flow_svg(self._expr)
|
|
559
|
+
return self._html_cache
|
|
560
|
+
|
|
561
|
+
def __repr__(self) -> str:
|
|
562
|
+
"""Terminal display - show summary."""
|
|
563
|
+
nodes, edges = build_flow_graph(self._expr)
|
|
564
|
+
n_ops = count_operations(self._expr)
|
|
565
|
+
shapes = [n.shape for n in nodes]
|
|
566
|
+
shape_str = " → ".join(str(s) for s in shapes)
|
|
567
|
+
return f"Expression: {n_ops} operations, {len(nodes)} shape(s): {shape_str}"
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def expr_flow(expr) -> FlowDiagram:
|
|
571
|
+
"""Create a flow diagram visualization of an expression.
|
|
572
|
+
|
|
573
|
+
Parameters
|
|
574
|
+
----------
|
|
575
|
+
expr : ArrayExpr or Array
|
|
576
|
+
The expression or array to visualize
|
|
577
|
+
|
|
578
|
+
Returns
|
|
579
|
+
-------
|
|
580
|
+
FlowDiagram
|
|
581
|
+
A displayable flow diagram (works in Jupyter and terminal)
|
|
582
|
+
"""
|
|
583
|
+
# Handle both Array and ArrayExpr
|
|
584
|
+
if hasattr(expr, "_expr"):
|
|
585
|
+
expr = expr._expr
|
|
586
|
+
return FlowDiagram(expr)
|