dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_stack.py ADDED
@@ -0,0 +1,264 @@
1
+ """Stack operation - expression and collection function."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ from itertools import product
7
+
8
+ import numpy as np
9
+ from toolz import concat, first
10
+
11
+ from dask_array._new_collection import new_collection
12
+ from dask._task_spec import Task, TaskRef
13
+ from dask_array._expr import ArrayExpr, unify_chunks_expr
14
+ from dask_array._chunk import getitem
15
+ from dask_array._utils import meta_from_array
16
+
17
+
18
+ class Stack(ArrayExpr):
19
+ _parameters = ["array", "axis", "meta"]
20
+
21
+ @functools.cached_property
22
+ def args(self):
23
+ return [self.array] + self.operands[len(self._parameters) :]
24
+
25
+ @functools.cached_property
26
+ def _meta(self):
27
+ return self.operand("meta")
28
+
29
+ @functools.cached_property
30
+ def chunks(self):
31
+ n = len(self.args)
32
+ return self.array.chunks[: self.axis] + ((1,) * n,) + self.array.chunks[self.axis :]
33
+
34
+ @functools.cached_property
35
+ def _name(self):
36
+ return "stack-" + self.deterministic_token
37
+
38
+ def _layer(self) -> dict:
39
+ keys = list(product([self._name], *[range(len(bd)) for bd in self.chunks]))
40
+ names = [a.name for a in self.args]
41
+ axis = self.axis
42
+ ndim = self._meta.ndim - 1
43
+
44
+ inputs = [(names[key[axis + 1]],) + key[1 : axis + 1] + key[axis + 2 :] for key in keys]
45
+ values = [
46
+ Task(
47
+ key,
48
+ getitem,
49
+ TaskRef(inp),
50
+ (slice(None, None, None),) * axis + (None,) + (slice(None, None, None),) * (ndim - axis),
51
+ )
52
+ for key, inp in zip(keys, inputs)
53
+ ]
54
+ return dict(zip(keys, values))
55
+
56
+ def _simplify_up(self, parent, dependents):
57
+ """Allow slice and shuffle operations to push through Stack."""
58
+ from dask_array._shuffle import Shuffle
59
+ from dask_array.slicing import SliceSlicesIntegers
60
+
61
+ if isinstance(parent, SliceSlicesIntegers):
62
+ return self._accept_slice(parent)
63
+ if isinstance(parent, Shuffle):
64
+ return self._accept_shuffle(parent)
65
+ return None
66
+
67
+ def _accept_shuffle(self, shuffle_expr):
68
+ """Accept a shuffle being pushed through Stack.
69
+
70
+ Stack adds a new dimension at axis. Can't shuffle on the stacked axis.
71
+ For other axes, adjust axis index for inputs (they have one fewer dim).
72
+ """
73
+ from dask_array._shuffle import Shuffle
74
+
75
+ stack_axis = self.axis
76
+ shuffle_axis = shuffle_expr.axis
77
+
78
+ # Can't shuffle on the stacked axis itself
79
+ if shuffle_axis == stack_axis:
80
+ return None
81
+
82
+ # Adjust axis for inputs (they have one fewer dimension)
83
+ if shuffle_axis > stack_axis:
84
+ input_shuffle_axis = shuffle_axis - 1
85
+ else:
86
+ input_shuffle_axis = shuffle_axis
87
+
88
+ # Shuffle each input
89
+ arrays = self.args
90
+ shuffled_arrays = [
91
+ Shuffle(
92
+ a,
93
+ shuffle_expr.indexer,
94
+ input_shuffle_axis,
95
+ shuffle_expr.operand("name"),
96
+ )
97
+ for a in arrays
98
+ ]
99
+
100
+ return type(self)(
101
+ shuffled_arrays[0],
102
+ stack_axis,
103
+ self._meta,
104
+ *shuffled_arrays[1:],
105
+ )
106
+
107
+ def _accept_slice(self, slice_expr):
108
+ """Accept a slice being pushed through Stack.
109
+
110
+ Stack adds a new dimension at axis. Cases:
111
+ 1. Slice on stacked axis: select subset of inputs
112
+ 2. Slice on other axes: push to all inputs (adjusting for added dim)
113
+ """
114
+ from numbers import Integral
115
+
116
+ from dask_array._new_collection import new_collection
117
+ from dask_array._utils import meta_from_array
118
+
119
+ axis = self.axis
120
+ arrays = self.args
121
+ index = slice_expr.index
122
+
123
+ # Pad index to full length (output has one more dim than inputs)
124
+ full_index = index + (slice(None),) * (self.ndim - len(index))
125
+
126
+ # For now, only handle simple slices (no integers that reduce dims)
127
+ if any(isinstance(idx, Integral) for idx in full_index):
128
+ return None
129
+ if any(idx is None for idx in full_index):
130
+ return None
131
+
132
+ # Handle the stacked axis slice
133
+ stacked_axis_slice = full_index[axis]
134
+ n_arrays = len(arrays)
135
+
136
+ if isinstance(stacked_axis_slice, slice):
137
+ start, stop, step = stacked_axis_slice.indices(n_arrays)
138
+ if step != 1:
139
+ return None
140
+ selected_arrays = arrays[start:stop]
141
+ else:
142
+ return None
143
+
144
+ if not selected_arrays:
145
+ return None
146
+
147
+ # Build slice for the other axes (remove the stacked axis from index)
148
+ other_slices = full_index[:axis] + full_index[axis + 1 :]
149
+
150
+ # Check if we need to slice the other axes
151
+ needs_other_slice = any(s != slice(None) for s in other_slices)
152
+
153
+ # Slice each selected array on the other axes
154
+ sliced_arrays = []
155
+ for arr in selected_arrays:
156
+ if needs_other_slice:
157
+ sliced_arr = new_collection(arr)[other_slices]
158
+ sliced_arrays.append(sliced_arr.expr)
159
+ else:
160
+ sliced_arrays.append(arr)
161
+
162
+ # Compute new meta from dtype/ndim. Input metas may use different dummy
163
+ # shapes even when the real arrays have compatible shapes.
164
+ dtype = np.result_type(*[a.dtype for a in sliced_arrays])
165
+ new_meta = meta_from_array(None, ndim=self.ndim, dtype=dtype)
166
+
167
+ # Create new Stack with selected/sliced arrays
168
+ return type(self)(
169
+ sliced_arrays[0],
170
+ axis,
171
+ new_meta,
172
+ *sliced_arrays[1:],
173
+ )
174
+
175
+
176
+ def stack(seq, axis=0, allow_unknown_chunksizes=False):
177
+ """
178
+ Stack arrays along a new axis
179
+
180
+ Given a sequence of dask arrays, form a new dask array by stacking them
181
+ along a new dimension (axis=0 by default)
182
+
183
+ Parameters
184
+ ----------
185
+ seq: list of dask.arrays
186
+ axis: int
187
+ Dimension along which to align all of the arrays
188
+ allow_unknown_chunksizes: bool
189
+ Allow unknown chunksizes, such as come from converting from dask
190
+ dataframes. Dask.array is unable to verify that chunks line up. If
191
+ data comes from differently aligned sources then this can cause
192
+ unexpected results.
193
+
194
+ Examples
195
+ --------
196
+
197
+ Create slices
198
+
199
+ >>> import dask_array as da
200
+ >>> import numpy as np
201
+
202
+ >>> data = [da.from_array(np.ones((4, 4)), chunks=(2, 2))
203
+ ... for i in range(3)]
204
+
205
+ >>> x = da.stack(data, axis=0)
206
+ >>> x.shape
207
+ (3, 4, 4)
208
+
209
+ >>> da.stack(data, axis=1).shape
210
+ (4, 3, 4)
211
+
212
+ >>> da.stack(data, axis=-1).shape
213
+ (4, 4, 3)
214
+
215
+ Result is a new dask Array
216
+
217
+ See Also
218
+ --------
219
+ concatenate
220
+ """
221
+ from dask_array.creation import empty, empty_like
222
+
223
+ # Lazy import to avoid circular dependency
224
+ from dask_array.core import asarray
225
+
226
+ seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
227
+
228
+ if not seq:
229
+ raise ValueError("Need array(s) to stack")
230
+ if not allow_unknown_chunksizes and not all(x.shape == seq[0].shape for x in seq):
231
+ idx = first(i for i in enumerate(seq) if i[1].shape != seq[0].shape)
232
+ raise ValueError(
233
+ "Stacked arrays must have the same shape. The first array had shape "
234
+ f"{seq[0].shape}, while array {idx[0] + 1} has shape {idx[1].shape}."
235
+ )
236
+
237
+ ndim = seq[0].ndim
238
+ if axis < 0:
239
+ axis = ndim + axis + 1
240
+ dtype = np.result_type(*[x.dtype for x in seq])
241
+ meta = meta_from_array(None, ndim=ndim + 1, dtype=dtype)
242
+ seq = [x.astype(meta.dtype) for x in seq]
243
+ shape = tuple(
244
+ (len(seq) if i == axis else (seq[0].shape[i] if i < axis else seq[0].shape[i - 1])) for i in range(meta.ndim)
245
+ )
246
+
247
+ seq2 = [a for a in seq if a.size]
248
+ if not seq2:
249
+ seq2 = seq
250
+
251
+ n = len(seq2)
252
+ if n == 0:
253
+ try:
254
+ return empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
255
+ except TypeError:
256
+ return empty(shape, chunks=shape, dtype=meta.dtype)
257
+
258
+ ind = list(range(ndim))
259
+ uc_args = list(concat((x.expr, ind) for x in seq2))
260
+ _, seq2, _ = unify_chunks_expr(*uc_args)
261
+
262
+ assert len({a.chunks for a in seq2}) == 1 # same chunks
263
+
264
+ return new_collection(Stack(seq2[0], axis, meta, *seq2[1:]))
dask_array/_svg.py ADDED
@@ -0,0 +1,291 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import re
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+
9
+ from dask.utils import cached_cumsum
10
+
11
+
12
+ @lru_cache(maxsize=512)
13
+ def svg(chunks, size=200, labels=True, **kwargs):
14
+ """Convert chunks from Dask Array into an SVG Image
15
+
16
+ Parameters
17
+ ----------
18
+ chunks: tuple
19
+ size: int
20
+ Rough size of the image
21
+ labels: bool
22
+ Whether to include dimension labels (default True)
23
+
24
+ Returns
25
+ -------
26
+ text: An svg string depicting the array as a grid of chunks
27
+ """
28
+ shape = tuple(map(sum, chunks))
29
+ if np.isnan(shape).any(): # don't support unknown sizes
30
+ raise NotImplementedError(
31
+ "Can't generate SVG with unknown chunk sizes.\n\n A possible solution is with x.compute_chunk_sizes()"
32
+ )
33
+ if not all(shape):
34
+ raise NotImplementedError("Can't generate SVG with 0-length dimensions")
35
+ if len(chunks) == 0:
36
+ raise NotImplementedError("Can't generate SVG with 0 dimensions")
37
+ if len(chunks) == 1:
38
+ return svg_1d(chunks, size=size, labels=labels, **kwargs)
39
+ elif len(chunks) == 2:
40
+ return svg_2d(chunks, size=size, labels=labels, **kwargs)
41
+ elif len(chunks) == 3:
42
+ return svg_3d(chunks, size=size, labels=labels, **kwargs)
43
+ else:
44
+ return svg_nd(chunks, size=size, labels=labels, **kwargs)
45
+
46
+
47
+ # Modern styling constants - orange from Dask logo
48
+ STROKE_COLOR = "#78716c" # stone-500 - grid lines
49
+ FILL_COLOR = "#fb923c" # orange-400 - front face
50
+ FILL_COLOR_TOP = "#fdba74" # orange-300 - top face (lighter)
51
+ FILL_COLOR_DARK = "#ea580c" # orange-600 - side face / dense arrays
52
+ TEXT_STYLE = 'font-size="1.0rem" font-weight="400" text-anchor="middle" fill="currentColor"'
53
+
54
+ # SVG definitions for shadow filter
55
+ SVG_DEFS = """ <defs>
56
+ <filter id="shadow" x="-20%" y="-20%" width="140%" height="140%">
57
+ <feDropShadow dx="2" dy="2" stdDeviation="2" flood-opacity="0.15"/>
58
+ </filter>
59
+ </defs>
60
+ """
61
+
62
+
63
+ def svg_2d(chunks, offset=(0, 0), skew=(0, 0), size=200, sizes=None, face=None, labels=True):
64
+ shape = tuple(map(sum, chunks))
65
+ sizes = sizes or draw_sizes(shape, size=size)
66
+ y, x = grid_points(chunks, sizes)
67
+
68
+ lines, (min_x, max_x, min_y, max_y) = svg_grid(x, y, offset=offset, skew=skew, size=size, face=face)
69
+
70
+ # Adjust dimensions based on whether labels are shown
71
+ margin = 50 if labels else 10
72
+ header = f'<svg width="{int(max_x + margin)}" height="{int(max_y + margin)}" style="stroke:{STROKE_COLOR};stroke-width:1">\n'
73
+ header += SVG_DEFS
74
+ footer = "\n</svg>"
75
+
76
+ if labels:
77
+ if shape[0] >= 100:
78
+ rotate = -90
79
+ else:
80
+ rotate = 0
81
+
82
+ text = [
83
+ "",
84
+ " <!-- Text -->",
85
+ f' <text x="{max_x / 2}" y="{max_y + 20}" {TEXT_STYLE}>{shape[1]}</text>',
86
+ f' <text x="{max_x + 20}" y="{max_y / 2}" {TEXT_STYLE} transform="rotate({rotate},{max_x + 20},{max_y / 2})">{shape[0]}</text>',
87
+ ]
88
+ else:
89
+ text = []
90
+
91
+ return header + "\n".join(lines + text) + footer
92
+
93
+
94
+ def svg_3d(chunks, size=200, sizes=None, offset=(0, 0), labels=True):
95
+ shape = tuple(map(sum, chunks))
96
+ sizes = sizes or draw_sizes(shape, size=size)
97
+ x, y, z = grid_points(chunks, sizes)
98
+ ox, oy = offset
99
+
100
+ # Left face (side) - darker
101
+ xy, (mnx, mxx, mny, mxy) = svg_grid(x / 1.7, y, offset=(ox + 10, oy + 0), skew=(1, 0), size=size, face="side")
102
+
103
+ # Top face - lighter
104
+ zx, (_, _, _, max_x) = svg_grid(z, x / 1.7, offset=(ox + 10, oy + 0), skew=(0, 1), size=size, face="top")
105
+
106
+ # Front face - normal
107
+ zy, (min_z, max_z, min_y, max_y) = svg_grid(
108
+ z, y, offset=(ox + max_x + 10, oy + max_x), skew=(0, 0), size=size, face="front"
109
+ )
110
+
111
+ margin = 50 if labels else 10
112
+ header = f'<svg width="{int(max_z + margin)}" height="{int(max_y + margin)}" style="stroke:{STROKE_COLOR};stroke-width:1">\n'
113
+ header += SVG_DEFS
114
+ footer = "\n</svg>"
115
+
116
+ if labels:
117
+ if shape[1] >= 100:
118
+ rotate = -90
119
+ else:
120
+ rotate = 0
121
+
122
+ text = [
123
+ "",
124
+ " <!-- Text -->",
125
+ f' <text x="{(min_z + max_z) / 2}" y="{max_y + 20}" {TEXT_STYLE}>{shape[2]}</text>',
126
+ f' <text x="{max_z + 20}" y="{(min_y + max_y) / 2}" {TEXT_STYLE} transform="rotate({rotate},{max_z + 20},{(min_y + max_y) / 2})">{shape[1]}</text>',
127
+ f' <text x="{(mnx + mxx) / 2 - 10}" y="{mxy - (mxx - mnx) / 2 + 20}" {TEXT_STYLE} transform="rotate(45,{(mnx + mxx) / 2 - 10},{mxy - (mxx - mnx) / 2 + 20})">{shape[0]}</text>',
128
+ ]
129
+ else:
130
+ text = []
131
+
132
+ return header + "\n".join(xy + zx + zy + text) + footer
133
+
134
+
135
+ def svg_nd(chunks, size=200, labels=True):
136
+ if len(chunks) % 3 == 1:
137
+ chunks = ((1,),) + chunks
138
+ shape = tuple(map(sum, chunks))
139
+ sizes = draw_sizes(shape, size=size)
140
+
141
+ chunks2 = chunks
142
+ sizes2 = sizes
143
+ out = []
144
+ left = 0
145
+ total_height = 0
146
+ while chunks2:
147
+ n = len(chunks2) % 3 or 3
148
+ o = svg(chunks2[:n], sizes=sizes2[:n], offset=(left, 0), labels=labels)
149
+ chunks2 = chunks2[n:]
150
+ sizes2 = sizes2[n:]
151
+
152
+ lines = o.split("\n")
153
+ header = lines[0]
154
+ height = float(re.search(r'height="(\d*\.?\d*)"', header).groups()[0])
155
+ total_height = max(total_height, height)
156
+ width = float(re.search(r'width="(\d*\.?\d*)"', header).groups()[0])
157
+ left += width + 10
158
+ o = "\n".join(lines[1:-1]) # remove header and footer
159
+
160
+ out.append(o)
161
+
162
+ header = f'<svg width="{int(left)}" height="{int(total_height)}" style="stroke:{STROKE_COLOR};stroke-width:1">\n'
163
+ header += SVG_DEFS
164
+ footer = "\n</svg>"
165
+ return header + "\n\n".join(out) + footer
166
+
167
+
168
+ def svg_lines(x1, y1, x2, y2, max_n=20):
169
+ """Convert points into lines of text for an SVG plot
170
+
171
+ Examples
172
+ --------
173
+ >>> svg_lines([0, 1], [0, 0], [10, 11], [1, 1]) # doctest: +NORMALIZE_WHITESPACE
174
+ [' <line x1="0" y1="0" x2="10" y2="1" style="stroke-width:2" />',
175
+ ' <line x1="1" y1="0" x2="11" y2="1" style="stroke-width:2" />']
176
+ """
177
+ n = len(x1)
178
+
179
+ if n > max_n:
180
+ indices = np.linspace(0, n - 1, max_n, dtype="int")
181
+ else:
182
+ indices = range(n)
183
+
184
+ lines = [
185
+ ' <line x1="{}" y1="{}" x2="{}" y2="{}" />'.format(int(x1[i]), int(y1[i]), int(x2[i]), int(y2[i]))
186
+ for i in indices
187
+ ]
188
+
189
+ lines[0] = lines[0].replace(" /", ' style="stroke-width:2" /')
190
+ lines[-1] = lines[-1].replace(" /", ' style="stroke-width:2" /')
191
+ return lines
192
+
193
+
194
+ def svg_grid(x, y, offset=(0, 0), skew=(0, 0), size=200, face=None):
195
+ """Create lines of SVG text that show a grid
196
+
197
+ Parameters
198
+ ----------
199
+ x: numpy.ndarray
200
+ y: numpy.ndarray
201
+ offset: tuple
202
+ translational displacement of the grid in SVG coordinates
203
+ skew: tuple
204
+ face: str, optional
205
+ For 3D: "top", "front", or "side" to determine shading
206
+ """
207
+ # Horizontal lines
208
+ x1 = np.zeros_like(y) + offset[0]
209
+ y1 = y + offset[1]
210
+ x2 = np.full_like(y, x[-1]) + offset[0]
211
+ y2 = y + offset[1]
212
+
213
+ if skew[0]:
214
+ y2 += x.max() * skew[0]
215
+ if skew[1]:
216
+ x1 += skew[1] * y
217
+ x2 += skew[1] * y
218
+
219
+ min_x = min(x1.min(), x2.min())
220
+ min_y = min(y1.min(), y2.min())
221
+ max_x = max(x1.max(), x2.max())
222
+ max_y = max(y1.max(), y2.max())
223
+ max_n = size // 6
224
+
225
+ h_lines = ["", " <!-- Horizontal lines -->"] + svg_lines(x1, y1, x2, y2, max_n)
226
+
227
+ # Vertical lines
228
+ x1 = x + offset[0]
229
+ y1 = np.zeros_like(x) + offset[1]
230
+ x2 = x + offset[0]
231
+ y2 = np.full_like(x, y[-1]) + offset[1]
232
+
233
+ if skew[0]:
234
+ y1 += skew[0] * x
235
+ y2 += skew[0] * x
236
+ if skew[1]:
237
+ x2 += skew[1] * y.max()
238
+
239
+ v_lines = ["", " <!-- Vertical lines -->"] + svg_lines(x1, y1, x2, y2, max_n)
240
+
241
+ # Determine fill color based on face
242
+ if face == "top":
243
+ color = FILL_COLOR_TOP
244
+ elif face == "side":
245
+ color = FILL_COLOR_DARK
246
+ else:
247
+ color = FILL_COLOR
248
+
249
+ corners = f"{x1[0]},{y1[0]} {x1[-1]},{y1[-1]} {x2[-1]},{y2[-1]} {x2[0]},{y2[0]}"
250
+ rect = [
251
+ "",
252
+ " <!-- Colored Rectangle -->",
253
+ f' <polygon points="{corners}" style="fill:{color};fill-opacity:0.7;stroke-width:0" filter="url(#shadow)"/>',
254
+ ]
255
+
256
+ return h_lines + v_lines + rect, (min_x, max_x, min_y, max_y)
257
+
258
+
259
+ def svg_1d(chunks, sizes=None, labels=True, **kwargs):
260
+ return svg_2d(((1,),) + chunks, labels=labels, **kwargs)
261
+
262
+
263
+ def grid_points(chunks, sizes):
264
+ cumchunks = [np.array(cached_cumsum(c, initial_zero=True)) for c in chunks]
265
+ points = [x * size / x[-1] for x, size in zip(cumchunks, sizes)]
266
+ return points
267
+
268
+
269
+ def draw_sizes(shape, size=200):
270
+ """Get size in pixels for all dimensions"""
271
+ mx = max(shape)
272
+ ratios = [mx / max(0.1, d) for d in shape]
273
+ ratios = [ratio_response(r) for r in ratios]
274
+ return tuple(size / r for r in ratios)
275
+
276
+
277
+ def ratio_response(x):
278
+ """How we display actual size ratios
279
+
280
+ Common ratios in sizes span several orders of magnitude,
281
+ which is hard for us to perceive.
282
+
283
+ We keep ratios in the 1-3 range accurate, and then apply a logarithm to
284
+ values up until about 100 or so, at which point we stop scaling.
285
+ """
286
+ if x < math.e:
287
+ return x
288
+ elif x <= 100:
289
+ return math.log(x + 12.4) # f(e) == e
290
+ else:
291
+ return math.log(100 + 12.4)
@@ -0,0 +1,29 @@
1
+ """Template loading for array HTML representations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os.path
6
+
7
+ from jinja2 import Environment, FileSystemLoader, Template
8
+ from jinja2.exceptions import TemplateNotFound
9
+
10
+ from dask.utils import typename
11
+
12
+ FILTERS = {
13
+ "type": type,
14
+ "typename": typename,
15
+ }
16
+
17
+ TEMPLATE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
18
+
19
+
20
+ def get_template(name: str) -> Template:
21
+ """Load a Jinja2 template from the templates directory."""
22
+ loader = FileSystemLoader([TEMPLATE_PATH])
23
+ environment = Environment(loader=loader)
24
+ environment.filters.update(FILTERS)
25
+
26
+ try:
27
+ return environment.get_template(name)
28
+ except TemplateNotFound as e:
29
+ raise TemplateNotFound(f"Unable to find {name} in {TEMPLATE_PATH}") from e