dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from itertools import product
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from dask._task_spec import Task, TaskRef
|
|
9
|
+
from dask_array._expr import ArrayExpr
|
|
10
|
+
from dask_array._utils import meta_from_array
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Squeeze(ArrayExpr):
|
|
14
|
+
"""Remove axes of length one from array."""
|
|
15
|
+
|
|
16
|
+
_parameters = ["array", "axis"]
|
|
17
|
+
|
|
18
|
+
@functools.cached_property
|
|
19
|
+
def _name(self):
|
|
20
|
+
return f"squeeze-{self.deterministic_token}"
|
|
21
|
+
|
|
22
|
+
@functools.cached_property
|
|
23
|
+
def _meta(self):
|
|
24
|
+
return meta_from_array(self.array._meta, ndim=len(self.chunks))
|
|
25
|
+
|
|
26
|
+
@functools.cached_property
|
|
27
|
+
def _axis_set(self):
|
|
28
|
+
"""Normalized axis as a set of positive integers."""
|
|
29
|
+
axis = self.axis
|
|
30
|
+
if axis is None:
|
|
31
|
+
return set(i for i, d in enumerate(self.array.shape) if d == 1)
|
|
32
|
+
if not isinstance(axis, tuple):
|
|
33
|
+
axis = (axis,)
|
|
34
|
+
# Normalize negative indices
|
|
35
|
+
return set(i % self.array.ndim for i in axis)
|
|
36
|
+
|
|
37
|
+
@functools.cached_property
|
|
38
|
+
def chunks(self):
|
|
39
|
+
return tuple(c for i, c in enumerate(self.array.chunks) if i not in self._axis_set)
|
|
40
|
+
|
|
41
|
+
def _layer(self) -> dict:
|
|
42
|
+
# Map from output chunk indices to input chunk indices
|
|
43
|
+
# Input has more dimensions than output
|
|
44
|
+
in_chunks = self.array.chunks
|
|
45
|
+
out_chunks = self.chunks
|
|
46
|
+
axis_set = self._axis_set
|
|
47
|
+
|
|
48
|
+
# Generate all output chunk indices
|
|
49
|
+
out_chunk_tuples = list(product(*(range(len(c)) for c in out_chunks)))
|
|
50
|
+
|
|
51
|
+
dsk = {}
|
|
52
|
+
for out_idx in out_chunk_tuples:
|
|
53
|
+
# Build input index by inserting 0 at squeezed positions
|
|
54
|
+
in_idx = []
|
|
55
|
+
out_pos = 0
|
|
56
|
+
for i in range(len(in_chunks)):
|
|
57
|
+
if i in axis_set:
|
|
58
|
+
in_idx.append(0)
|
|
59
|
+
else:
|
|
60
|
+
in_idx.append(out_idx[out_pos])
|
|
61
|
+
out_pos += 1
|
|
62
|
+
|
|
63
|
+
out_key = (self._name,) + tuple(out_idx)
|
|
64
|
+
in_key = (self.array._name,) + tuple(in_idx)
|
|
65
|
+
|
|
66
|
+
# Build squeeze axis for this chunk (relative to input chunk dimensions)
|
|
67
|
+
chunk_axis = tuple(sorted(axis_set))
|
|
68
|
+
|
|
69
|
+
dsk[out_key] = Task(out_key, np.squeeze, TaskRef(in_key), axis=chunk_axis)
|
|
70
|
+
|
|
71
|
+
return dsk
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def squeeze(a, axis=None):
|
|
75
|
+
"""Remove axes of length one from array.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
a : Array
|
|
80
|
+
Input array
|
|
81
|
+
axis : None or int or tuple of ints, optional
|
|
82
|
+
Selects a subset of entries of length one in the shape.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
squeezed : Array
|
|
87
|
+
"""
|
|
88
|
+
from dask_array._new_collection import new_collection
|
|
89
|
+
from dask_array._utils import validate_axis
|
|
90
|
+
|
|
91
|
+
if axis is None:
|
|
92
|
+
axis = tuple(i for i, d in enumerate(a.shape) if d == 1)
|
|
93
|
+
elif not isinstance(axis, tuple):
|
|
94
|
+
axis = (axis,)
|
|
95
|
+
|
|
96
|
+
if any(a.shape[i] != 1 for i in axis):
|
|
97
|
+
raise ValueError("cannot squeeze axis with size other than one")
|
|
98
|
+
|
|
99
|
+
axis = validate_axis(axis, a.ndim)
|
|
100
|
+
|
|
101
|
+
return new_collection(Squeeze(a.expr, axis))
|