dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Test utilities for dask_array.
|
|
2
|
+
|
|
3
|
+
These functions are used in tests to compare dask arrays with numpy arrays.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from dask.base import is_dask_collection
|
|
14
|
+
|
|
15
|
+
from dask_array._utils import is_arraylike, is_cupy_type
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize_to_array(x):
|
|
19
|
+
"""Convert CuPy arrays to numpy arrays."""
|
|
20
|
+
if is_cupy_type(x):
|
|
21
|
+
return x.get()
|
|
22
|
+
else:
|
|
23
|
+
return x
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def allclose(a, b, equal_nan=False, **kwargs):
|
|
27
|
+
"""Check if two arrays are element-wise equal within a tolerance."""
|
|
28
|
+
a = normalize_to_array(a)
|
|
29
|
+
b = normalize_to_array(b)
|
|
30
|
+
if getattr(a, "dtype", None) != "O":
|
|
31
|
+
if hasattr(a, "mask") or hasattr(b, "mask"):
|
|
32
|
+
return np.ma.allclose(a, b, masked_equal=True, **kwargs)
|
|
33
|
+
else:
|
|
34
|
+
return np.allclose(a, b, equal_nan=equal_nan, **kwargs)
|
|
35
|
+
if equal_nan:
|
|
36
|
+
return a.shape == b.shape and all(np.isnan(b) if np.isnan(a) else a == b for (a, b) in zip(a.flat, b.flat))
|
|
37
|
+
return (a == b).all()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def same_keys(a, b):
|
|
41
|
+
"""Check if two dask collections have the same keys in their graphs."""
|
|
42
|
+
|
|
43
|
+
def key(k):
|
|
44
|
+
if isinstance(k, str):
|
|
45
|
+
return (k, -1, -1, -1)
|
|
46
|
+
else:
|
|
47
|
+
return k
|
|
48
|
+
|
|
49
|
+
return sorted(a.dask, key=key) == sorted(b.dask, key=key)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _not_empty(x):
|
|
53
|
+
return x.shape and 0 not in x.shape
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def assert_eq_shape(a, b, check_ndim=True, check_nan=True):
|
|
57
|
+
"""Assert that two shapes are equal, handling NaN values."""
|
|
58
|
+
if check_ndim:
|
|
59
|
+
assert len(a) == len(b)
|
|
60
|
+
|
|
61
|
+
for aa, bb in zip(a, b):
|
|
62
|
+
if math.isnan(aa) or math.isnan(bb):
|
|
63
|
+
if check_nan:
|
|
64
|
+
assert math.isnan(aa) == math.isnan(bb)
|
|
65
|
+
else:
|
|
66
|
+
assert aa == bb
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _check_chunks(x, check_ndim=True, scheduler=None):
|
|
70
|
+
"""Check that chunk shapes match expected shapes."""
|
|
71
|
+
x = x.persist(scheduler=scheduler)
|
|
72
|
+
dsk = x.dask # Cache to avoid repeated graph materialization
|
|
73
|
+
for idx in itertools.product(*(range(len(c)) for c in x.chunks)):
|
|
74
|
+
chunk = dsk[(x.name,) + idx]
|
|
75
|
+
if hasattr(chunk, "result"): # it's a future
|
|
76
|
+
chunk = chunk.result()
|
|
77
|
+
if not hasattr(chunk, "dtype"):
|
|
78
|
+
chunk = np.array(chunk, dtype="O")
|
|
79
|
+
expected_shape = tuple(c[i] for c, i in zip(x.chunks, idx))
|
|
80
|
+
assert_eq_shape(expected_shape, chunk.shape, check_ndim=check_ndim, check_nan=False)
|
|
81
|
+
assert chunk.dtype == x.dtype, "maybe you forgot to pass the scheduler to `assert_eq`?"
|
|
82
|
+
return x
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _get_dt_meta_computed(
|
|
86
|
+
x,
|
|
87
|
+
check_shape=True,
|
|
88
|
+
check_graph=True,
|
|
89
|
+
check_chunks=True,
|
|
90
|
+
check_ndim=True,
|
|
91
|
+
scheduler=None,
|
|
92
|
+
):
|
|
93
|
+
"""Get dtype, meta, and computed value from an array-like."""
|
|
94
|
+
x_original = x
|
|
95
|
+
x_meta = None
|
|
96
|
+
x_computed = None
|
|
97
|
+
|
|
98
|
+
if is_dask_collection(x) and is_arraylike(x):
|
|
99
|
+
assert x.dtype is not None
|
|
100
|
+
adt = x.dtype
|
|
101
|
+
# Note: check_graph is ignored in array-expr mode as it triggers
|
|
102
|
+
# expensive graph regeneration. The HLG validation can be enabled
|
|
103
|
+
# when needed for debugging.
|
|
104
|
+
x_meta = getattr(x, "_meta", None)
|
|
105
|
+
if check_chunks:
|
|
106
|
+
# Replace x with persisted version to avoid computing it twice.
|
|
107
|
+
x = _check_chunks(x, check_ndim=check_ndim, scheduler=scheduler)
|
|
108
|
+
x = x.compute(scheduler=scheduler)
|
|
109
|
+
x_computed = x
|
|
110
|
+
if hasattr(x, "todense"):
|
|
111
|
+
x = x.todense()
|
|
112
|
+
if not hasattr(x, "dtype"):
|
|
113
|
+
x = np.array(x, dtype="O")
|
|
114
|
+
if _not_empty(x):
|
|
115
|
+
assert x.dtype == x_original.dtype
|
|
116
|
+
if check_shape:
|
|
117
|
+
assert_eq_shape(x_original.shape, x.shape, check_nan=False)
|
|
118
|
+
else:
|
|
119
|
+
if not hasattr(x, "dtype"):
|
|
120
|
+
x = np.array(x, dtype="O")
|
|
121
|
+
adt = getattr(x, "dtype", None)
|
|
122
|
+
|
|
123
|
+
return x, adt, x_meta, x_computed
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def assert_eq(
|
|
127
|
+
a,
|
|
128
|
+
b,
|
|
129
|
+
check_shape=True,
|
|
130
|
+
check_graph=True,
|
|
131
|
+
check_meta=True,
|
|
132
|
+
check_chunks=True,
|
|
133
|
+
check_ndim=True,
|
|
134
|
+
check_type=True,
|
|
135
|
+
check_dtype=True,
|
|
136
|
+
equal_nan=True,
|
|
137
|
+
scheduler="sync",
|
|
138
|
+
**kwargs,
|
|
139
|
+
):
|
|
140
|
+
"""Assert that two arrays are equal.
|
|
141
|
+
|
|
142
|
+
This function handles dask arrays, numpy arrays, and other array-likes.
|
|
143
|
+
It computes dask arrays before comparison and performs various checks.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
a, b : array-like
|
|
148
|
+
Arrays to compare
|
|
149
|
+
check_shape : bool
|
|
150
|
+
Whether to check that shapes match
|
|
151
|
+
check_graph : bool
|
|
152
|
+
Whether to validate the dask graph (currently not implemented locally)
|
|
153
|
+
check_meta : bool
|
|
154
|
+
Whether to check metadata consistency
|
|
155
|
+
check_chunks : bool
|
|
156
|
+
Whether to check chunk shapes
|
|
157
|
+
check_ndim : bool
|
|
158
|
+
Whether to check that ndims match
|
|
159
|
+
check_type : bool
|
|
160
|
+
Whether to check that types match
|
|
161
|
+
check_dtype : bool
|
|
162
|
+
Whether to check that dtypes match
|
|
163
|
+
equal_nan : bool
|
|
164
|
+
Whether to treat NaN values as equal
|
|
165
|
+
scheduler : str
|
|
166
|
+
Scheduler to use for computing dask arrays
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
bool
|
|
171
|
+
True if arrays are equal
|
|
172
|
+
"""
|
|
173
|
+
a_original = a
|
|
174
|
+
b_original = b
|
|
175
|
+
|
|
176
|
+
if isinstance(a, (list, int, float)):
|
|
177
|
+
a = np.array(a)
|
|
178
|
+
if isinstance(b, (list, int, float)):
|
|
179
|
+
b = np.array(b)
|
|
180
|
+
|
|
181
|
+
a, adt, a_meta, a_computed = _get_dt_meta_computed(
|
|
182
|
+
a,
|
|
183
|
+
check_shape=check_shape,
|
|
184
|
+
check_graph=check_graph,
|
|
185
|
+
check_chunks=check_chunks,
|
|
186
|
+
check_ndim=check_ndim,
|
|
187
|
+
scheduler=scheduler,
|
|
188
|
+
)
|
|
189
|
+
b, bdt, b_meta, b_computed = _get_dt_meta_computed(
|
|
190
|
+
b,
|
|
191
|
+
check_shape=check_shape,
|
|
192
|
+
check_graph=check_graph,
|
|
193
|
+
check_chunks=check_chunks,
|
|
194
|
+
check_ndim=check_ndim,
|
|
195
|
+
scheduler=scheduler,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if check_dtype and str(adt) != str(bdt):
|
|
199
|
+
raise AssertionError(f"a and b have different dtypes: (a: {adt}, b: {bdt})")
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
assert a.shape == b.shape, f"a and b have different shapes (a: {a.shape}, b: {b.shape})"
|
|
203
|
+
if check_type:
|
|
204
|
+
_a = a if a.shape else a.item()
|
|
205
|
+
_b = b if b.shape else b.item()
|
|
206
|
+
assert type(_a) == type(_b), f"a and b have different types (a: {type(_a)}, b: {type(_b)})"
|
|
207
|
+
if check_meta:
|
|
208
|
+
if hasattr(a, "_meta") and hasattr(b, "_meta"):
|
|
209
|
+
assert_eq(a._meta, b._meta)
|
|
210
|
+
if hasattr(a_original, "_meta"):
|
|
211
|
+
msg = (
|
|
212
|
+
f"compute()-ing 'a' changes its number of dimensions "
|
|
213
|
+
f"(before: {a_original._meta.ndim}, after: {a.ndim})"
|
|
214
|
+
)
|
|
215
|
+
assert a_original._meta.ndim == a.ndim, msg
|
|
216
|
+
if a_meta is not None:
|
|
217
|
+
msg = (
|
|
218
|
+
f"compute()-ing 'a' changes its type (before: {type(a_original._meta)}, after: {type(a_meta)})"
|
|
219
|
+
)
|
|
220
|
+
assert type(a_original._meta) == type(a_meta), msg
|
|
221
|
+
if not (np.isscalar(a_meta) or np.isscalar(a_computed)):
|
|
222
|
+
msg = (
|
|
223
|
+
f"compute()-ing 'a' results in a different type than implied by its metadata "
|
|
224
|
+
f"(meta: {type(a_meta)}, computed: {type(a_computed)})"
|
|
225
|
+
)
|
|
226
|
+
assert type(a_meta) == type(a_computed), msg
|
|
227
|
+
if hasattr(b_original, "_meta"):
|
|
228
|
+
msg = (
|
|
229
|
+
f"compute()-ing 'b' changes its number of dimensions "
|
|
230
|
+
f"(before: {b_original._meta.ndim}, after: {b.ndim})"
|
|
231
|
+
)
|
|
232
|
+
assert b_original._meta.ndim == b.ndim, msg
|
|
233
|
+
if b_meta is not None:
|
|
234
|
+
msg = (
|
|
235
|
+
f"compute()-ing 'b' changes its type (before: {type(b_original._meta)}, after: {type(b_meta)})"
|
|
236
|
+
)
|
|
237
|
+
assert type(b_original._meta) == type(b_meta), msg
|
|
238
|
+
if not (np.isscalar(b_meta) or np.isscalar(b_computed)):
|
|
239
|
+
msg = (
|
|
240
|
+
f"compute()-ing 'b' results in a different type than implied by its metadata "
|
|
241
|
+
f"(meta: {type(b_meta)}, computed: {type(b_computed)})"
|
|
242
|
+
)
|
|
243
|
+
assert type(b_meta) == type(b_computed), msg
|
|
244
|
+
msg = "found values in 'a' and 'b' which differ by more than the allowed amount"
|
|
245
|
+
assert allclose(a, b, equal_nan=equal_nan, **kwargs), msg
|
|
246
|
+
return True
|
|
247
|
+
except TypeError:
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
c = a == b
|
|
251
|
+
|
|
252
|
+
if isinstance(c, np.ndarray):
|
|
253
|
+
assert c.all()
|
|
254
|
+
else:
|
|
255
|
+
assert c
|
|
256
|
+
|
|
257
|
+
return True
|
dask_array/_ufunc.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from functools import partial
|
|
5
|
+
from operator import getitem
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from dask import core as dask_core
|
|
10
|
+
from dask_array._new_collection import new_collection
|
|
11
|
+
from dask_array._collection import Array, asarray, blockwise, elemwise
|
|
12
|
+
from dask_array._expr import ArrayExpr
|
|
13
|
+
from dask_array._core_utils import apply_infer_dtype
|
|
14
|
+
from dask.base import is_dask_collection
|
|
15
|
+
from dask.tokenize import _tokenize_deterministic, normalize_token
|
|
16
|
+
from dask.utils import derived_from, funcname
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def wrap_elemwise(numpy_ufunc, source=np):
|
|
20
|
+
"""Wrap up numpy function into dask.array"""
|
|
21
|
+
|
|
22
|
+
def wrapped(*args, **kwargs):
|
|
23
|
+
dsk = [arg for arg in args if hasattr(arg, "_elemwise")]
|
|
24
|
+
if len(dsk) > 0:
|
|
25
|
+
return dsk[0]._elemwise(numpy_ufunc, *args, **kwargs)
|
|
26
|
+
else:
|
|
27
|
+
return numpy_ufunc(*args, **kwargs)
|
|
28
|
+
|
|
29
|
+
# functools.wraps cannot wrap ufunc in Python 2.x
|
|
30
|
+
wrapped.__name__ = numpy_ufunc.__name__
|
|
31
|
+
return derived_from(source)(wrapped)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class da_frompyfunc:
|
|
35
|
+
"""A serializable `frompyfunc` object"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, func, nin, nout):
|
|
38
|
+
self._ufunc = np.frompyfunc(func, nin, nout)
|
|
39
|
+
self._func = func
|
|
40
|
+
self.nin = nin
|
|
41
|
+
self.nout = nout
|
|
42
|
+
self._name = funcname(func)
|
|
43
|
+
self.__name__ = f"frompyfunc-{self._name}"
|
|
44
|
+
|
|
45
|
+
def outer(self, *args, **kwargs):
|
|
46
|
+
"""Outer product - tokenizable because da_frompyfunc has __dask_tokenize__."""
|
|
47
|
+
return self._ufunc.outer(*args, **kwargs)
|
|
48
|
+
|
|
49
|
+
def __repr__(self):
|
|
50
|
+
return f"da.frompyfunc<{self._name}, {self.nin}, {self.nout}>"
|
|
51
|
+
|
|
52
|
+
def __dask_tokenize__(self):
|
|
53
|
+
return (normalize_token(self._func), self.nin, self.nout)
|
|
54
|
+
|
|
55
|
+
def __reduce__(self):
|
|
56
|
+
return (da_frompyfunc, (self._func, self.nin, self.nout))
|
|
57
|
+
|
|
58
|
+
def __call__(self, *args, **kwargs):
|
|
59
|
+
return self._ufunc(*args, **kwargs)
|
|
60
|
+
|
|
61
|
+
def __getattr__(self, a):
|
|
62
|
+
if not a.startswith("_"):
|
|
63
|
+
return getattr(self._ufunc, a)
|
|
64
|
+
raise AttributeError(f"{type(self).__name__!r} object has no attribute {a!r}")
|
|
65
|
+
|
|
66
|
+
def __dir__(self):
|
|
67
|
+
o = set(dir(type(self)))
|
|
68
|
+
o.update(self.__dict__)
|
|
69
|
+
o.update(dir(self._ufunc))
|
|
70
|
+
return list(o)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@derived_from(np)
|
|
74
|
+
def frompyfunc(func, nin, nout):
|
|
75
|
+
if nout > 1:
|
|
76
|
+
raise NotImplementedError("frompyfunc with more than one output")
|
|
77
|
+
return ufunc(da_frompyfunc(func, nin, nout))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ufunc:
|
|
81
|
+
_forward_attrs = {
|
|
82
|
+
"nin",
|
|
83
|
+
"nargs",
|
|
84
|
+
"nout",
|
|
85
|
+
"ntypes",
|
|
86
|
+
"identity",
|
|
87
|
+
"signature",
|
|
88
|
+
"types",
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
def __init__(self, ufunc):
|
|
92
|
+
if not isinstance(ufunc, (np.ufunc, da_frompyfunc)):
|
|
93
|
+
raise TypeError(f"must be an instance of `ufunc` or `da_frompyfunc`, got `{type(ufunc).__name__}")
|
|
94
|
+
self._ufunc = ufunc
|
|
95
|
+
self.__name__ = ufunc.__name__
|
|
96
|
+
if isinstance(ufunc, np.ufunc):
|
|
97
|
+
derived_from(np)(self)
|
|
98
|
+
|
|
99
|
+
def __dask_tokenize__(self):
|
|
100
|
+
return self.__name__, normalize_token(self._ufunc)
|
|
101
|
+
|
|
102
|
+
def __getattr__(self, key):
|
|
103
|
+
if key in self._forward_attrs:
|
|
104
|
+
return getattr(self._ufunc, key)
|
|
105
|
+
raise AttributeError(f"{type(self).__name__!r} object has no attribute {key!r}")
|
|
106
|
+
|
|
107
|
+
def __dir__(self):
|
|
108
|
+
return list(self._forward_attrs.union(dir(type(self)), self.__dict__))
|
|
109
|
+
|
|
110
|
+
def __repr__(self):
|
|
111
|
+
return repr(self._ufunc)
|
|
112
|
+
|
|
113
|
+
def __call__(self, *args, **kwargs):
|
|
114
|
+
# Validate kwargs - only allow known ufunc kwargs
|
|
115
|
+
valid_kwargs = {"out", "where", "dtype"}
|
|
116
|
+
extra_kwargs = set(kwargs) - valid_kwargs
|
|
117
|
+
if extra_kwargs:
|
|
118
|
+
raise TypeError(f"{self.__name__} does not take the following keyword arguments {sorted(extra_kwargs)}")
|
|
119
|
+
|
|
120
|
+
dsks = [arg for arg in args if hasattr(arg, "_elemwise")]
|
|
121
|
+
if len(dsks) > 0:
|
|
122
|
+
for dsk in dsks:
|
|
123
|
+
result = dsk._elemwise(self._ufunc, *args, **kwargs)
|
|
124
|
+
if type(result) != type(NotImplemented):
|
|
125
|
+
return result
|
|
126
|
+
raise TypeError(f"Parameters of such types are not supported by {self.__name__}")
|
|
127
|
+
else:
|
|
128
|
+
return self._ufunc(*args, **kwargs)
|
|
129
|
+
|
|
130
|
+
@derived_from(np.ufunc)
|
|
131
|
+
def outer(self, A, B, **kwargs):
|
|
132
|
+
if self.nin != 2:
|
|
133
|
+
raise ValueError("outer product only supported for binary functions")
|
|
134
|
+
if "out" in kwargs:
|
|
135
|
+
raise ValueError("`out` kwarg not supported")
|
|
136
|
+
|
|
137
|
+
A_is_dask = is_dask_collection(A)
|
|
138
|
+
B_is_dask = is_dask_collection(B)
|
|
139
|
+
if not A_is_dask and not B_is_dask:
|
|
140
|
+
return self._ufunc.outer(A, B, **kwargs)
|
|
141
|
+
elif A_is_dask and not isinstance(A, Array) or B_is_dask and not isinstance(B, Array):
|
|
142
|
+
raise NotImplementedError("Dask objects besides `dask.array.Array` are not supported at this time.")
|
|
143
|
+
|
|
144
|
+
A = asarray(A)
|
|
145
|
+
B = asarray(B)
|
|
146
|
+
ndim = A.ndim + B.ndim
|
|
147
|
+
out_inds = tuple(range(ndim))
|
|
148
|
+
A_inds = out_inds[: A.ndim]
|
|
149
|
+
B_inds = out_inds[A.ndim :]
|
|
150
|
+
|
|
151
|
+
dtype = apply_infer_dtype(self._ufunc.outer, [A, B], kwargs, "ufunc.outer", suggest_dtype=False)
|
|
152
|
+
|
|
153
|
+
if "dtype" in kwargs:
|
|
154
|
+
func = partial(self._ufunc.outer, dtype=kwargs.pop("dtype"))
|
|
155
|
+
else:
|
|
156
|
+
func = self._ufunc.outer
|
|
157
|
+
|
|
158
|
+
return blockwise(
|
|
159
|
+
func,
|
|
160
|
+
out_inds,
|
|
161
|
+
A,
|
|
162
|
+
A_inds,
|
|
163
|
+
B,
|
|
164
|
+
B_inds,
|
|
165
|
+
dtype=dtype,
|
|
166
|
+
name=self.__name__ + ".outer-" + _tokenize_deterministic(self._ufunc),
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ufuncs, copied from this page:
|
|
172
|
+
# https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
|
173
|
+
|
|
174
|
+
# math operations
|
|
175
|
+
add = ufunc(np.add)
|
|
176
|
+
subtract = ufunc(np.subtract)
|
|
177
|
+
multiply = ufunc(np.multiply)
|
|
178
|
+
divide = ufunc(np.divide)
|
|
179
|
+
logaddexp = ufunc(np.logaddexp)
|
|
180
|
+
logaddexp2 = ufunc(np.logaddexp2)
|
|
181
|
+
true_divide = ufunc(np.true_divide)
|
|
182
|
+
floor_divide = ufunc(np.floor_divide)
|
|
183
|
+
negative = ufunc(np.negative)
|
|
184
|
+
positive = ufunc(np.positive)
|
|
185
|
+
power = ufunc(np.power)
|
|
186
|
+
float_power = ufunc(np.float_power)
|
|
187
|
+
remainder = ufunc(np.remainder)
|
|
188
|
+
mod = ufunc(np.mod)
|
|
189
|
+
# fmod: see below
|
|
190
|
+
conj = conjugate = ufunc(np.conjugate)
|
|
191
|
+
exp = ufunc(np.exp)
|
|
192
|
+
exp2 = ufunc(np.exp2)
|
|
193
|
+
log = ufunc(np.log)
|
|
194
|
+
log2 = ufunc(np.log2)
|
|
195
|
+
log10 = ufunc(np.log10)
|
|
196
|
+
log1p = ufunc(np.log1p)
|
|
197
|
+
expm1 = ufunc(np.expm1)
|
|
198
|
+
sqrt = ufunc(np.sqrt)
|
|
199
|
+
square = ufunc(np.square)
|
|
200
|
+
cbrt = ufunc(np.cbrt)
|
|
201
|
+
reciprocal = ufunc(np.reciprocal)
|
|
202
|
+
|
|
203
|
+
# trigonometric functions
|
|
204
|
+
sin = ufunc(np.sin)
|
|
205
|
+
cos = ufunc(np.cos)
|
|
206
|
+
tan = ufunc(np.tan)
|
|
207
|
+
arcsin = ufunc(np.arcsin)
|
|
208
|
+
arccos = ufunc(np.arccos)
|
|
209
|
+
arctan = ufunc(np.arctan)
|
|
210
|
+
arctan2 = ufunc(np.arctan2)
|
|
211
|
+
hypot = ufunc(np.hypot)
|
|
212
|
+
sinh = ufunc(np.sinh)
|
|
213
|
+
cosh = ufunc(np.cosh)
|
|
214
|
+
tanh = ufunc(np.tanh)
|
|
215
|
+
arcsinh = ufunc(np.arcsinh)
|
|
216
|
+
arccosh = ufunc(np.arccosh)
|
|
217
|
+
arctanh = ufunc(np.arctanh)
|
|
218
|
+
deg2rad = ufunc(np.deg2rad)
|
|
219
|
+
rad2deg = ufunc(np.rad2deg)
|
|
220
|
+
|
|
221
|
+
# comparison functions
|
|
222
|
+
greater = ufunc(np.greater)
|
|
223
|
+
greater_equal = ufunc(np.greater_equal)
|
|
224
|
+
less = ufunc(np.less)
|
|
225
|
+
less_equal = ufunc(np.less_equal)
|
|
226
|
+
not_equal = ufunc(np.not_equal)
|
|
227
|
+
equal = ufunc(np.equal)
|
|
228
|
+
isneginf = partial(equal, -np.inf)
|
|
229
|
+
isposinf = partial(equal, np.inf)
|
|
230
|
+
logical_and = ufunc(np.logical_and)
|
|
231
|
+
logical_or = ufunc(np.logical_or)
|
|
232
|
+
logical_xor = ufunc(np.logical_xor)
|
|
233
|
+
logical_not = ufunc(np.logical_not)
|
|
234
|
+
maximum = ufunc(np.maximum)
|
|
235
|
+
minimum = ufunc(np.minimum)
|
|
236
|
+
fmax = ufunc(np.fmax)
|
|
237
|
+
fmin = ufunc(np.fmin)
|
|
238
|
+
|
|
239
|
+
# bitwise functions
|
|
240
|
+
bitwise_and = ufunc(np.bitwise_and)
|
|
241
|
+
bitwise_or = ufunc(np.bitwise_or)
|
|
242
|
+
bitwise_xor = ufunc(np.bitwise_xor)
|
|
243
|
+
bitwise_not = ufunc(np.bitwise_not)
|
|
244
|
+
invert = bitwise_not
|
|
245
|
+
left_shift = ufunc(np.left_shift)
|
|
246
|
+
right_shift = ufunc(np.right_shift)
|
|
247
|
+
|
|
248
|
+
# floating functions
|
|
249
|
+
isfinite = ufunc(np.isfinite)
|
|
250
|
+
isinf = ufunc(np.isinf)
|
|
251
|
+
isnan = ufunc(np.isnan)
|
|
252
|
+
signbit = ufunc(np.signbit)
|
|
253
|
+
copysign = ufunc(np.copysign)
|
|
254
|
+
nextafter = ufunc(np.nextafter)
|
|
255
|
+
spacing = ufunc(np.spacing)
|
|
256
|
+
# modf: see below
|
|
257
|
+
ldexp = ufunc(np.ldexp)
|
|
258
|
+
# frexp: see below
|
|
259
|
+
fmod = ufunc(np.fmod)
|
|
260
|
+
floor = ufunc(np.floor)
|
|
261
|
+
ceil = ufunc(np.ceil)
|
|
262
|
+
trunc = ufunc(np.trunc)
|
|
263
|
+
|
|
264
|
+
# more math routines, from this page:
|
|
265
|
+
# https://docs.scipy.org/doc/numpy/reference/routines.math.html
|
|
266
|
+
degrees = ufunc(np.degrees)
|
|
267
|
+
radians = ufunc(np.radians)
|
|
268
|
+
rint = ufunc(np.rint)
|
|
269
|
+
fabs = ufunc(np.fabs)
|
|
270
|
+
sign = ufunc(np.sign)
|
|
271
|
+
absolute = ufunc(np.absolute)
|
|
272
|
+
abs = absolute
|
|
273
|
+
|
|
274
|
+
# non-ufunc elementwise functions
|
|
275
|
+
clip = wrap_elemwise(np.clip)
|
|
276
|
+
isreal = wrap_elemwise(np.isreal)
|
|
277
|
+
iscomplex = wrap_elemwise(np.iscomplex)
|
|
278
|
+
real = wrap_elemwise(np.real)
|
|
279
|
+
imag = wrap_elemwise(np.imag)
|
|
280
|
+
fix = wrap_elemwise(np.fix)
|
|
281
|
+
i0 = wrap_elemwise(np.i0)
|
|
282
|
+
sinc = wrap_elemwise(np.sinc)
|
|
283
|
+
nan_to_num = wrap_elemwise(np.nan_to_num)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@derived_from(np)
|
|
287
|
+
def angle(x, deg=0):
|
|
288
|
+
deg = bool(deg)
|
|
289
|
+
if hasattr(x, "_elemwise"):
|
|
290
|
+
return x._elemwise(np.angle, x, deg)
|
|
291
|
+
return np.angle(x, deg=deg)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class DoubleOutputs(ArrayExpr):
|
|
295
|
+
_parameters = ["array", "index", "meta", "name", "func"]
|
|
296
|
+
|
|
297
|
+
@functools.cached_property
|
|
298
|
+
def _meta(self):
|
|
299
|
+
meta = self.operand("meta")
|
|
300
|
+
a = np.empty_like(meta, shape=(1,) * meta.ndim, dtype=meta.dtype)
|
|
301
|
+
result = self.operand("func")(a)
|
|
302
|
+
return result[self.operand("index")]
|
|
303
|
+
|
|
304
|
+
@functools.cached_property
|
|
305
|
+
def chunks(self):
|
|
306
|
+
return self.array.chunks
|
|
307
|
+
|
|
308
|
+
@functools.cached_property
|
|
309
|
+
def _name(self):
|
|
310
|
+
return self.operand("name") + _tokenize_deterministic(*self.operands)
|
|
311
|
+
|
|
312
|
+
def _layer(self) -> dict:
|
|
313
|
+
return {
|
|
314
|
+
(self._name,) + key[1:]: (getitem, key, self.operand("index"))
|
|
315
|
+
for key in dask_core.flatten(self.array.__dask_keys__())
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@derived_from(np)
|
|
320
|
+
def frexp(x):
|
|
321
|
+
# Not actually object dtype, just need to specify something
|
|
322
|
+
tmp = elemwise(np.frexp, x, dtype=object)
|
|
323
|
+
left = DoubleOutputs(tmp, 0, getattr(x, "_meta", x), "mantissa-", np.frexp)
|
|
324
|
+
right = DoubleOutputs(tmp, 1, getattr(x, "_meta", x), "exponent-", np.frexp)
|
|
325
|
+
return new_collection(left), new_collection(right)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@derived_from(np)
|
|
329
|
+
def modf(x):
|
|
330
|
+
# Not actually object dtype, just need to specify something
|
|
331
|
+
tmp = elemwise(np.modf, x, dtype=object)
|
|
332
|
+
left = DoubleOutputs(tmp, 0, getattr(x, "_meta", x), "modf1-", np.modf)
|
|
333
|
+
right = DoubleOutputs(tmp, 1, getattr(x, "_meta", x), "modf2-", np.modf)
|
|
334
|
+
return new_collection(left), new_collection(right)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
@derived_from(np)
|
|
338
|
+
def divmod(x, y):
|
|
339
|
+
res1 = x // y
|
|
340
|
+
res2 = x % y
|
|
341
|
+
return res1, res2
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def round(a, decimals=0):
|
|
345
|
+
"""Round an array to the given number of decimals."""
|
|
346
|
+
a = asarray(a)
|
|
347
|
+
return elemwise(np.round, a, dtype=a.dtype, decimals=decimals)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@derived_from(np)
|
|
351
|
+
def around(x, decimals=0):
|
|
352
|
+
"""Evenly round to the given number of decimals."""
|
|
353
|
+
return round(x, decimals=decimals)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _asarray_isnull(values):
|
|
357
|
+
import pandas as pd
|
|
358
|
+
|
|
359
|
+
return np.asarray(pd.isnull(values))
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def isnull(values):
|
|
363
|
+
"""pandas.isnull for dask arrays"""
|
|
364
|
+
# eagerly raise ImportError, if pandas isn't available
|
|
365
|
+
import pandas as pd # noqa: F401
|
|
366
|
+
|
|
367
|
+
return elemwise(_asarray_isnull, values, dtype="bool")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def notnull(values):
|
|
371
|
+
"""pandas.notnull for dask arrays"""
|
|
372
|
+
return ~isnull(values)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@derived_from(np)
|
|
376
|
+
def isclose(arr1, arr2, rtol=1e-5, atol=1e-8, equal_nan=False):
|
|
377
|
+
"""Returns a boolean array where two arrays are element-wise equal within a tolerance."""
|
|
378
|
+
func = partial(np.isclose, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
|
379
|
+
return elemwise(func, arr1, arr2, dtype="bool")
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@derived_from(np)
|
|
383
|
+
def allclose(arr1, arr2, rtol=1e-5, atol=1e-8, equal_nan=False):
|
|
384
|
+
"""Returns True if two arrays are element-wise equal within a tolerance."""
|
|
385
|
+
return isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=equal_nan).all()
|