dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+
8
+ from dask_array._new_collection import new_collection
9
+ from dask._task_spec import Task
10
+ from dask_array._expr import ArrayExpr
11
+ from dask_array._chunk import arange as _arange
12
+ from dask_array._core_utils import normalize_chunks
13
+ from dask_array._utils import meta_from_array
14
+
15
+
16
+ class Arange(ArrayExpr):
17
+ _parameters = ["start", "stop", "step", "chunks", "like", "dtype", "kwargs"]
18
+ _defaults = {"chunks": "auto", "like": None, "dtype": None, "kwargs": None}
19
+
20
+ @functools.cached_property
21
+ def num_rows(self):
22
+ return int(max(np.ceil((self.stop - self.start) / self.step), 0))
23
+
24
+ @functools.cached_property
25
+ def dtype(self):
26
+ # Use type(x)(0) to determine dtype without overflow issues
27
+ # when start/stop are very large integers
28
+ dt = self.operand("dtype")
29
+ if dt is not None:
30
+ return np.dtype(dt)
31
+ return np.arange(type(self.start)(0), type(self.stop)(0), self.step).dtype
32
+
33
+ @functools.cached_property
34
+ def _meta(self):
35
+ return meta_from_array(self.like, ndim=1, dtype=self.dtype)
36
+
37
+ @functools.cached_property
38
+ def chunks(self):
39
+ return normalize_chunks(self.operand("chunks"), (self.num_rows,), dtype=self.dtype)
40
+
41
+ def _layer(self) -> dict:
42
+ dsk = {}
43
+ elem_count = 0
44
+ start, step = self.start, self.step
45
+ like = self.like
46
+ func = partial(_arange, like=like)
47
+
48
+ for i, bs in enumerate(self.chunks[0]):
49
+ blockstart = start + (elem_count * step)
50
+ blockstop = start + ((elem_count + bs) * step)
51
+ task = Task(
52
+ (self._name, i),
53
+ func,
54
+ blockstart,
55
+ blockstop,
56
+ step,
57
+ bs,
58
+ self.dtype,
59
+ )
60
+ dsk[(self._name, i)] = task
61
+ elem_count += bs
62
+ return dsk
63
+
64
+
65
+ _arange_sentinel = object()
66
+
67
+
68
+ def arange(start=_arange_sentinel, stop=None, step=1, *, chunks="auto", like=None, dtype=None):
69
+ """
70
+ Return evenly spaced values from `start` to `stop` with step size `step`.
71
+
72
+ The values are half-open [start, stop), so including start and excluding
73
+ stop. This is basically the same as python's range function but for dask
74
+ arrays.
75
+
76
+ When using a non-integer step, such as 0.1, the results will often not be
77
+ consistent. It is better to use linspace for these cases.
78
+
79
+ Parameters
80
+ ----------
81
+ start : int, optional
82
+ The starting value of the sequence. The default is 0.
83
+ stop : int
84
+ The end of the interval, this value is excluded from the interval.
85
+ step : int, optional
86
+ The spacing between the values. The default is 1 when not specified.
87
+ chunks : int
88
+ The number of samples on each block. Note that the last block will have
89
+ fewer samples if ``len(array) % chunks != 0``.
90
+ Defaults to "auto" which will automatically determine chunk sizes.
91
+ dtype : numpy.dtype
92
+ Output dtype. Omit to infer it from start, stop, step
93
+ Defaults to ``None``.
94
+ like : array type or ``None``
95
+ Array to extract meta from. Defaults to ``None``.
96
+
97
+ Returns
98
+ -------
99
+ samples : dask array
100
+
101
+ See Also
102
+ --------
103
+ dask.array.linspace
104
+ """
105
+ if start is _arange_sentinel:
106
+ if stop is None:
107
+ raise TypeError("arange() requires stop to be specified.")
108
+ # Only stop was provided as a keyword argument
109
+ start = 0
110
+ elif stop is None:
111
+ # Only start was provided, treat it as stop
112
+ stop = start
113
+ start = 0
114
+
115
+ # Avoid loss of precision calculating blockstart and blockstop
116
+ # when start is a very large int (~2**63) and step is a small float
117
+ if start != 0 and not np.isclose(start + step - start, step, atol=0):
118
+ r = arange(0, stop - start, step, chunks=chunks, dtype=dtype, like=like)
119
+ return r + start
120
+
121
+ return new_collection(Arange(start, stop, step, chunks, like, dtype))
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._new_collection import new_collection
8
+ from dask._task_spec import Task, TaskRef
9
+ from dask_array._collection import asarray
10
+ from dask_array._expr import ArrayExpr
11
+ from dask_array._utils import meta_from_array
12
+ from dask.utils import derived_from
13
+
14
+
15
+ class Diag1D(ArrayExpr):
16
+ """Create a diagonal matrix from a 1D array (k=0 case only)."""
17
+
18
+ _parameters = ["x"]
19
+
20
+ @functools.cached_property
21
+ def _meta(self):
22
+ return meta_from_array(self.x, ndim=2)
23
+
24
+ @functools.cached_property
25
+ def dtype(self):
26
+ return self.x.dtype
27
+
28
+ @functools.cached_property
29
+ def chunks(self):
30
+ chunks_1d = self.x.chunks[0]
31
+ return (chunks_1d, chunks_1d)
32
+
33
+ def _layer(self) -> dict:
34
+ dsk = {}
35
+ x = self.x
36
+ chunks_1d = x.chunks[0]
37
+ meta = self._meta
38
+
39
+ for i, m in enumerate(chunks_1d):
40
+ for j, n in enumerate(chunks_1d):
41
+ key = (self._name, i, j)
42
+ if i == j:
43
+ dsk[key] = Task(key, np.diag, TaskRef((x._name, i)))
44
+ else:
45
+ dsk[key] = Task(key, np.zeros_like, meta, shape=(m, n))
46
+ return dsk
47
+
48
+
49
+ class Diag2DSimple(ArrayExpr):
50
+ """Extract diagonal from a 2D array with square chunks (k=0 case only)."""
51
+
52
+ _parameters = ["x"]
53
+
54
+ @functools.cached_property
55
+ def _meta(self):
56
+ return meta_from_array(self.x, ndim=1)
57
+
58
+ @functools.cached_property
59
+ def dtype(self):
60
+ return self.x.dtype
61
+
62
+ @functools.cached_property
63
+ def chunks(self):
64
+ return (self.x.chunks[0],)
65
+
66
+ def _layer(self) -> dict:
67
+ dsk = {}
68
+ x = self.x
69
+ x_keys = x.__dask_keys__()
70
+
71
+ for i, row in enumerate(x_keys):
72
+ key = (self._name, i)
73
+ dsk[key] = Task(key, np.diag, TaskRef(row[i]))
74
+ return dsk
75
+
76
+
77
+ @derived_from(np)
78
+ def diag(v, k=0):
79
+ from dask_array._collection import Array
80
+
81
+ from ._diagonal import diagonal
82
+ from ._pad import pad
83
+
84
+ if not isinstance(v, np.ndarray) and not isinstance(v, Array):
85
+ raise TypeError(f"v must be a dask array or numpy array, got {type(v)}")
86
+
87
+ # Handle numpy arrays - wrap and return
88
+ if isinstance(v, np.ndarray) or (hasattr(v, "__array_function__") and not isinstance(v, Array)):
89
+ if v.ndim == 1:
90
+ abs(k)
91
+ result = np.diag(v, k)
92
+ return asarray(result)
93
+ elif v.ndim == 2:
94
+ result = np.diag(v, k)
95
+ return asarray(result)
96
+ else:
97
+ raise ValueError("Array must be 1d or 2d only")
98
+
99
+ v = asarray(v)
100
+
101
+ if v.ndim != 1:
102
+ if v.ndim != 2:
103
+ raise ValueError("Array must be 1d or 2d only")
104
+ # 2D case: extract diagonal
105
+ if k == 0 and v.chunks[0] == v.chunks[1]:
106
+ return new_collection(Diag2DSimple(v.expr))
107
+ else:
108
+ return diagonal(v, k)
109
+
110
+ # 1D case: create diagonal matrix
111
+ if k == 0:
112
+ return new_collection(Diag1D(v.expr))
113
+ elif k > 0:
114
+ return pad(diag(v), [[0, k], [k, 0]], mode="constant")
115
+ else: # k < 0
116
+ return pad(diag(v), [[-k, 0], [0, -k]], mode="constant")
@@ -0,0 +1,241 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+
8
+ from dask_array._new_collection import new_collection
9
+ from dask._task_spec import Task, TaskRef
10
+ from dask_array._collection import asarray
11
+ from dask_array._expr import ArrayExpr
12
+ from dask_array._utils import meta_from_array
13
+ from dask.utils import derived_from
14
+
15
+
16
+ class Diagonal(ArrayExpr):
17
+ """Extract a diagonal from a multi-dimensional array."""
18
+
19
+ _parameters = ["x", "offset", "axis1", "axis2"]
20
+ _defaults = {"offset": 0, "axis1": 0, "axis2": 1}
21
+
22
+ @functools.cached_property
23
+ def _axis1_normalized(self):
24
+ axis = self.axis1
25
+ if axis < 0:
26
+ axis = self.x.ndim + axis
27
+ return axis
28
+
29
+ @functools.cached_property
30
+ def _axis2_normalized(self):
31
+ axis = self.axis2
32
+ if axis < 0:
33
+ axis = self.x.ndim + axis
34
+ return axis
35
+
36
+ @functools.cached_property
37
+ def _effective_axes(self):
38
+ """Return (axis1, axis2, k) with axis1 < axis2."""
39
+ axis1, axis2 = self._axis1_normalized, self._axis2_normalized
40
+ k = self.offset
41
+ if axis1 > axis2:
42
+ axis1, axis2 = axis2, axis1
43
+ k = -self.offset
44
+ return axis1, axis2, k
45
+
46
+ @functools.cached_property
47
+ def _diag_info(self):
48
+ """Compute diagonal metadata."""
49
+ from itertools import product
50
+
51
+ x = self.x
52
+ axis1, axis2, k = self._effective_axes
53
+
54
+ kdiag_row_start = max(0, -k)
55
+ kdiag_col_start = max(0, k)
56
+ kdiag_row_stop = min(x.shape[axis1], x.shape[axis2] - k)
57
+ len_kdiag = kdiag_row_stop - kdiag_row_start
58
+
59
+ free_axes = set(range(x.ndim)) - {axis1, axis2}
60
+ free_indices = list(product(*(range(x.numblocks[i]) for i in free_axes)))
61
+ ndims_free = len(free_axes)
62
+
63
+ return {
64
+ "axis1": axis1,
65
+ "axis2": axis2,
66
+ "k": k,
67
+ "len_kdiag": len_kdiag,
68
+ "kdiag_row_start": kdiag_row_start,
69
+ "kdiag_col_start": kdiag_col_start,
70
+ "kdiag_row_stop": kdiag_row_stop,
71
+ "free_axes": free_axes,
72
+ "free_indices": free_indices,
73
+ "ndims_free": ndims_free,
74
+ }
75
+
76
+ @functools.cached_property
77
+ def _meta(self):
78
+ return meta_from_array(self.x, ndim=self._diag_info["ndims_free"] + 1)
79
+
80
+ @functools.cached_property
81
+ def dtype(self):
82
+ return self.x.dtype
83
+
84
+ @functools.cached_property
85
+ def chunks(self):
86
+ info = self._diag_info
87
+ x = self.x
88
+ axis1, axis2 = info["axis1"], info["axis2"]
89
+
90
+ def pop_axes(chunks, axis1, axis2):
91
+ chunks = list(chunks)
92
+ chunks.pop(axis2)
93
+ chunks.pop(axis1)
94
+ return tuple(chunks)
95
+
96
+ if info["len_kdiag"] <= 0:
97
+ return pop_axes(x.chunks, axis1, axis2) + ((0,),)
98
+
99
+ # Compute diagonal chunks by following the diagonal through blocks
100
+ info["k"]
101
+ kdiag_row_start = info["kdiag_row_start"]
102
+ kdiag_col_start = info["kdiag_col_start"]
103
+
104
+ row_stops_ = np.cumsum(x.chunks[axis1])
105
+ row_starts = np.roll(row_stops_, 1)
106
+ row_starts[0] = 0
107
+
108
+ col_stops_ = np.cumsum(x.chunks[axis2])
109
+ col_starts = np.roll(col_stops_, 1)
110
+ col_starts[0] = 0
111
+
112
+ row_blockid = np.arange(x.numblocks[axis1])
113
+ col_blockid = np.arange(x.numblocks[axis2])
114
+
115
+ row_filter = (row_starts <= kdiag_row_start) & (kdiag_row_start < row_stops_)
116
+ col_filter = (col_starts <= kdiag_col_start) & (kdiag_col_start < col_stops_)
117
+ (I,) = row_blockid[row_filter]
118
+ (J,) = col_blockid[col_filter]
119
+
120
+ kdiag_chunks = ()
121
+ kdiag_r_start = kdiag_row_start
122
+ kdiag_c_start = kdiag_col_start
123
+ curr_I, curr_J = I, J
124
+
125
+ while kdiag_r_start < x.shape[axis1] and kdiag_c_start < x.shape[axis2]:
126
+ nrows, ncols = x.chunks[axis1][curr_I], x.chunks[axis2][curr_J]
127
+ local_r_start = kdiag_r_start - row_starts[curr_I]
128
+ local_c_start = kdiag_c_start - col_starts[curr_J]
129
+ local_k = -local_r_start if local_r_start > 0 else local_c_start
130
+ kdiag_row_end = min(nrows, ncols - local_k)
131
+ kdiag_len = kdiag_row_end - local_r_start
132
+ kdiag_chunks += (kdiag_len,)
133
+
134
+ kdiag_r_start = kdiag_row_end + row_starts[curr_I]
135
+ kdiag_c_start = min(ncols, nrows + local_k) + col_starts[curr_J]
136
+ curr_I = curr_I + 1 if kdiag_r_start == row_stops_[curr_I] else curr_I
137
+ curr_J = curr_J + 1 if kdiag_c_start == col_stops_[curr_J] else curr_J
138
+
139
+ return pop_axes(x.chunks, axis1, axis2) + (kdiag_chunks,)
140
+
141
+ def _layer(self) -> dict:
142
+ from dask_array._utils import is_cupy_type
143
+
144
+ dsk = {}
145
+ info = self._diag_info
146
+ x = self.x
147
+ axis1, axis2, _k = info["axis1"], info["axis2"], info["k"]
148
+ free_indices = info["free_indices"]
149
+ ndims_free = info["ndims_free"]
150
+
151
+ if info["len_kdiag"] <= 0:
152
+ xp = np
153
+ if is_cupy_type(x._meta):
154
+ import cupy
155
+
156
+ xp = cupy
157
+
158
+ out_chunks = self.chunks
159
+ for free_idx in free_indices:
160
+ shape = tuple(out_chunks[axis][free_idx[axis]] for axis in range(ndims_free))
161
+ key = (self._name,) + free_idx + (0,)
162
+ dsk[key] = Task(key, partial(xp.empty, dtype=x.dtype), shape + (0,)) # type: ignore[misc]
163
+ return dsk
164
+
165
+ # Follow k-diagonal through chunks
166
+ kdiag_row_start = info["kdiag_row_start"]
167
+ kdiag_col_start = info["kdiag_col_start"]
168
+
169
+ row_stops_ = np.cumsum(x.chunks[axis1])
170
+ row_starts = np.roll(row_stops_, 1)
171
+ row_starts[0] = 0
172
+
173
+ col_stops_ = np.cumsum(x.chunks[axis2])
174
+ col_starts = np.roll(col_stops_, 1)
175
+ col_starts[0] = 0
176
+
177
+ row_blockid = np.arange(x.numblocks[axis1])
178
+ col_blockid = np.arange(x.numblocks[axis2])
179
+
180
+ row_filter = (row_starts <= kdiag_row_start) & (kdiag_row_start < row_stops_)
181
+ col_filter = (col_starts <= kdiag_col_start) & (kdiag_col_start < col_stops_)
182
+ (I,) = row_blockid[row_filter]
183
+ (J,) = col_blockid[col_filter]
184
+
185
+ i = 0
186
+ kdiag_r_start = kdiag_row_start
187
+ kdiag_c_start = kdiag_col_start
188
+
189
+ while kdiag_r_start < x.shape[axis1] and kdiag_c_start < x.shape[axis2]:
190
+ nrows, ncols = x.chunks[axis1][I], x.chunks[axis2][J]
191
+ local_r_start = kdiag_r_start - row_starts[I]
192
+ local_c_start = kdiag_c_start - col_starts[J]
193
+ local_k = -local_r_start if local_r_start > 0 else local_c_start
194
+ kdiag_row_end = min(nrows, ncols - local_k)
195
+
196
+ for free_idx in free_indices:
197
+ input_idx = free_idx[:axis1] + (I,) + free_idx[axis1 : axis2 - 1] + (J,) + free_idx[axis2 - 1 :]
198
+ output_idx = free_idx + (i,)
199
+ key = (self._name,) + output_idx
200
+ dsk[key] = Task(
201
+ key,
202
+ np.diagonal,
203
+ TaskRef((x._name,) + input_idx),
204
+ local_k,
205
+ axis1,
206
+ axis2,
207
+ )
208
+
209
+ i += 1
210
+ kdiag_r_start = kdiag_row_end + row_starts[I]
211
+ kdiag_c_start = min(ncols, nrows + local_k) + col_starts[J]
212
+ I = I + 1 if kdiag_r_start == row_stops_[I] else I
213
+ J = J + 1 if kdiag_c_start == col_stops_[J] else J
214
+
215
+ return dsk
216
+
217
+
218
+ @derived_from(np)
219
+ def diagonal(a, offset=0, axis1=0, axis2=1):
220
+ from dask_array._numpy_compat import AxisError
221
+
222
+ if a.ndim < 2:
223
+ raise ValueError("diag requires an array of at least two dimensions")
224
+
225
+ def _axis_fmt(axis, name, ndim):
226
+ if axis < 0:
227
+ t = ndim + axis
228
+ if t < 0:
229
+ msg = "{}: axis {} is out of bounds for array of dimension {}"
230
+ raise AxisError(msg.format(name, axis, ndim))
231
+ axis = t
232
+ return axis
233
+
234
+ axis1_norm = _axis_fmt(axis1, "axis1", a.ndim)
235
+ axis2_norm = _axis_fmt(axis2, "axis2", a.ndim)
236
+
237
+ if axis1_norm == axis2_norm:
238
+ raise ValueError("axis1 and axis2 cannot be the same")
239
+
240
+ a = asarray(a)
241
+ return new_collection(Diagonal(a.expr, offset, axis1, axis2))
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+
5
+ import numpy as np
6
+
7
+ from dask_array._new_collection import new_collection
8
+ from dask._task_spec import Task
9
+ from dask_array._expr import ArrayExpr
10
+ from dask_array._core_utils import normalize_chunks
11
+
12
+
13
+ class Eye(ArrayExpr):
14
+ _parameters = ["N", "M", "k", "dtype", "chunks"]
15
+ _defaults = {"M": None, "k": 0, "dtype": float, "chunks": "auto"}
16
+
17
+ @functools.cached_property
18
+ def _M(self):
19
+ return self.M if self.M is not None else self.N
20
+
21
+ @functools.cached_property
22
+ def dtype(self):
23
+ return np.dtype(self.operand("dtype") or float)
24
+
25
+ @functools.cached_property
26
+ def _meta(self):
27
+ return np.empty((0, 0), dtype=self.dtype)
28
+
29
+ @functools.cached_property
30
+ def chunks(self):
31
+ vchunks, hchunks = normalize_chunks(self.operand("chunks"), shape=(self.N, self._M), dtype=self.dtype)
32
+ return (vchunks, hchunks)
33
+
34
+ @functools.cached_property
35
+ def _chunk_size(self):
36
+ # Use the first vertical chunk size for diagonal positioning logic
37
+ return self.chunks[0][0]
38
+
39
+ def _layer(self) -> dict:
40
+ dsk = {}
41
+ vchunks, hchunks = self.chunks
42
+ chunk_size = self._chunk_size
43
+ k = self.k
44
+ dtype = self.dtype
45
+
46
+ for i, vchunk in enumerate(vchunks):
47
+ for j, hchunk in enumerate(hchunks):
48
+ key = (self._name, i, j)
49
+ # Check if this block contains part of the k-diagonal
50
+ if (j - i - 1) * chunk_size <= k <= (j - i + 1) * chunk_size:
51
+ local_k = k - (j - i) * chunk_size
52
+ task = Task(
53
+ key,
54
+ np.eye,
55
+ vchunk,
56
+ hchunk,
57
+ local_k,
58
+ dtype,
59
+ )
60
+ else:
61
+ task = Task(key, np.zeros, (vchunk, hchunk), dtype)
62
+ dsk[key] = task
63
+ return dsk
64
+
65
+
66
+ def eye(N, chunks="auto", M=None, k=0, dtype=float):
67
+ """
68
+ Return a 2-D Array with ones on the diagonal and zeros elsewhere.
69
+
70
+ Parameters
71
+ ----------
72
+ N : int
73
+ Number of rows in the output.
74
+ chunks : int, str
75
+ How to chunk the array. Must be one of the following forms:
76
+
77
+ - A blocksize like 1000.
78
+ - A size in bytes, like "100 MiB" which will choose a uniform
79
+ block-like shape
80
+ - The word "auto" which acts like the above, but uses a configuration
81
+ value ``array.chunk-size`` for the chunk size
82
+ M : int, optional
83
+ Number of columns in the output. If None, defaults to `N`.
84
+ k : int, optional
85
+ Index of the diagonal: 0 (the default) refers to the main diagonal,
86
+ a positive value refers to an upper diagonal, and a negative value
87
+ to a lower diagonal.
88
+ dtype : data-type, optional
89
+ Data-type of the returned array.
90
+
91
+ Returns
92
+ -------
93
+ I : Array of shape (N,M)
94
+ An array where all elements are equal to zero, except for the `k`-th
95
+ diagonal, whose values are equal to one.
96
+ """
97
+ if dtype is None:
98
+ dtype = float
99
+
100
+ if not isinstance(chunks, (int, str)):
101
+ raise ValueError("chunks must be an int or string")
102
+
103
+ return new_collection(Eye(N, M, k, dtype, chunks))
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+
8
+ from dask_array._new_collection import new_collection
9
+ from dask._task_spec import Task
10
+ from dask_array._chunk import linspace as _linspace
11
+
12
+ from ._arange import Arange
13
+
14
+
15
+ class Linspace(Arange):
16
+ _parameters = ["start", "stop", "num", "endpoint", "chunks", "dtype"]
17
+ _defaults = {"num": 50, "endpoint": True, "chunks": "auto", "dtype": None}
18
+ like = None
19
+
20
+ @functools.cached_property
21
+ def num_rows(self):
22
+ return self.operand("num")
23
+
24
+ @functools.cached_property
25
+ def dtype(self):
26
+ dt = self.operand("dtype")
27
+ if dt is not None:
28
+ return np.dtype(dt)
29
+ return np.linspace(0, 1, 1).dtype
30
+
31
+ @functools.cached_property
32
+ def step(self):
33
+ range_ = self.stop - self.start
34
+
35
+ div = (self.num_rows - 1) if self.endpoint else self.num_rows
36
+ if div == 0:
37
+ div = 1
38
+
39
+ return float(range_) / div
40
+
41
+ def _layer(self) -> dict:
42
+ dsk = {}
43
+ blockstart = self.start
44
+ func = partial(_linspace, endpoint=self.endpoint, dtype=self.dtype)
45
+
46
+ for i, bs in enumerate(self.chunks[0]):
47
+ bs_space = bs - 1 if self.endpoint else bs
48
+ blockstop = blockstart + (bs_space * self.step)
49
+ task = Task(
50
+ (self._name, i),
51
+ func,
52
+ blockstart,
53
+ blockstop,
54
+ bs,
55
+ )
56
+ blockstart = blockstart + (self.step * bs)
57
+ dsk[task.key] = task
58
+ return dsk
59
+
60
+
61
+ def linspace(start, stop, num=50, endpoint=True, retstep=False, chunks="auto", dtype=None):
62
+ """
63
+ Return `num` evenly spaced values over the closed interval [`start`,
64
+ `stop`].
65
+
66
+ Parameters
67
+ ----------
68
+ start : scalar
69
+ The starting value of the sequence.
70
+ stop : scalar
71
+ The last value of the sequence.
72
+ num : int, optional
73
+ Number of samples to include in the returned dask array, including the
74
+ endpoints. Default is 50.
75
+ endpoint : bool, optional
76
+ If True, ``stop`` is the last sample. Otherwise, it is not included.
77
+ Default is True.
78
+ retstep : bool, optional
79
+ If True, return (samples, step), where step is the spacing between
80
+ samples. Default is False.
81
+ chunks : int
82
+ The number of samples on each block. Note that the last block will have
83
+ fewer samples if `num % blocksize != 0`
84
+ dtype : dtype, optional
85
+ The type of the output array.
86
+
87
+ Returns
88
+ -------
89
+ samples : dask array
90
+ step : float, optional
91
+ Only returned if ``retstep`` is True. Size of spacing between samples.
92
+
93
+ See Also
94
+ --------
95
+ dask.array.arange
96
+ """
97
+ num = int(num)
98
+ result = new_collection(Linspace(start, stop, num, endpoint, chunks, dtype))
99
+ if retstep:
100
+ return result, result.expr.step
101
+ else:
102
+ return result