annbatch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of annbatch might be problematic. Click here for more details.

annbatch/sparse.py ADDED
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from importlib.util import find_spec
5
+ from itertools import accumulate, chain, pairwise
6
+ from typing import NamedTuple, cast
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import zarr
11
+ import zarr.core.sync as zsync
12
+
13
+ if find_spec("torch"):
14
+ from torch.utils.data import IterableDataset as _IterableDataset
15
+ else:
16
+
17
+ class _IterableDataset:
18
+ pass
19
+
20
+
21
+ from annbatch.abc import AbstractIterableDataset, _assign_methods_to_ensure_unique_docstrings
22
+ from annbatch.utils import (
23
+ CSRContainer,
24
+ MultiBasicIndexer,
25
+ add_anndata_docstring,
26
+ add_anndatas_docstring,
27
+ add_dataset_docstring,
28
+ add_datasets_docstring,
29
+ )
30
+
31
+
32
+ class CSRDatasetElems(NamedTuple):
33
+ """Container for cached objects that will be indexed into to generate CSR matrices"""
34
+
35
+ indptr: np.ndarray
36
+ indices: zarr.AsyncArray
37
+ data: zarr.AsyncArray
38
+
39
+
40
+ class ZarrSparseDataset( # noqa: D101
41
+ AbstractIterableDataset[ad.abc.CSRDataset, CSRContainer], _IterableDataset
42
+ ):
43
+ _dataset_elem_cache: dict[int, CSRDatasetElems] = {}
44
+
45
+ def _cache_update_callback(self):
46
+ """Callback for when datasets are added to ensure the cache is updated."""
47
+ return zsync.sync(self._ensure_cache())
48
+
49
+ def _validate(self, datasets: list[ad.abc.CSRDataset]):
50
+ if not all(isinstance(d, ad.abc.CSRDataset) for d in datasets):
51
+ raise TypeError("Cannot create sparse dataset using CSRDataset data")
52
+ if not all(cast("ad.abc.CSRDataset", d).backend == "zarr" for d in datasets):
53
+ raise TypeError(
54
+ "Cannot use CSRDataset backed by h5ad at the moment: see https://github.com/zarr-developers/VirtualiZarr/pull/790"
55
+ )
56
+
57
+ async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems:
58
+ """Fetch the in-memory indptr, and backed indices and data for a given dataset index.
59
+
60
+ Parameters
61
+ ----------
62
+ idx
63
+ The index
64
+
65
+ Returns
66
+ -------
67
+ The constituent elems of the CSR dataset.
68
+ """
69
+ indptr = await self._dataset_manager.train_datasets[idx].group._async_group.getitem("indptr")
70
+ return CSRDatasetElems(
71
+ *(
72
+ await asyncio.gather(
73
+ indptr.getitem(Ellipsis),
74
+ self._dataset_manager.train_datasets[idx].group._async_group.getitem("indices"),
75
+ self._dataset_manager.train_datasets[idx].group._async_group.getitem("data"),
76
+ )
77
+ )
78
+ )
79
+
80
+ async def _ensure_cache(self):
81
+ """Build up the cache of datasets i.e., in-memory indptr, and backed indices and data."""
82
+ arr_idxs = [
83
+ idx for idx in range(len(self._dataset_manager.train_datasets)) if idx not in self._dataset_elem_cache
84
+ ]
85
+ all_elems = await asyncio.gather(
86
+ *(
87
+ self._create_sparse_elems(idx)
88
+ for idx in range(len(self._dataset_manager.train_datasets))
89
+ if idx not in self._dataset_elem_cache
90
+ )
91
+ )
92
+ for idx, elems in zip(arr_idxs, all_elems, strict=True):
93
+ self._dataset_elem_cache[idx] = elems
94
+
95
+ async def _get_sparse_elems(self, dataset_idx: int) -> CSRDatasetElems:
96
+ """Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index.
97
+
98
+ Parameters
99
+ ----------
100
+ dataset_idx
101
+ The index of the dataset whose arrays are sought.
102
+
103
+ Returns
104
+ -------
105
+ The arrays representing the sparse data.
106
+ """
107
+ if dataset_idx not in self._dataset_elem_cache:
108
+ await self._ensure_cache()
109
+ return self._dataset_elem_cache[dataset_idx]
110
+
111
+ async def _fetch_data(
112
+ self,
113
+ slices: list[slice],
114
+ dataset_idx: int,
115
+ ) -> CSRContainer:
116
+ # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295
117
+ # for the inspiration of this function.
118
+ indptr, indices, data = await self._get_sparse_elems(dataset_idx)
119
+ indptr_indices = [indptr[slice(s.start, s.stop + 1)] for s in slices]
120
+ indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices]
121
+ indexer = MultiBasicIndexer(
122
+ [
123
+ zarr.core.indexing.BasicIndexer((l,), shape=data.metadata.shape, chunk_grid=data.metadata.chunk_grid)
124
+ for l in indptr_limits
125
+ ]
126
+ )
127
+ data_np, indices_np = await asyncio.gather(
128
+ data._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
129
+ indices._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
130
+ )
131
+ gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits))
132
+ offsets = accumulate(chain([indptr_limits[0].start], gaps))
133
+ start_indptr = indptr_indices[0] - next(offsets)
134
+ if len(slices) < 2: # there is only one slice so no need to concatenate
135
+ return CSRContainer(
136
+ elems=(data_np, indices_np, start_indptr),
137
+ shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var),
138
+ dtype=data_np.dtype,
139
+ )
140
+ end_indptr = np.concatenate([s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)])
141
+ indptr_np = np.concatenate([start_indptr, end_indptr])
142
+ return CSRContainer(
143
+ elems=(data_np, indices_np, indptr_np),
144
+ shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var),
145
+ dtype=data_np.dtype,
146
+ )
147
+
148
+
149
+ _assign_methods_to_ensure_unique_docstrings(ZarrSparseDataset)
150
+
151
+ ZarrSparseDataset.__doc__ = AbstractIterableDataset.__init__.__doc__.format(
152
+ array_type="sparse", child_class="ZarrSparseDataset"
153
+ )
154
+ ZarrSparseDataset.add_datasets.__doc__ = add_datasets_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
155
+ ZarrSparseDataset.add_dataset.__doc__ = add_dataset_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
156
+ ZarrSparseDataset.add_anndatas.__doc__ = add_anndatas_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
157
+ ZarrSparseDataset.add_anndata.__doc__ = add_anndata_docstring.format(on_disk_array_type="anndata.abc.CSRDataset")
158
+ ZarrSparseDataset.__iter__.__doc__ = AbstractIterableDataset.__iter__.__doc__.format(
159
+ gpu_array="cupyx.scipy.sparse.spmatrix", cpu_array="scipy.sparse.csr_matrix"
160
+ )
annbatch/types.py ADDED
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from types import NoneType
4
+ from typing import TypeVar
5
+
6
+ import anndata as ad
7
+ import numpy as np
8
+ import zarr
9
+ from scipy import sparse as sp
10
+
11
+ from annbatch.utils import CSRContainer
12
+
13
+ try:
14
+ from cupy import ndarray as CupyArray
15
+ from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
16
+ except ImportError:
17
+ CupyCSRMatrix = NoneType
18
+ CupyArray = NoneType
19
+
20
+ OutputInMemoryArray = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray
21
+
22
+ OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array)
23
+
24
+
25
+ InputInMemoryArray = TypeVar("InputInMemoryArray", CSRContainer, np.ndarray)
annbatch/utils.py ADDED
@@ -0,0 +1,319 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from importlib.util import find_spec
8
+ from itertools import islice
9
+ from typing import TYPE_CHECKING, Protocol
10
+
11
+ import numpy as np
12
+ import scipy as sp
13
+ import zarr
14
+
15
+ try:
16
+ from cupy import ndarray as CupyArray
17
+ from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover
18
+ except ImportError:
19
+ CupyArray = None
20
+ CupyCSRMatrix = None
21
+
22
+ if TYPE_CHECKING:
23
+ from collections import OrderedDict
24
+ from collections.abc import Awaitable, Callable, Generator, Iterable
25
+
26
+ from torch import Tensor
27
+
28
+ from annbatch.types import InputInMemoryArray, OutputInMemoryArray
29
+
30
+
31
+ def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]:
32
+ """Wrapper around `np.split` to split up an array into `size` chunks"""
33
+ return np.split(a, np.arange(size, len(a), size))
34
+
35
+
36
+ @dataclass
37
+ class CSRContainer:
38
+ """A low-cost container for moving around the buffers of a CSR object"""
39
+
40
+ elems: tuple[np.ndarray, np.ndarray, np.ndarray]
41
+ shape: tuple[int, int]
42
+ dtype: np.dtype
43
+
44
+
45
+ def _batched[T](iterable: Iterable[T], n: int) -> Generator[list[T], None, None]:
46
+ if n < 1:
47
+ raise ValueError("n must be >= 1")
48
+ it = iter(iterable)
49
+ while batch := list(islice(it, n)):
50
+ yield batch
51
+
52
+
53
+ async def index_datasets(
54
+ dataset_index_to_slices: OrderedDict[int, list[slice]],
55
+ fetch_data: Callable[[list[slice], int], Awaitable[CSRContainer | np.ndarray]],
56
+ ) -> list[InputInMemoryArray]:
57
+ """Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr.
58
+
59
+ Parameters
60
+ ----------
61
+ dataset_index_to_slices
62
+ A lookup of the list-placement index of a dataset to the request slices.
63
+ fetch_data
64
+ The function to do the fetching for a given slice-dataset index pair.
65
+ """
66
+ tasks = []
67
+ for dataset_idx in dataset_index_to_slices.keys():
68
+ tasks.append(
69
+ fetch_data(
70
+ dataset_index_to_slices[dataset_idx],
71
+ dataset_idx,
72
+ )
73
+ )
74
+ return await asyncio.gather(*tasks)
75
+
76
+
77
+ add_datasets_docstring = """\
78
+ Append datasets to this dataset.
79
+
80
+ Parameters
81
+ ----------
82
+ datasets
83
+ List of :class:`{on_disk_array_type}` objects, generally from :attr:`anndata.AnnData.X`.
84
+ obs
85
+ List of :class:`numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`.
86
+ """
87
+
88
+ add_dataset_docstring = """\
89
+ Append a dataset to this dataset.
90
+
91
+ Parameters
92
+ ----------
93
+ dataset
94
+ :class:`{on_disk_array_type}` object, generally from :attr:`anndata.AnnData.X`.
95
+ obs
96
+ :class:`numpy.ndarray` labels for the anndata, generally from :attr:`anndata.AnnData.obs`.
97
+ """
98
+
99
+
100
+ add_anndatas_docstring = """\
101
+ Append anndatas to this dataset.
102
+
103
+ Parameters
104
+ ----------
105
+ anndatas
106
+ List of :class:`anndata.AnnData` objects, with :class:`{on_disk_array_type}` as the data matrix.
107
+ obs_keys
108
+ List of :attr:`anndata.AnnData.obs` column labels.
109
+ layer_keys
110
+ List of :attr:`anndata.AnnData.layers` keys, and if None, :attr:`anndata.AnnData.X` will be used.
111
+ """
112
+
113
+ add_anndata_docstring = """\
114
+ Append a anndata to this dataset.
115
+
116
+ Parameters
117
+ ----------
118
+ anndata
119
+ :class:`anndata.AnnData` object, with :class:`{on_disk_array_type}` as the data matrix.
120
+ obs_key
121
+ :attr:`anndata.AnnData.obs` column labels.
122
+ layer_key
123
+ :attr:`anndata.AnnData.layers` key, and if None, :attr:`anndata.AnnData.X` will be used.
124
+ """
125
+
126
+
127
+ # TODO: make this part of the public zarr or zarrs-python API.
128
+ # We can do chunk coalescing in zarrs based on integer arrays, so I think
129
+ # there would make sense with ezclump or similar.
130
+ # Another "solution" would be for zarrs to support integer indexing properly, if that pipeline works,
131
+ # or make this an "experimental setting" and to use integer indexing for the zarr-python pipeline.
132
+ # See: https://github.com/zarr-developers/zarr-python/issues/3175 for why this is better than simpler alternatives.
133
+ class MultiBasicIndexer(zarr.core.indexing.Indexer):
134
+ """Custom indexer to enable joint fetching of disparate slices"""
135
+
136
+ def __init__(self, indexers: list[zarr.core.indexing.Indexer]):
137
+ self.shape = (sum(i.shape[0] for i in indexers), *indexers[0].shape[1:])
138
+ self.drop_axes = indexers[0].drop_axes # maybe?
139
+ self.indexers = indexers
140
+
141
+ def __iter__(self):
142
+ total = 0
143
+ for i in self.indexers:
144
+ for c in i:
145
+ out_selection = c[2]
146
+ gap = out_selection[0].stop - out_selection[0].start
147
+ yield type(c)(c[0], c[1], (slice(total, total + gap), *out_selection[1:]), c[3])
148
+ total += gap
149
+
150
+
151
+ def sample_rows(
152
+ x_list: list[np.ndarray],
153
+ obs_list: list[np.ndarray] | None,
154
+ indices: list[np.ndarray] | None = None,
155
+ *,
156
+ shuffle: bool = True,
157
+ ) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
158
+ """Samples rows from multiple arrays and their corresponding observation arrays.
159
+
160
+ Parameters
161
+ ----------
162
+ x_list
163
+ A list of numpy arrays containing the data to sample from.
164
+ obs_list
165
+ A list of numpy arrays containing the corresponding observations.
166
+ indices
167
+ the list of indexes for each element in `x_list/`
168
+ shuffle
169
+ Whether to shuffle the rows before sampling.
170
+
171
+ Yields
172
+ ------
173
+ tuple
174
+ A tuple containing a row from `x_list` and the corresponding row from `obs_list`.
175
+ """
176
+ lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int)
177
+ cum = np.concatenate(([0], np.cumsum(lengths)))
178
+ total = cum[-1]
179
+ idxs = np.arange(total)
180
+ if shuffle:
181
+ np.random.default_rng().shuffle(idxs)
182
+ arr_idxs = np.searchsorted(cum, idxs, side="right") - 1
183
+ row_idxs = idxs - cum[arr_idxs]
184
+ for ai, ri in zip(arr_idxs, row_idxs, strict=True):
185
+ res = [
186
+ x_list[ai][ri],
187
+ obs_list[ai][ri] if obs_list is not None else None,
188
+ ]
189
+ if indices is not None:
190
+ yield (*res, indices[ai][ri])
191
+ else:
192
+ yield tuple(res)
193
+
194
+
195
+ class WorkerHandle: # noqa: D101
196
+ @cached_property
197
+ def _worker_info(self):
198
+ if find_spec("torch"):
199
+ from torch.utils.data import get_worker_info
200
+
201
+ return get_worker_info()
202
+ return None
203
+
204
+ @cached_property
205
+ def _rng(self):
206
+ if self._worker_info is None:
207
+ return np.random.default_rng()
208
+ else:
209
+ # This is used for the _get_chunks function
210
+ # Use the same seed for all workers that the resulting splits are the same across workers
211
+ # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed
212
+ return np.random.default_rng(self._worker_info.seed - self._worker_info.id)
213
+
214
+ def shuffle(self, obj: np.typing.ArrayLike) -> None:
215
+ """Perform in-place shuffle.
216
+
217
+ Parameters
218
+ ----------
219
+ obj
220
+ The object to be shuffled
221
+ """
222
+ self._rng.shuffle(obj)
223
+
224
+ def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray:
225
+ """Get a chunk of an incoming array accordnig to the current worker id.
226
+
227
+ Parameters
228
+ ----------
229
+ obj
230
+ Incoming array
231
+
232
+ Returns
233
+ -------
234
+ A evenly split part of the ray corresponding to how many workers there are.
235
+ """
236
+ if self._worker_info is None:
237
+ return obj
238
+ num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id
239
+ chunks_split = np.array_split(obj, num_workers)
240
+ return chunks_split[worker_id]
241
+
242
+
243
+ def check_lt_1(vals: list[int], labels: list[str]) -> None:
244
+ """Raise a ValueError if any of the values are less than one.
245
+
246
+ The format of the error is "{labels[i]} must be greater than 1, got {values[i]}"
247
+ and is raised based on the first found less than one value.
248
+
249
+ Parameters
250
+ ----------
251
+ vals
252
+ The values to check < 1
253
+ labels
254
+ The label for the value in the error if the value is less than one.
255
+
256
+ Raises
257
+ ------
258
+ ValueError: _description_
259
+ """
260
+ if any(is_lt_1 := [v < 1 for v in vals]):
261
+ label, value = next(
262
+ (label, value)
263
+ for label, value, check in zip(
264
+ labels,
265
+ vals,
266
+ is_lt_1,
267
+ strict=True,
268
+ )
269
+ if check
270
+ )
271
+ raise ValueError(f"{label} must be greater than 1, got {value}")
272
+
273
+
274
+ class SupportsShape(Protocol): # noqa: D101
275
+ @property
276
+ def shape(self) -> tuple[int, int] | list[int]: ... # noqa: D102
277
+
278
+
279
+ def check_var_shapes(objs: list[SupportsShape]) -> None:
280
+ """Small utility function to check that all objects have the same shape along the second axis"""
281
+ if not all(objs[0].shape[1] == d.shape[1] for d in objs):
282
+ raise ValueError("TODO: All datasets must have same shape along the var axis.")
283
+
284
+
285
+ def to_torch(input: OutputInMemoryArray, preload_to_gpu: bool) -> Tensor:
286
+ """Send the input data to a torch.Tensor"""
287
+ import torch
288
+
289
+ if isinstance(input, torch.Tensor):
290
+ return input
291
+ if isinstance(input, sp.sparse.csr_matrix):
292
+ with warnings.catch_warnings():
293
+ warnings.filterwarnings("ignore", "Sparse CSR tensor support is in beta state", UserWarning)
294
+ tensor = torch.sparse_csr_tensor(
295
+ torch.from_numpy(input.indptr),
296
+ torch.from_numpy(input.indices),
297
+ torch.from_numpy(input.data),
298
+ input.shape,
299
+ )
300
+ if preload_to_gpu:
301
+ return tensor.cuda(non_blocking=True)
302
+ return tensor
303
+ if isinstance(input, np.ndarray):
304
+ tensor = torch.from_numpy(input)
305
+ if preload_to_gpu:
306
+ return tensor.cuda(non_blocking=True)
307
+ return tensor
308
+ if isinstance(input, CupyArray):
309
+ return torch.from_dlpack(input)
310
+ if isinstance(input, CupyCSRMatrix):
311
+ with warnings.catch_warnings():
312
+ warnings.filterwarnings("ignore", "Sparse CSR tensor support is in beta state", UserWarning)
313
+ return torch.sparse_csr_tensor(
314
+ torch.from_dlpack(input.indptr),
315
+ torch.from_dlpack(input.indices),
316
+ torch.from_dlpack(input.data),
317
+ input.shape,
318
+ )
319
+ raise TypeError(f"Cannot convert {type(input)} to torch.Tensor")