annbatch 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of annbatch might be problematic. Click here for more details.
- annbatch/__init__.py +15 -0
- annbatch/abc.py +228 -0
- annbatch/anndata_manager.py +396 -0
- annbatch/dense.py +63 -0
- annbatch/io.py +474 -0
- annbatch/sparse.py +160 -0
- annbatch/types.py +25 -0
- annbatch/utils.py +319 -0
- annbatch-0.0.1.dist-info/METADATA +214 -0
- annbatch-0.0.1.dist-info/RECORD +12 -0
- annbatch-0.0.1.dist-info/WHEEL +4 -0
- annbatch-0.0.1.dist-info/licenses/LICENSE +21 -0
annbatch/dense.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib.util import find_spec
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import zarr
|
|
8
|
+
|
|
9
|
+
if find_spec("torch"):
|
|
10
|
+
from torch.utils.data import IterableDataset as _IterableDataset
|
|
11
|
+
else:
|
|
12
|
+
|
|
13
|
+
class _IterableDataset:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from annbatch.abc import AbstractIterableDataset, _assign_methods_to_ensure_unique_docstrings
|
|
18
|
+
from annbatch.utils import (
|
|
19
|
+
MultiBasicIndexer,
|
|
20
|
+
add_anndata_docstring,
|
|
21
|
+
add_anndatas_docstring,
|
|
22
|
+
add_dataset_docstring,
|
|
23
|
+
add_datasets_docstring,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ZarrDenseDataset(AbstractIterableDataset[zarr.Array, np.ndarray], _IterableDataset): # noqa: D101
|
|
28
|
+
async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> np.ndarray:
|
|
29
|
+
dataset = self._dataset_manager.train_datasets[dataset_idx]
|
|
30
|
+
indexer = MultiBasicIndexer(
|
|
31
|
+
[
|
|
32
|
+
zarr.core.indexing.BasicIndexer(
|
|
33
|
+
(s, Ellipsis),
|
|
34
|
+
shape=dataset.metadata.shape,
|
|
35
|
+
chunk_grid=dataset.metadata.chunk_grid,
|
|
36
|
+
)
|
|
37
|
+
for s in slices
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
res = cast(
|
|
41
|
+
"np.ndarray",
|
|
42
|
+
await dataset._async_array._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
|
|
43
|
+
)
|
|
44
|
+
return res
|
|
45
|
+
|
|
46
|
+
def _validate(self, datasets: list[zarr.Array]):
|
|
47
|
+
if not all(isinstance(d, zarr.Array) for d in datasets):
|
|
48
|
+
raise TypeError("Cannot create dense dataset without using a zarr.Array")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_assign_methods_to_ensure_unique_docstrings(ZarrDenseDataset)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
ZarrDenseDataset.__doc__ = AbstractIterableDataset.__init__.__doc__.format(
|
|
55
|
+
array_type="dense", child_class="ZarrDenseDataset"
|
|
56
|
+
)
|
|
57
|
+
ZarrDenseDataset.add_datasets.__doc__ = add_datasets_docstring.format(on_disk_array_type="zarr.Array")
|
|
58
|
+
ZarrDenseDataset.add_dataset.__doc__ = add_dataset_docstring.format(on_disk_array_type="zarr.Array")
|
|
59
|
+
ZarrDenseDataset.add_anndatas.__doc__ = add_anndatas_docstring.format(on_disk_array_type="zarr.Array")
|
|
60
|
+
ZarrDenseDataset.add_anndata.__doc__ = add_anndata_docstring.format(on_disk_array_type="zarr.Array")
|
|
61
|
+
ZarrDenseDataset.__iter__.__doc__ = AbstractIterableDataset.__iter__.__doc__.format(
|
|
62
|
+
gpu_array="cupy.ndarray", cpu_array="numpy.ndarray"
|
|
63
|
+
)
|
annbatch/io.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
import warnings
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import anndata as ad
|
|
11
|
+
import dask.array as da
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import scipy.sparse as sp
|
|
15
|
+
import zarr
|
|
16
|
+
from anndata.experimental.backed import Dataset2D
|
|
17
|
+
from dask.array.core import Array as DaskArray
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from zarr.codecs import BloscCodec, BloscShuffle
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
23
|
+
from os import PathLike
|
|
24
|
+
from typing import Any, Literal
|
|
25
|
+
|
|
26
|
+
from zarr.abc.codec import BytesBytesCodec
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _round_down(num: int, divisor: int):
|
|
30
|
+
return num - (num % divisor)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def write_sharded(
|
|
34
|
+
group: zarr.Group,
|
|
35
|
+
adata: ad.AnnData,
|
|
36
|
+
*,
|
|
37
|
+
sparse_chunk_size: int = 32768,
|
|
38
|
+
sparse_shard_size: int = 134_217_728,
|
|
39
|
+
dense_chunk_size: int = 1024,
|
|
40
|
+
dense_shard_size: int = 4194304,
|
|
41
|
+
compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
|
|
42
|
+
):
|
|
43
|
+
"""Write a sharded zarr store from a single AnnData object.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
group
|
|
48
|
+
The destination group, must be zarr v3
|
|
49
|
+
adata
|
|
50
|
+
The source anndata object
|
|
51
|
+
sparse_chunk_size
|
|
52
|
+
Chunk size of `indices` and `data` inside a shard.
|
|
53
|
+
sparse_shard_size
|
|
54
|
+
Shard size i.e., number of elements in a single sparse `data` or `indices` file.
|
|
55
|
+
dense_chunk_size
|
|
56
|
+
Number of obs elements per dense chunk along the first axis
|
|
57
|
+
dense_shard_size
|
|
58
|
+
Number of obs elements per dense shard along the first axis
|
|
59
|
+
compressors
|
|
60
|
+
The compressors to pass to `zarr`.
|
|
61
|
+
"""
|
|
62
|
+
ad.settings.zarr_write_format = 3
|
|
63
|
+
|
|
64
|
+
def callback(
|
|
65
|
+
write_func: ad.experimental.Write,
|
|
66
|
+
store: zarr.Group,
|
|
67
|
+
elem_name: str,
|
|
68
|
+
elem: ad.typing.RWAble,
|
|
69
|
+
dataset_kwargs: Mapping[str, Any],
|
|
70
|
+
*,
|
|
71
|
+
iospec: ad.experimental.IOSpec,
|
|
72
|
+
):
|
|
73
|
+
# Ensure we're not overriding anything here
|
|
74
|
+
dataset_kwargs = dataset_kwargs.copy()
|
|
75
|
+
if iospec.encoding_type in {"array"} and (
|
|
76
|
+
any(n in store.name for n in {"obsm", "layers", "obsp"}) or "X" == elem_name
|
|
77
|
+
):
|
|
78
|
+
# Get either the desired size or the next multiple down to ensure divisibility of chunks and shards
|
|
79
|
+
shard_size = min(dense_shard_size, _round_down(elem.shape[0], dense_chunk_size))
|
|
80
|
+
chunk_size = min(dense_chunk_size, _round_down(elem.shape[0], dense_chunk_size))
|
|
81
|
+
# If the shape is less than the computed size (impossible given rounds?) or the rounding caused created a 0-size chunk, then error
|
|
82
|
+
if elem.shape[0] < chunk_size or chunk_size == 0:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Choose a dense shard obs {dense_shard_size} and chunk obs {dense_chunk_size} with non-zero size less than the number of observations {elem.shape[0]}"
|
|
85
|
+
)
|
|
86
|
+
dataset_kwargs = {
|
|
87
|
+
**dataset_kwargs,
|
|
88
|
+
"shards": (shard_size,) + elem.shape[1:], # only shard over 1st dim
|
|
89
|
+
"chunks": (chunk_size,) + elem.shape[1:], # only chunk over 1st dim
|
|
90
|
+
"compressors": compressors,
|
|
91
|
+
}
|
|
92
|
+
elif iospec.encoding_type in {"csr_matrix", "csc_matrix"}:
|
|
93
|
+
dataset_kwargs = {
|
|
94
|
+
**dataset_kwargs,
|
|
95
|
+
"shards": (sparse_shard_size,),
|
|
96
|
+
"chunks": (sparse_chunk_size,),
|
|
97
|
+
"compressors": compressors,
|
|
98
|
+
}
|
|
99
|
+
write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs)
|
|
100
|
+
|
|
101
|
+
ad.experimental.write_dispatched(group, "/", adata, callback=callback)
|
|
102
|
+
zarr.consolidate_metadata(group.store)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]):
|
|
106
|
+
num_raw_in_adata = 0
|
|
107
|
+
found_keys: dict[str, defaultdict[str, int]] = {
|
|
108
|
+
"layers": defaultdict(lambda: 0),
|
|
109
|
+
"obsm": defaultdict(lambda: 0),
|
|
110
|
+
"obs": defaultdict(lambda: 0),
|
|
111
|
+
}
|
|
112
|
+
for path_or_anndata in paths_or_anndatas:
|
|
113
|
+
if not isinstance(path_or_anndata, ad.AnnData):
|
|
114
|
+
adata = ad.experimental.read_lazy(path_or_anndata)
|
|
115
|
+
else:
|
|
116
|
+
adata = path_or_anndata
|
|
117
|
+
for elem_name, key_count in found_keys.items():
|
|
118
|
+
curr_keys = set(getattr(adata, elem_name).keys())
|
|
119
|
+
for key in curr_keys:
|
|
120
|
+
key_count[key] += 1
|
|
121
|
+
if adata.raw is not None:
|
|
122
|
+
num_raw_in_adata += 1
|
|
123
|
+
if num_raw_in_adata != len(paths_or_anndatas) and num_raw_in_adata != 0:
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`",
|
|
126
|
+
stacklevel=2,
|
|
127
|
+
)
|
|
128
|
+
for elem_name, key_count in found_keys.items():
|
|
129
|
+
elem_keys_mismatched = [
|
|
130
|
+
key for key, count in key_count.items() if (count != len(paths_or_anndatas) and count != 0)
|
|
131
|
+
]
|
|
132
|
+
if len(elem_keys_mismatched) > 0:
|
|
133
|
+
warnings.warn(
|
|
134
|
+
f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.",
|
|
135
|
+
stacklevel=2,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _lazy_load_anndatas(
|
|
140
|
+
paths: Iterable[PathLike[str]] | Iterable[str],
|
|
141
|
+
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy,
|
|
142
|
+
):
|
|
143
|
+
adatas = []
|
|
144
|
+
for i, path in enumerate(paths):
|
|
145
|
+
adata = load_adata(path)
|
|
146
|
+
# Concatenating Dataset2D drops categoricals
|
|
147
|
+
if isinstance(adata.obs, Dataset2D):
|
|
148
|
+
adata.obs = adata.obs.to_memory()
|
|
149
|
+
adata.obs["src_path"] = pd.Categorical.from_codes([i] * adata.shape[0], categories=[str(p) for p in paths])
|
|
150
|
+
adatas.append(adata)
|
|
151
|
+
if len(adatas) == 1:
|
|
152
|
+
return adatas[0]
|
|
153
|
+
return ad.concat(adatas, join="outer")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True):
|
|
157
|
+
chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0]))
|
|
158
|
+
slices = [
|
|
159
|
+
slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True)
|
|
160
|
+
]
|
|
161
|
+
if shuffle:
|
|
162
|
+
random.shuffle(slices)
|
|
163
|
+
idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices])
|
|
164
|
+
idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset))
|
|
165
|
+
|
|
166
|
+
return idxs
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _compute_blockwise(x: DaskArray) -> sp.spmatrix:
|
|
170
|
+
""".compute() for large datasets is bad: https://github.com/scverse/annbatch/pull/75"""
|
|
171
|
+
return sp.vstack(da.compute(*list(x.blocks)))
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
|
|
175
|
+
if isinstance(adata.X, DaskArray):
|
|
176
|
+
if isinstance(adata.X._meta, sp.csr_matrix | sp.csr_array):
|
|
177
|
+
adata.X = _compute_blockwise(adata.X)
|
|
178
|
+
else:
|
|
179
|
+
adata.X = adata.X.compute()
|
|
180
|
+
if isinstance(adata.obs, Dataset2D):
|
|
181
|
+
adata.obs = adata.obs.to_memory()
|
|
182
|
+
if isinstance(adata.var, Dataset2D):
|
|
183
|
+
adata.var = adata.var.to_memory()
|
|
184
|
+
|
|
185
|
+
if adata.raw is not None:
|
|
186
|
+
adata_raw = adata.raw.to_adata()
|
|
187
|
+
if isinstance(adata_raw.X, DaskArray):
|
|
188
|
+
if isinstance(adata_raw.X._meta, sp.csr_array | sp.csr_matrix):
|
|
189
|
+
adata_raw.X = _compute_blockwise(adata_raw.X)
|
|
190
|
+
else:
|
|
191
|
+
adata_raw.X = adata_raw.X.compute()
|
|
192
|
+
if isinstance(adata_raw.var, Dataset2D):
|
|
193
|
+
adata_raw.var = adata_raw.var.to_memory()
|
|
194
|
+
if isinstance(adata_raw.obs, Dataset2D):
|
|
195
|
+
adata_raw.obs = adata_raw.obs.to_memory()
|
|
196
|
+
del adata.raw
|
|
197
|
+
adata.raw = adata_raw
|
|
198
|
+
|
|
199
|
+
for k, elem in adata.obsm.items():
|
|
200
|
+
# TODO: handle `Dataset2D` in `obsm` and `varm` that are
|
|
201
|
+
if isinstance(elem, DaskArray):
|
|
202
|
+
if isinstance(elem, sp.csr_matrix | sp.csr_array):
|
|
203
|
+
adata.obsm[k] = _compute_blockwise(elem)
|
|
204
|
+
else:
|
|
205
|
+
adata.obsm[k] = elem.compute()
|
|
206
|
+
|
|
207
|
+
for k, elem in adata.layers.items():
|
|
208
|
+
if isinstance(elem, DaskArray):
|
|
209
|
+
if isinstance(elem, sp.csr_matrix | sp.csr_array):
|
|
210
|
+
adata.layers[k] = _compute_blockwise(elem)
|
|
211
|
+
else:
|
|
212
|
+
adata.layers[k] = elem.compute()
|
|
213
|
+
|
|
214
|
+
return adata
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
DATASET_PREFIX = "dataset"
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def create_anndata_collection(
|
|
221
|
+
adata_paths: Iterable[PathLike[str]] | Iterable[str],
|
|
222
|
+
output_path: PathLike[str] | str,
|
|
223
|
+
*,
|
|
224
|
+
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy,
|
|
225
|
+
var_subset: Iterable[str] | None = None,
|
|
226
|
+
zarr_sparse_chunk_size: int = 32768,
|
|
227
|
+
zarr_sparse_shard_size: int = 134_217_728,
|
|
228
|
+
zarr_dense_chunk_size: int = 1024,
|
|
229
|
+
zarr_dense_shard_size: int = 4_194_304,
|
|
230
|
+
zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
|
|
231
|
+
h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
|
|
232
|
+
n_obs_per_dataset: int = 2_097_152,
|
|
233
|
+
shuffle: bool = True,
|
|
234
|
+
should_denseify: bool = False,
|
|
235
|
+
output_format: Literal["h5ad", "zarr"] = "zarr",
|
|
236
|
+
):
|
|
237
|
+
"""Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store.
|
|
238
|
+
|
|
239
|
+
The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`.
|
|
240
|
+
The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function.
|
|
241
|
+
However, this function can also output h5 datasets and also unshuffled datasets as well.
|
|
242
|
+
The var space is by default outer-joined, but can be subsetted by `var_subset`.
|
|
243
|
+
A key `src_path` is added to `obs` to indicate where individual row came from.
|
|
244
|
+
We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`.
|
|
245
|
+
Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk.
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
adata_paths
|
|
250
|
+
Paths to the AnnData files used to create the zarr store.
|
|
251
|
+
output_path
|
|
252
|
+
Path to the output zarr store.
|
|
253
|
+
load_adata
|
|
254
|
+
Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used.
|
|
255
|
+
If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data.
|
|
256
|
+
The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`.
|
|
257
|
+
var_subset
|
|
258
|
+
Subset of gene names to include in the store. If None, all genes are included.
|
|
259
|
+
Genes are subset based on the `var_names` attribute of the concatenated AnnData object.
|
|
260
|
+
zarr_sparse_chunk_size
|
|
261
|
+
Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
|
|
262
|
+
zarr_sparse_shard_size
|
|
263
|
+
Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
|
|
264
|
+
zarr_dense_chunk_size
|
|
265
|
+
Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
|
|
266
|
+
zarr_dense_shard_size
|
|
267
|
+
Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
|
|
268
|
+
zarr_compressor
|
|
269
|
+
Compressors to use to compress the data in the zarr store.
|
|
270
|
+
h5ad_compressor
|
|
271
|
+
Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad.
|
|
272
|
+
n_obs_per_dataset
|
|
273
|
+
Number of observations to load into memory at once for shuffling / pre-processing.
|
|
274
|
+
The higher this number, the more memory is used, but the better the shuffling.
|
|
275
|
+
This corresponds to the size of the shards created.
|
|
276
|
+
shuffle
|
|
277
|
+
Whether to shuffle the data before writing it to the store.
|
|
278
|
+
should_denseify
|
|
279
|
+
Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing.
|
|
280
|
+
output_format
|
|
281
|
+
Format of the output store. Can be either "zarr" or "h5ad".
|
|
282
|
+
|
|
283
|
+
Examples
|
|
284
|
+
--------
|
|
285
|
+
>>> import anndata as ad
|
|
286
|
+
>>> from annbatch import create_anndata_collection
|
|
287
|
+
# create a custom load function to only keep `.X`, `.obs` and `.var` in the output store
|
|
288
|
+
>>> def read_lazy_x_and_obs_only(path):
|
|
289
|
+
... adata = ad.experimental.read_lazy(path)
|
|
290
|
+
... return ad.AnnData(
|
|
291
|
+
... X=adata.X,
|
|
292
|
+
... obs=adata.obs.to_memory(),
|
|
293
|
+
... var=adata.var.to_memory(),
|
|
294
|
+
...)
|
|
295
|
+
|
|
296
|
+
>>> datasets = [
|
|
297
|
+
... "path/to/first_adata.h5ad",
|
|
298
|
+
... "path/to/second_adata.h5ad",
|
|
299
|
+
... "path/to/third_adata.h5ad",
|
|
300
|
+
... ]
|
|
301
|
+
>>> create_anndata_collection(
|
|
302
|
+
... datasets,
|
|
303
|
+
... "path/to/output/zarr_store",
|
|
304
|
+
... load_adata=read_lazy_x_and_obs_only,
|
|
305
|
+
...)
|
|
306
|
+
"""
|
|
307
|
+
Path(output_path).mkdir(parents=True, exist_ok=True)
|
|
308
|
+
ad.settings.zarr_write_format = 3
|
|
309
|
+
_check_for_mismatched_keys(adata_paths)
|
|
310
|
+
adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
|
|
311
|
+
adata_concat.obs_names_make_unique()
|
|
312
|
+
chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle)
|
|
313
|
+
|
|
314
|
+
if var_subset is None:
|
|
315
|
+
var_subset = adata_concat.var_names
|
|
316
|
+
|
|
317
|
+
for i, chunk in enumerate(tqdm(chunks)):
|
|
318
|
+
var_mask = adata_concat.var_names.isin(var_subset)
|
|
319
|
+
# np.sort: It's more efficient to access elements sequentially from dask arrays
|
|
320
|
+
# The data will be shuffled later on, we just want the elements at this point
|
|
321
|
+
adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy()
|
|
322
|
+
adata_chunk = _persist_adata_in_memory(adata_chunk)
|
|
323
|
+
if shuffle:
|
|
324
|
+
# shuffle adata in memory to break up individual chunks
|
|
325
|
+
idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk)))
|
|
326
|
+
adata_chunk = adata_chunk[idxs]
|
|
327
|
+
# convert to dense format before writing to disk
|
|
328
|
+
if should_denseify:
|
|
329
|
+
# Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense
|
|
330
|
+
adata_chunk = adata_chunk.copy()
|
|
331
|
+
adata_chunk.X = da.from_array(
|
|
332
|
+
adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X
|
|
333
|
+
).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype)
|
|
334
|
+
|
|
335
|
+
if output_format == "zarr":
|
|
336
|
+
f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w")
|
|
337
|
+
write_sharded(
|
|
338
|
+
f,
|
|
339
|
+
adata_chunk,
|
|
340
|
+
sparse_chunk_size=zarr_sparse_chunk_size,
|
|
341
|
+
sparse_shard_size=zarr_sparse_shard_size,
|
|
342
|
+
dense_chunk_size=zarr_dense_chunk_size,
|
|
343
|
+
dense_shard_size=zarr_dense_shard_size,
|
|
344
|
+
compressors=zarr_compressor,
|
|
345
|
+
)
|
|
346
|
+
elif output_format == "h5ad":
|
|
347
|
+
adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor)
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _get_array_encoding_type(path: PathLike[str] | str) -> str:
|
|
353
|
+
shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr"))
|
|
354
|
+
with open(shards[0] / "X" / "zarr.json") as f:
|
|
355
|
+
encoding = json.load(f)
|
|
356
|
+
return encoding["attributes"]["encoding-type"]
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def add_to_collection(
|
|
360
|
+
adata_paths: Iterable[PathLike[str]] | Iterable[str],
|
|
361
|
+
output_path: PathLike[str] | str,
|
|
362
|
+
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad,
|
|
363
|
+
zarr_sparse_chunk_size: int = 32768,
|
|
364
|
+
zarr_sparse_shard_size: int = 134_217_728,
|
|
365
|
+
zarr_dense_chunk_size: int = 1024,
|
|
366
|
+
zarr_dense_shard_size: int = 4_194_304,
|
|
367
|
+
zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
|
|
368
|
+
should_sparsify_output_in_memory: bool = False,
|
|
369
|
+
) -> None:
|
|
370
|
+
"""Add anndata files to an existing collection of sharded anndata zarr datasets.
|
|
371
|
+
|
|
372
|
+
The var space of the source anndata files will be adapted to the target store.
|
|
373
|
+
|
|
374
|
+
Parameters
|
|
375
|
+
----------
|
|
376
|
+
adata_paths
|
|
377
|
+
Paths to the anndata files to be appended to the collection of output chunks.
|
|
378
|
+
output_path
|
|
379
|
+
Path to the output zarr store.
|
|
380
|
+
load_adata
|
|
381
|
+
Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used.
|
|
382
|
+
If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data.
|
|
383
|
+
The input to the function is a path to an anndata file, and the output is an anndata object.
|
|
384
|
+
If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead.
|
|
385
|
+
zarr_sparse_chunk_size
|
|
386
|
+
Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
|
|
387
|
+
zarr_sparse_shard_size
|
|
388
|
+
Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
|
|
389
|
+
zarr_dense_chunk_size
|
|
390
|
+
Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
|
|
391
|
+
zarr_dense_shard_size
|
|
392
|
+
Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
|
|
393
|
+
zarr_compressor
|
|
394
|
+
Compressors to use to compress the data in the zarr store.
|
|
395
|
+
should_sparsify_output_in_memory
|
|
396
|
+
This option is for testing only appending sparse files to dense stores.
|
|
397
|
+
To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing.
|
|
398
|
+
|
|
399
|
+
Examples
|
|
400
|
+
--------
|
|
401
|
+
>>> import anndata as ad
|
|
402
|
+
>>> from annbatch import add_to_collection
|
|
403
|
+
>>> datasets = [
|
|
404
|
+
... "path/to/first_adata.h5ad",
|
|
405
|
+
... "path/to/second_adata.h5ad",
|
|
406
|
+
... "path/to/third_adata.h5ad",
|
|
407
|
+
... ]
|
|
408
|
+
>>> add_to_collection(
|
|
409
|
+
... datasets,
|
|
410
|
+
... "path/to/output/zarr_store",
|
|
411
|
+
... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory
|
|
412
|
+
...)
|
|
413
|
+
"""
|
|
414
|
+
shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr"))
|
|
415
|
+
if len(shards) == 0:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
"Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first."
|
|
418
|
+
)
|
|
419
|
+
encoding = _get_array_encoding_type(output_path)
|
|
420
|
+
if encoding == "array":
|
|
421
|
+
print("Detected array encoding type. Will convert to dense format before writing.")
|
|
422
|
+
# Check for mismatched keys among the inputs.
|
|
423
|
+
_check_for_mismatched_keys(adata_paths)
|
|
424
|
+
|
|
425
|
+
adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
|
|
426
|
+
# Check for mismatched keys between shards and the inputs.
|
|
427
|
+
_check_for_mismatched_keys([adata_concat] + shards)
|
|
428
|
+
if isinstance(adata_concat.X, DaskArray):
|
|
429
|
+
chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True)
|
|
430
|
+
else:
|
|
431
|
+
chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards))
|
|
432
|
+
|
|
433
|
+
adata_concat.obs_names_make_unique()
|
|
434
|
+
if encoding == "array":
|
|
435
|
+
if not should_sparsify_output_in_memory:
|
|
436
|
+
if isinstance(adata_concat.X, sp.spmatrix):
|
|
437
|
+
adata_concat.X = adata_concat.X.toarray()
|
|
438
|
+
elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix):
|
|
439
|
+
adata_concat.X = adata_concat.X.map_blocks(
|
|
440
|
+
lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype
|
|
441
|
+
)
|
|
442
|
+
elif encoding == "csr_matrix":
|
|
443
|
+
if isinstance(adata_concat.X, np.ndarray):
|
|
444
|
+
adata_concat.X = sp.csr_matrix(adata_concat.X)
|
|
445
|
+
elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray):
|
|
446
|
+
adata_concat.X = adata_concat.X.map_blocks(
|
|
447
|
+
sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype))
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards)):
|
|
451
|
+
if should_sparsify_output_in_memory and encoding == "array":
|
|
452
|
+
adata_shard = _lazy_load_anndatas([shard])
|
|
453
|
+
adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute()
|
|
454
|
+
else:
|
|
455
|
+
adata_shard = ad.read_zarr(shard)
|
|
456
|
+
|
|
457
|
+
adata = ad.concat(
|
|
458
|
+
[adata_shard, adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)]], join="outer"
|
|
459
|
+
)
|
|
460
|
+
idxs_shuffled = np.random.default_rng().permutation(len(adata))
|
|
461
|
+
adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk
|
|
462
|
+
if should_sparsify_output_in_memory and encoding == "array":
|
|
463
|
+
adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute()
|
|
464
|
+
|
|
465
|
+
f = zarr.open_group(shard, mode="w")
|
|
466
|
+
write_sharded(
|
|
467
|
+
f,
|
|
468
|
+
adata,
|
|
469
|
+
sparse_chunk_size=zarr_sparse_chunk_size,
|
|
470
|
+
sparse_shard_size=zarr_sparse_shard_size,
|
|
471
|
+
dense_chunk_size=zarr_dense_chunk_size,
|
|
472
|
+
dense_shard_size=zarr_dense_shard_size,
|
|
473
|
+
compressors=zarr_compressor,
|
|
474
|
+
)
|