dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_dispatch.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dispatch registries for dask_array.
|
|
3
|
+
|
|
4
|
+
This module provides Dispatch objects for array operations that need to be
|
|
5
|
+
dispatched based on array type (numpy, cupy, sparse, etc.).
|
|
6
|
+
|
|
7
|
+
concatenate_lookup and tensordot_lookup are defined in _core_utils.py but
|
|
8
|
+
re-exported here for convenience.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from dask.utils import Dispatch
|
|
16
|
+
|
|
17
|
+
# Re-export from _core_utils for convenience
|
|
18
|
+
from dask_array._core_utils import concatenate_lookup, tensordot_lookup
|
|
19
|
+
|
|
20
|
+
# Dispatch registries for array operations
|
|
21
|
+
take_lookup = Dispatch("take")
|
|
22
|
+
einsum_lookup = Dispatch("einsum")
|
|
23
|
+
empty_lookup = Dispatch("empty")
|
|
24
|
+
divide_lookup = Dispatch("divide")
|
|
25
|
+
percentile_lookup = Dispatch("percentile")
|
|
26
|
+
numel_lookup = Dispatch("numel")
|
|
27
|
+
nannumel_lookup = Dispatch("nannumel")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# --- numpy implementations ---
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _divide(x1, x2, out=None, dtype=None):
|
|
34
|
+
"""Implementation of numpy.divide that works with dtype kwarg."""
|
|
35
|
+
x = np.divide(x1, x2, out)
|
|
36
|
+
if dtype is not None:
|
|
37
|
+
x = x.astype(dtype)
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _percentile(a, q, method="linear"):
|
|
42
|
+
"""
|
|
43
|
+
Chunk-level percentile calculation.
|
|
44
|
+
|
|
45
|
+
Returns (percentile_values, n) tuple where n is the number of elements.
|
|
46
|
+
Used for combining percentiles from multiple chunks.
|
|
47
|
+
"""
|
|
48
|
+
from collections.abc import Iterator
|
|
49
|
+
|
|
50
|
+
n = len(a)
|
|
51
|
+
if not len(a):
|
|
52
|
+
return None, n
|
|
53
|
+
if isinstance(q, Iterator):
|
|
54
|
+
q = list(q)
|
|
55
|
+
if a.dtype.name == "category":
|
|
56
|
+
result = np.percentile(a.cat.codes, q, method=method)
|
|
57
|
+
import pandas as pd
|
|
58
|
+
|
|
59
|
+
return (
|
|
60
|
+
pd.Categorical.from_codes(result, a.dtype.categories, a.dtype.ordered),
|
|
61
|
+
n,
|
|
62
|
+
)
|
|
63
|
+
if type(a.dtype).__name__ == "DatetimeTZDtype":
|
|
64
|
+
import pandas as pd
|
|
65
|
+
|
|
66
|
+
if isinstance(a, (pd.Series, pd.Index)):
|
|
67
|
+
a = a.values
|
|
68
|
+
|
|
69
|
+
if np.issubdtype(a.dtype, np.datetime64):
|
|
70
|
+
values = a
|
|
71
|
+
if type(a).__name__ in ("Series", "Index"):
|
|
72
|
+
a2 = values.astype("i8")
|
|
73
|
+
else:
|
|
74
|
+
a2 = values.view("i8")
|
|
75
|
+
result = np.percentile(a2, q, method=method).astype(values.dtype)
|
|
76
|
+
if q[0] == 0:
|
|
77
|
+
# https://github.com/dask/dask/issues/6864
|
|
78
|
+
result[0] = min(result[0], values.min())
|
|
79
|
+
return result, n
|
|
80
|
+
if not np.issubdtype(a.dtype, np.number):
|
|
81
|
+
method = "nearest"
|
|
82
|
+
return np.percentile(a, q, method=method), n
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _numel(x, **kwargs):
|
|
86
|
+
"""
|
|
87
|
+
A reduction to count the number of elements.
|
|
88
|
+
|
|
89
|
+
Returns ndarray result (coerces to numpy).
|
|
90
|
+
"""
|
|
91
|
+
import math
|
|
92
|
+
|
|
93
|
+
shape = x.shape
|
|
94
|
+
keepdims = kwargs.get("keepdims", False)
|
|
95
|
+
axis = kwargs.get("axis")
|
|
96
|
+
dtype = kwargs.get("dtype", np.float64)
|
|
97
|
+
|
|
98
|
+
if axis is None:
|
|
99
|
+
prod = np.prod(shape, dtype=dtype)
|
|
100
|
+
if keepdims is False:
|
|
101
|
+
return prod
|
|
102
|
+
|
|
103
|
+
return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype)
|
|
104
|
+
|
|
105
|
+
if not isinstance(axis, (tuple, list)):
|
|
106
|
+
axis = [axis]
|
|
107
|
+
|
|
108
|
+
prod = math.prod(shape[dim] for dim in axis)
|
|
109
|
+
if keepdims is True:
|
|
110
|
+
new_shape = tuple(shape[dim] if dim not in axis else 1 for dim in range(len(shape)))
|
|
111
|
+
else:
|
|
112
|
+
new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)
|
|
113
|
+
|
|
114
|
+
return np.broadcast_to(np.array(prod, dtype=dtype), new_shape)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _nannumel(x, **kwargs):
|
|
118
|
+
"""A reduction to count the number of elements, excluding nans"""
|
|
119
|
+
return np.sum(~(np.isnan(x)), **kwargs)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# --- Register numpy implementations ---
|
|
123
|
+
|
|
124
|
+
take_lookup.register((object, np.ndarray, np.ma.masked_array), np.take)
|
|
125
|
+
einsum_lookup.register((object, np.ndarray), np.einsum)
|
|
126
|
+
empty_lookup.register((object, np.ndarray), np.empty)
|
|
127
|
+
empty_lookup.register(np.ma.masked_array, np.ma.empty)
|
|
128
|
+
divide_lookup.register((object, np.ndarray), _divide)
|
|
129
|
+
divide_lookup.register(np.ma.masked_array, np.ma.divide)
|
|
130
|
+
percentile_lookup.register(np.ndarray, _percentile)
|
|
131
|
+
numel_lookup.register((object, np.ndarray), _numel)
|
|
132
|
+
nannumel_lookup.register((object, np.ndarray), _nannumel)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# --- Register masked array numel ---
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@numel_lookup.register(np.ma.masked_array)
|
|
139
|
+
def _numel_masked(x, **kwargs):
|
|
140
|
+
"""Numel implementation for masked arrays."""
|
|
141
|
+
return np.sum(np.ones_like(x), **kwargs)
|
dask_array/_einsum.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""Einstein summation for array-expr."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from dask import config
|
|
10
|
+
from dask.utils import cached_max, derived_from
|
|
11
|
+
|
|
12
|
+
from dask_array._dispatch import einsum_lookup
|
|
13
|
+
|
|
14
|
+
# Valid characters for einsum subscripts (from numpy)
|
|
15
|
+
einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
16
|
+
einsum_symbols_set = set(einsum_symbols)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def chunk_einsum(*operands, **kwargs):
|
|
20
|
+
"""Chunk-level einsum computation.
|
|
21
|
+
|
|
22
|
+
This function is used by blockwise to compute einsum on individual chunks.
|
|
23
|
+
It dispatches to the appropriate einsum implementation based on array type.
|
|
24
|
+
"""
|
|
25
|
+
subscripts = kwargs.pop("subscripts")
|
|
26
|
+
ncontract_inds = kwargs.pop("ncontract_inds")
|
|
27
|
+
dtype = kwargs.pop("kernel_dtype")
|
|
28
|
+
einsum = einsum_lookup.dispatch(type(operands[0]))
|
|
29
|
+
chunk = einsum(subscripts, *operands, dtype=dtype, **kwargs)
|
|
30
|
+
|
|
31
|
+
# Avoid concatenate=True in blockwise by adding 1's
|
|
32
|
+
# for the contracted dimensions
|
|
33
|
+
return chunk.reshape(chunk.shape + (1,) * ncontract_inds)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _calculate_new_chunksizes(old_chunks, new_chunks, changeable_dimensions, target_size):
|
|
37
|
+
"""Calculate new chunk sizes for einsum rechunking."""
|
|
38
|
+
from dask_array._shuffle import _calculate_new_chunksizes as _calc
|
|
39
|
+
|
|
40
|
+
return _calc(old_chunks, new_chunks, changeable_dimensions, target_size)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _parse_einsum_input(operands, asarray):
|
|
44
|
+
"""Parse einsum input, adapted from numpy/dask.
|
|
45
|
+
|
|
46
|
+
This is a copy of parse_einsum_input from einsumfuncs.py but uses
|
|
47
|
+
the provided asarray function to ensure correct array type.
|
|
48
|
+
"""
|
|
49
|
+
if len(operands) == 0:
|
|
50
|
+
raise ValueError("No input operands")
|
|
51
|
+
|
|
52
|
+
if isinstance(operands[0], str):
|
|
53
|
+
subscripts = operands[0].replace(" ", "")
|
|
54
|
+
operands = [asarray(o) for o in operands[1:]]
|
|
55
|
+
|
|
56
|
+
# Ensure all characters are valid
|
|
57
|
+
for s in subscripts:
|
|
58
|
+
if s in ".,->":
|
|
59
|
+
continue
|
|
60
|
+
if s not in einsum_symbols_set:
|
|
61
|
+
raise ValueError(f"Character {s} is not a valid symbol.")
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
tmp_operands = list(operands)
|
|
65
|
+
operand_list = []
|
|
66
|
+
subscript_list = []
|
|
67
|
+
for _ in range(len(operands) // 2):
|
|
68
|
+
operand_list.append(tmp_operands.pop(0))
|
|
69
|
+
subscript_list.append(tmp_operands.pop(0))
|
|
70
|
+
|
|
71
|
+
output_list = tmp_operands[-1] if len(tmp_operands) else None
|
|
72
|
+
operands = [asarray(v) for v in operand_list]
|
|
73
|
+
subscripts = ""
|
|
74
|
+
last = len(subscript_list) - 1
|
|
75
|
+
for num, sub in enumerate(subscript_list):
|
|
76
|
+
for s in sub:
|
|
77
|
+
if s is Ellipsis:
|
|
78
|
+
subscripts += "..."
|
|
79
|
+
elif isinstance(s, int):
|
|
80
|
+
subscripts += einsum_symbols[s]
|
|
81
|
+
else:
|
|
82
|
+
raise TypeError("For this input type lists must contain either int or Ellipsis")
|
|
83
|
+
if num != last:
|
|
84
|
+
subscripts += ","
|
|
85
|
+
|
|
86
|
+
if output_list is not None:
|
|
87
|
+
subscripts += "->"
|
|
88
|
+
for s in output_list:
|
|
89
|
+
if s is Ellipsis:
|
|
90
|
+
subscripts += "..."
|
|
91
|
+
elif isinstance(s, int):
|
|
92
|
+
subscripts += einsum_symbols[s]
|
|
93
|
+
else:
|
|
94
|
+
raise TypeError("For this input type lists must contain either int or Ellipsis")
|
|
95
|
+
# Check for proper "->"
|
|
96
|
+
if ("-" in subscripts) or (">" in subscripts):
|
|
97
|
+
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
|
|
98
|
+
if invalid or (subscripts.count("->") != 1):
|
|
99
|
+
raise ValueError("Subscripts can only contain one '->'.")
|
|
100
|
+
|
|
101
|
+
# Parse ellipses
|
|
102
|
+
if "." in subscripts:
|
|
103
|
+
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
|
|
104
|
+
unused = list(einsum_symbols_set - set(used))
|
|
105
|
+
ellipse_inds = "".join(unused)
|
|
106
|
+
longest = 0
|
|
107
|
+
|
|
108
|
+
if "->" in subscripts:
|
|
109
|
+
input_tmp, output_sub = subscripts.split("->")
|
|
110
|
+
split_subscripts = input_tmp.split(",")
|
|
111
|
+
out_sub = True
|
|
112
|
+
else:
|
|
113
|
+
split_subscripts = subscripts.split(",")
|
|
114
|
+
out_sub = False
|
|
115
|
+
|
|
116
|
+
for num, sub in enumerate(split_subscripts):
|
|
117
|
+
if "." in sub:
|
|
118
|
+
if (sub.count(".") != 3) or (sub.count("...") != 1):
|
|
119
|
+
raise ValueError("Invalid Ellipses.")
|
|
120
|
+
|
|
121
|
+
# Take into account numerical values
|
|
122
|
+
if operands[num].shape == ():
|
|
123
|
+
ellipse_count = 0
|
|
124
|
+
else:
|
|
125
|
+
ellipse_count = max(operands[num].ndim, 1)
|
|
126
|
+
ellipse_count -= len(sub) - 3
|
|
127
|
+
|
|
128
|
+
if ellipse_count > longest:
|
|
129
|
+
longest = ellipse_count
|
|
130
|
+
|
|
131
|
+
if ellipse_count < 0:
|
|
132
|
+
raise ValueError("Ellipses lengths do not match.")
|
|
133
|
+
elif ellipse_count == 0:
|
|
134
|
+
split_subscripts[num] = sub.replace("...", "")
|
|
135
|
+
else:
|
|
136
|
+
rep_inds = ellipse_inds[-ellipse_count:]
|
|
137
|
+
split_subscripts[num] = sub.replace("...", rep_inds)
|
|
138
|
+
|
|
139
|
+
subscripts = ",".join(split_subscripts)
|
|
140
|
+
if longest == 0:
|
|
141
|
+
out_ellipse = ""
|
|
142
|
+
else:
|
|
143
|
+
out_ellipse = ellipse_inds[-longest:]
|
|
144
|
+
|
|
145
|
+
if out_sub:
|
|
146
|
+
subscripts += "->" + output_sub.replace("...", out_ellipse)
|
|
147
|
+
else:
|
|
148
|
+
# Special care for outputless ellipses
|
|
149
|
+
output_subscript = ""
|
|
150
|
+
tmp_subscripts = subscripts.replace(",", "")
|
|
151
|
+
for s in sorted(set(tmp_subscripts)):
|
|
152
|
+
if s not in einsum_symbols_set:
|
|
153
|
+
raise ValueError(f"Character {s} is not a valid symbol.")
|
|
154
|
+
if tmp_subscripts.count(s) == 1:
|
|
155
|
+
output_subscript += s
|
|
156
|
+
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
|
|
157
|
+
|
|
158
|
+
subscripts += f"->{out_ellipse}{normal_inds}"
|
|
159
|
+
|
|
160
|
+
# Build output string if does not exist
|
|
161
|
+
if "->" in subscripts:
|
|
162
|
+
input_subscripts, output_subscript = subscripts.split("->")
|
|
163
|
+
else:
|
|
164
|
+
input_subscripts = subscripts
|
|
165
|
+
# Build output subscripts
|
|
166
|
+
tmp_subscripts = subscripts.replace(",", "")
|
|
167
|
+
output_subscript = ""
|
|
168
|
+
for s in sorted(set(tmp_subscripts)):
|
|
169
|
+
if s not in einsum_symbols_set:
|
|
170
|
+
raise ValueError(f"Character {s} is not a valid symbol.")
|
|
171
|
+
if tmp_subscripts.count(s) == 1:
|
|
172
|
+
output_subscript += s
|
|
173
|
+
|
|
174
|
+
# Make sure output subscripts are in the input
|
|
175
|
+
for char in output_subscript:
|
|
176
|
+
if char not in input_subscripts:
|
|
177
|
+
raise ValueError(f"Output character {char} did not appear in the input")
|
|
178
|
+
|
|
179
|
+
# Make sure number operands is equivalent to the number of terms
|
|
180
|
+
if len(input_subscripts.split(",")) != len(operands):
|
|
181
|
+
raise ValueError("Number of einsum subscripts must be equal to the number of operands.")
|
|
182
|
+
|
|
183
|
+
return (input_subscripts, output_subscript, operands)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@derived_from(np)
|
|
187
|
+
def einsum(*operands, dtype=None, optimize=False, split_every=None, **kwargs):
|
|
188
|
+
"""Dask added an additional keyword-only argument ``split_every``.
|
|
189
|
+
|
|
190
|
+
split_every: int >= 2 or dict(axis: int), optional
|
|
191
|
+
Determines the depth of the recursive aggregation.
|
|
192
|
+
Defaults to ``None`` which would let dask heuristically
|
|
193
|
+
decide a good default.
|
|
194
|
+
"""
|
|
195
|
+
from dask_array._collection import asarray, blockwise
|
|
196
|
+
|
|
197
|
+
einsum_dtype = dtype
|
|
198
|
+
|
|
199
|
+
# Parse operands, converting to dask arrays using array-expr asarray
|
|
200
|
+
inputs, outputs, ops = _parse_einsum_input(operands, asarray)
|
|
201
|
+
|
|
202
|
+
subscripts = "->".join((inputs, outputs))
|
|
203
|
+
|
|
204
|
+
# Infer the output dtype from operands
|
|
205
|
+
if dtype is None:
|
|
206
|
+
dtype = np.result_type(*[o.dtype for o in ops])
|
|
207
|
+
|
|
208
|
+
if optimize is not False:
|
|
209
|
+
# Avoid computation of dask arrays within np.einsum_path
|
|
210
|
+
# by passing in small numpy arrays broadcasted
|
|
211
|
+
# up to the right shape
|
|
212
|
+
fake_ops = [np.broadcast_to(o.dtype.type(0), shape=o.shape) for o in ops]
|
|
213
|
+
optimize, _ = np.einsum_path(subscripts, *fake_ops, optimize=optimize)
|
|
214
|
+
|
|
215
|
+
inputs = [tuple(i) for i in inputs.split(",")]
|
|
216
|
+
|
|
217
|
+
# Set of all indices
|
|
218
|
+
all_inds = {a for i in inputs for a in i}
|
|
219
|
+
|
|
220
|
+
# Which indices are contracted?
|
|
221
|
+
contract_inds = all_inds - set(outputs)
|
|
222
|
+
ncontract_inds = len(contract_inds)
|
|
223
|
+
|
|
224
|
+
if len(inputs) > 1 and len(outputs) > 0:
|
|
225
|
+
# Calculate the increase in chunk size compared to the largest input chunk
|
|
226
|
+
max_chunk_sizes, max_chunk_size_input = {}, 1
|
|
227
|
+
for op, input in zip(ops, inputs):
|
|
228
|
+
max_chunk_size_input = max(math.prod(map(cached_max, op.chunks)), max_chunk_size_input)
|
|
229
|
+
max_chunk_sizes.update(
|
|
230
|
+
{
|
|
231
|
+
inp: max(cached_max(op.chunks[i]), max_chunk_sizes.get(inp, 1))
|
|
232
|
+
for i, inp in enumerate(input)
|
|
233
|
+
if inp not in contract_inds
|
|
234
|
+
}
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
max_chunk_size_output = math.prod(max_chunk_sizes.values())
|
|
238
|
+
factor = max_chunk_size_output / (max_chunk_size_input * config.get("array.chunk-size-tolerance"))
|
|
239
|
+
|
|
240
|
+
# Rechunk inputs to make input chunks smaller to avoid an increase in
|
|
241
|
+
# output chunks
|
|
242
|
+
new_ops = []
|
|
243
|
+
for op, input in zip(ops, inputs):
|
|
244
|
+
changeable_dimensions = {ctr for ctr, i in enumerate(input) if i in outputs}
|
|
245
|
+
f = max(factor ** (len(changeable_dimensions) / len(outputs)), 1)
|
|
246
|
+
result = _calculate_new_chunksizes(
|
|
247
|
+
op.chunks,
|
|
248
|
+
list(op.chunks),
|
|
249
|
+
changeable_dimensions,
|
|
250
|
+
math.prod(map(cached_max, op.chunks)) / f,
|
|
251
|
+
)
|
|
252
|
+
new_ops.append(op.rechunk(result))
|
|
253
|
+
ops = new_ops
|
|
254
|
+
|
|
255
|
+
# Introduce the contracted indices into the blockwise product
|
|
256
|
+
# so that we get numpy arrays, not lists
|
|
257
|
+
result = blockwise(
|
|
258
|
+
chunk_einsum,
|
|
259
|
+
tuple(outputs) + tuple(contract_inds),
|
|
260
|
+
*(a for ap in zip(ops, inputs) for a in ap),
|
|
261
|
+
# blockwise parameters
|
|
262
|
+
adjust_chunks=dict.fromkeys(contract_inds, 1),
|
|
263
|
+
dtype=dtype,
|
|
264
|
+
# np.einsum parameters
|
|
265
|
+
subscripts=subscripts,
|
|
266
|
+
kernel_dtype=einsum_dtype,
|
|
267
|
+
ncontract_inds=ncontract_inds,
|
|
268
|
+
optimize=optimize,
|
|
269
|
+
**kwargs,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Now reduce over any extra contraction dimensions
|
|
273
|
+
if ncontract_inds > 0:
|
|
274
|
+
size = len(outputs)
|
|
275
|
+
return result.sum(axis=list(range(size, size + ncontract_inds)), split_every=split_every)
|
|
276
|
+
|
|
277
|
+
return result
|