annbatch 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of annbatch might be problematic. Click here for more details.
- annbatch/__init__.py +15 -0
- annbatch/abc.py +228 -0
- annbatch/anndata_manager.py +396 -0
- annbatch/dense.py +63 -0
- annbatch/io.py +474 -0
- annbatch/sparse.py +160 -0
- annbatch/types.py +25 -0
- annbatch/utils.py +319 -0
- annbatch-0.0.1.dist-info/METADATA +214 -0
- annbatch-0.0.1.dist-info/RECORD +12 -0
- annbatch-0.0.1.dist-info/WHEEL +4 -0
- annbatch-0.0.1.dist-info/licenses/LICENSE +21 -0
annbatch/sparse.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from importlib.util import find_spec
|
|
5
|
+
from itertools import accumulate, chain, pairwise
|
|
6
|
+
from typing import NamedTuple, cast
|
|
7
|
+
|
|
8
|
+
import anndata as ad
|
|
9
|
+
import numpy as np
|
|
10
|
+
import zarr
|
|
11
|
+
import zarr.core.sync as zsync
|
|
12
|
+
|
|
13
|
+
if find_spec("torch"):
|
|
14
|
+
from torch.utils.data import IterableDataset as _IterableDataset
|
|
15
|
+
else:
|
|
16
|
+
|
|
17
|
+
class _IterableDataset:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from annbatch.abc import AbstractIterableDataset, _assign_methods_to_ensure_unique_docstrings
|
|
22
|
+
from annbatch.utils import (
|
|
23
|
+
CSRContainer,
|
|
24
|
+
MultiBasicIndexer,
|
|
25
|
+
add_anndata_docstring,
|
|
26
|
+
add_anndatas_docstring,
|
|
27
|
+
add_dataset_docstring,
|
|
28
|
+
add_datasets_docstring,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CSRDatasetElems(NamedTuple):
|
|
33
|
+
"""Container for cached objects that will be indexed into to generate CSR matrices"""
|
|
34
|
+
|
|
35
|
+
indptr: np.ndarray
|
|
36
|
+
indices: zarr.AsyncArray
|
|
37
|
+
data: zarr.AsyncArray
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ZarrSparseDataset( # noqa: D101
|
|
41
|
+
AbstractIterableDataset[ad.abc.CSRDataset, CSRContainer], _IterableDataset
|
|
42
|
+
):
|
|
43
|
+
_dataset_elem_cache: dict[int, CSRDatasetElems] = {}
|
|
44
|
+
|
|
45
|
+
def _cache_update_callback(self):
|
|
46
|
+
"""Callback for when datasets are added to ensure the cache is updated."""
|
|
47
|
+
return zsync.sync(self._ensure_cache())
|
|
48
|
+
|
|
49
|
+
def _validate(self, datasets: list[ad.abc.CSRDataset]):
|
|
50
|
+
if not all(isinstance(d, ad.abc.CSRDataset) for d in datasets):
|
|
51
|
+
raise TypeError("Cannot create sparse dataset using CSRDataset data")
|
|
52
|
+
if not all(cast("ad.abc.CSRDataset", d).backend == "zarr" for d in datasets):
|
|
53
|
+
raise TypeError(
|
|
54
|
+
"Cannot use CSRDataset backed by h5ad at the moment: see https://github.com/zarr-developers/VirtualiZarr/pull/790"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems:
|
|
58
|
+
"""Fetch the in-memory indptr, and backed indices and data for a given dataset index.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
idx
|
|
63
|
+
The index
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
The constituent elems of the CSR dataset.
|
|
68
|
+
"""
|
|
69
|
+
indptr = await self._dataset_manager.train_datasets[idx].group._async_group.getitem("indptr")
|
|
70
|
+
return CSRDatasetElems(
|
|
71
|
+
*(
|
|
72
|
+
await asyncio.gather(
|
|
73
|
+
indptr.getitem(Ellipsis),
|
|
74
|
+
self._dataset_manager.train_datasets[idx].group._async_group.getitem("indices"),
|
|
75
|
+
self._dataset_manager.train_datasets[idx].group._async_group.getitem("data"),
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
async def _ensure_cache(self):
|
|
81
|
+
"""Build up the cache of datasets i.e., in-memory indptr, and backed indices and data."""
|
|
82
|
+
arr_idxs = [
|
|
83
|
+
idx for idx in range(len(self._dataset_manager.train_datasets)) if idx not in self._dataset_elem_cache
|
|
84
|
+
]
|
|
85
|
+
all_elems = await asyncio.gather(
|
|
86
|
+
*(
|
|
87
|
+
self._create_sparse_elems(idx)
|
|
88
|
+
for idx in range(len(self._dataset_manager.train_datasets))
|
|
89
|
+
if idx not in self._dataset_elem_cache
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
for idx, elems in zip(arr_idxs, all_elems, strict=True):
|
|
93
|
+
self._dataset_elem_cache[idx] = elems
|
|
94
|
+
|
|
95
|
+
async def _get_sparse_elems(self, dataset_idx: int) -> CSRDatasetElems:
|
|
96
|
+
"""Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
dataset_idx
|
|
101
|
+
The index of the dataset whose arrays are sought.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
The arrays representing the sparse data.
|
|
106
|
+
"""
|
|
107
|
+
if dataset_idx not in self._dataset_elem_cache:
|
|
108
|
+
await self._ensure_cache()
|
|
109
|
+
return self._dataset_elem_cache[dataset_idx]
|
|
110
|
+
|
|
111
|
+
async def _fetch_data(
|
|
112
|
+
self,
|
|
113
|
+
slices: list[slice],
|
|
114
|
+
dataset_idx: int,
|
|
115
|
+
) -> CSRContainer:
|
|
116
|
+
# See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295
|
|
117
|
+
# for the inspiration of this function.
|
|
118
|
+
indptr, indices, data = await self._get_sparse_elems(dataset_idx)
|
|
119
|
+
indptr_indices = [indptr[slice(s.start, s.stop + 1)] for s in slices]
|
|
120
|
+
indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices]
|
|
121
|
+
indexer = MultiBasicIndexer(
|
|
122
|
+
[
|
|
123
|
+
zarr.core.indexing.BasicIndexer((l,), shape=data.metadata.shape, chunk_grid=data.metadata.chunk_grid)
|
|
124
|
+
for l in indptr_limits
|
|
125
|
+
]
|
|
126
|
+
)
|
|
127
|
+
data_np, indices_np = await asyncio.gather(
|
|
128
|
+
data._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
|
|
129
|
+
indices._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
|
|
130
|
+
)
|
|
131
|
+
gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits))
|
|
132
|
+
offsets = accumulate(chain([indptr_limits[0].start], gaps))
|
|
133
|
+
start_indptr = indptr_indices[0] - next(offsets)
|
|
134
|
+
if len(slices) < 2: # there is only one slice so no need to concatenate
|
|
135
|
+
return CSRContainer(
|
|
136
|
+
elems=(data_np, indices_np, start_indptr),
|
|
137
|
+
shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var),
|
|
138
|
+
dtype=data_np.dtype,
|
|
139
|
+
)
|
|
140
|
+
end_indptr = np.concatenate([s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)])
|
|
141
|
+
indptr_np = np.concatenate([start_indptr, end_indptr])
|
|
142
|
+
return CSRContainer(
|
|
143
|
+
elems=(data_np, indices_np, indptr_np),
|
|
144
|
+
shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var),
|
|
145
|
+
dtype=data_np.dtype,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
_assign_methods_to_ensure_unique_docstrings(ZarrSparseDataset)
|
|
150
|
+
|
|
151
|
+
ZarrSparseDataset.__doc__ = AbstractIterableDataset.__init__.__doc__.format(
|
|
152
|
+
array_type="sparse", child_class="ZarrSparseDataset"
|
|
153
|
+
)
|
|
154
|
+
ZarrSparseDataset.add_datasets.__doc__ = add_datasets_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
|
|
155
|
+
ZarrSparseDataset.add_dataset.__doc__ = add_dataset_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
|
|
156
|
+
ZarrSparseDataset.add_anndatas.__doc__ = add_anndatas_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
|
|
157
|
+
ZarrSparseDataset.add_anndata.__doc__ = add_anndata_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
|
|
158
|
+
ZarrSparseDataset.__iter__.__doc__ = AbstractIterableDataset.__iter__.__doc__.format(
|
|
159
|
+
gpu_array="cupyx.scipy.sparse.spmatrix", cpu_array="scipy.sparse.csr_matrix"
|
|
160
|
+
)
|
annbatch/types.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from types import NoneType
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
import anndata as ad
|
|
7
|
+
import numpy as np
|
|
8
|
+
import zarr
|
|
9
|
+
from scipy import sparse as sp
|
|
10
|
+
|
|
11
|
+
from annbatch.utils import CSRContainer
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from cupy import ndarray as CupyArray
|
|
15
|
+
from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
|
|
16
|
+
except ImportError:
|
|
17
|
+
CupyCSRMatrix = NoneType
|
|
18
|
+
CupyArray = NoneType
|
|
19
|
+
|
|
20
|
+
OutputInMemoryArray = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray
|
|
21
|
+
|
|
22
|
+
OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
InputInMemoryArray = TypeVar("InputInMemoryArray", CSRContainer, np.ndarray)
|
annbatch/utils.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import warnings
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from importlib.util import find_spec
|
|
8
|
+
from itertools import islice
|
|
9
|
+
from typing import TYPE_CHECKING, Protocol
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import scipy as sp
|
|
13
|
+
import zarr
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from cupy import ndarray as CupyArray
|
|
17
|
+
from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
|
|
18
|
+
except ImportError:
|
|
19
|
+
CupyArray = None
|
|
20
|
+
CupyCSRMatrix = None
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections import OrderedDict
|
|
24
|
+
from collections.abc import Awaitable, Callable, Generator, Iterable
|
|
25
|
+
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
|
|
28
|
+
from annbatch.types import InputInMemoryArray, OutputInMemoryArray
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]:
|
|
32
|
+
"""Wrapper around `np.split` to split up an array into `size` chunks"""
|
|
33
|
+
return np.split(a, np.arange(size, len(a), size))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class CSRContainer:
|
|
38
|
+
"""A low-cost container for moving around the buffers of a CSR object"""
|
|
39
|
+
|
|
40
|
+
elems: tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
41
|
+
shape: tuple[int, int]
|
|
42
|
+
dtype: np.dtype
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _batched[T](iterable: Iterable[T], n: int) -> Generator[list[T], None, None]:
|
|
46
|
+
if n < 1:
|
|
47
|
+
raise ValueError("n must be >= 1")
|
|
48
|
+
it = iter(iterable)
|
|
49
|
+
while batch := list(islice(it, n)):
|
|
50
|
+
yield batch
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def index_datasets(
|
|
54
|
+
dataset_index_to_slices: OrderedDict[int, list[slice]],
|
|
55
|
+
fetch_data: Callable[[list[slice], int], Awaitable[CSRContainer | np.ndarray]],
|
|
56
|
+
) -> list[InputInMemoryArray]:
|
|
57
|
+
"""Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
dataset_index_to_slices
|
|
62
|
+
A lookup of the list-placement index of a dataset to the request slices.
|
|
63
|
+
fetch_data
|
|
64
|
+
The function to do the fetching for a given slice-dataset index pair.
|
|
65
|
+
"""
|
|
66
|
+
tasks = []
|
|
67
|
+
for dataset_idx in dataset_index_to_slices.keys():
|
|
68
|
+
tasks.append(
|
|
69
|
+
fetch_data(
|
|
70
|
+
dataset_index_to_slices[dataset_idx],
|
|
71
|
+
dataset_idx,
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
return await asyncio.gather(*tasks)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
add_datasets_docstring = """\
|
|
78
|
+
Append datasets to this dataset.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
datasets
|
|
83
|
+
List of :class:`{on_disk_array_type}` objects, generally from :attr:`anndata.AnnData.X`.
|
|
84
|
+
obs
|
|
85
|
+
List of :class:`numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
add_dataset_docstring = """\
|
|
89
|
+
Append a dataset to this dataset.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
dataset
|
|
94
|
+
:class:`{on_disk_array_type}` object, generally from :attr:`anndata.AnnData.X`.
|
|
95
|
+
obs
|
|
96
|
+
:class:`numpy.ndarray` labels for the anndata, generally from :attr:`anndata.AnnData.obs`.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
add_anndatas_docstring = """\
|
|
101
|
+
Append anndatas to this dataset.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
anndatas
|
|
106
|
+
List of :class:`anndata.AnnData` objects, with :class:`{on_disk_array_type}` as the data matrix.
|
|
107
|
+
obs_keys
|
|
108
|
+
List of :attr:`anndata.AnnData.obs` column labels.
|
|
109
|
+
layer_keys
|
|
110
|
+
List of :attr:`anndata.AnnData.layers` keys, and if None, :attr:`anndata.AnnData.X` will be used.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
add_anndata_docstring = """\
|
|
114
|
+
Append a anndata to this dataset.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
anndata
|
|
119
|
+
:class:`anndata.AnnData` object, with :class:`{on_disk_array_type}` as the data matrix.
|
|
120
|
+
obs_key
|
|
121
|
+
:attr:`anndata.AnnData.obs` column labels.
|
|
122
|
+
layer_key
|
|
123
|
+
:attr:`anndata.AnnData.layers` key, and if None, :attr:`anndata.AnnData.X` will be used.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# TODO: make this part of the public zarr or zarrs-python API.
|
|
128
|
+
# We can do chunk coalescing in zarrs based on integer arrays, so I think
|
|
129
|
+
# there would make sense with ezclump or similar.
|
|
130
|
+
# Another "solution" would be for zarrs to support integer indexing properly, if that pipeline works,
|
|
131
|
+
# or make this an "experimental setting" and to use integer indexing for the zarr-python pipeline.
|
|
132
|
+
# See: https://github.com/zarr-developers/zarr-python/issues/3175 for why this is better than simpler alternatives.
|
|
133
|
+
class MultiBasicIndexer(zarr.core.indexing.Indexer):
|
|
134
|
+
"""Custom indexer to enable joint fetching of disparate slices"""
|
|
135
|
+
|
|
136
|
+
def __init__(self, indexers: list[zarr.core.indexing.Indexer]):
|
|
137
|
+
self.shape = (sum(i.shape[0] for i in indexers), *indexers[0].shape[1:])
|
|
138
|
+
self.drop_axes = indexers[0].drop_axes # maybe?
|
|
139
|
+
self.indexers = indexers
|
|
140
|
+
|
|
141
|
+
def __iter__(self):
|
|
142
|
+
total = 0
|
|
143
|
+
for i in self.indexers:
|
|
144
|
+
for c in i:
|
|
145
|
+
out_selection = c[2]
|
|
146
|
+
gap = out_selection[0].stop - out_selection[0].start
|
|
147
|
+
yield type(c)(c[0], c[1], (slice(total, total + gap), *out_selection[1:]), c[3])
|
|
148
|
+
total += gap
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def sample_rows(
|
|
152
|
+
x_list: list[np.ndarray],
|
|
153
|
+
obs_list: list[np.ndarray] | None,
|
|
154
|
+
indices: list[np.ndarray] | None = None,
|
|
155
|
+
*,
|
|
156
|
+
shuffle: bool = True,
|
|
157
|
+
) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
|
|
158
|
+
"""Samples rows from multiple arrays and their corresponding observation arrays.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
x_list
|
|
163
|
+
A list of numpy arrays containing the data to sample from.
|
|
164
|
+
obs_list
|
|
165
|
+
A list of numpy arrays containing the corresponding observations.
|
|
166
|
+
indices
|
|
167
|
+
the list of indexes for each element in `x_list/`
|
|
168
|
+
shuffle
|
|
169
|
+
Whether to shuffle the rows before sampling.
|
|
170
|
+
|
|
171
|
+
Yields
|
|
172
|
+
------
|
|
173
|
+
tuple
|
|
174
|
+
A tuple containing a row from `x_list` and the corresponding row from `obs_list`.
|
|
175
|
+
"""
|
|
176
|
+
lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int)
|
|
177
|
+
cum = np.concatenate(([0], np.cumsum(lengths)))
|
|
178
|
+
total = cum[-1]
|
|
179
|
+
idxs = np.arange(total)
|
|
180
|
+
if shuffle:
|
|
181
|
+
np.random.default_rng().shuffle(idxs)
|
|
182
|
+
arr_idxs = np.searchsorted(cum, idxs, side="right") - 1
|
|
183
|
+
row_idxs = idxs - cum[arr_idxs]
|
|
184
|
+
for ai, ri in zip(arr_idxs, row_idxs, strict=True):
|
|
185
|
+
res = [
|
|
186
|
+
x_list[ai][ri],
|
|
187
|
+
obs_list[ai][ri] if obs_list is not None else None,
|
|
188
|
+
]
|
|
189
|
+
if indices is not None:
|
|
190
|
+
yield (*res, indices[ai][ri])
|
|
191
|
+
else:
|
|
192
|
+
yield tuple(res)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class WorkerHandle: # noqa: D101
|
|
196
|
+
@cached_property
|
|
197
|
+
def _worker_info(self):
|
|
198
|
+
if find_spec("torch"):
|
|
199
|
+
from torch.utils.data import get_worker_info
|
|
200
|
+
|
|
201
|
+
return get_worker_info()
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
@cached_property
|
|
205
|
+
def _rng(self):
|
|
206
|
+
if self._worker_info is None:
|
|
207
|
+
return np.random.default_rng()
|
|
208
|
+
else:
|
|
209
|
+
# This is used for the _get_chunks function
|
|
210
|
+
# Use the same seed for all workers that the resulting splits are the same across workers
|
|
211
|
+
# torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed
|
|
212
|
+
return np.random.default_rng(self._worker_info.seed - self._worker_info.id)
|
|
213
|
+
|
|
214
|
+
def shuffle(self, obj: np.typing.ArrayLike) -> None:
|
|
215
|
+
"""Perform in-place shuffle.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
obj
|
|
220
|
+
The object to be shuffled
|
|
221
|
+
"""
|
|
222
|
+
self._rng.shuffle(obj)
|
|
223
|
+
|
|
224
|
+
def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray:
|
|
225
|
+
"""Get a chunk of an incoming array accordnig to the current worker id.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
obj
|
|
230
|
+
Incoming array
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
A evenly split part of the ray corresponding to how many workers there are.
|
|
235
|
+
"""
|
|
236
|
+
if self._worker_info is None:
|
|
237
|
+
return obj
|
|
238
|
+
num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id
|
|
239
|
+
chunks_split = np.array_split(obj, num_workers)
|
|
240
|
+
return chunks_split[worker_id]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def check_lt_1(vals: list[int], labels: list[str]) -> None:
|
|
244
|
+
"""Raise a ValueError if any of the values are less than one.
|
|
245
|
+
|
|
246
|
+
The format of the error is "{labels[i]} must be greater than 1, got {values[i]}"
|
|
247
|
+
and is raised based on the first found less than one value.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
vals
|
|
252
|
+
The values to check < 1
|
|
253
|
+
labels
|
|
254
|
+
The label for the value in the error if the value is less than one.
|
|
255
|
+
|
|
256
|
+
Raises
|
|
257
|
+
------
|
|
258
|
+
ValueError: _description_
|
|
259
|
+
"""
|
|
260
|
+
if any(is_lt_1 := [v < 1 for v in vals]):
|
|
261
|
+
label, value = next(
|
|
262
|
+
(label, value)
|
|
263
|
+
for label, value, check in zip(
|
|
264
|
+
labels,
|
|
265
|
+
vals,
|
|
266
|
+
is_lt_1,
|
|
267
|
+
strict=True,
|
|
268
|
+
)
|
|
269
|
+
if check
|
|
270
|
+
)
|
|
271
|
+
raise ValueError(f"{label} must be greater than 1, got {value}")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class SupportsShape(Protocol): # noqa: D101
|
|
275
|
+
@property
|
|
276
|
+
def shape(self) -> tuple[int, int] | list[int]: ... # noqa: D102
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def check_var_shapes(objs: list[SupportsShape]) -> None:
|
|
280
|
+
"""Small utility function to check that all objects have the same shape along the second axis"""
|
|
281
|
+
if not all(objs[0].shape[1] == d.shape[1] for d in objs):
|
|
282
|
+
raise ValueError("TODO: All datasets must have same shape along the var axis.")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def to_torch(input: OutputInMemoryArray, preload_to_gpu: bool) -> Tensor:
|
|
286
|
+
"""Send the input data to a torch.Tensor"""
|
|
287
|
+
import torch
|
|
288
|
+
|
|
289
|
+
if isinstance(input, torch.Tensor):
|
|
290
|
+
return input
|
|
291
|
+
if isinstance(input, sp.sparse.csr_matrix):
|
|
292
|
+
with warnings.catch_warnings():
|
|
293
|
+
warnings.filterwarnings("ignore", "Sparse CSR tensor support is in beta state", UserWarning)
|
|
294
|
+
tensor = torch.sparse_csr_tensor(
|
|
295
|
+
torch.from_numpy(input.indptr),
|
|
296
|
+
torch.from_numpy(input.indices),
|
|
297
|
+
torch.from_numpy(input.data),
|
|
298
|
+
input.shape,
|
|
299
|
+
)
|
|
300
|
+
if preload_to_gpu:
|
|
301
|
+
return tensor.cuda(non_blocking=True)
|
|
302
|
+
return tensor
|
|
303
|
+
if isinstance(input, np.ndarray):
|
|
304
|
+
tensor = torch.from_numpy(input)
|
|
305
|
+
if preload_to_gpu:
|
|
306
|
+
return tensor.cuda(non_blocking=True)
|
|
307
|
+
return tensor
|
|
308
|
+
if isinstance(input, CupyArray):
|
|
309
|
+
return torch.from_dlpack(input)
|
|
310
|
+
if isinstance(input, CupyCSRMatrix):
|
|
311
|
+
with warnings.catch_warnings():
|
|
312
|
+
warnings.filterwarnings("ignore", "Sparse CSR tensor support is in beta state", UserWarning)
|
|
313
|
+
return torch.sparse_csr_tensor(
|
|
314
|
+
torch.from_dlpack(input.indptr),
|
|
315
|
+
torch.from_dlpack(input.indices),
|
|
316
|
+
torch.from_dlpack(input.data),
|
|
317
|
+
input.shape,
|
|
318
|
+
)
|
|
319
|
+
raise TypeError(f"Cannot convert {type(input)} to torch.Tensor")
|