dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,586 @@
1
+ """Expression flow visualization - shows data transformation pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from math import prod
7
+
8
+ from dask.utils import funcname
9
+
10
+
11
+ @dataclass
12
+ class FlowNode:
13
+ """A node in the flow visualization representing a unique shape."""
14
+
15
+ shape: tuple
16
+ chunks: tuple
17
+ operations: list[str] = field(default_factory=list)
18
+ expressions: list = field(default_factory=list) # Original expr objects
19
+ nbytes: int = 0
20
+ row: int = 0
21
+ col: int = 0
22
+
23
+ @property
24
+ def ndim(self):
25
+ return len(self.shape)
26
+
27
+
28
+ @dataclass
29
+ class FlowEdge:
30
+ """Connection between flow nodes."""
31
+
32
+ source: FlowNode
33
+ target: FlowNode
34
+
35
+
36
+ def _get_operation_name(expr) -> str:
37
+ """Get a user-friendly operation name from an expression."""
38
+ class_name = funcname(type(expr))
39
+
40
+ # Special cases for nicer names
41
+ if class_name == "FromArray":
42
+ return "Load"
43
+
44
+ # Try to extract meaningful name from _name attribute
45
+ if hasattr(expr, "_name"):
46
+ name = expr._name
47
+ if "-" in name:
48
+ prefix = name.rsplit("-", 1)[0]
49
+ # Clean up common patterns
50
+ prefix = prefix.replace("_", " ")
51
+ prefix = prefix.replace("-aggregate", "")
52
+ prefix = prefix.replace("-partial", "")
53
+ prefix = prefix.strip().replace("-", " ")
54
+ if prefix:
55
+ # Capitalize nicely
56
+ return prefix.title()
57
+
58
+ return class_name
59
+
60
+
61
+ def _is_reduction_intermediate(expr) -> bool:
62
+ """Check if this expression is an intermediate reduction shape.
63
+
64
+ Conservative filter - only catches the most obvious tree_reduce intermediates.
65
+ """
66
+ from dask_array.reductions._reduction import PartialReduce
67
+
68
+ if not isinstance(expr, PartialReduce):
69
+ return False
70
+
71
+ shape = expr.shape
72
+ if not shape:
73
+ return False
74
+
75
+ # Only filter if ALL dimensions are small (clearly chunk counts, not user data)
76
+ CHUNK_COUNT_MAX = 16
77
+ return all(0 < d <= CHUNK_COUNT_MAX for d in shape)
78
+
79
+
80
+ def _walk_expr_tree(expr, visited=None):
81
+ """Walk expression tree depth-first, yielding expressions from leaves to root."""
82
+ if visited is None:
83
+ visited = set()
84
+
85
+ expr_id = id(expr)
86
+ if expr_id in visited:
87
+ return
88
+ visited.add(expr_id)
89
+
90
+ # Get array dependencies
91
+ deps = [op for op in expr.dependencies() if hasattr(op, "chunks")]
92
+
93
+ # Visit children first (leaves to root)
94
+ for dep in deps:
95
+ yield from _walk_expr_tree(dep, visited)
96
+
97
+ yield expr
98
+
99
+
100
+ def _get_expr_inputs(expr):
101
+ """Get the direct array expression inputs to this expression."""
102
+ # Use dependencies() instead of operands - handles fused nodes correctly
103
+ return [dep for dep in expr.dependencies() if hasattr(dep, "chunks")]
104
+
105
+
106
+ def build_flow_graph(expr):
107
+ """Build a flow graph from an expression tree.
108
+
109
+ Returns a tuple of (nodes, edges) where:
110
+ - nodes: list of FlowNode objects
111
+ - edges: list of FlowEdge objects
112
+
113
+ Nodes are grouped by shape, with consecutive same-shape operations
114
+ collapsed into a single node (only when they form a linear chain).
115
+ """
116
+ # Collect all expressions
117
+ all_exprs = list(_walk_expr_tree(expr))
118
+
119
+ # Filter out intermediate reduction shapes (but never filter the root expression)
120
+ root_id = id(expr)
121
+ filtered_exprs = [e for e in all_exprs if id(e) == root_id or not _is_reduction_intermediate(e)]
122
+
123
+ if not filtered_exprs:
124
+ return [], []
125
+
126
+ # Build expression -> node mapping
127
+ # Key insight: we only merge operations into the same node if:
128
+ # 1. Same shape
129
+ # 2. Single input that's already in a node
130
+ # 3. That input node is the one we'd extend (linear chain)
131
+ expr_to_node = {}
132
+ nodes = []
133
+
134
+ for e in filtered_exprs:
135
+ shape = e.shape
136
+ inputs = _get_expr_inputs(e)
137
+
138
+ # Find which node(s) our inputs belong to
139
+ input_nodes = [expr_to_node.get(inp) for inp in inputs if inp in expr_to_node]
140
+ input_nodes = [n for n in input_nodes if n is not None]
141
+
142
+ # Can only extend if: single input, same shape, and that input's node
143
+ # has the same shape (indicating a linear chain)
144
+ can_extend = len(input_nodes) == 1 and len(inputs) == 1 and shape == input_nodes[0].shape
145
+
146
+ if can_extend:
147
+ # Extend the existing node
148
+ node = input_nodes[0]
149
+ node.operations.append(_get_operation_name(e))
150
+ node.expressions.append(e)
151
+ expr_to_node[e] = node
152
+ # Update nbytes to reflect latest expression
153
+ try:
154
+ node.nbytes = prod(shape) * e.dtype.itemsize
155
+ except Exception:
156
+ pass
157
+ else:
158
+ # Create a new node
159
+ try:
160
+ nbytes = prod(shape) * e.dtype.itemsize
161
+ except Exception:
162
+ nbytes = 0
163
+
164
+ node = FlowNode(
165
+ shape=shape,
166
+ chunks=e.chunks,
167
+ operations=[_get_operation_name(e)],
168
+ expressions=[e],
169
+ nbytes=nbytes,
170
+ )
171
+ nodes.append(node)
172
+ expr_to_node[e] = node
173
+
174
+ # Build edges based on expression dependencies
175
+ # We need to trace through filtered intermediates to find actual sources
176
+ edges = []
177
+ seen_edges = set()
178
+
179
+ def find_source_nodes(expr, visited=None):
180
+ """Trace back through filtered expressions to find source nodes."""
181
+ if visited is None:
182
+ visited = set()
183
+ if id(expr) in visited:
184
+ return []
185
+ visited.add(id(expr))
186
+
187
+ node = expr_to_node.get(expr)
188
+ if node is not None:
189
+ return [node]
190
+
191
+ # This expression was filtered - look at its inputs
192
+ results = []
193
+ for inp in _get_expr_inputs(expr):
194
+ results.extend(find_source_nodes(inp, visited))
195
+ return results
196
+
197
+ for e in filtered_exprs:
198
+ target_node = expr_to_node.get(e)
199
+ if target_node is None:
200
+ continue
201
+
202
+ for inp in _get_expr_inputs(e):
203
+ # Trace back through filtered intermediates
204
+ for source_node in find_source_nodes(inp):
205
+ if source_node != target_node:
206
+ edge_key = (id(source_node), id(target_node))
207
+ if edge_key not in seen_edges:
208
+ edges.append(FlowEdge(source=source_node, target=target_node))
209
+ seen_edges.add(edge_key)
210
+
211
+ # Assign row/column positions for layout
212
+ _assign_layout(nodes, edges)
213
+
214
+ return nodes, edges
215
+
216
+
217
+ def _assign_layout(nodes, edges):
218
+ """Assign row and column positions to nodes for rendering.
219
+
220
+ Uses a simple algorithm:
221
+ - Nodes with no incoming edges start at column 0
222
+ - Each node's column = max(input columns) + 1
223
+ - Nodes at the same column are stacked in rows
224
+ """
225
+ if not nodes:
226
+ return
227
+
228
+ # Build adjacency info
229
+ node_inputs = {id(n): [] for n in nodes}
230
+ for edge in edges:
231
+ node_inputs[id(edge.target)].append(edge.source)
232
+
233
+ # Assign columns (topological order)
234
+ node_col = {}
235
+ for node in nodes:
236
+ inputs = node_inputs[id(node)]
237
+ if not inputs:
238
+ node_col[id(node)] = 0
239
+ else:
240
+ max_input_col = max(node_col.get(id(inp), 0) for inp in inputs)
241
+ node_col[id(node)] = max_input_col + 1
242
+ node.col = node_col[id(node)]
243
+
244
+ # Assign rows within each column
245
+ col_counts = {}
246
+ for node in nodes:
247
+ col = node.col
248
+ row = col_counts.get(col, 0)
249
+ node.row = row
250
+ col_counts[col] = row + 1
251
+
252
+
253
+ def count_operations(expr) -> int:
254
+ """Count total operations in an expression tree."""
255
+ return len(list(_walk_expr_tree(expr)))
256
+
257
+
258
+ def _format_bytes(nbytes: int) -> str:
259
+ """Format bytes with 2 significant figures."""
260
+ for unit, threshold in [
261
+ ("PiB", 2**50),
262
+ ("TiB", 2**40),
263
+ ("GiB", 2**30),
264
+ ("MiB", 2**20),
265
+ ("kiB", 2**10),
266
+ ]:
267
+ if nbytes >= threshold:
268
+ value = nbytes / threshold
269
+ if value >= 10:
270
+ return f"{value:.0f} {unit}"
271
+ else:
272
+ return f"{value:.1f} {unit}"
273
+ return f"{nbytes} B"
274
+
275
+
276
+ def _format_shape(shape: tuple) -> str:
277
+ """Format shape tuple for display."""
278
+ if not shape:
279
+ return "scalar"
280
+ return f"({', '.join(str(s) for s in shape)})"
281
+
282
+
283
+ # Card dimensions (fixed for consistency)
284
+ CARD_WIDTH = 130
285
+ CARD_HEIGHT = 150
286
+ CARD_SVG_REGION = 55 # Available space for SVG in card
287
+ CARD_GAP = 20
288
+ ARROW_WIDTH = 50
289
+
290
+
291
+ def _compute_emphasis(nodes, threshold: float = 0.5) -> dict:
292
+ """Compute which nodes should be emphasized based on array size.
293
+
294
+ Returns a dict mapping node id to bool (True = emphasize).
295
+ Nodes with nbytes > threshold * max_nbytes are emphasized.
296
+ """
297
+ valid_bytes = [n.nbytes for n in nodes if n.nbytes > 0]
298
+ if not valid_bytes:
299
+ return {id(n): True for n in nodes}
300
+
301
+ max_bytes = max(valid_bytes)
302
+ if max_bytes <= 0:
303
+ return {id(n): True for n in nodes}
304
+
305
+ return {id(n): n.nbytes > threshold * max_bytes for n in nodes}
306
+
307
+
308
+ def render_flow_svg(expr) -> str:
309
+ """Render expression flow as an SVG diagram with card-based layout.
310
+
311
+ Parameters
312
+ ----------
313
+ expr : ArrayExpr
314
+ The expression to visualize
315
+
316
+ Returns
317
+ -------
318
+ str
319
+ HTML with embedded SVG showing the data flow
320
+ """
321
+ nodes, edges = build_flow_graph(expr)
322
+ if not nodes:
323
+ return "<div>Empty expression</div>"
324
+
325
+ max_col = max(n.col for n in nodes) + 1
326
+ max_row = max(n.row for n in nodes) + 1
327
+
328
+ # Compute which nodes to emphasize
329
+ emphasis = _compute_emphasis(nodes)
330
+
331
+ # Compute global max dimension for consistent scaling across all SVGs
332
+ all_shapes = [n.shape for n in nodes if n.shape]
333
+ global_max_dim = max(max(s) for s in all_shapes) if all_shapes else 1
334
+
335
+ # Group nodes by column
336
+ cols = {}
337
+ for node in nodes:
338
+ cols.setdefault(node.col, []).append(node)
339
+
340
+ # Calculate SVG dimensions
341
+ padding = 24
342
+ col_width = CARD_WIDTH + ARROW_WIDTH
343
+ row_height = CARD_HEIGHT + CARD_GAP
344
+ svg_width = max_col * col_width - ARROW_WIDTH + 2 * padding
345
+ svg_height = max_row * row_height - CARD_GAP + 2 * padding
346
+
347
+ # Build node position map (center of each card)
348
+ node_positions = {}
349
+ for node in nodes:
350
+ x = padding + node.col * col_width + CARD_WIDTH / 2
351
+ y = padding + node.row * row_height + CARD_HEIGHT / 2
352
+ node_positions[id(node)] = (x, y)
353
+
354
+ # Start SVG
355
+ svg_parts = [
356
+ f'<svg width="{svg_width}" height="{svg_height}" '
357
+ f'style="font-family: system-ui;" xmlns="http://www.w3.org/2000/svg">'
358
+ ]
359
+
360
+ # Add styles using JupyterLab CSS variables with light-mode fallbacks
361
+ svg_parts.append("""<style>
362
+ .flow-card { fill: var(--jp-layout-color1, #fafaf9); }
363
+ .flow-card-emphasized { fill: var(--jp-layout-color2, #fff7ed); }
364
+ .flow-card-border { stroke: var(--jp-border-color1, #d6d3d1); }
365
+ .flow-card-border-emphasized { stroke: #fb923c; }
366
+ .flow-text-title { fill: var(--jp-ui-font-color1, #44403c); }
367
+ .flow-text-info { fill: var(--jp-ui-font-color2, #57534e); }
368
+ .flow-text-secondary { fill: var(--jp-ui-font-color3, #a8a29e); }
369
+ .flow-divider { stroke: var(--jp-border-color2, #e7e5e4); }
370
+ .flow-arrow-line { stroke: var(--jp-ui-font-color3, #a8a29e); }
371
+ .flow-arrow-head { fill: var(--jp-ui-font-color3, #a8a29e); }
372
+ .flow-arrow-path { stroke: var(--jp-ui-font-color3, #a8a29e); fill: none; }
373
+ </style>""")
374
+
375
+ # Draw arrows first (so they appear behind cards)
376
+ for edge in edges:
377
+ src_x, src_y = node_positions[id(edge.source)]
378
+ tgt_x, tgt_y = node_positions[id(edge.target)]
379
+
380
+ # Arrow from right edge of source to left edge of target
381
+ x1 = src_x + CARD_WIDTH / 2 + 4
382
+ y1 = src_y
383
+ x2 = tgt_x - CARD_WIDTH / 2 - 4
384
+ y2 = tgt_y
385
+
386
+ # Calculate column span - long arrows get vertical offset
387
+ col_span = edge.target.col - edge.source.col
388
+
389
+ if abs(y1 - y2) < 5:
390
+ if col_span <= 1:
391
+ # Simple horizontal arrow - straight line with arrowhead
392
+ svg_parts.append(
393
+ f'<line class="flow-arrow-line" x1="{x1}" y1="{y1}" x2="{x2 - 8}" y2="{y2}" stroke="#a8a29e" stroke-width="2"/>'
394
+ )
395
+ svg_parts.append(
396
+ f'<polygon class="flow-arrow-head" points="{x2},{y2} {x2 - 8},{y2 - 4} {x2 - 8},{y2 + 4}" fill="#a8a29e"/>'
397
+ )
398
+ else:
399
+ # Long-span arrow - horizontal line offset below, cards cover the middle
400
+ y_offset = 12 + (col_span - 2) * 6
401
+ y_line = y1 + y_offset
402
+ svg_parts.append(
403
+ f'<line class="flow-arrow-line" x1="{x1}" y1="{y_line}" x2="{x2 - 8}" y2="{y_line}" stroke="#a8a29e" stroke-width="2"/>'
404
+ )
405
+ svg_parts.append(
406
+ f'<polygon class="flow-arrow-head" points="{x2},{y_line} {x2 - 8},{y_line - 4} {x2 - 8},{y_line + 4}" fill="#a8a29e"/>'
407
+ )
408
+ else:
409
+ # Curved arrow for cross-row connections - use dot instead of arrowhead
410
+ mid_x = (x1 + x2) / 2
411
+ svg_parts.append(
412
+ f'<path class="flow-arrow-path" d="M {x1} {y1} C {mid_x} {y1}, {mid_x} {y2}, {x2} {y2}" '
413
+ f'stroke="#a8a29e" stroke-width="2" fill="none"/>'
414
+ )
415
+ svg_parts.append(f'<circle class="flow-arrow-head" cx="{x2}" cy="{y2}" r="4" fill="#a8a29e"/>')
416
+
417
+ # Draw cards - use consistent SVG size across all cards
418
+ for node in nodes:
419
+ cx, cy = node_positions[id(node)]
420
+ card_x = cx - CARD_WIDTH / 2
421
+ card_y = cy - CARD_HEIGHT / 2
422
+ emphasized = emphasis.get(id(node), False)
423
+ svg_parts.append(_render_card(node, card_x, card_y, emphasized, global_max_dim))
424
+
425
+ svg_parts.append("</svg>")
426
+ return "\n".join(svg_parts)
427
+
428
+
429
+ def _render_card(node: FlowNode, x: float, y: float, emphasized: bool, global_max_dim: int) -> str:
430
+ """Render a single flow node as an SVG card."""
431
+ from dask_array._svg import svg, ratio_response
432
+
433
+ parts = []
434
+
435
+ # Card styling based on emphasis
436
+ if emphasized:
437
+ # Emphasized: visible border, warm orange tint for large arrays
438
+ fill = "#fff7ed" # orange-50 - noticeable warm tint
439
+ stroke = "#fb923c" # orange-400 - matches array color
440
+ stroke_width = "2"
441
+ card_class = "flow-card-emphasized flow-card-border-emphasized"
442
+ title_class = "flow-text-title"
443
+ info_class = "flow-text-info"
444
+ secondary_class = "flow-text-secondary"
445
+ else:
446
+ # Normal: subtle gray background, visible border
447
+ fill = "#fafaf9" # stone-50 - subtle off-white
448
+ stroke = "#d6d3d1" # stone-300 - visible but not harsh
449
+ stroke_width = "1"
450
+ card_class = "flow-card flow-card-border"
451
+ title_class = "flow-text-title"
452
+ info_class = "flow-text-info"
453
+ secondary_class = "flow-text-secondary"
454
+
455
+ parts.append(
456
+ f'<rect class="{card_class}" x="{x}" y="{y}" width="{CARD_WIDTH}" height="{CARD_HEIGHT}" '
457
+ f'rx="6" fill="{fill}" stroke="{stroke}" stroke-width="{stroke_width}"/>'
458
+ )
459
+
460
+ # Text colors based on emphasis
461
+ title_color = "#44403c" if emphasized else "#78716c"
462
+ info_color = "#57534e" if emphasized else "#a8a29e"
463
+ secondary_color = "#a8a29e" if emphasized else "#d6d3d1"
464
+
465
+ # Operation name at top (title)
466
+ ops = node.operations
467
+ if len(ops) > 2:
468
+ ops_str = f"{ops[0]} → {ops[-1]}"
469
+ elif len(ops) == 2:
470
+ ops_str = f"{ops[0]} → {ops[1]}"
471
+ else:
472
+ ops_str = ops[0] if ops else ""
473
+
474
+ # Truncate if too long
475
+ if len(ops_str) > 18:
476
+ ops_str = ops_str[:16] + "…"
477
+
478
+ parts.append(
479
+ f'<text class="{title_class}" x="{x + CARD_WIDTH / 2}" y="{y + 20}" '
480
+ f'text-anchor="middle" font-size="11" font-weight="600" fill="{title_color}">'
481
+ f"{ops_str}</text>"
482
+ )
483
+
484
+ # Divider line
485
+ parts.append(
486
+ f'<line class="flow-divider" x1="{x + 10}" y1="{y + 30}" x2="{x + CARD_WIDTH - 10}" y2="{y + 30}" '
487
+ f'stroke="#e7e5e4" stroke-width="1"/>'
488
+ )
489
+
490
+ # SVG visualization (centered in middle region)
491
+ svg_y = y + 35
492
+ svg_region_height = 70
493
+ svg_region_width = CARD_WIDTH - 20
494
+ try:
495
+ if node.chunks and all(node.chunks):
496
+ # Compute sizes using global reference so dimensions are comparable across cards
497
+ shape = node.shape
498
+ # Ratio of global max to each dimension, with logarithmic compression
499
+ ratios = [global_max_dim / max(0.1, d) for d in shape]
500
+ ratios = [ratio_response(r) for r in ratios]
501
+ sizes = tuple(CARD_SVG_REGION / r for r in ratios)
502
+
503
+ node_svg = svg(node.chunks, size=CARD_SVG_REGION, sizes=sizes, labels=False)
504
+ parts.append(
505
+ f'<foreignObject x="{x + 10}" y="{svg_y}" width="{svg_region_width}" height="{svg_region_height}">'
506
+ f'<div xmlns="http://www.w3.org/1999/xhtml" style="display:flex;justify-content:center;align-items:center;height:100%;overflow:hidden;">'
507
+ f"{node_svg}"
508
+ f"</div></foreignObject>"
509
+ )
510
+ else:
511
+ # Scalar - show small circle
512
+ cx = x + CARD_WIDTH / 2
513
+ cy = svg_y + svg_region_height / 2
514
+ parts.append(f'<circle cx="{cx}" cy="{cy}" r="8" fill="#fb923c" fill-opacity="0.7"/>')
515
+ except (NotImplementedError, ValueError):
516
+ # Fallback - empty area
517
+ pass
518
+
519
+ # Divider line before info section
520
+ parts.append(
521
+ f'<line class="flow-divider" x1="{x + 10}" y1="{y + 110}" x2="{x + CARD_WIDTH - 10}" y2="{y + 110}" '
522
+ f'stroke="#e7e5e4" stroke-width="1"/>'
523
+ )
524
+
525
+ # Bottom section: shape and bytes, left-aligned
526
+ shape_str = _format_shape(node.shape)
527
+ bytes_str = _format_bytes(node.nbytes) if node.nbytes > 0 else ""
528
+ left_margin = x + 12
529
+
530
+ parts.append(
531
+ f'<text class="{info_class}" x="{left_margin}" y="{y + 128}" '
532
+ f'text-anchor="start" font-size="10" fill="{info_color}">'
533
+ f"{shape_str}</text>"
534
+ )
535
+
536
+ if bytes_str:
537
+ # Same font size as shape, but bold for emphasized (large) arrays
538
+ bytes_weight = 'font-weight="600"' if emphasized else ""
539
+ parts.append(
540
+ f'<text class="{secondary_class}" x="{left_margin}" y="{y + 142}" '
541
+ f'text-anchor="start" font-size="10" {bytes_weight} fill="{secondary_color}">'
542
+ f"{bytes_str}</text>"
543
+ )
544
+
545
+ return "\n".join(parts)
546
+
547
+
548
+ class FlowDiagram:
549
+ """Wrapper for flow diagram with Jupyter and terminal display support."""
550
+
551
+ def __init__(self, expr):
552
+ self._expr = expr
553
+ self._html_cache = None
554
+
555
+ def _repr_html_(self) -> str:
556
+ """Jupyter notebook display."""
557
+ if self._html_cache is None:
558
+ self._html_cache = render_flow_svg(self._expr)
559
+ return self._html_cache
560
+
561
+ def __repr__(self) -> str:
562
+ """Terminal display - show summary."""
563
+ nodes, edges = build_flow_graph(self._expr)
564
+ n_ops = count_operations(self._expr)
565
+ shapes = [n.shape for n in nodes]
566
+ shape_str = " → ".join(str(s) for s in shapes)
567
+ return f"Expression: {n_ops} operations, {len(nodes)} shape(s): {shape_str}"
568
+
569
+
570
+ def expr_flow(expr) -> FlowDiagram:
571
+ """Create a flow diagram visualization of an expression.
572
+
573
+ Parameters
574
+ ----------
575
+ expr : ArrayExpr or Array
576
+ The expression or array to visualize
577
+
578
+ Returns
579
+ -------
580
+ FlowDiagram
581
+ A displayable flow diagram (works in Jupyter and terminal)
582
+ """
583
+ # Handle both Array and ArrayExpr
584
+ if hasattr(expr, "_expr"):
585
+ expr = expr._expr
586
+ return FlowDiagram(expr)