dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,141 @@
1
+ """
2
+ Dispatch registries for dask_array.
3
+
4
+ This module provides Dispatch objects for array operations that need to be
5
+ dispatched based on array type (numpy, cupy, sparse, etc.).
6
+
7
+ concatenate_lookup and tensordot_lookup are defined in _core_utils.py but
8
+ re-exported here for convenience.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+
15
+ from dask.utils import Dispatch
16
+
17
+ # Re-export from _core_utils for convenience
18
+ from dask_array._core_utils import concatenate_lookup, tensordot_lookup
19
+
20
+ # Dispatch registries for array operations
21
+ take_lookup = Dispatch("take")
22
+ einsum_lookup = Dispatch("einsum")
23
+ empty_lookup = Dispatch("empty")
24
+ divide_lookup = Dispatch("divide")
25
+ percentile_lookup = Dispatch("percentile")
26
+ numel_lookup = Dispatch("numel")
27
+ nannumel_lookup = Dispatch("nannumel")
28
+
29
+
30
+ # --- numpy implementations ---
31
+
32
+
33
+ def _divide(x1, x2, out=None, dtype=None):
34
+ """Implementation of numpy.divide that works with dtype kwarg."""
35
+ x = np.divide(x1, x2, out)
36
+ if dtype is not None:
37
+ x = x.astype(dtype)
38
+ return x
39
+
40
+
41
+ def _percentile(a, q, method="linear"):
42
+ """
43
+ Chunk-level percentile calculation.
44
+
45
+ Returns (percentile_values, n) tuple where n is the number of elements.
46
+ Used for combining percentiles from multiple chunks.
47
+ """
48
+ from collections.abc import Iterator
49
+
50
+ n = len(a)
51
+ if not len(a):
52
+ return None, n
53
+ if isinstance(q, Iterator):
54
+ q = list(q)
55
+ if a.dtype.name == "category":
56
+ result = np.percentile(a.cat.codes, q, method=method)
57
+ import pandas as pd
58
+
59
+ return (
60
+ pd.Categorical.from_codes(result, a.dtype.categories, a.dtype.ordered),
61
+ n,
62
+ )
63
+ if type(a.dtype).__name__ == "DatetimeTZDtype":
64
+ import pandas as pd
65
+
66
+ if isinstance(a, (pd.Series, pd.Index)):
67
+ a = a.values
68
+
69
+ if np.issubdtype(a.dtype, np.datetime64):
70
+ values = a
71
+ if type(a).__name__ in ("Series", "Index"):
72
+ a2 = values.astype("i8")
73
+ else:
74
+ a2 = values.view("i8")
75
+ result = np.percentile(a2, q, method=method).astype(values.dtype)
76
+ if q[0] == 0:
77
+ # https://github.com/dask/dask/issues/6864
78
+ result[0] = min(result[0], values.min())
79
+ return result, n
80
+ if not np.issubdtype(a.dtype, np.number):
81
+ method = "nearest"
82
+ return np.percentile(a, q, method=method), n
83
+
84
+
85
+ def _numel(x, **kwargs):
86
+ """
87
+ A reduction to count the number of elements.
88
+
89
+ Returns ndarray result (coerces to numpy).
90
+ """
91
+ import math
92
+
93
+ shape = x.shape
94
+ keepdims = kwargs.get("keepdims", False)
95
+ axis = kwargs.get("axis")
96
+ dtype = kwargs.get("dtype", np.float64)
97
+
98
+ if axis is None:
99
+ prod = np.prod(shape, dtype=dtype)
100
+ if keepdims is False:
101
+ return prod
102
+
103
+ return np.full(shape=(1,) * len(shape), fill_value=prod, dtype=dtype)
104
+
105
+ if not isinstance(axis, (tuple, list)):
106
+ axis = [axis]
107
+
108
+ prod = math.prod(shape[dim] for dim in axis)
109
+ if keepdims is True:
110
+ new_shape = tuple(shape[dim] if dim not in axis else 1 for dim in range(len(shape)))
111
+ else:
112
+ new_shape = tuple(shape[dim] for dim in range(len(shape)) if dim not in axis)
113
+
114
+ return np.broadcast_to(np.array(prod, dtype=dtype), new_shape)
115
+
116
+
117
+ def _nannumel(x, **kwargs):
118
+ """A reduction to count the number of elements, excluding nans"""
119
+ return np.sum(~(np.isnan(x)), **kwargs)
120
+
121
+
122
+ # --- Register numpy implementations ---
123
+
124
+ take_lookup.register((object, np.ndarray, np.ma.masked_array), np.take)
125
+ einsum_lookup.register((object, np.ndarray), np.einsum)
126
+ empty_lookup.register((object, np.ndarray), np.empty)
127
+ empty_lookup.register(np.ma.masked_array, np.ma.empty)
128
+ divide_lookup.register((object, np.ndarray), _divide)
129
+ divide_lookup.register(np.ma.masked_array, np.ma.divide)
130
+ percentile_lookup.register(np.ndarray, _percentile)
131
+ numel_lookup.register((object, np.ndarray), _numel)
132
+ nannumel_lookup.register((object, np.ndarray), _nannumel)
133
+
134
+
135
+ # --- Register masked array numel ---
136
+
137
+
138
+ @numel_lookup.register(np.ma.masked_array)
139
+ def _numel_masked(x, **kwargs):
140
+ """Numel implementation for masked arrays."""
141
+ return np.sum(np.ones_like(x), **kwargs)
dask_array/_einsum.py ADDED
@@ -0,0 +1,277 @@
1
+ """Einstein summation for array-expr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+
9
+ from dask import config
10
+ from dask.utils import cached_max, derived_from
11
+
12
+ from dask_array._dispatch import einsum_lookup
13
+
14
+ # Valid characters for einsum subscripts (from numpy)
15
+ einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
16
+ einsum_symbols_set = set(einsum_symbols)
17
+
18
+
19
+ def chunk_einsum(*operands, **kwargs):
20
+ """Chunk-level einsum computation.
21
+
22
+ This function is used by blockwise to compute einsum on individual chunks.
23
+ It dispatches to the appropriate einsum implementation based on array type.
24
+ """
25
+ subscripts = kwargs.pop("subscripts")
26
+ ncontract_inds = kwargs.pop("ncontract_inds")
27
+ dtype = kwargs.pop("kernel_dtype")
28
+ einsum = einsum_lookup.dispatch(type(operands[0]))
29
+ chunk = einsum(subscripts, *operands, dtype=dtype, **kwargs)
30
+
31
+ # Avoid concatenate=True in blockwise by adding 1's
32
+ # for the contracted dimensions
33
+ return chunk.reshape(chunk.shape + (1,) * ncontract_inds)
34
+
35
+
36
+ def _calculate_new_chunksizes(old_chunks, new_chunks, changeable_dimensions, target_size):
37
+ """Calculate new chunk sizes for einsum rechunking."""
38
+ from dask_array._shuffle import _calculate_new_chunksizes as _calc
39
+
40
+ return _calc(old_chunks, new_chunks, changeable_dimensions, target_size)
41
+
42
+
43
+ def _parse_einsum_input(operands, asarray):
44
+ """Parse einsum input, adapted from numpy/dask.
45
+
46
+ This is a copy of parse_einsum_input from einsumfuncs.py but uses
47
+ the provided asarray function to ensure correct array type.
48
+ """
49
+ if len(operands) == 0:
50
+ raise ValueError("No input operands")
51
+
52
+ if isinstance(operands[0], str):
53
+ subscripts = operands[0].replace(" ", "")
54
+ operands = [asarray(o) for o in operands[1:]]
55
+
56
+ # Ensure all characters are valid
57
+ for s in subscripts:
58
+ if s in ".,->":
59
+ continue
60
+ if s not in einsum_symbols_set:
61
+ raise ValueError(f"Character {s} is not a valid symbol.")
62
+
63
+ else:
64
+ tmp_operands = list(operands)
65
+ operand_list = []
66
+ subscript_list = []
67
+ for _ in range(len(operands) // 2):
68
+ operand_list.append(tmp_operands.pop(0))
69
+ subscript_list.append(tmp_operands.pop(0))
70
+
71
+ output_list = tmp_operands[-1] if len(tmp_operands) else None
72
+ operands = [asarray(v) for v in operand_list]
73
+ subscripts = ""
74
+ last = len(subscript_list) - 1
75
+ for num, sub in enumerate(subscript_list):
76
+ for s in sub:
77
+ if s is Ellipsis:
78
+ subscripts += "..."
79
+ elif isinstance(s, int):
80
+ subscripts += einsum_symbols[s]
81
+ else:
82
+ raise TypeError("For this input type lists must contain either int or Ellipsis")
83
+ if num != last:
84
+ subscripts += ","
85
+
86
+ if output_list is not None:
87
+ subscripts += "->"
88
+ for s in output_list:
89
+ if s is Ellipsis:
90
+ subscripts += "..."
91
+ elif isinstance(s, int):
92
+ subscripts += einsum_symbols[s]
93
+ else:
94
+ raise TypeError("For this input type lists must contain either int or Ellipsis")
95
+ # Check for proper "->"
96
+ if ("-" in subscripts) or (">" in subscripts):
97
+ invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
98
+ if invalid or (subscripts.count("->") != 1):
99
+ raise ValueError("Subscripts can only contain one '->'.")
100
+
101
+ # Parse ellipses
102
+ if "." in subscripts:
103
+ used = subscripts.replace(".", "").replace(",", "").replace("->", "")
104
+ unused = list(einsum_symbols_set - set(used))
105
+ ellipse_inds = "".join(unused)
106
+ longest = 0
107
+
108
+ if "->" in subscripts:
109
+ input_tmp, output_sub = subscripts.split("->")
110
+ split_subscripts = input_tmp.split(",")
111
+ out_sub = True
112
+ else:
113
+ split_subscripts = subscripts.split(",")
114
+ out_sub = False
115
+
116
+ for num, sub in enumerate(split_subscripts):
117
+ if "." in sub:
118
+ if (sub.count(".") != 3) or (sub.count("...") != 1):
119
+ raise ValueError("Invalid Ellipses.")
120
+
121
+ # Take into account numerical values
122
+ if operands[num].shape == ():
123
+ ellipse_count = 0
124
+ else:
125
+ ellipse_count = max(operands[num].ndim, 1)
126
+ ellipse_count -= len(sub) - 3
127
+
128
+ if ellipse_count > longest:
129
+ longest = ellipse_count
130
+
131
+ if ellipse_count < 0:
132
+ raise ValueError("Ellipses lengths do not match.")
133
+ elif ellipse_count == 0:
134
+ split_subscripts[num] = sub.replace("...", "")
135
+ else:
136
+ rep_inds = ellipse_inds[-ellipse_count:]
137
+ split_subscripts[num] = sub.replace("...", rep_inds)
138
+
139
+ subscripts = ",".join(split_subscripts)
140
+ if longest == 0:
141
+ out_ellipse = ""
142
+ else:
143
+ out_ellipse = ellipse_inds[-longest:]
144
+
145
+ if out_sub:
146
+ subscripts += "->" + output_sub.replace("...", out_ellipse)
147
+ else:
148
+ # Special care for outputless ellipses
149
+ output_subscript = ""
150
+ tmp_subscripts = subscripts.replace(",", "")
151
+ for s in sorted(set(tmp_subscripts)):
152
+ if s not in einsum_symbols_set:
153
+ raise ValueError(f"Character {s} is not a valid symbol.")
154
+ if tmp_subscripts.count(s) == 1:
155
+ output_subscript += s
156
+ normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
157
+
158
+ subscripts += f"->{out_ellipse}{normal_inds}"
159
+
160
+ # Build output string if does not exist
161
+ if "->" in subscripts:
162
+ input_subscripts, output_subscript = subscripts.split("->")
163
+ else:
164
+ input_subscripts = subscripts
165
+ # Build output subscripts
166
+ tmp_subscripts = subscripts.replace(",", "")
167
+ output_subscript = ""
168
+ for s in sorted(set(tmp_subscripts)):
169
+ if s not in einsum_symbols_set:
170
+ raise ValueError(f"Character {s} is not a valid symbol.")
171
+ if tmp_subscripts.count(s) == 1:
172
+ output_subscript += s
173
+
174
+ # Make sure output subscripts are in the input
175
+ for char in output_subscript:
176
+ if char not in input_subscripts:
177
+ raise ValueError(f"Output character {char} did not appear in the input")
178
+
179
+ # Make sure number operands is equivalent to the number of terms
180
+ if len(input_subscripts.split(",")) != len(operands):
181
+ raise ValueError("Number of einsum subscripts must be equal to the number of operands.")
182
+
183
+ return (input_subscripts, output_subscript, operands)
184
+
185
+
186
+ @derived_from(np)
187
+ def einsum(*operands, dtype=None, optimize=False, split_every=None, **kwargs):
188
+ """Dask added an additional keyword-only argument ``split_every``.
189
+
190
+ split_every: int >= 2 or dict(axis: int), optional
191
+ Determines the depth of the recursive aggregation.
192
+ Defaults to ``None`` which would let dask heuristically
193
+ decide a good default.
194
+ """
195
+ from dask_array._collection import asarray, blockwise
196
+
197
+ einsum_dtype = dtype
198
+
199
+ # Parse operands, converting to dask arrays using array-expr asarray
200
+ inputs, outputs, ops = _parse_einsum_input(operands, asarray)
201
+
202
+ subscripts = "->".join((inputs, outputs))
203
+
204
+ # Infer the output dtype from operands
205
+ if dtype is None:
206
+ dtype = np.result_type(*[o.dtype for o in ops])
207
+
208
+ if optimize is not False:
209
+ # Avoid computation of dask arrays within np.einsum_path
210
+ # by passing in small numpy arrays broadcasted
211
+ # up to the right shape
212
+ fake_ops = [np.broadcast_to(o.dtype.type(0), shape=o.shape) for o in ops]
213
+ optimize, _ = np.einsum_path(subscripts, *fake_ops, optimize=optimize)
214
+
215
+ inputs = [tuple(i) for i in inputs.split(",")]
216
+
217
+ # Set of all indices
218
+ all_inds = {a for i in inputs for a in i}
219
+
220
+ # Which indices are contracted?
221
+ contract_inds = all_inds - set(outputs)
222
+ ncontract_inds = len(contract_inds)
223
+
224
+ if len(inputs) > 1 and len(outputs) > 0:
225
+ # Calculate the increase in chunk size compared to the largest input chunk
226
+ max_chunk_sizes, max_chunk_size_input = {}, 1
227
+ for op, input in zip(ops, inputs):
228
+ max_chunk_size_input = max(math.prod(map(cached_max, op.chunks)), max_chunk_size_input)
229
+ max_chunk_sizes.update(
230
+ {
231
+ inp: max(cached_max(op.chunks[i]), max_chunk_sizes.get(inp, 1))
232
+ for i, inp in enumerate(input)
233
+ if inp not in contract_inds
234
+ }
235
+ )
236
+
237
+ max_chunk_size_output = math.prod(max_chunk_sizes.values())
238
+ factor = max_chunk_size_output / (max_chunk_size_input * config.get("array.chunk-size-tolerance"))
239
+
240
+ # Rechunk inputs to make input chunks smaller to avoid an increase in
241
+ # output chunks
242
+ new_ops = []
243
+ for op, input in zip(ops, inputs):
244
+ changeable_dimensions = {ctr for ctr, i in enumerate(input) if i in outputs}
245
+ f = max(factor ** (len(changeable_dimensions) / len(outputs)), 1)
246
+ result = _calculate_new_chunksizes(
247
+ op.chunks,
248
+ list(op.chunks),
249
+ changeable_dimensions,
250
+ math.prod(map(cached_max, op.chunks)) / f,
251
+ )
252
+ new_ops.append(op.rechunk(result))
253
+ ops = new_ops
254
+
255
+ # Introduce the contracted indices into the blockwise product
256
+ # so that we get numpy arrays, not lists
257
+ result = blockwise(
258
+ chunk_einsum,
259
+ tuple(outputs) + tuple(contract_inds),
260
+ *(a for ap in zip(ops, inputs) for a in ap),
261
+ # blockwise parameters
262
+ adjust_chunks=dict.fromkeys(contract_inds, 1),
263
+ dtype=dtype,
264
+ # np.einsum parameters
265
+ subscripts=subscripts,
266
+ kernel_dtype=einsum_dtype,
267
+ ncontract_inds=ncontract_inds,
268
+ optimize=optimize,
269
+ **kwargs,
270
+ )
271
+
272
+ # Now reduce over any extra contraction dimensions
273
+ if ncontract_inds > 0:
274
+ size = len(outputs)
275
+ return result.sum(axis=list(range(size, size + ncontract_inds)), split_every=split_every)
276
+
277
+ return result