dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Tests for pushing integer slices through transpose operations."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
import dask_array as da
|
|
6
|
+
from dask_array._test_utils import assert_eq
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_transpose_integer_slice_2d():
|
|
10
|
+
"""x.T[0] should optimize to x[:, 0] (transpose eliminated)."""
|
|
11
|
+
x = da.ones((3, 4), chunks=2)
|
|
12
|
+
|
|
13
|
+
# Naive expression
|
|
14
|
+
result = x.T[0]
|
|
15
|
+
|
|
16
|
+
# Expected: direct slice without transpose
|
|
17
|
+
expected = x[:, 0]
|
|
18
|
+
|
|
19
|
+
# Optimized naive should match expected structure
|
|
20
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
21
|
+
|
|
22
|
+
# Verify correctness
|
|
23
|
+
assert_eq(result, expected)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_transpose_integer_slice_scalar():
|
|
27
|
+
"""x.T[0, 0] should optimize to x[0, 0] (transpose eliminated)."""
|
|
28
|
+
x = da.ones((3, 4), chunks=2)
|
|
29
|
+
|
|
30
|
+
# Naive expression
|
|
31
|
+
result = x.T[0, 0]
|
|
32
|
+
|
|
33
|
+
# Expected: direct slice without transpose
|
|
34
|
+
expected = x[0, 0]
|
|
35
|
+
|
|
36
|
+
# Optimized naive should match expected structure
|
|
37
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
38
|
+
|
|
39
|
+
# Verify correctness
|
|
40
|
+
assert_eq(result, expected)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_transpose_integer_slice_3d():
|
|
44
|
+
"""x.T[0] on 3D should optimize to x[:, :, 0].T (slice pushed through)."""
|
|
45
|
+
x = da.ones((2, 3, 4), chunks=2)
|
|
46
|
+
|
|
47
|
+
# Naive expression: x.T[0] where x.T has shape (4, 3, 2)
|
|
48
|
+
result = x.T[0]
|
|
49
|
+
|
|
50
|
+
# Expected: slice first, then transpose remaining dims
|
|
51
|
+
# x[:, :, 0] has shape (2, 3), then transpose to (3, 2)
|
|
52
|
+
expected = x[:, :, 0].T
|
|
53
|
+
|
|
54
|
+
# Optimized naive should match expected structure
|
|
55
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
56
|
+
|
|
57
|
+
# Verify correctness
|
|
58
|
+
assert_eq(result, expected)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_transpose_mixed_slice_integer():
|
|
62
|
+
"""x.T[:, 0, :] should push integer through transpose."""
|
|
63
|
+
x = da.ones((3, 4, 5), chunks=2)
|
|
64
|
+
|
|
65
|
+
# Naive expression: x.T[:, 0, :] where x.T has shape (5, 4, 3)
|
|
66
|
+
result = x.T[:, 0, :]
|
|
67
|
+
|
|
68
|
+
# Expected: x[:, 0, :] then transpose remaining (5, 3) -> (5, 3) with axes (0, 1)
|
|
69
|
+
# Actually x.T[:, 0, :] selects middle dim of x.T which is dim 1 of x
|
|
70
|
+
# So: x[:, 0, :].transpose((1, 0))
|
|
71
|
+
expected = x[:, 0, :].transpose((1, 0))
|
|
72
|
+
|
|
73
|
+
# Optimized naive should match expected structure
|
|
74
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
75
|
+
|
|
76
|
+
# Verify correctness
|
|
77
|
+
assert_eq(result, expected)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_transpose_custom_axes_integer_slice():
|
|
81
|
+
"""Integer slice with custom transpose axes."""
|
|
82
|
+
x = da.ones((2, 3, 4), chunks=2)
|
|
83
|
+
|
|
84
|
+
# Naive: x.transpose((2, 0, 1))[0] - shape (4, 2, 3) -> slice dim 0 -> (2, 3)
|
|
85
|
+
result = x.transpose((2, 0, 1))[0]
|
|
86
|
+
|
|
87
|
+
# Expected: x[:, :, 0] then transpose remaining with reduced axes
|
|
88
|
+
# Original axes (2, 0, 1): out[0]=in[2], out[1]=in[0], out[2]=in[1]
|
|
89
|
+
# Slice out[0] (which is in[2]) -> remaining: out[1]=in[0], out[2]=in[1]
|
|
90
|
+
# After renumbering: axes (0, 1) = identity, so no transpose needed
|
|
91
|
+
expected = x[:, :, 0]
|
|
92
|
+
|
|
93
|
+
# Optimized naive should match expected structure
|
|
94
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
95
|
+
|
|
96
|
+
# Verify correctness
|
|
97
|
+
assert_eq(result, expected)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_transpose_slice_task_count():
|
|
101
|
+
"""Verify task count reduction when slice pushes through transpose."""
|
|
102
|
+
from dask_array._collection import Array
|
|
103
|
+
|
|
104
|
+
x = da.ones((4, 6), chunks=2)
|
|
105
|
+
|
|
106
|
+
# Without optimization: slice(transpose(ones)) has transpose layer
|
|
107
|
+
result = x.T[0]
|
|
108
|
+
unopt_graph = dict(result.__dask_graph__())
|
|
109
|
+
|
|
110
|
+
# With optimization: transpose is eliminated, becomes slice(ones)
|
|
111
|
+
optimized = result.expr.optimize()
|
|
112
|
+
opt_result = Array(optimized)
|
|
113
|
+
opt_graph = dict(opt_result.__dask_graph__())
|
|
114
|
+
|
|
115
|
+
# Optimized should have fewer tasks (no transpose layer)
|
|
116
|
+
assert len(opt_graph) < len(unopt_graph), (
|
|
117
|
+
f"Optimized graph should be smaller: {len(opt_graph)} vs {len(unopt_graph)}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Specifically: unoptimized has ones(6) + transpose(6) + getitem(2) = 14
|
|
121
|
+
# Optimized has ones(6) + getitem(2) = 8 (transpose eliminated)
|
|
122
|
+
assert len(opt_graph) == 8, f"Expected 8 tasks (6 ones + 2 getitem), got {len(opt_graph)}"
|
|
123
|
+
assert len(unopt_graph) == 14, f"Expected 14 unoptimized tasks, got {len(unopt_graph)}"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# --- Transpose through Elemwise Tests ---
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_transpose_pushes_through_elemwise_add():
|
|
130
|
+
"""(x + y).T should optimize to x.T + y.T."""
|
|
131
|
+
x = da.ones((10, 5), chunks=5)
|
|
132
|
+
y = da.ones((10, 5), chunks=5)
|
|
133
|
+
|
|
134
|
+
result = (x + y).T
|
|
135
|
+
expected = x.T + y.T
|
|
136
|
+
|
|
137
|
+
# Simplified should match expected structure
|
|
138
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
139
|
+
|
|
140
|
+
# Verify correctness
|
|
141
|
+
assert_eq(result, expected)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_transpose_pushes_through_elemwise_mul():
|
|
145
|
+
"""(x * y).T should optimize to x.T * y.T."""
|
|
146
|
+
x = da.ones((6, 4), chunks=2)
|
|
147
|
+
y = da.ones((6, 4), chunks=2)
|
|
148
|
+
|
|
149
|
+
result = (x * y).T
|
|
150
|
+
expected = x.T * y.T
|
|
151
|
+
|
|
152
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
153
|
+
assert_eq(result, expected)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_transpose_through_elemwise_broadcasting_no_pushdown():
|
|
157
|
+
"""Transpose doesn't push through elemwise when broadcasting (different ndim)."""
|
|
158
|
+
from dask_array.manipulation._transpose import Transpose
|
|
159
|
+
|
|
160
|
+
x = da.ones((6, 4), chunks=2)
|
|
161
|
+
y = da.ones((4,), chunks=2) # broadcasts along axis 1
|
|
162
|
+
|
|
163
|
+
result = (x + y).T # (6, 4) + (4,) -> (6, 4), then .T -> (4, 6)
|
|
164
|
+
|
|
165
|
+
# We don't push through broadcasting cases, so outer op is still Transpose
|
|
166
|
+
opt = result.expr.simplify()
|
|
167
|
+
assert isinstance(opt, Transpose)
|
|
168
|
+
|
|
169
|
+
# But result is still correct
|
|
170
|
+
import numpy as np
|
|
171
|
+
|
|
172
|
+
x_np = np.ones((6, 4))
|
|
173
|
+
y_np = np.ones((4,))
|
|
174
|
+
assert_eq(result, (x_np + y_np).T)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_transpose_pushes_through_elemwise_scalar():
|
|
178
|
+
"""Transpose through elemwise with scalar."""
|
|
179
|
+
x = da.ones((5, 3), chunks=2)
|
|
180
|
+
|
|
181
|
+
result = (x + 1).T
|
|
182
|
+
expected = x.T + 1
|
|
183
|
+
|
|
184
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
185
|
+
assert_eq(result, expected)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_transpose_pushes_through_unary_elemwise():
|
|
189
|
+
"""Transpose through unary elemwise (e.g. negative)."""
|
|
190
|
+
x = da.ones((4, 6), chunks=2)
|
|
191
|
+
|
|
192
|
+
result = (-x).T
|
|
193
|
+
expected = -(x.T)
|
|
194
|
+
|
|
195
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
196
|
+
assert_eq(result, expected)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_transpose_custom_axes_through_elemwise():
|
|
200
|
+
"""Custom transpose axes through elemwise."""
|
|
201
|
+
x = da.ones((2, 3, 4), chunks=2)
|
|
202
|
+
y = da.ones((2, 3, 4), chunks=2)
|
|
203
|
+
|
|
204
|
+
result = (x + y).transpose((2, 0, 1))
|
|
205
|
+
expected = x.transpose((2, 0, 1)) + y.transpose((2, 0, 1))
|
|
206
|
+
|
|
207
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
208
|
+
assert_eq(result, expected)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Tests for expression visualization."""
|
|
2
|
+
|
|
3
|
+
import dask_array as da
|
|
4
|
+
from dask_array._visualize import expr_table
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_expr_table_contains_shapes():
|
|
8
|
+
"""Test that expr_table output contains array shapes."""
|
|
9
|
+
x = da.ones((10, 100), chunks=(5, 50))
|
|
10
|
+
y = x.sum()
|
|
11
|
+
|
|
12
|
+
table = expr_table(y.expr)
|
|
13
|
+
text = str(table)
|
|
14
|
+
|
|
15
|
+
# Should contain the input shape
|
|
16
|
+
assert "(10, 100)" in text
|
|
17
|
+
# Should contain scalar output shape
|
|
18
|
+
assert "()" in text
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_expr_table_contains_bytes():
|
|
22
|
+
"""Test that expr_table output contains byte sizes."""
|
|
23
|
+
x = da.ones((10, 100), chunks=(5, 50))
|
|
24
|
+
y = x.sum()
|
|
25
|
+
|
|
26
|
+
table = expr_table(y.expr)
|
|
27
|
+
text = str(table)
|
|
28
|
+
|
|
29
|
+
# 1000 float64 elements = 8000 bytes = 7.8 kiB
|
|
30
|
+
assert "kiB" in text
|
|
31
|
+
# Scalar output = 8 bytes
|
|
32
|
+
assert "8 B" in text
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_expr_table_contains_operation_names():
|
|
36
|
+
"""Test that expr_table shows operation names."""
|
|
37
|
+
x = da.ones((10, 10), chunks=5)
|
|
38
|
+
y = x + 1
|
|
39
|
+
z = y.sum()
|
|
40
|
+
|
|
41
|
+
table = expr_table(z.expr)
|
|
42
|
+
text = str(table)
|
|
43
|
+
|
|
44
|
+
assert "Reduction" in text
|
|
45
|
+
assert "Ones" in text
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_expr_table_styling_emphasis():
|
|
49
|
+
"""Test that large arrays are bold and small arrays are dim."""
|
|
50
|
+
x = da.ones((100, 100), chunks=50)
|
|
51
|
+
y = x.sum() # Reduces to scalar
|
|
52
|
+
|
|
53
|
+
table = expr_table(y.expr)
|
|
54
|
+
rich_table = table._table
|
|
55
|
+
|
|
56
|
+
operation_cells = rich_table.columns[0]._cells
|
|
57
|
+
shape_cells = rich_table.columns[1]._cells
|
|
58
|
+
bytes_cells = rich_table.columns[2]._cells
|
|
59
|
+
|
|
60
|
+
# The scalar reduction row is small relative to the source and should dim
|
|
61
|
+
# non-operation data. Operation names should remain emphasized.
|
|
62
|
+
assert "bold" in operation_cells[0].spans[0].style
|
|
63
|
+
assert shape_cells[0].style == "dim"
|
|
64
|
+
assert bytes_cells[0].style == "dim"
|
|
65
|
+
|
|
66
|
+
# The large source row should stay visually emphasized.
|
|
67
|
+
assert any("bold" in span.style for span in operation_cells[1].spans)
|
|
68
|
+
assert shape_cells[1].style is None
|
|
69
|
+
assert bytes_cells[1].style is None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_expr_table_html_output():
|
|
73
|
+
"""Test that HTML output is generated for Jupyter."""
|
|
74
|
+
x = da.ones((10, 10), chunks=5)
|
|
75
|
+
y = x.sum()
|
|
76
|
+
|
|
77
|
+
table = expr_table(y.expr)
|
|
78
|
+
html = table._repr_html_()
|
|
79
|
+
|
|
80
|
+
assert "<pre>" in html
|
|
81
|
+
assert "Reduction" in html
|
|
82
|
+
assert "Ones" in html
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_expr_repr_html():
|
|
86
|
+
"""Test that ArrayExpr._repr_html_ works for Jupyter display."""
|
|
87
|
+
x = da.ones((10, 10), chunks=5)
|
|
88
|
+
y = x.sum()
|
|
89
|
+
|
|
90
|
+
# The expression should have _repr_html_ for Jupyter
|
|
91
|
+
html = y.expr._repr_html_()
|
|
92
|
+
|
|
93
|
+
assert "<pre>" in html
|
|
94
|
+
assert "Reduction" in html
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Tests for xarray ChunkManager integration.
|
|
2
|
+
|
|
3
|
+
This module tests the DaskArrayExprManager which registers as the "dask"
|
|
4
|
+
chunk manager, replacing xarray's built-in DaskManager. This allows it to
|
|
5
|
+
handle dask_array.Array types.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytest
|
|
10
|
+
|
|
11
|
+
xr = pytest.importorskip("xarray")
|
|
12
|
+
|
|
13
|
+
import dask_array as da
|
|
14
|
+
from dask_array._xarray import DaskArrayExprManager
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestDaskArrayExprManager:
|
|
18
|
+
"""Tests for the DaskArrayExprManager class."""
|
|
19
|
+
|
|
20
|
+
def test_init(self):
|
|
21
|
+
manager = DaskArrayExprManager()
|
|
22
|
+
assert manager.array_cls is da.Array
|
|
23
|
+
assert manager.available is True
|
|
24
|
+
|
|
25
|
+
def test_is_chunked_array(self):
|
|
26
|
+
manager = DaskArrayExprManager()
|
|
27
|
+
arr = da.ones((10, 10), chunks=(5, 5))
|
|
28
|
+
assert manager.is_chunked_array(arr)
|
|
29
|
+
assert not manager.is_chunked_array(np.ones((10, 10)))
|
|
30
|
+
|
|
31
|
+
def test_is_chunked_array_legacy_dask(self):
|
|
32
|
+
"""Test that manager rejects legacy dask.array.Array."""
|
|
33
|
+
import dask.array as legacy_da
|
|
34
|
+
|
|
35
|
+
manager = DaskArrayExprManager()
|
|
36
|
+
arr = legacy_da.ones((10, 10), chunks=(5, 5))
|
|
37
|
+
assert not manager.is_chunked_array(arr)
|
|
38
|
+
|
|
39
|
+
def test_chunks(self):
|
|
40
|
+
manager = DaskArrayExprManager()
|
|
41
|
+
arr = da.ones((10, 10), chunks=(5, 5))
|
|
42
|
+
assert manager.chunks(arr) == ((5, 5), (5, 5))
|
|
43
|
+
|
|
44
|
+
def test_normalize_chunks(self):
|
|
45
|
+
manager = DaskArrayExprManager()
|
|
46
|
+
result = manager.normalize_chunks((5, 5), shape=(10, 10))
|
|
47
|
+
assert result == ((5, 5), (5, 5))
|
|
48
|
+
|
|
49
|
+
def test_from_array(self):
|
|
50
|
+
manager = DaskArrayExprManager()
|
|
51
|
+
x = np.arange(100).reshape(10, 10)
|
|
52
|
+
arr = manager.from_array(x, chunks=(5, 5))
|
|
53
|
+
assert isinstance(arr, da.Array)
|
|
54
|
+
assert arr.chunks == ((5, 5), (5, 5))
|
|
55
|
+
|
|
56
|
+
def test_from_array_lazy_indexing_adapter_uses_numpy_meta(self):
|
|
57
|
+
from xarray.core.indexing import (
|
|
58
|
+
ImplicitToExplicitIndexingAdapter,
|
|
59
|
+
LazilyIndexedArray,
|
|
60
|
+
NumpyIndexingAdapter,
|
|
61
|
+
OuterIndexer,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
manager = DaskArrayExprManager()
|
|
65
|
+
base = NumpyIndexingAdapter(np.ones((2, 3)))
|
|
66
|
+
lazy = LazilyIndexedArray(base, OuterIndexer((slice(None), slice(None))))
|
|
67
|
+
adapter = ImplicitToExplicitIndexingAdapter(lazy, OuterIndexer)
|
|
68
|
+
|
|
69
|
+
arr = manager.from_array(adapter, chunks=(1, 3))
|
|
70
|
+
out = arr * 1.0
|
|
71
|
+
|
|
72
|
+
assert isinstance(arr.expr._meta, np.ndarray)
|
|
73
|
+
assert isinstance(out.expr._meta, np.ndarray)
|
|
74
|
+
assert out.expr._meta.shape == (0, 0)
|
|
75
|
+
|
|
76
|
+
def test_compute(self):
|
|
77
|
+
manager = DaskArrayExprManager()
|
|
78
|
+
arr = da.ones((10, 10), chunks=(5, 5))
|
|
79
|
+
(result,) = manager.compute(arr)
|
|
80
|
+
np.testing.assert_array_equal(result, np.ones((10, 10)))
|
|
81
|
+
|
|
82
|
+
def test_array_api(self):
|
|
83
|
+
manager = DaskArrayExprManager()
|
|
84
|
+
api = manager.array_api
|
|
85
|
+
assert hasattr(api, "ones")
|
|
86
|
+
assert hasattr(api, "zeros")
|
|
87
|
+
assert hasattr(api, "full_like")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TestXarrayIntegration:
|
|
91
|
+
"""Tests for xarray integration with DaskArrayExprManager."""
|
|
92
|
+
|
|
93
|
+
def test_manager_discoverable(self):
|
|
94
|
+
"""Test that the manager is discoverable via xarray as 'dask'."""
|
|
95
|
+
from xarray.namedarray.parallelcompat import list_chunkmanagers
|
|
96
|
+
|
|
97
|
+
managers = list_chunkmanagers()
|
|
98
|
+
assert "dask" in managers
|
|
99
|
+
# Our manager should be the one registered (replaces built-in)
|
|
100
|
+
assert isinstance(managers["dask"], DaskArrayExprManager)
|
|
101
|
+
|
|
102
|
+
def test_get_chunked_array_type_selects_manager_once(self):
|
|
103
|
+
"""Test xarray sees dask_array.Array through one chunk manager."""
|
|
104
|
+
from xarray.namedarray.parallelcompat import get_chunked_array_type
|
|
105
|
+
|
|
106
|
+
arr = da.ones((10, 10), chunks=(5, 5))
|
|
107
|
+
|
|
108
|
+
assert isinstance(get_chunked_array_type(arr), DaskArrayExprManager)
|
|
109
|
+
|
|
110
|
+
def test_dask_new_collection_roundtrip(self):
|
|
111
|
+
"""Test Dask can rebuild dask_array.Array from its expression."""
|
|
112
|
+
from dask._collections import new_collection
|
|
113
|
+
|
|
114
|
+
arr = da.arange(6, chunks=(3,)) + 1
|
|
115
|
+
|
|
116
|
+
rebuilt = new_collection(arr.expr)
|
|
117
|
+
|
|
118
|
+
assert isinstance(rebuilt, da.Array)
|
|
119
|
+
np.testing.assert_array_equal(rebuilt.compute(), np.arange(6) + 1)
|
|
120
|
+
|
|
121
|
+
def test_dataarray_from_dask_array(self):
|
|
122
|
+
"""Test creating a DataArray from a dask_array.Array."""
|
|
123
|
+
arr = da.ones((10, 20), chunks=(5, 10))
|
|
124
|
+
da_xr = xr.DataArray(arr, dims=["x", "y"])
|
|
125
|
+
|
|
126
|
+
assert da_xr.shape == (10, 20)
|
|
127
|
+
assert da_xr.chunks == ((5, 5), (10, 10))
|
|
128
|
+
|
|
129
|
+
def test_dataarray_compute(self):
|
|
130
|
+
"""Test computing a DataArray backed by dask_array.Array."""
|
|
131
|
+
arr = da.arange(100, chunks=25).reshape(10, 10)
|
|
132
|
+
da_xr = xr.DataArray(arr, dims=["x", "y"])
|
|
133
|
+
|
|
134
|
+
result = da_xr.compute()
|
|
135
|
+
expected = np.arange(100).reshape(10, 10)
|
|
136
|
+
np.testing.assert_array_equal(result.values, expected)
|
|
137
|
+
|
|
138
|
+
def test_dataarray_operations(self):
|
|
139
|
+
"""Test that DataArray operations work with dask_array.Array."""
|
|
140
|
+
arr = da.ones((10, 20), chunks=(5, 10))
|
|
141
|
+
da_xr = xr.DataArray(arr, dims=["x", "y"])
|
|
142
|
+
|
|
143
|
+
# Test arithmetic
|
|
144
|
+
result = (da_xr + 1).compute()
|
|
145
|
+
np.testing.assert_array_equal(result.values, np.full((10, 20), 2.0))
|
|
146
|
+
|
|
147
|
+
# Test reduction
|
|
148
|
+
result = da_xr.sum().compute()
|
|
149
|
+
assert result.values == 200.0
|
|
150
|
+
|
|
151
|
+
def test_dataarray_rechunk(self):
|
|
152
|
+
"""Test rechunking a DataArray."""
|
|
153
|
+
arr = da.ones((10, 20), chunks=(5, 10))
|
|
154
|
+
da_xr = xr.DataArray(arr, dims=["x", "y"])
|
|
155
|
+
|
|
156
|
+
rechunked = da_xr.chunk({"x": 10, "y": 5})
|
|
157
|
+
assert rechunked.chunks == ((10,), (5, 5, 5, 5))
|
|
158
|
+
|
|
159
|
+
def test_dataset_from_dask_arrays(self):
|
|
160
|
+
"""Test creating a Dataset from dask_array.Arrays."""
|
|
161
|
+
arr1 = da.ones((10, 20), chunks=(5, 10))
|
|
162
|
+
arr2 = da.zeros((10, 20), chunks=(5, 10))
|
|
163
|
+
|
|
164
|
+
ds = xr.Dataset(
|
|
165
|
+
{
|
|
166
|
+
"var1": (["x", "y"], arr1),
|
|
167
|
+
"var2": (["x", "y"], arr2),
|
|
168
|
+
}
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
assert ds["var1"].shape == (10, 20)
|
|
172
|
+
assert ds["var2"].shape == (10, 20)
|
|
173
|
+
|
|
174
|
+
result = ds.compute()
|
|
175
|
+
np.testing.assert_array_equal(result["var1"].values, np.ones((10, 20)))
|
|
176
|
+
np.testing.assert_array_equal(result["var2"].values, np.zeros((10, 20)))
|
|
177
|
+
|
|
178
|
+
def test_apply_ufunc(self):
|
|
179
|
+
"""Test xr.apply_ufunc with dask_array.Array."""
|
|
180
|
+
arr = da.ones((10, 20), chunks=(5, 10))
|
|
181
|
+
da_xr = xr.DataArray(arr, dims=["x", "y"])
|
|
182
|
+
|
|
183
|
+
result = xr.apply_ufunc(
|
|
184
|
+
lambda x: x * 2,
|
|
185
|
+
da_xr,
|
|
186
|
+
dask="parallelized",
|
|
187
|
+
output_dtypes=[float],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
assert isinstance(result.data, da.Array)
|
|
191
|
+
|
|
192
|
+
computed = result.compute()
|
|
193
|
+
np.testing.assert_array_equal(computed.values, np.full((10, 20), 2.0))
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dask-array
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Parallel arrays with task graphs
|
|
5
|
+
Project-URL: Homepage, https://github.com/mrocklin/dask-array
|
|
6
|
+
Project-URL: Source, https://github.com/mrocklin/dask-array
|
|
7
|
+
Project-URL: Issues, https://github.com/mrocklin/dask-array/issues
|
|
8
|
+
License-Expression: BSD-3-Clause
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Python: >=3.10
|
|
11
|
+
Requires-Dist: dask>=2024.1.0
|
|
12
|
+
Requires-Dist: jinja2>=3.1.6
|
|
13
|
+
Requires-Dist: numpy!=2.4.0,>=2.0.0
|
|
14
|
+
Requires-Dist: rich
|
|
15
|
+
Requires-Dist: toolz>=0.8.2
|
|
16
|
+
Provides-Extra: complete
|
|
17
|
+
Requires-Dist: h5py; extra == 'complete'
|
|
18
|
+
Requires-Dist: scipy; extra == 'complete'
|
|
19
|
+
Requires-Dist: tiledb; extra == 'complete'
|
|
20
|
+
Requires-Dist: zarr>=2.0; extra == 'complete'
|
|
21
|
+
Provides-Extra: sparse
|
|
22
|
+
Requires-Dist: numba>=0.60; (python_version < '3.13') and extra == 'sparse'
|
|
23
|
+
Requires-Dist: sparse; (python_version < '3.13') and extra == 'sparse'
|
|
24
|
+
Provides-Extra: test
|
|
25
|
+
Requires-Dist: pytest; extra == 'test'
|
|
26
|
+
Requires-Dist: pytest-cov; extra == 'test'
|
|
27
|
+
Requires-Dist: pytest-xdist; extra == 'test'
|
|
28
|
+
Requires-Dist: xarray; extra == 'test'
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
|
|
31
|
+
# dask-array
|
|
32
|
+
|
|
33
|
+
Expression-based dask array implementation.
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install dask-array
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Usage
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import dask_array as da
|
|
45
|
+
|
|
46
|
+
x = da.ones((1000, 1000), chunks=(100, 100))
|
|
47
|
+
result = x.sum().compute()
|
|
48
|
+
```
|