annbatch 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of annbatch might be problematic. Click here for more details.
- annbatch/__init__.py +15 -0
- annbatch/abc.py +228 -0
- annbatch/anndata_manager.py +396 -0
- annbatch/dense.py +63 -0
- annbatch/io.py +474 -0
- annbatch/sparse.py +160 -0
- annbatch/types.py +25 -0
- annbatch/utils.py +319 -0
- annbatch-0.0.1.dist-info/METADATA +214 -0
- annbatch-0.0.1.dist-info/RECORD +12 -0
- annbatch-0.0.1.dist-info/WHEEL +4 -0
- annbatch-0.0.1.dist-info/licenses/LICENSE +21 -0
annbatch/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
|
|
3
|
+
from .dense import ZarrDenseDataset
|
|
4
|
+
from .io import add_to_collection, create_anndata_collection, write_sharded
|
|
5
|
+
from .sparse import ZarrSparseDataset
|
|
6
|
+
|
|
7
|
+
__version__ = version("annbatch")
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ZarrSparseDataset",
|
|
11
|
+
"ZarrDenseDataset",
|
|
12
|
+
"write_sharded",
|
|
13
|
+
"add_to_collection",
|
|
14
|
+
"create_anndata_collection",
|
|
15
|
+
]
|
annbatch/abc.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABCMeta, abstractmethod
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from importlib.util import find_spec
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from annbatch.anndata_manager import AnnDataManager
|
|
9
|
+
from annbatch.utils import (
|
|
10
|
+
WorkerHandle,
|
|
11
|
+
add_anndata_docstring,
|
|
12
|
+
add_anndatas_docstring,
|
|
13
|
+
add_dataset_docstring,
|
|
14
|
+
add_datasets_docstring,
|
|
15
|
+
check_lt_1,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
from typing import Self
|
|
21
|
+
|
|
22
|
+
import anndata as ad
|
|
23
|
+
import numpy as np
|
|
24
|
+
from torch import Tensor
|
|
25
|
+
|
|
26
|
+
from annbatch.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AbstractIterableDataset[OnDiskArray, InputInMemoryArray](metaclass=ABCMeta): # noqa: D101
|
|
30
|
+
_shuffle: bool
|
|
31
|
+
_preload_nchunks: int
|
|
32
|
+
_worker_handle: WorkerHandle
|
|
33
|
+
_chunk_size: int
|
|
34
|
+
_dataset_manager: AnnDataManager[OnDiskArray, InputInMemoryArray]
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
chunk_size: int = 512,
|
|
40
|
+
preload_nchunks: int = 32,
|
|
41
|
+
shuffle: bool = True,
|
|
42
|
+
return_index: bool = False,
|
|
43
|
+
batch_size: int = 1,
|
|
44
|
+
preload_to_gpu: bool = True,
|
|
45
|
+
drop_last: bool = False,
|
|
46
|
+
to_torch: bool = find_spec("torch") is not None,
|
|
47
|
+
):
|
|
48
|
+
"""A loader for on-disk {array_type} data.
|
|
49
|
+
|
|
50
|
+
This loader batches together slice requests to the underlying {array_type} stores to achieve higher performance.
|
|
51
|
+
This custom code to do this task will be upstreamed into anndata at some point and no longer rely on private zarr apis.
|
|
52
|
+
The loader is agnostic to the on-disk chunking/sharding, but it may be advisable to align with the in-memory chunk size for dense.
|
|
53
|
+
|
|
54
|
+
The dataset class on its own is quite performant for "chunked loading" i.e., `chunk_size > 1`.
|
|
55
|
+
When `chunk_size == 1`, a :class:`torch.utils.data.DataLoader` should wrap the dataset object.
|
|
56
|
+
In this case, do not use the `add_anndata` or `add_anndatas` option due to https://github.com/scverse/anndata/issues/2021.
|
|
57
|
+
Instead use :func:`anndata.io.sparse_dataset` or :func:`zarr.open` to only get the array you need.
|
|
58
|
+
|
|
59
|
+
If `preload_to_gpu` to True and `to_torch` is False, the yielded type is a `cupy` matrix.
|
|
60
|
+
If `to_torch` is True, the yielded type is a :class:`torch.Tensor`.
|
|
61
|
+
If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for {array_type}.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
chunk_size
|
|
66
|
+
The obs size (i.e., axis 0) of contiguous array data to fetch.
|
|
67
|
+
preload_nchunks
|
|
68
|
+
The number of chunks of contiguous array data to fetch.
|
|
69
|
+
shuffle
|
|
70
|
+
Whether or not to shuffle the data.
|
|
71
|
+
return_index
|
|
72
|
+
Whether or not to yield the index on each iteration.
|
|
73
|
+
batch_size
|
|
74
|
+
Batch size to yield from the dataset.
|
|
75
|
+
preload_to_gpu
|
|
76
|
+
Whether or not to use cupy for non-io array operations like vstack and indexing once the data is in memory internally.
|
|
77
|
+
This option entails greater GPU memory usage, but is faster at least for sparse operations.
|
|
78
|
+
:func:`torch.vstack` does not support CSR sparse matrices, hence the current use of cupy internally.
|
|
79
|
+
Setting this to `False` is advisable when using the :class:`torch.utils.data.DataLoader` wrapper or potentially with dense data.
|
|
80
|
+
For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to denseify.
|
|
81
|
+
drop_last
|
|
82
|
+
Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
|
|
83
|
+
If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
|
|
84
|
+
Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`.
|
|
85
|
+
to_torch
|
|
86
|
+
Whether to return `torch.Tensor` as the output.
|
|
87
|
+
Data transferred should be 0-copy independent of source, and transfer to cuda when applicable is non-blocking.
|
|
88
|
+
Defaults to True if `torch` is installed.
|
|
89
|
+
|
|
90
|
+
Examples
|
|
91
|
+
--------
|
|
92
|
+
>>> from annbatch import {child_class}
|
|
93
|
+
>>> ds = {child_class}(
|
|
94
|
+
batch_size=4096,
|
|
95
|
+
chunk_size=32,
|
|
96
|
+
preload_nchunks=512,
|
|
97
|
+
).add_anndata(my_anndata)
|
|
98
|
+
>>> for batch in ds:
|
|
99
|
+
# optionally convert to dense
|
|
100
|
+
# batch = batch.to_dense()
|
|
101
|
+
do_fit(batch)
|
|
102
|
+
"""
|
|
103
|
+
check_lt_1(
|
|
104
|
+
[
|
|
105
|
+
chunk_size,
|
|
106
|
+
preload_nchunks,
|
|
107
|
+
],
|
|
108
|
+
["Chunk size", "Preload chunks"],
|
|
109
|
+
)
|
|
110
|
+
if batch_size > (chunk_size * preload_nchunks):
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
"Cannot yield batches bigger than the iterated in-memory size i.e., batch_size > (chunk_size * preload_nchunks)."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
for package, arg, arg_name in [
|
|
116
|
+
("torch", to_torch, f"{to_torch=}"),
|
|
117
|
+
("cupy", preload_to_gpu, f"{preload_to_gpu=}"),
|
|
118
|
+
]:
|
|
119
|
+
if arg and not find_spec(package):
|
|
120
|
+
raise ImportError(
|
|
121
|
+
f"Could not find {package} dependency even though {arg_name}. Try `pip install {package}`"
|
|
122
|
+
)
|
|
123
|
+
self._dataset_manager = AnnDataManager(
|
|
124
|
+
# TODO: https://github.com/scverse/anndata/issues/2021
|
|
125
|
+
# on_add=self._cache_update_callback,
|
|
126
|
+
return_index=return_index,
|
|
127
|
+
batch_size=batch_size,
|
|
128
|
+
preload_to_gpu=preload_to_gpu,
|
|
129
|
+
drop_last=drop_last,
|
|
130
|
+
to_torch=to_torch,
|
|
131
|
+
)
|
|
132
|
+
self._chunk_size = chunk_size
|
|
133
|
+
self._preload_nchunks = preload_nchunks
|
|
134
|
+
self._shuffle = shuffle
|
|
135
|
+
self._worker_handle = WorkerHandle()
|
|
136
|
+
|
|
137
|
+
async def _cache_update_callback(self):
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
@abstractmethod
|
|
141
|
+
async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> InputInMemoryArray:
|
|
142
|
+
"""Fetch the data for given slices and the arrays representing a dataset on-disk.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
slices: The indexing slices to fetch.
|
|
147
|
+
dataset_idx: The index of the dataset to fetch from.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
The in-memory array data.
|
|
152
|
+
"""
|
|
153
|
+
...
|
|
154
|
+
|
|
155
|
+
# TODO: validations once the sparse and dense are merged with the AnnDataManager
|
|
156
|
+
def add_anndatas( # noqa: D102
|
|
157
|
+
self,
|
|
158
|
+
adatas: list[ad.AnnData],
|
|
159
|
+
layer_keys: list[str | None] | str | None = None,
|
|
160
|
+
obs_keys: list[str] | str | None = None,
|
|
161
|
+
) -> Self:
|
|
162
|
+
self._dataset_manager.add_anndatas(adatas, layer_keys=layer_keys, obs_keys=obs_keys)
|
|
163
|
+
return self
|
|
164
|
+
|
|
165
|
+
def add_anndata( # noqa: D102
|
|
166
|
+
self,
|
|
167
|
+
adata: ad.AnnData,
|
|
168
|
+
layer_key: str | None = None,
|
|
169
|
+
obs_key: str | None = None,
|
|
170
|
+
) -> Self:
|
|
171
|
+
self._dataset_manager.add_anndata(adata, layer_key=layer_key, obs_key=obs_key)
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def _validate(self, datasets: list[OnDiskArray]) -> None: ...
|
|
176
|
+
|
|
177
|
+
def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> Self: # noqa: D102
|
|
178
|
+
self._validate(datasets)
|
|
179
|
+
self._dataset_manager.add_datasets(datasets, obs)
|
|
180
|
+
return self
|
|
181
|
+
|
|
182
|
+
def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> Self: # noqa: D102
|
|
183
|
+
self._validate([dataset])
|
|
184
|
+
self._dataset_manager.add_dataset(dataset, obs)
|
|
185
|
+
return self
|
|
186
|
+
|
|
187
|
+
def __len__(self) -> int:
|
|
188
|
+
return self._dataset_manager.n_obs
|
|
189
|
+
|
|
190
|
+
def __iter__(
|
|
191
|
+
self,
|
|
192
|
+
) -> Iterator[
|
|
193
|
+
tuple[OutputInMemoryArray, None | np.ndarray]
|
|
194
|
+
| tuple[OutputInMemoryArray | Tensor, None | np.ndarray, np.ndarray]
|
|
195
|
+
]:
|
|
196
|
+
"""
|
|
197
|
+
Iterate over the on-disk datasets, returning :class:`{gpu_array}` or :class:`{cpu_array}` depending on whether or not `preload_to_gpu` is set.
|
|
198
|
+
|
|
199
|
+
Will convert to a :class:`torch.Tensor` if `to_torch` is True.
|
|
200
|
+
|
|
201
|
+
Yields
|
|
202
|
+
------
|
|
203
|
+
An in-memory array optionally with its label and location in the global store.
|
|
204
|
+
"""
|
|
205
|
+
yield from self._dataset_manager.iter(
|
|
206
|
+
self._chunk_size,
|
|
207
|
+
self._worker_handle,
|
|
208
|
+
self._preload_nchunks,
|
|
209
|
+
self._shuffle,
|
|
210
|
+
self._fetch_data,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
AbstractIterableDataset.add_dataset.__doc__ = add_dataset_docstring
|
|
215
|
+
AbstractIterableDataset.add_datasets.__doc__ = add_datasets_docstring
|
|
216
|
+
AbstractIterableDataset.add_anndata.__doc__ = add_anndata_docstring
|
|
217
|
+
AbstractIterableDataset.add_anndatas.__doc__ = add_anndatas_docstring
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _assign_methods_to_ensure_unique_docstrings(typ):
|
|
221
|
+
"""Because both children AbstractIterableDataset inherit but do not override the methods listed, they need to be copied to ensure unique docstrings"""
|
|
222
|
+
for name in ["add_datasets", "add_dataset", "add_anndatas", "add_anndata", "__init__", "__iter__"]:
|
|
223
|
+
|
|
224
|
+
@wraps(getattr(AbstractIterableDataset, name))
|
|
225
|
+
def func(self, *args, name=name, **kwargs):
|
|
226
|
+
return getattr(super(typ, self), name)(*args, **kwargs)
|
|
227
|
+
|
|
228
|
+
setattr(typ, name, func)
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections import OrderedDict, defaultdict
|
|
5
|
+
from types import NoneType
|
|
6
|
+
from typing import TYPE_CHECKING, cast
|
|
7
|
+
|
|
8
|
+
import anndata as ad
|
|
9
|
+
import numpy as np
|
|
10
|
+
import zarr.core.sync as zsync
|
|
11
|
+
from scipy import sparse as sp
|
|
12
|
+
|
|
13
|
+
from annbatch.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray
|
|
14
|
+
from annbatch.utils import (
|
|
15
|
+
CSRContainer,
|
|
16
|
+
WorkerHandle,
|
|
17
|
+
_batched,
|
|
18
|
+
add_anndata_docstring,
|
|
19
|
+
add_anndatas_docstring,
|
|
20
|
+
add_dataset_docstring,
|
|
21
|
+
add_datasets_docstring,
|
|
22
|
+
check_lt_1,
|
|
23
|
+
check_var_shapes,
|
|
24
|
+
index_datasets,
|
|
25
|
+
split_given_size,
|
|
26
|
+
to_torch,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from cupy import ndarray as CupyArray
|
|
31
|
+
from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
|
|
32
|
+
except ImportError:
|
|
33
|
+
CupyCSRMatrix = NoneType
|
|
34
|
+
CupyArray = NoneType
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from collections.abc import Awaitable, Callable, Iterator
|
|
38
|
+
from types import ModuleType
|
|
39
|
+
|
|
40
|
+
accepted_on_disk_types = OnDiskArray.__constraints__
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AnnDataManager[OnDiskArray, InputInMemoryArray]: # noqa: D101
|
|
44
|
+
train_datasets: list[OnDiskArray] = []
|
|
45
|
+
labels: list[np.ndarray] | None = None
|
|
46
|
+
_return_index: bool = False
|
|
47
|
+
_on_add: Callable | None = None
|
|
48
|
+
_batch_size: int = 1
|
|
49
|
+
_shapes: list[tuple[int, int]] = []
|
|
50
|
+
_preload_to_gpu: bool = True
|
|
51
|
+
_drop_last: bool = False
|
|
52
|
+
_to_torch: bool = True
|
|
53
|
+
_used_anndata_adder: bool = False
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
*,
|
|
58
|
+
on_add: Callable | None = None,
|
|
59
|
+
return_index: bool = False,
|
|
60
|
+
batch_size: int = 1,
|
|
61
|
+
preload_to_gpu: bool = True,
|
|
62
|
+
drop_last: bool = False,
|
|
63
|
+
to_torch: bool = True,
|
|
64
|
+
):
|
|
65
|
+
self._on_add = on_add
|
|
66
|
+
self._return_index = return_index
|
|
67
|
+
self._batch_size = batch_size
|
|
68
|
+
self._preload_to_gpu = preload_to_gpu
|
|
69
|
+
self._to_torch = to_torch
|
|
70
|
+
self._drop_last = drop_last
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def _sp_module(self) -> ModuleType:
|
|
74
|
+
if self._preload_to_gpu:
|
|
75
|
+
try:
|
|
76
|
+
import cupyx.scipy.sparse as cpx # pragma: no cover
|
|
77
|
+
|
|
78
|
+
return cpx
|
|
79
|
+
except ImportError:
|
|
80
|
+
raise ImportError(
|
|
81
|
+
"Cannot find cupy module even though `preload_to_gpu` argument was set to `True`"
|
|
82
|
+
) from None
|
|
83
|
+
return sp
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def _np_module(self) -> ModuleType:
|
|
87
|
+
if self._preload_to_gpu:
|
|
88
|
+
try:
|
|
89
|
+
import cupy as cp
|
|
90
|
+
|
|
91
|
+
return cp
|
|
92
|
+
except ImportError:
|
|
93
|
+
raise ImportError(
|
|
94
|
+
"Cannot find cupy module even though `preload_to_gpu` argument was set to `True`"
|
|
95
|
+
) from None
|
|
96
|
+
|
|
97
|
+
return np
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def dataset_type(self) -> type[OnDiskArray]: # noqa: D102
|
|
101
|
+
return type(self.train_datasets[0])
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def n_obs(self) -> int: # noqa: D102
|
|
105
|
+
return sum(shape[0] for shape in self._shapes)
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def n_var(self) -> int: # noqa: D102
|
|
109
|
+
return self._shapes[0][1]
|
|
110
|
+
|
|
111
|
+
def add_anndatas( # noqa: D102
|
|
112
|
+
self,
|
|
113
|
+
adatas: list[ad.AnnData],
|
|
114
|
+
layer_keys: list[str | None] | str | None = None,
|
|
115
|
+
obs_keys: list[str] | str | None = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
self._used_anndata_adder = True
|
|
118
|
+
if isinstance(layer_keys, str | None):
|
|
119
|
+
layer_keys = [layer_keys] * len(adatas)
|
|
120
|
+
if isinstance(obs_keys, str | None):
|
|
121
|
+
obs_keys = [obs_keys] * len(adatas)
|
|
122
|
+
elem_to_keys = dict(zip(["layer", "obs"], [layer_keys, obs_keys], strict=True))
|
|
123
|
+
check_lt_1(
|
|
124
|
+
[len(adatas)] + sum((([len(k)] if k is not None else []) for k in elem_to_keys.values()), []),
|
|
125
|
+
["Number of anndatas"]
|
|
126
|
+
+ sum(
|
|
127
|
+
([f"Number of {label} keys"] if keys is not None else [] for keys, label in elem_to_keys.items()),
|
|
128
|
+
[],
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
for adata, obs_key, layer_key in zip(adatas, obs_keys, layer_keys, strict=True):
|
|
132
|
+
kwargs = {"obs_key": obs_key, "layer_key": layer_key}
|
|
133
|
+
self.add_anndata(adata, **kwargs)
|
|
134
|
+
|
|
135
|
+
def add_anndata( # noqa: D102
|
|
136
|
+
self,
|
|
137
|
+
adata: ad.AnnData,
|
|
138
|
+
layer_key: str | None = None,
|
|
139
|
+
obs_key: str | None = None,
|
|
140
|
+
) -> None:
|
|
141
|
+
self._used_anndata_adder = True
|
|
142
|
+
dataset = adata.X if layer_key is None else adata.layers[layer_key]
|
|
143
|
+
if not isinstance(dataset, accepted_on_disk_types):
|
|
144
|
+
raise TypeError(f"Found {type(dataset)} but only {accepted_on_disk_types} are usable")
|
|
145
|
+
obs = adata.obs[obs_key].to_numpy() if obs_key is not None else None
|
|
146
|
+
self.add_dataset(cast("OnDiskArray", dataset), obs)
|
|
147
|
+
|
|
148
|
+
def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> None: # noqa: D102
|
|
149
|
+
if obs is None:
|
|
150
|
+
obs = [None] * len(datasets)
|
|
151
|
+
for ds, o in zip(datasets, obs, strict=True):
|
|
152
|
+
self.add_dataset(ds, o)
|
|
153
|
+
|
|
154
|
+
def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> None: # noqa: D102
|
|
155
|
+
if len(self.train_datasets) > 0:
|
|
156
|
+
if self.labels is None and obs is not None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Cannot add a dataset with obs label {obs} when training datasets have already been added without labels"
|
|
159
|
+
)
|
|
160
|
+
if self.labels is not None and obs is None:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"Cannot add a dataset with no obs label when training datasets have already been added without labels"
|
|
163
|
+
)
|
|
164
|
+
if not isinstance(dataset, accepted_types := accepted_on_disk_types):
|
|
165
|
+
raise TypeError(f"Cannot add a dataset of type {type(dataset)}, only {accepted_types} are allowed")
|
|
166
|
+
if len(self.train_datasets) > 0 and not isinstance(dataset, self.dataset_type):
|
|
167
|
+
raise TypeError(
|
|
168
|
+
f"Cannot add a dataset whose data of type {type(dataset)} was not an instance of expected type {self.dataset_type}"
|
|
169
|
+
)
|
|
170
|
+
datasets = self.train_datasets + [dataset]
|
|
171
|
+
check_var_shapes(datasets)
|
|
172
|
+
self._shapes = self._shapes + [dataset.shape]
|
|
173
|
+
self.train_datasets = datasets
|
|
174
|
+
if self.labels is not None: # labels exist
|
|
175
|
+
self.labels += [obs]
|
|
176
|
+
elif obs is not None: # labels dont exist yet, but are being added for the first time
|
|
177
|
+
self.labels = [obs]
|
|
178
|
+
if self._on_add is not None:
|
|
179
|
+
self._on_add()
|
|
180
|
+
|
|
181
|
+
def _get_relative_obs_indices(self, index: slice, *, use_original_space: bool = False) -> list[tuple[slice, int]]:
|
|
182
|
+
"""Generate a slice relative to a dataset given a global slice index over all datasets.
|
|
183
|
+
|
|
184
|
+
For a given slice indexer of axis 0, return a new slice relative to the on-disk
|
|
185
|
+
data it represents given the number of total observations as well as the index of
|
|
186
|
+
the underlying data on disk from the argument `sparse_datasets` to the initializer.
|
|
187
|
+
|
|
188
|
+
For example, given slice index (10, 15), for 4 datasets each with size 5 on axis zero,
|
|
189
|
+
this function returns ((0,5), 2) representing slice (0,5) along axis zero of sparse dataset 2.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
index
|
|
194
|
+
The queried slice.
|
|
195
|
+
use_original_space
|
|
196
|
+
Whether or not the slices should be reindexed against the anndata objects.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
A slice relative to the dataset it represents as well as the index of said dataset in `sparse_datasets`.
|
|
201
|
+
"""
|
|
202
|
+
min_idx = index.start
|
|
203
|
+
max_idx = index.stop
|
|
204
|
+
curr_pos = 0
|
|
205
|
+
slices = []
|
|
206
|
+
for idx, (n_obs, _) in enumerate(self._shapes):
|
|
207
|
+
array_start = curr_pos
|
|
208
|
+
array_end = curr_pos + n_obs
|
|
209
|
+
|
|
210
|
+
start = max(min_idx, array_start)
|
|
211
|
+
stop = min(max_idx, array_end)
|
|
212
|
+
if start < stop:
|
|
213
|
+
if use_original_space:
|
|
214
|
+
slices.append((slice(start, stop), idx))
|
|
215
|
+
else:
|
|
216
|
+
relative_start = start - array_start
|
|
217
|
+
relative_stop = stop - array_start
|
|
218
|
+
slices.append((slice(relative_start, relative_stop), idx))
|
|
219
|
+
curr_pos += n_obs
|
|
220
|
+
return slices
|
|
221
|
+
|
|
222
|
+
def _slices_to_slices_with_array_index(
|
|
223
|
+
self, slices: list[slice], *, use_original_space: bool = False
|
|
224
|
+
) -> OrderedDict[int, list[slice]]:
|
|
225
|
+
"""Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset.
|
|
226
|
+
|
|
227
|
+
Parameters
|
|
228
|
+
----------
|
|
229
|
+
slices
|
|
230
|
+
Slices to relative to the on-disk datasets.
|
|
231
|
+
use_original_space
|
|
232
|
+
Whether or not the slices should be reindexed against the anndata objects.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
A lookup between the dataset and its indexing slices, ordered by keys.
|
|
237
|
+
"""
|
|
238
|
+
dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list)
|
|
239
|
+
for slice in slices:
|
|
240
|
+
for relative_obs_indices in self._get_relative_obs_indices(slice, use_original_space=use_original_space):
|
|
241
|
+
dataset_index_to_slices[relative_obs_indices[1]] += [relative_obs_indices[0]]
|
|
242
|
+
keys = sorted(dataset_index_to_slices.keys())
|
|
243
|
+
dataset_index_to_slices_sorted = OrderedDict()
|
|
244
|
+
for k in keys:
|
|
245
|
+
dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k]
|
|
246
|
+
return dataset_index_to_slices_sorted
|
|
247
|
+
|
|
248
|
+
def _get_chunks(self, chunk_size: int, worker_handle: WorkerHandle, shuffle: bool) -> np.ndarray:
|
|
249
|
+
"""Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker.
|
|
250
|
+
|
|
251
|
+
Returns
|
|
252
|
+
-------
|
|
253
|
+
A :class:`numpy.ndarray` of chunk ids.
|
|
254
|
+
"""
|
|
255
|
+
chunks = np.arange(math.ceil(self.n_obs / chunk_size))
|
|
256
|
+
if shuffle:
|
|
257
|
+
worker_handle.shuffle(chunks)
|
|
258
|
+
|
|
259
|
+
return worker_handle.get_part_for_worker(chunks)
|
|
260
|
+
|
|
261
|
+
def iter(
|
|
262
|
+
self,
|
|
263
|
+
chunk_size: int,
|
|
264
|
+
worker_handle: WorkerHandle,
|
|
265
|
+
preload_nchunks: int,
|
|
266
|
+
shuffle: bool,
|
|
267
|
+
fetch_data: Callable[[list[slice], int], Awaitable[np.ndarray | CSRContainer]],
|
|
268
|
+
) -> Iterator[
|
|
269
|
+
tuple[OutputInMemoryArray, None | np.ndarray] | tuple[OutputInMemoryArray, None | np.ndarray, np.ndarray]
|
|
270
|
+
]:
|
|
271
|
+
"""Iterate over the on-disk csr datasets.
|
|
272
|
+
|
|
273
|
+
Yields
|
|
274
|
+
------
|
|
275
|
+
A one-row sparse matrix.
|
|
276
|
+
"""
|
|
277
|
+
check_lt_1(
|
|
278
|
+
[len(self.train_datasets), self.n_obs],
|
|
279
|
+
["Number of datasets", "Number of observations"],
|
|
280
|
+
)
|
|
281
|
+
# In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0
|
|
282
|
+
# we must keep track of the leftover data.
|
|
283
|
+
in_memory_data = None
|
|
284
|
+
in_memory_labels = None
|
|
285
|
+
in_memory_indices = None
|
|
286
|
+
mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np
|
|
287
|
+
for chunk_indices in _batched(self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks):
|
|
288
|
+
slices = [
|
|
289
|
+
slice(
|
|
290
|
+
index * chunk_size,
|
|
291
|
+
min(self.n_obs, (index + 1) * chunk_size),
|
|
292
|
+
)
|
|
293
|
+
for index in chunk_indices
|
|
294
|
+
]
|
|
295
|
+
dataset_index_to_slices = self._slices_to_slices_with_array_index(slices)
|
|
296
|
+
# Fetch the data over slices
|
|
297
|
+
chunks: list[InputInMemoryArray] = zsync.sync(index_datasets(dataset_index_to_slices, fetch_data))
|
|
298
|
+
if any(isinstance(c, CSRContainer) for c in chunks):
|
|
299
|
+
chunks_converted: list[OutputInMemoryArray] = [
|
|
300
|
+
self._sp_module.csr_matrix(
|
|
301
|
+
tuple(self._np_module.asarray(e) for e in c.elems),
|
|
302
|
+
shape=c.shape,
|
|
303
|
+
dtype="float64" if self._preload_to_gpu else c.dtype,
|
|
304
|
+
)
|
|
305
|
+
for c in chunks
|
|
306
|
+
]
|
|
307
|
+
else:
|
|
308
|
+
chunks_converted = [self._np_module.asarray(c) for c in chunks]
|
|
309
|
+
# Accumulate labels
|
|
310
|
+
labels: None | list[np.ndarray] = None
|
|
311
|
+
if self.labels is not None:
|
|
312
|
+
labels = []
|
|
313
|
+
for dataset_idx in dataset_index_to_slices.keys():
|
|
314
|
+
labels += [
|
|
315
|
+
self.labels[dataset_idx][
|
|
316
|
+
np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]])
|
|
317
|
+
]
|
|
318
|
+
]
|
|
319
|
+
# Accumulate indices if necessary
|
|
320
|
+
indices: None | list[np.ndarray] = None
|
|
321
|
+
if self._return_index:
|
|
322
|
+
dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True)
|
|
323
|
+
dataset_indices = dataset_index_to_slices.keys()
|
|
324
|
+
indices = [
|
|
325
|
+
np.concatenate(
|
|
326
|
+
[
|
|
327
|
+
np.arange(
|
|
328
|
+
s.start,
|
|
329
|
+
s.stop,
|
|
330
|
+
)
|
|
331
|
+
for s in dataset_index_to_slices[index]
|
|
332
|
+
]
|
|
333
|
+
)
|
|
334
|
+
for index in dataset_indices
|
|
335
|
+
]
|
|
336
|
+
# Do batch returns, handling leftover data as necessary
|
|
337
|
+
in_memory_data = (
|
|
338
|
+
mod.vstack(chunks_converted)
|
|
339
|
+
if in_memory_data is None
|
|
340
|
+
else mod.vstack([in_memory_data, *chunks_converted])
|
|
341
|
+
)
|
|
342
|
+
if self.labels is not None:
|
|
343
|
+
in_memory_labels = (
|
|
344
|
+
np.concatenate(labels) if in_memory_labels is None else np.concatenate([in_memory_labels, *labels])
|
|
345
|
+
)
|
|
346
|
+
if self._return_index:
|
|
347
|
+
in_memory_indices = (
|
|
348
|
+
np.concatenate(indices)
|
|
349
|
+
if in_memory_indices is None
|
|
350
|
+
else np.concatenate([in_memory_indices, *indices])
|
|
351
|
+
)
|
|
352
|
+
# Create random indices into in_memory_data and then index into it
|
|
353
|
+
# If there is "leftover" at the end (see the modulo op),
|
|
354
|
+
# save it for the next iteration.
|
|
355
|
+
batch_indices = np.arange(in_memory_data.shape[0])
|
|
356
|
+
if shuffle:
|
|
357
|
+
np.random.default_rng().shuffle(batch_indices)
|
|
358
|
+
splits = split_given_size(batch_indices, self._batch_size)
|
|
359
|
+
for i, s in enumerate(splits):
|
|
360
|
+
if s.shape[0] == self._batch_size:
|
|
361
|
+
res = [
|
|
362
|
+
in_memory_data[s],
|
|
363
|
+
in_memory_labels[s] if self.labels is not None else None,
|
|
364
|
+
]
|
|
365
|
+
if self._return_index:
|
|
366
|
+
res += [in_memory_indices[s]]
|
|
367
|
+
if self._to_torch:
|
|
368
|
+
res[0] = to_torch(res[0], self._preload_to_gpu)
|
|
369
|
+
yield tuple(res)
|
|
370
|
+
if i == (len(splits) - 1): # end of iteration, leftover data needs be kept
|
|
371
|
+
if (s.shape[0] % self._batch_size) != 0:
|
|
372
|
+
in_memory_data = in_memory_data[s]
|
|
373
|
+
if in_memory_labels is not None:
|
|
374
|
+
in_memory_labels = in_memory_labels[s]
|
|
375
|
+
if in_memory_indices is not None:
|
|
376
|
+
in_memory_indices = in_memory_indices[s]
|
|
377
|
+
else:
|
|
378
|
+
in_memory_data = None
|
|
379
|
+
in_memory_labels = None
|
|
380
|
+
in_memory_indices = None
|
|
381
|
+
if in_memory_data is not None and not self._drop_last: # handle any leftover data
|
|
382
|
+
res = [
|
|
383
|
+
in_memory_data,
|
|
384
|
+
in_memory_labels if self.labels is not None else None,
|
|
385
|
+
]
|
|
386
|
+
if self._return_index:
|
|
387
|
+
res += [in_memory_indices]
|
|
388
|
+
if self._to_torch:
|
|
389
|
+
res[0] = to_torch(res[0], self._preload_to_gpu)
|
|
390
|
+
yield tuple(res)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
AnnDataManager.add_datasets.__doc__ = add_datasets_docstring
|
|
394
|
+
AnnDataManager.add_dataset.__doc__ = add_dataset_docstring
|
|
395
|
+
AnnDataManager.add_anndatas.__doc__ = add_anndatas_docstring
|
|
396
|
+
AnnDataManager.add_anndata.__doc__ = add_anndata_docstring
|