annbatch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of annbatch might be problematic. Click here for more details.

annbatch/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ from importlib.metadata import version
2
+
3
+ from .dense import ZarrDenseDataset
4
+ from .io import add_to_collection, create_anndata_collection, write_sharded
5
+ from .sparse import ZarrSparseDataset
6
+
7
+ __version__ = version("annbatch")
8
+
9
+ __all__ = [
10
+ "ZarrSparseDataset",
11
+ "ZarrDenseDataset",
12
+ "write_sharded",
13
+ "add_to_collection",
14
+ "create_anndata_collection",
15
+ ]
annbatch/abc.py ADDED
@@ -0,0 +1,228 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+ from functools import wraps
5
+ from importlib.util import find_spec
6
+ from typing import TYPE_CHECKING
7
+
8
+ from annbatch.anndata_manager import AnnDataManager
9
+ from annbatch.utils import (
10
+ WorkerHandle,
11
+ add_anndata_docstring,
12
+ add_anndatas_docstring,
13
+ add_dataset_docstring,
14
+ add_datasets_docstring,
15
+ check_lt_1,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Iterator
20
+ from typing import Self
21
+
22
+ import anndata as ad
23
+ import numpy as np
24
+ from torch import Tensor
25
+
26
+ from annbatch.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray
27
+
28
+
29
+ class AbstractIterableDataset[OnDiskArray, InputInMemoryArray](metaclass=ABCMeta): # noqa: D101
30
+ _shuffle: bool
31
+ _preload_nchunks: int
32
+ _worker_handle: WorkerHandle
33
+ _chunk_size: int
34
+ _dataset_manager: AnnDataManager[OnDiskArray, InputInMemoryArray]
35
+
36
+ def __init__(
37
+ self,
38
+ *,
39
+ chunk_size: int = 512,
40
+ preload_nchunks: int = 32,
41
+ shuffle: bool = True,
42
+ return_index: bool = False,
43
+ batch_size: int = 1,
44
+ preload_to_gpu: bool = True,
45
+ drop_last: bool = False,
46
+ to_torch: bool = find_spec("torch") is not None,
47
+ ):
48
+ """A loader for on-disk {array_type} data.
49
+
50
+ This loader batches together slice requests to the underlying {array_type} stores to achieve higher performance.
51
+ This custom code to do this task will be upstreamed into anndata at some point and no longer rely on private zarr apis.
52
+ The loader is agnostic to the on-disk chunking/sharding, but it may be advisable to align with the in-memory chunk size for dense.
53
+
54
+ The dataset class on its own is quite performant for "chunked loading" i.e., `chunk_size > 1`.
55
+ When `chunk_size == 1`, a :class:`torch.utils.data.DataLoader` should wrap the dataset object.
56
+ In this case, do not use the `add_anndata` or `add_anndatas` option due to https://github.com/scverse/anndata/issues/2021.
57
+ Instead use :func:`anndata.io.sparse_dataset` or :func:`zarr.open` to only get the array you need.
58
+
59
+ If `preload_to_gpu` to True and `to_torch` is False, the yielded type is a `cupy` matrix.
60
+ If `to_torch` is True, the yielded type is a :class:`torch.Tensor`.
61
+ If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for {array_type}.
62
+
63
+ Parameters
64
+ ----------
65
+ chunk_size
66
+ The obs size (i.e., axis 0) of contiguous array data to fetch.
67
+ preload_nchunks
68
+ The number of chunks of contiguous array data to fetch.
69
+ shuffle
70
+ Whether or not to shuffle the data.
71
+ return_index
72
+ Whether or not to yield the index on each iteration.
73
+ batch_size
74
+ Batch size to yield from the dataset.
75
+ preload_to_gpu
76
+ Whether or not to use cupy for non-io array operations like vstack and indexing once the data is in memory internally.
77
+ This option entails greater GPU memory usage, but is faster at least for sparse operations.
78
+ :func:`torch.vstack` does not support CSR sparse matrices, hence the current use of cupy internally.
79
+ Setting this to `False` is advisable when using the :class:`torch.utils.data.DataLoader` wrapper or potentially with dense data.
80
+ For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to denseify.
81
+ drop_last
82
+ Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
83
+ If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
84
+ Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`.
85
+ to_torch
86
+ Whether to return `torch.Tensor` as the output.
87
+ Data transferred should be 0-copy independent of source, and transfer to cuda when applicable is non-blocking.
88
+ Defaults to True if `torch` is installed.
89
+
90
+ Examples
91
+ --------
92
+ >>> from annbatch import {child_class}
93
+ >>> ds = {child_class}(
94
+ batch_size=4096,
95
+ chunk_size=32,
96
+ preload_nchunks=512,
97
+ ).add_anndata(my_anndata)
98
+ >>> for batch in ds:
99
+ # optionally convert to dense
100
+ # batch = batch.to_dense()
101
+ do_fit(batch)
102
+ """
103
+ check_lt_1(
104
+ [
105
+ chunk_size,
106
+ preload_nchunks,
107
+ ],
108
+ ["Chunk size", "Preload chunks"],
109
+ )
110
+ if batch_size > (chunk_size * preload_nchunks):
111
+ raise NotImplementedError(
112
+ "Cannot yield batches bigger than the iterated in-memory size i.e., batch_size > (chunk_size * preload_nchunks)."
113
+ )
114
+
115
+ for package, arg, arg_name in [
116
+ ("torch", to_torch, f"{to_torch=}"),
117
+ ("cupy", preload_to_gpu, f"{preload_to_gpu=}"),
118
+ ]:
119
+ if arg and not find_spec(package):
120
+ raise ImportError(
121
+ f"Could not find {package} dependency even though {arg_name}. Try `pip install {package}`"
122
+ )
123
+ self._dataset_manager = AnnDataManager(
124
+ # TODO: https://github.com/scverse/anndata/issues/2021
125
+ # on_add=self._cache_update_callback,
126
+ return_index=return_index,
127
+ batch_size=batch_size,
128
+ preload_to_gpu=preload_to_gpu,
129
+ drop_last=drop_last,
130
+ to_torch=to_torch,
131
+ )
132
+ self._chunk_size = chunk_size
133
+ self._preload_nchunks = preload_nchunks
134
+ self._shuffle = shuffle
135
+ self._worker_handle = WorkerHandle()
136
+
137
+ async def _cache_update_callback(self):
138
+ return None
139
+
140
+ @abstractmethod
141
+ async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> InputInMemoryArray:
142
+ """Fetch the data for given slices and the arrays representing a dataset on-disk.
143
+
144
+ Parameters
145
+ ----------
146
+ slices: The indexing slices to fetch.
147
+ dataset_idx: The index of the dataset to fetch from.
148
+
149
+ Returns
150
+ -------
151
+ The in-memory array data.
152
+ """
153
+ ...
154
+
155
+ # TODO: validations once the sparse and dense are merged with the AnnDataManager
156
+ def add_anndatas( # noqa: D102
157
+ self,
158
+ adatas: list[ad.AnnData],
159
+ layer_keys: list[str | None] | str | None = None,
160
+ obs_keys: list[str] | str | None = None,
161
+ ) -> Self:
162
+ self._dataset_manager.add_anndatas(adatas, layer_keys=layer_keys, obs_keys=obs_keys)
163
+ return self
164
+
165
+ def add_anndata( # noqa: D102
166
+ self,
167
+ adata: ad.AnnData,
168
+ layer_key: str | None = None,
169
+ obs_key: str | None = None,
170
+ ) -> Self:
171
+ self._dataset_manager.add_anndata(adata, layer_key=layer_key, obs_key=obs_key)
172
+ return self
173
+
174
+ @abstractmethod
175
+ def _validate(self, datasets: list[OnDiskArray]) -> None: ...
176
+
177
+ def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> Self: # noqa: D102
178
+ self._validate(datasets)
179
+ self._dataset_manager.add_datasets(datasets, obs)
180
+ return self
181
+
182
+ def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> Self: # noqa: D102
183
+ self._validate([dataset])
184
+ self._dataset_manager.add_dataset(dataset, obs)
185
+ return self
186
+
187
+ def __len__(self) -> int:
188
+ return self._dataset_manager.n_obs
189
+
190
+ def __iter__(
191
+ self,
192
+ ) -> Iterator[
193
+ tuple[OutputInMemoryArray, None | np.ndarray]
194
+ | tuple[OutputInMemoryArray | Tensor, None | np.ndarray, np.ndarray]
195
+ ]:
196
+ """
197
+ Iterate over the on-disk datasets, returning :class:`{gpu_array}` or :class:`{cpu_array}` depending on whether or not `preload_to_gpu` is set.
198
+
199
+ Will convert to a :class:`torch.Tensor` if `to_torch` is True.
200
+
201
+ Yields
202
+ ------
203
+ An in-memory array optionally with its label and location in the global store.
204
+ """
205
+ yield from self._dataset_manager.iter(
206
+ self._chunk_size,
207
+ self._worker_handle,
208
+ self._preload_nchunks,
209
+ self._shuffle,
210
+ self._fetch_data,
211
+ )
212
+
213
+
214
+ AbstractIterableDataset.add_dataset.__doc__ = add_dataset_docstring
215
+ AbstractIterableDataset.add_datasets.__doc__ = add_datasets_docstring
216
+ AbstractIterableDataset.add_anndata.__doc__ = add_anndata_docstring
217
+ AbstractIterableDataset.add_anndatas.__doc__ = add_anndatas_docstring
218
+
219
+
220
+ def _assign_methods_to_ensure_unique_docstrings(typ):
221
+ """Because both children AbstractIterableDataset inherit but do not override the methods listed, they need to be copied to ensure unique docstrings"""
222
+ for name in ["add_datasets", "add_dataset", "add_anndatas", "add_anndata", "__init__", "__iter__"]:
223
+
224
+ @wraps(getattr(AbstractIterableDataset, name))
225
+ def func(self, *args, name=name, **kwargs):
226
+ return getattr(super(typ, self), name)(*args, **kwargs)
227
+
228
+ setattr(typ, name, func)
@@ -0,0 +1,396 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections import OrderedDict, defaultdict
5
+ from types import NoneType
6
+ from typing import TYPE_CHECKING, cast
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import zarr.core.sync as zsync
11
+ from scipy import sparse as sp
12
+
13
+ from annbatch.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray
14
+ from annbatch.utils import (
15
+ CSRContainer,
16
+ WorkerHandle,
17
+ _batched,
18
+ add_anndata_docstring,
19
+ add_anndatas_docstring,
20
+ add_dataset_docstring,
21
+ add_datasets_docstring,
22
+ check_lt_1,
23
+ check_var_shapes,
24
+ index_datasets,
25
+ split_given_size,
26
+ to_torch,
27
+ )
28
+
29
+ try:
30
+ from cupy import ndarray as CupyArray
31
+ from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
32
+ except ImportError:
33
+ CupyCSRMatrix = NoneType
34
+ CupyArray = NoneType
35
+
36
+ if TYPE_CHECKING:
37
+ from collections.abc import Awaitable, Callable, Iterator
38
+ from types import ModuleType
39
+
40
+ accepted_on_disk_types = OnDiskArray.__constraints__
41
+
42
+
43
+ class AnnDataManager[OnDiskArray, InputInMemoryArray]: # noqa: D101
44
+ train_datasets: list[OnDiskArray] = []
45
+ labels: list[np.ndarray] | None = None
46
+ _return_index: bool = False
47
+ _on_add: Callable | None = None
48
+ _batch_size: int = 1
49
+ _shapes: list[tuple[int, int]] = []
50
+ _preload_to_gpu: bool = True
51
+ _drop_last: bool = False
52
+ _to_torch: bool = True
53
+ _used_anndata_adder: bool = False
54
+
55
+ def __init__(
56
+ self,
57
+ *,
58
+ on_add: Callable | None = None,
59
+ return_index: bool = False,
60
+ batch_size: int = 1,
61
+ preload_to_gpu: bool = True,
62
+ drop_last: bool = False,
63
+ to_torch: bool = True,
64
+ ):
65
+ self._on_add = on_add
66
+ self._return_index = return_index
67
+ self._batch_size = batch_size
68
+ self._preload_to_gpu = preload_to_gpu
69
+ self._to_torch = to_torch
70
+ self._drop_last = drop_last
71
+
72
+ @property
73
+ def _sp_module(self) -> ModuleType:
74
+ if self._preload_to_gpu:
75
+ try:
76
+ import cupyx.scipy.sparse as cpx # pragma: no cover
77
+
78
+ return cpx
79
+ except ImportError:
80
+ raise ImportError(
81
+ "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`"
82
+ ) from None
83
+ return sp
84
+
85
+ @property
86
+ def _np_module(self) -> ModuleType:
87
+ if self._preload_to_gpu:
88
+ try:
89
+ import cupy as cp
90
+
91
+ return cp
92
+ except ImportError:
93
+ raise ImportError(
94
+ "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`"
95
+ ) from None
96
+
97
+ return np
98
+
99
+ @property
100
+ def dataset_type(self) -> type[OnDiskArray]: # noqa: D102
101
+ return type(self.train_datasets[0])
102
+
103
+ @property
104
+ def n_obs(self) -> int: # noqa: D102
105
+ return sum(shape[0] for shape in self._shapes)
106
+
107
+ @property
108
+ def n_var(self) -> int: # noqa: D102
109
+ return self._shapes[0][1]
110
+
111
+ def add_anndatas( # noqa: D102
112
+ self,
113
+ adatas: list[ad.AnnData],
114
+ layer_keys: list[str | None] | str | None = None,
115
+ obs_keys: list[str] | str | None = None,
116
+ ) -> None:
117
+ self._used_anndata_adder = True
118
+ if isinstance(layer_keys, str | None):
119
+ layer_keys = [layer_keys] * len(adatas)
120
+ if isinstance(obs_keys, str | None):
121
+ obs_keys = [obs_keys] * len(adatas)
122
+ elem_to_keys = dict(zip(["layer", "obs"], [layer_keys, obs_keys], strict=True))
123
+ check_lt_1(
124
+ [len(adatas)] + sum((([len(k)] if k is not None else []) for k in elem_to_keys.values()), []),
125
+ ["Number of anndatas"]
126
+ + sum(
127
+ ([f"Number of {label} keys"] if keys is not None else [] for keys, label in elem_to_keys.items()),
128
+ [],
129
+ ),
130
+ )
131
+ for adata, obs_key, layer_key in zip(adatas, obs_keys, layer_keys, strict=True):
132
+ kwargs = {"obs_key": obs_key, "layer_key": layer_key}
133
+ self.add_anndata(adata, **kwargs)
134
+
135
+ def add_anndata( # noqa: D102
136
+ self,
137
+ adata: ad.AnnData,
138
+ layer_key: str | None = None,
139
+ obs_key: str | None = None,
140
+ ) -> None:
141
+ self._used_anndata_adder = True
142
+ dataset = adata.X if layer_key is None else adata.layers[layer_key]
143
+ if not isinstance(dataset, accepted_on_disk_types):
144
+ raise TypeError(f"Found {type(dataset)} but only {accepted_on_disk_types} are usable")
145
+ obs = adata.obs[obs_key].to_numpy() if obs_key is not None else None
146
+ self.add_dataset(cast("OnDiskArray", dataset), obs)
147
+
148
+ def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> None: # noqa: D102
149
+ if obs is None:
150
+ obs = [None] * len(datasets)
151
+ for ds, o in zip(datasets, obs, strict=True):
152
+ self.add_dataset(ds, o)
153
+
154
+ def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> None: # noqa: D102
155
+ if len(self.train_datasets) > 0:
156
+ if self.labels is None and obs is not None:
157
+ raise ValueError(
158
+ f"Cannot add a dataset with obs label {obs} when training datasets have already been added without labels"
159
+ )
160
+ if self.labels is not None and obs is None:
161
+ raise ValueError(
162
+ "Cannot add a dataset with no obs label when training datasets have already been added without labels"
163
+ )
164
+ if not isinstance(dataset, accepted_types := accepted_on_disk_types):
165
+ raise TypeError(f"Cannot add a dataset of type {type(dataset)}, only {accepted_types} are allowed")
166
+ if len(self.train_datasets) > 0 and not isinstance(dataset, self.dataset_type):
167
+ raise TypeError(
168
+ f"Cannot add a dataset whose data of type {type(dataset)} was not an instance of expected type {self.dataset_type}"
169
+ )
170
+ datasets = self.train_datasets + [dataset]
171
+ check_var_shapes(datasets)
172
+ self._shapes = self._shapes + [dataset.shape]
173
+ self.train_datasets = datasets
174
+ if self.labels is not None: # labels exist
175
+ self.labels += [obs]
176
+ elif obs is not None: # labels dont exist yet, but are being added for the first time
177
+ self.labels = [obs]
178
+ if self._on_add is not None:
179
+ self._on_add()
180
+
181
+ def _get_relative_obs_indices(self, index: slice, *, use_original_space: bool = False) -> list[tuple[slice, int]]:
182
+ """Generate a slice relative to a dataset given a global slice index over all datasets.
183
+
184
+ For a given slice indexer of axis 0, return a new slice relative to the on-disk
185
+ data it represents given the number of total observations as well as the index of
186
+ the underlying data on disk from the argument `sparse_datasets` to the initializer.
187
+
188
+ For example, given slice index (10, 15), for 4 datasets each with size 5 on axis zero,
189
+ this function returns ((0,5), 2) representing slice (0,5) along axis zero of sparse dataset 2.
190
+
191
+ Parameters
192
+ ----------
193
+ index
194
+ The queried slice.
195
+ use_original_space
196
+ Whether or not the slices should be reindexed against the anndata objects.
197
+
198
+ Returns
199
+ -------
200
+ A slice relative to the dataset it represents as well as the index of said dataset in `sparse_datasets`.
201
+ """
202
+ min_idx = index.start
203
+ max_idx = index.stop
204
+ curr_pos = 0
205
+ slices = []
206
+ for idx, (n_obs, _) in enumerate(self._shapes):
207
+ array_start = curr_pos
208
+ array_end = curr_pos + n_obs
209
+
210
+ start = max(min_idx, array_start)
211
+ stop = min(max_idx, array_end)
212
+ if start < stop:
213
+ if use_original_space:
214
+ slices.append((slice(start, stop), idx))
215
+ else:
216
+ relative_start = start - array_start
217
+ relative_stop = stop - array_start
218
+ slices.append((slice(relative_start, relative_stop), idx))
219
+ curr_pos += n_obs
220
+ return slices
221
+
222
+ def _slices_to_slices_with_array_index(
223
+ self, slices: list[slice], *, use_original_space: bool = False
224
+ ) -> OrderedDict[int, list[slice]]:
225
+ """Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset.
226
+
227
+ Parameters
228
+ ----------
229
+ slices
230
+ Slices to relative to the on-disk datasets.
231
+ use_original_space
232
+ Whether or not the slices should be reindexed against the anndata objects.
233
+
234
+ Returns
235
+ -------
236
+ A lookup between the dataset and its indexing slices, ordered by keys.
237
+ """
238
+ dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list)
239
+ for slice in slices:
240
+ for relative_obs_indices in self._get_relative_obs_indices(slice, use_original_space=use_original_space):
241
+ dataset_index_to_slices[relative_obs_indices[1]] += [relative_obs_indices[0]]
242
+ keys = sorted(dataset_index_to_slices.keys())
243
+ dataset_index_to_slices_sorted = OrderedDict()
244
+ for k in keys:
245
+ dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k]
246
+ return dataset_index_to_slices_sorted
247
+
248
+ def _get_chunks(self, chunk_size: int, worker_handle: WorkerHandle, shuffle: bool) -> np.ndarray:
249
+ """Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker.
250
+
251
+ Returns
252
+ -------
253
+ A :class:`numpy.ndarray` of chunk ids.
254
+ """
255
+ chunks = np.arange(math.ceil(self.n_obs / chunk_size))
256
+ if shuffle:
257
+ worker_handle.shuffle(chunks)
258
+
259
+ return worker_handle.get_part_for_worker(chunks)
260
+
261
+ def iter(
262
+ self,
263
+ chunk_size: int,
264
+ worker_handle: WorkerHandle,
265
+ preload_nchunks: int,
266
+ shuffle: bool,
267
+ fetch_data: Callable[[list[slice], int], Awaitable[np.ndarray | CSRContainer]],
268
+ ) -> Iterator[
269
+ tuple[OutputInMemoryArray, None | np.ndarray] | tuple[OutputInMemoryArray, None | np.ndarray, np.ndarray]
270
+ ]:
271
+ """Iterate over the on-disk csr datasets.
272
+
273
+ Yields
274
+ ------
275
+ A one-row sparse matrix.
276
+ """
277
+ check_lt_1(
278
+ [len(self.train_datasets), self.n_obs],
279
+ ["Number of datasets", "Number of observations"],
280
+ )
281
+ # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0
282
+ # we must keep track of the leftover data.
283
+ in_memory_data = None
284
+ in_memory_labels = None
285
+ in_memory_indices = None
286
+ mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np
287
+ for chunk_indices in _batched(self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks):
288
+ slices = [
289
+ slice(
290
+ index * chunk_size,
291
+ min(self.n_obs, (index + 1) * chunk_size),
292
+ )
293
+ for index in chunk_indices
294
+ ]
295
+ dataset_index_to_slices = self._slices_to_slices_with_array_index(slices)
296
+ # Fetch the data over slices
297
+ chunks: list[InputInMemoryArray] = zsync.sync(index_datasets(dataset_index_to_slices, fetch_data))
298
+ if any(isinstance(c, CSRContainer) for c in chunks):
299
+ chunks_converted: list[OutputInMemoryArray] = [
300
+ self._sp_module.csr_matrix(
301
+ tuple(self._np_module.asarray(e) for e in c.elems),
302
+ shape=c.shape,
303
+ dtype="float64" if self._preload_to_gpu else c.dtype,
304
+ )
305
+ for c in chunks
306
+ ]
307
+ else:
308
+ chunks_converted = [self._np_module.asarray(c) for c in chunks]
309
+ # Accumulate labels
310
+ labels: None | list[np.ndarray] = None
311
+ if self.labels is not None:
312
+ labels = []
313
+ for dataset_idx in dataset_index_to_slices.keys():
314
+ labels += [
315
+ self.labels[dataset_idx][
316
+ np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]])
317
+ ]
318
+ ]
319
+ # Accumulate indices if necessary
320
+ indices: None | list[np.ndarray] = None
321
+ if self._return_index:
322
+ dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True)
323
+ dataset_indices = dataset_index_to_slices.keys()
324
+ indices = [
325
+ np.concatenate(
326
+ [
327
+ np.arange(
328
+ s.start,
329
+ s.stop,
330
+ )
331
+ for s in dataset_index_to_slices[index]
332
+ ]
333
+ )
334
+ for index in dataset_indices
335
+ ]
336
+ # Do batch returns, handling leftover data as necessary
337
+ in_memory_data = (
338
+ mod.vstack(chunks_converted)
339
+ if in_memory_data is None
340
+ else mod.vstack([in_memory_data, *chunks_converted])
341
+ )
342
+ if self.labels is not None:
343
+ in_memory_labels = (
344
+ np.concatenate(labels) if in_memory_labels is None else np.concatenate([in_memory_labels, *labels])
345
+ )
346
+ if self._return_index:
347
+ in_memory_indices = (
348
+ np.concatenate(indices)
349
+ if in_memory_indices is None
350
+ else np.concatenate([in_memory_indices, *indices])
351
+ )
352
+ # Create random indices into in_memory_data and then index into it
353
+ # If there is "leftover" at the end (see the modulo op),
354
+ # save it for the next iteration.
355
+ batch_indices = np.arange(in_memory_data.shape[0])
356
+ if shuffle:
357
+ np.random.default_rng().shuffle(batch_indices)
358
+ splits = split_given_size(batch_indices, self._batch_size)
359
+ for i, s in enumerate(splits):
360
+ if s.shape[0] == self._batch_size:
361
+ res = [
362
+ in_memory_data[s],
363
+ in_memory_labels[s] if self.labels is not None else None,
364
+ ]
365
+ if self._return_index:
366
+ res += [in_memory_indices[s]]
367
+ if self._to_torch:
368
+ res[0] = to_torch(res[0], self._preload_to_gpu)
369
+ yield tuple(res)
370
+ if i == (len(splits) - 1): # end of iteration, leftover data needs be kept
371
+ if (s.shape[0] % self._batch_size) != 0:
372
+ in_memory_data = in_memory_data[s]
373
+ if in_memory_labels is not None:
374
+ in_memory_labels = in_memory_labels[s]
375
+ if in_memory_indices is not None:
376
+ in_memory_indices = in_memory_indices[s]
377
+ else:
378
+ in_memory_data = None
379
+ in_memory_labels = None
380
+ in_memory_indices = None
381
+ if in_memory_data is not None and not self._drop_last: # handle any leftover data
382
+ res = [
383
+ in_memory_data,
384
+ in_memory_labels if self.labels is not None else None,
385
+ ]
386
+ if self._return_index:
387
+ res += [in_memory_indices]
388
+ if self._to_torch:
389
+ res[0] = to_torch(res[0], self._preload_to_gpu)
390
+ yield tuple(res)
391
+
392
+
393
+ AnnDataManager.add_datasets.__doc__ = add_datasets_docstring
394
+ AnnDataManager.add_dataset.__doc__ = add_dataset_docstring
395
+ AnnDataManager.add_anndatas.__doc__ = add_anndatas_docstring
396
+ AnnDataManager.add_anndata.__doc__ = add_anndata_docstring