dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,104 @@
1
+ import numpy as np
2
+
3
+ import dask
4
+ import dask_array as da
5
+ from dask_array._map_blocks import map_blocks_multi_output
6
+
7
+
8
+ def split_block(spec, block):
9
+ return {
10
+ "double": block * 2,
11
+ "row_sum": block.sum(axis=1),
12
+ }
13
+
14
+
15
+ def test_map_blocks_multi_output_computes_projected_arrays():
16
+ x = da.from_array(np.arange(8).reshape(4, 2), chunks=(2, 2))
17
+ block_specs = {(i, j): None for i in range(2) for j in range(1)}
18
+
19
+ double, row_sum = map_blocks_multi_output(
20
+ split_block,
21
+ [x.expr],
22
+ [("x", "y")],
23
+ ("x", "y"),
24
+ block_specs,
25
+ [
26
+ {
27
+ "key": "double",
28
+ "indices": ("x", "y"),
29
+ "chunks": x.chunks,
30
+ "dtype": x.dtype,
31
+ },
32
+ {
33
+ "key": "row_sum",
34
+ "indices": ("x",),
35
+ "chunks": (x.chunks[0],),
36
+ "dtype": x.dtype,
37
+ },
38
+ ],
39
+ token="split",
40
+ )
41
+
42
+ assert isinstance(double, da.Array)
43
+ assert isinstance(row_sum, da.Array)
44
+ np.testing.assert_array_equal(double.compute(), np.arange(8).reshape(4, 2) * 2)
45
+ np.testing.assert_array_equal(row_sum.compute(), np.arange(8).reshape(4, 2).sum(axis=1))
46
+
47
+ opt_double, opt_row_sum = dask.optimize(double, row_sum)
48
+ np.testing.assert_array_equal(opt_double.compute(), np.arange(8).reshape(4, 2) * 2)
49
+ np.testing.assert_array_equal(opt_row_sum.compute(), np.arange(8).reshape(4, 2).sum(axis=1))
50
+
51
+ persisted_double, persisted_row_sum = dask.persist(double, row_sum, scheduler="single-threaded")
52
+ np.testing.assert_array_equal(persisted_double.compute(), np.arange(8).reshape(4, 2) * 2)
53
+ np.testing.assert_array_equal(persisted_row_sum.compute(), np.arange(8).reshape(4, 2).sum(axis=1))
54
+
55
+
56
+ def test_map_blocks_multi_output_shares_block_calls():
57
+ calls = []
58
+
59
+ def record_block(spec, block):
60
+ calls.append(spec)
61
+ return {"a": block + 1, "b": block + 2}
62
+
63
+ x = da.from_array(np.arange(6), chunks=(3,))
64
+
65
+ a, b = map_blocks_multi_output(
66
+ record_block,
67
+ [x.expr],
68
+ [("x",)],
69
+ ("x",),
70
+ {(0,): 0, (1,): 1},
71
+ [
72
+ {"key": "a", "indices": ("x",), "chunks": x.chunks, "dtype": x.dtype},
73
+ {"key": "b", "indices": ("x",), "chunks": x.chunks, "dtype": x.dtype},
74
+ ],
75
+ token="record",
76
+ )
77
+
78
+ got_a, got_b = dask.compute(a, b, scheduler="single-threaded")
79
+
80
+ np.testing.assert_array_equal(got_a, np.arange(6) + 1)
81
+ np.testing.assert_array_equal(got_b, np.arange(6) + 2)
82
+ assert sorted(calls) == [0, 1]
83
+
84
+
85
+ def test_map_blocks_multi_output_single_projection_omits_other_projection_keys():
86
+ x = da.from_array(np.arange(6), chunks=(3,))
87
+
88
+ a, b = map_blocks_multi_output(
89
+ lambda spec, block: {"a": block + 1, "b": block + 2},
90
+ [x.expr],
91
+ [("x",)],
92
+ ("x",),
93
+ {(0,): None, (1,): None},
94
+ [
95
+ {"key": "a", "indices": ("x",), "chunks": x.chunks, "dtype": x.dtype},
96
+ {"key": "b", "indices": ("x",), "chunks": x.chunks, "dtype": x.dtype},
97
+ ],
98
+ token="cull",
99
+ )
100
+
101
+ graph_keys = set(a.__dask_graph__())
102
+
103
+ assert not any(key[0].startswith(b.name) for key in graph_keys if isinstance(key, tuple))
104
+ np.testing.assert_array_equal(a.compute(), np.arange(6) + 1)
@@ -0,0 +1,214 @@
1
+ """Tests for rechunk pushdown optimizations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ import dask_array as da
9
+ from dask_array.io import FromArray
10
+ from dask_array._test_utils import assert_eq
11
+
12
+
13
+ def test_rechunk_dict_simplifies_through_from_array():
14
+ """Dict chunks should simplify Rechunk(FromArray) -> FromArray."""
15
+ np_arr = np.ones((100, 50))
16
+ darr = da.from_array(np_arr, chunks=(25, 25))
17
+
18
+ rechunked_tuple = darr.rechunk((50, 50))
19
+ rechunked_dict = darr.rechunk({0: 50, 1: 50})
20
+
21
+ assert isinstance(rechunked_tuple.expr.simplify(), FromArray)
22
+ assert isinstance(rechunked_dict.expr.simplify(), FromArray)
23
+
24
+
25
+ def test_rechunk_dict_partial_dims():
26
+ """Dict chunks specifying only some dimensions."""
27
+ np_arr = np.ones((100, 50))
28
+ darr = da.from_array(np_arr, chunks=(25, 25))
29
+
30
+ rechunked = darr.rechunk({0: 50})
31
+ simplified = rechunked.expr.simplify()
32
+
33
+ assert isinstance(simplified, FromArray)
34
+ assert simplified.chunks == ((50, 50), (25, 25))
35
+
36
+
37
+ def test_rechunk_dict_correctness():
38
+ """Dict and tuple chunks produce identical results."""
39
+ np_arr = np.arange(100).reshape(10, 10)
40
+ darr = da.from_array(np_arr, chunks=(5, 5))
41
+
42
+ result_tuple = darr.rechunk((2, 3))
43
+ result_dict = darr.rechunk({0: 2, 1: 3})
44
+
45
+ assert_eq(result_tuple, result_dict)
46
+ assert_eq(result_tuple, np_arr)
47
+
48
+
49
+ def test_rechunk_dict_through_elemwise():
50
+ """Rechunk with dict chunks pushes through elemwise."""
51
+ from dask_array._blockwise import Elemwise
52
+
53
+ x = da.ones((10, 10), chunks=(5, 5))
54
+ y = da.ones((10, 10), chunks=(5, 5))
55
+
56
+ result = (x + y).rechunk({0: 2, 1: 2})
57
+ simplified = result.expr.simplify()
58
+
59
+ assert isinstance(simplified, Elemwise)
60
+
61
+
62
+ def test_rechunk_dict_through_elemwise_correctness():
63
+ """Verify correctness of rechunk through elemwise with dict chunks."""
64
+ np_x = np.arange(100).reshape(10, 10)
65
+ np_y = np.arange(100, 200).reshape(10, 10)
66
+ x = da.from_array(np_x, chunks=(5, 5))
67
+ y = da.from_array(np_y, chunks=(5, 5))
68
+
69
+ result_tuple = (x + y).rechunk((2, 3))
70
+ result_dict = (x + y).rechunk({0: 2, 1: 3})
71
+
72
+ expected = np_x + np_y
73
+ assert_eq(result_tuple, expected)
74
+ assert_eq(result_dict, expected)
75
+
76
+
77
+ def test_rechunk_pushdown_broadcast_elemwise():
78
+ """Test rechunk optimization through elemwise with broadcast args."""
79
+ import dask
80
+
81
+ # Array with broadcast dimension
82
+ a = da.ones((10, 10), chunks=(5, 5))
83
+ b = da.ones((10,), chunks=(5,)) # broadcasts along axis 0
84
+
85
+ # Elemwise with broadcast, then rechunk
86
+ c = (a + b).rechunk((2, 2))
87
+
88
+ # Verify chunks and result
89
+ assert c.chunks == ((2, 2, 2, 2, 2), (2, 2, 2, 2, 2))
90
+ result = c.compute()
91
+ assert result.shape == (10, 10)
92
+ assert (result == 2).all()
93
+
94
+ # Joint compute case
95
+ d = (a * b).rechunk((2, 2)).sum(axis=0)
96
+ e = (a - b).rechunk((2, 2)).mean(axis=0)
97
+ r1, r2 = dask.compute(d, e)
98
+ assert r1.shape == (10,)
99
+ assert r2.shape == (10,)
100
+
101
+
102
+ def test_rechunk_pushdown_through_transpose():
103
+ """Test rechunk pushes through transpose with correct chunk mapping."""
104
+ # Create array with shape (2, 3, 4) and specific chunks
105
+ x = da.ones((2, 3, 4), chunks=(1, 1, 2))
106
+
107
+ # Transpose with axes (2, 0, 1): output shape (4, 2, 3)
108
+ # Output axis 0 <- input axis 2
109
+ # Output axis 1 <- input axis 0
110
+ # Output axis 2 <- input axis 1
111
+ y = x.transpose((2, 0, 1))
112
+
113
+ # Rechunk the transposed output to (2, 1, 3)
114
+ result = y.rechunk((2, 1, 3))
115
+
116
+ # Build expected expression: transpose of rechunked input
117
+ # Input axis 0 needs output axis 1's chunks = 1
118
+ # Input axis 1 needs output axis 2's chunks = 3
119
+ # Input axis 2 needs output axis 0's chunks = 2
120
+ # So input rechunk should be (1, 3, 2)
121
+ expected = x.rechunk((1, 3, 2)).transpose((2, 0, 1))
122
+
123
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
124
+
125
+
126
+ def test_rechunk_pushdown_through_transpose_simple():
127
+ """Test rechunk pushes through simple 2D transpose."""
128
+ x = da.arange(12, chunks=4).reshape(3, 4).rechunk((1, 2))
129
+
130
+ y = x.T # axes = (1, 0)
131
+
132
+ # Rechunk transposed to (2, 3)
133
+ result = y.rechunk((2, 3))
134
+
135
+ # After pushdown: input should be rechunked to (3, 2)
136
+ # Because input axis 0 -> output axis 1 (needs chunks 3)
137
+ # input axis 1 -> output axis 0 (needs chunks 2)
138
+ expected = x.rechunk((3, 2)).T
139
+
140
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
141
+
142
+
143
+ def test_rechunk_pushdown_through_transpose_dict():
144
+ """Test rechunk with dict chunks through transpose."""
145
+ x = da.ones((2, 3, 4), chunks=(1, 1, 2))
146
+
147
+ # Transpose: output axis 0 <- input axis 2
148
+ y = x.transpose((2, 0, 1))
149
+
150
+ # Rechunk only output axis 0 to chunk size 2
151
+ result = y.rechunk({0: 2})
152
+
153
+ # Output axis 0 comes from input axis 2, so input axis 2 should be rechunked
154
+ expected = x.rechunk({2: 2}).transpose((2, 0, 1))
155
+
156
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
157
+
158
+
159
+ # =============================================================================
160
+ # Regression tests
161
+ # =============================================================================
162
+
163
+
164
+ def test_rechunk_noop_preserves_identity():
165
+ """Rechunk with matching chunks should return identical expression.
166
+
167
+ Regression test: rechunk was creating new expressions even when
168
+ target chunks matched input chunks exactly.
169
+ """
170
+ x = da.ones((10, 10), chunks=(5, 5))
171
+
172
+ # All of these should return the same expression
173
+ y_tuple = x.rechunk((5, 5))
174
+ y_dict = x.rechunk({0: 5, 1: 5})
175
+ y_none = x.rechunk((None, None))
176
+
177
+ assert x.expr is y_tuple.expr
178
+ assert x.expr is y_dict.expr
179
+ assert x.expr is y_none.expr
180
+ assert x.name == y_tuple.name == y_dict.name == y_none.name
181
+
182
+
183
+ def test_rechunk_noop_negative_index():
184
+ """Rechunk no-op with negative axis index."""
185
+ x = da.ones((10, 10), chunks=(5, 5))
186
+
187
+ y = x.rechunk({-1: 5, -2: 5})
188
+
189
+ assert x.expr is y.expr
190
+
191
+
192
+ def test_rechunk_multistep_no_cycle():
193
+ """Multi-step rechunk should not create cyclic dependencies.
194
+
195
+ Regression test: when rechunk required multiple steps (split one dim,
196
+ merge another), incorrect task name propagation between steps caused
197
+ cyclic task graph.
198
+ """
199
+ x = da.ones((16, 50), chunks=(16, 1))
200
+ y = x.rechunk((3, 10))
201
+
202
+ # This was raising RuntimeError: Cycle detected
203
+ result = y.compute()
204
+ assert result.shape == (16, 50)
205
+ assert np.all(result == 1)
206
+
207
+
208
+ def test_rechunk_split_and_merge_correctness():
209
+ """Verify multi-step rechunk produces correct values."""
210
+ np_arr = np.arange(16 * 50).reshape(16, 50)
211
+ x = da.from_array(np_arr, chunks=(16, 1))
212
+ y = x.rechunk((3, 10))
213
+
214
+ assert_eq(y, np_arr)