annbatch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of annbatch might be problematic. Click here for more details.

annbatch/dense.py ADDED
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from importlib.util import find_spec
4
+ from typing import cast
5
+
6
+ import numpy as np
7
+ import zarr
8
+
9
+ if find_spec("torch"):
10
+ from torch.utils.data import IterableDataset as _IterableDataset
11
+ else:
12
+
13
+ class _IterableDataset:
14
+ pass
15
+
16
+
17
+ from annbatch.abc import AbstractIterableDataset, _assign_methods_to_ensure_unique_docstrings
18
+ from annbatch.utils import (
19
+ MultiBasicIndexer,
20
+ add_anndata_docstring,
21
+ add_anndatas_docstring,
22
+ add_dataset_docstring,
23
+ add_datasets_docstring,
24
+ )
25
+
26
+
27
+ class ZarrDenseDataset(AbstractIterableDataset[zarr.Array, np.ndarray], _IterableDataset): # noqa: D101
28
+ async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> np.ndarray:
29
+ dataset = self._dataset_manager.train_datasets[dataset_idx]
30
+ indexer = MultiBasicIndexer(
31
+ [
32
+ zarr.core.indexing.BasicIndexer(
33
+ (s, Ellipsis),
34
+ shape=dataset.metadata.shape,
35
+ chunk_grid=dataset.metadata.chunk_grid,
36
+ )
37
+ for s in slices
38
+ ]
39
+ )
40
+ res = cast(
41
+ "np.ndarray",
42
+ await dataset._async_array._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()),
43
+ )
44
+ return res
45
+
46
+ def _validate(self, datasets: list[zarr.Array]):
47
+ if not all(isinstance(d, zarr.Array) for d in datasets):
48
+ raise TypeError("Cannot create dense dataset without using a zarr.Array")
49
+
50
+
51
+ _assign_methods_to_ensure_unique_docstrings(ZarrDenseDataset)
52
+
53
+
54
+ ZarrDenseDataset.__doc__ = AbstractIterableDataset.__init__.__doc__.format(
55
+ array_type="dense", child_class="ZarrDenseDataset"
56
+ )
57
+ ZarrDenseDataset.add_datasets.__doc__ = add_datasets_docstring.format(on_disk_array_type="zarr.Array")
58
+ ZarrDenseDataset.add_dataset.__doc__ = add_dataset_docstring.format(on_disk_array_type="zarr.Array")
59
+ ZarrDenseDataset.add_anndatas.__doc__ = add_anndatas_docstring.format(on_disk_array_type="zarr.Array")
60
+ ZarrDenseDataset.add_anndata.__doc__ = add_anndata_docstring.format(on_disk_array_type="zarr.Array")
61
+ ZarrDenseDataset.__iter__.__doc__ = AbstractIterableDataset.__iter__.__doc__.format(
62
+ gpu_array="cupy.ndarray", cpu_array="numpy.ndarray"
63
+ )
annbatch/io.py ADDED
@@ -0,0 +1,474 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ import warnings
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ import anndata as ad
11
+ import dask.array as da
12
+ import numpy as np
13
+ import pandas as pd
14
+ import scipy.sparse as sp
15
+ import zarr
16
+ from anndata.experimental.backed import Dataset2D
17
+ from dask.array.core import Array as DaskArray
18
+ from tqdm import tqdm
19
+ from zarr.codecs import BloscCodec, BloscShuffle
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable, Iterable, Mapping
23
+ from os import PathLike
24
+ from typing import Any, Literal
25
+
26
+ from zarr.abc.codec import BytesBytesCodec
27
+
28
+
29
+ def _round_down(num: int, divisor: int):
30
+ return num - (num % divisor)
31
+
32
+
33
+ def write_sharded(
34
+ group: zarr.Group,
35
+ adata: ad.AnnData,
36
+ *,
37
+ sparse_chunk_size: int = 32768,
38
+ sparse_shard_size: int = 134_217_728,
39
+ dense_chunk_size: int = 1024,
40
+ dense_shard_size: int = 4194304,
41
+ compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
42
+ ):
43
+ """Write a sharded zarr store from a single AnnData object.
44
+
45
+ Parameters
46
+ ----------
47
+ group
48
+ The destination group, must be zarr v3
49
+ adata
50
+ The source anndata object
51
+ sparse_chunk_size
52
+ Chunk size of `indices` and `data` inside a shard.
53
+ sparse_shard_size
54
+ Shard size i.e., number of elements in a single sparse `data` or `indices` file.
55
+ dense_chunk_size
56
+ Number of obs elements per dense chunk along the first axis
57
+ dense_shard_size
58
+ Number of obs elements per dense shard along the first axis
59
+ compressors
60
+ The compressors to pass to `zarr`.
61
+ """
62
+ ad.settings.zarr_write_format = 3
63
+
64
+ def callback(
65
+ write_func: ad.experimental.Write,
66
+ store: zarr.Group,
67
+ elem_name: str,
68
+ elem: ad.typing.RWAble,
69
+ dataset_kwargs: Mapping[str, Any],
70
+ *,
71
+ iospec: ad.experimental.IOSpec,
72
+ ):
73
+ # Ensure we're not overriding anything here
74
+ dataset_kwargs = dataset_kwargs.copy()
75
+ if iospec.encoding_type in {"array"} and (
76
+ any(n in store.name for n in {"obsm", "layers", "obsp"}) or "X" == elem_name
77
+ ):
78
+ # Get either the desired size or the next multiple down to ensure divisibility of chunks and shards
79
+ shard_size = min(dense_shard_size, _round_down(elem.shape[0], dense_chunk_size))
80
+ chunk_size = min(dense_chunk_size, _round_down(elem.shape[0], dense_chunk_size))
81
+ # If the shape is less than the computed size (impossible given rounds?) or the rounding caused created a 0-size chunk, then error
82
+ if elem.shape[0] < chunk_size or chunk_size == 0:
83
+ raise ValueError(
84
+ f"Choose a dense shard obs {dense_shard_size} and chunk obs {dense_chunk_size} with non-zero size less than the number of observations {elem.shape[0]}"
85
+ )
86
+ dataset_kwargs = {
87
+ **dataset_kwargs,
88
+ "shards": (shard_size,) + elem.shape[1:], # only shard over 1st dim
89
+ "chunks": (chunk_size,) + elem.shape[1:], # only chunk over 1st dim
90
+ "compressors": compressors,
91
+ }
92
+ elif iospec.encoding_type in {"csr_matrix", "csc_matrix"}:
93
+ dataset_kwargs = {
94
+ **dataset_kwargs,
95
+ "shards": (sparse_shard_size,),
96
+ "chunks": (sparse_chunk_size,),
97
+ "compressors": compressors,
98
+ }
99
+ write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs)
100
+
101
+ ad.experimental.write_dispatched(group, "/", adata, callback=callback)
102
+ zarr.consolidate_metadata(group.store)
103
+
104
+
105
+ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]):
106
+ num_raw_in_adata = 0
107
+ found_keys: dict[str, defaultdict[str, int]] = {
108
+ "layers": defaultdict(lambda: 0),
109
+ "obsm": defaultdict(lambda: 0),
110
+ "obs": defaultdict(lambda: 0),
111
+ }
112
+ for path_or_anndata in paths_or_anndatas:
113
+ if not isinstance(path_or_anndata, ad.AnnData):
114
+ adata = ad.experimental.read_lazy(path_or_anndata)
115
+ else:
116
+ adata = path_or_anndata
117
+ for elem_name, key_count in found_keys.items():
118
+ curr_keys = set(getattr(adata, elem_name).keys())
119
+ for key in curr_keys:
120
+ key_count[key] += 1
121
+ if adata.raw is not None:
122
+ num_raw_in_adata += 1
123
+ if num_raw_in_adata != len(paths_or_anndatas) and num_raw_in_adata != 0:
124
+ warnings.warn(
125
+ f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`",
126
+ stacklevel=2,
127
+ )
128
+ for elem_name, key_count in found_keys.items():
129
+ elem_keys_mismatched = [
130
+ key for key, count in key_count.items() if (count != len(paths_or_anndatas) and count != 0)
131
+ ]
132
+ if len(elem_keys_mismatched) > 0:
133
+ warnings.warn(
134
+ f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.",
135
+ stacklevel=2,
136
+ )
137
+
138
+
139
+ def _lazy_load_anndatas(
140
+ paths: Iterable[PathLike[str]] | Iterable[str],
141
+ load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy,
142
+ ):
143
+ adatas = []
144
+ for i, path in enumerate(paths):
145
+ adata = load_adata(path)
146
+ # Concatenating Dataset2D drops categoricals
147
+ if isinstance(adata.obs, Dataset2D):
148
+ adata.obs = adata.obs.to_memory()
149
+ adata.obs["src_path"] = pd.Categorical.from_codes([i] * adata.shape[0], categories=[str(p) for p in paths])
150
+ adatas.append(adata)
151
+ if len(adatas) == 1:
152
+ return adatas[0]
153
+ return ad.concat(adatas, join="outer")
154
+
155
+
156
+ def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True):
157
+ chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0]))
158
+ slices = [
159
+ slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True)
160
+ ]
161
+ if shuffle:
162
+ random.shuffle(slices)
163
+ idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices])
164
+ idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset))
165
+
166
+ return idxs
167
+
168
+
169
+ def _compute_blockwise(x: DaskArray) -> sp.spmatrix:
170
+ """.compute() for large datasets is bad: https://github.com/scverse/annbatch/pull/75"""
171
+ return sp.vstack(da.compute(*list(x.blocks)))
172
+
173
+
174
+ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
175
+ if isinstance(adata.X, DaskArray):
176
+ if isinstance(adata.X._meta, sp.csr_matrix | sp.csr_array):
177
+ adata.X = _compute_blockwise(adata.X)
178
+ else:
179
+ adata.X = adata.X.compute()
180
+ if isinstance(adata.obs, Dataset2D):
181
+ adata.obs = adata.obs.to_memory()
182
+ if isinstance(adata.var, Dataset2D):
183
+ adata.var = adata.var.to_memory()
184
+
185
+ if adata.raw is not None:
186
+ adata_raw = adata.raw.to_adata()
187
+ if isinstance(adata_raw.X, DaskArray):
188
+ if isinstance(adata_raw.X._meta, sp.csr_array | sp.csr_matrix):
189
+ adata_raw.X = _compute_blockwise(adata_raw.X)
190
+ else:
191
+ adata_raw.X = adata_raw.X.compute()
192
+ if isinstance(adata_raw.var, Dataset2D):
193
+ adata_raw.var = adata_raw.var.to_memory()
194
+ if isinstance(adata_raw.obs, Dataset2D):
195
+ adata_raw.obs = adata_raw.obs.to_memory()
196
+ del adata.raw
197
+ adata.raw = adata_raw
198
+
199
+ for k, elem in adata.obsm.items():
200
+ # TODO: handle `Dataset2D` in `obsm` and `varm` that are
201
+ if isinstance(elem, DaskArray):
202
+ if isinstance(elem, sp.csr_matrix | sp.csr_array):
203
+ adata.obsm[k] = _compute_blockwise(elem)
204
+ else:
205
+ adata.obsm[k] = elem.compute()
206
+
207
+ for k, elem in adata.layers.items():
208
+ if isinstance(elem, DaskArray):
209
+ if isinstance(elem, sp.csr_matrix | sp.csr_array):
210
+ adata.layers[k] = _compute_blockwise(elem)
211
+ else:
212
+ adata.layers[k] = elem.compute()
213
+
214
+ return adata
215
+
216
+
217
+ DATASET_PREFIX = "dataset"
218
+
219
+
220
+ def create_anndata_collection(
221
+ adata_paths: Iterable[PathLike[str]] | Iterable[str],
222
+ output_path: PathLike[str] | str,
223
+ *,
224
+ load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy,
225
+ var_subset: Iterable[str] | None = None,
226
+ zarr_sparse_chunk_size: int = 32768,
227
+ zarr_sparse_shard_size: int = 134_217_728,
228
+ zarr_dense_chunk_size: int = 1024,
229
+ zarr_dense_shard_size: int = 4_194_304,
230
+ zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
231
+ h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
232
+ n_obs_per_dataset: int = 2_097_152,
233
+ shuffle: bool = True,
234
+ should_denseify: bool = False,
235
+ output_format: Literal["h5ad", "zarr"] = "zarr",
236
+ ):
237
+ """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store.
238
+
239
+ The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`.
240
+ The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function.
241
+ However, this function can also output h5 datasets and also unshuffled datasets as well.
242
+ The var space is by default outer-joined, but can be subsetted by `var_subset`.
243
+ A key `src_path` is added to `obs` to indicate where individual row came from.
244
+ We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`.
245
+ Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk.
246
+
247
+ Parameters
248
+ ----------
249
+ adata_paths
250
+ Paths to the AnnData files used to create the zarr store.
251
+ output_path
252
+ Path to the output zarr store.
253
+ load_adata
254
+ Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used.
255
+ If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data.
256
+ The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`.
257
+ var_subset
258
+ Subset of gene names to include in the store. If None, all genes are included.
259
+ Genes are subset based on the `var_names` attribute of the concatenated AnnData object.
260
+ zarr_sparse_chunk_size
261
+ Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
262
+ zarr_sparse_shard_size
263
+ Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
264
+ zarr_dense_chunk_size
265
+ Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
266
+ zarr_dense_shard_size
267
+ Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
268
+ zarr_compressor
269
+ Compressors to use to compress the data in the zarr store.
270
+ h5ad_compressor
271
+ Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad.
272
+ n_obs_per_dataset
273
+ Number of observations to load into memory at once for shuffling / pre-processing.
274
+ The higher this number, the more memory is used, but the better the shuffling.
275
+ This corresponds to the size of the shards created.
276
+ shuffle
277
+ Whether to shuffle the data before writing it to the store.
278
+ should_denseify
279
+ Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing.
280
+ output_format
281
+ Format of the output store. Can be either "zarr" or "h5ad".
282
+
283
+ Examples
284
+ --------
285
+ >>> import anndata as ad
286
+ >>> from annbatch import create_anndata_collection
287
+ # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store
288
+ >>> def read_lazy_x_and_obs_only(path):
289
+ ... adata = ad.experimental.read_lazy(path)
290
+ ... return ad.AnnData(
291
+ ... X=adata.X,
292
+ ... obs=adata.obs.to_memory(),
293
+ ... var=adata.var.to_memory(),
294
+ ...)
295
+
296
+ >>> datasets = [
297
+ ... "path/to/first_adata.h5ad",
298
+ ... "path/to/second_adata.h5ad",
299
+ ... "path/to/third_adata.h5ad",
300
+ ... ]
301
+ >>> create_anndata_collection(
302
+ ... datasets,
303
+ ... "path/to/output/zarr_store",
304
+ ... load_adata=read_lazy_x_and_obs_only,
305
+ ...)
306
+ """
307
+ Path(output_path).mkdir(parents=True, exist_ok=True)
308
+ ad.settings.zarr_write_format = 3
309
+ _check_for_mismatched_keys(adata_paths)
310
+ adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
311
+ adata_concat.obs_names_make_unique()
312
+ chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle)
313
+
314
+ if var_subset is None:
315
+ var_subset = adata_concat.var_names
316
+
317
+ for i, chunk in enumerate(tqdm(chunks)):
318
+ var_mask = adata_concat.var_names.isin(var_subset)
319
+ # np.sort: It's more efficient to access elements sequentially from dask arrays
320
+ # The data will be shuffled later on, we just want the elements at this point
321
+ adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy()
322
+ adata_chunk = _persist_adata_in_memory(adata_chunk)
323
+ if shuffle:
324
+ # shuffle adata in memory to break up individual chunks
325
+ idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk)))
326
+ adata_chunk = adata_chunk[idxs]
327
+ # convert to dense format before writing to disk
328
+ if should_denseify:
329
+ # Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense
330
+ adata_chunk = adata_chunk.copy()
331
+ adata_chunk.X = da.from_array(
332
+ adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X
333
+ ).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype)
334
+
335
+ if output_format == "zarr":
336
+ f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w")
337
+ write_sharded(
338
+ f,
339
+ adata_chunk,
340
+ sparse_chunk_size=zarr_sparse_chunk_size,
341
+ sparse_shard_size=zarr_sparse_shard_size,
342
+ dense_chunk_size=zarr_dense_chunk_size,
343
+ dense_shard_size=zarr_dense_shard_size,
344
+ compressors=zarr_compressor,
345
+ )
346
+ elif output_format == "h5ad":
347
+ adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor)
348
+ else:
349
+ raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.")
350
+
351
+
352
+ def _get_array_encoding_type(path: PathLike[str] | str) -> str:
353
+ shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr"))
354
+ with open(shards[0] / "X" / "zarr.json") as f:
355
+ encoding = json.load(f)
356
+ return encoding["attributes"]["encoding-type"]
357
+
358
+
359
+ def add_to_collection(
360
+ adata_paths: Iterable[PathLike[str]] | Iterable[str],
361
+ output_path: PathLike[str] | str,
362
+ load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad,
363
+ zarr_sparse_chunk_size: int = 32768,
364
+ zarr_sparse_shard_size: int = 134_217_728,
365
+ zarr_dense_chunk_size: int = 1024,
366
+ zarr_dense_shard_size: int = 4_194_304,
367
+ zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
368
+ should_sparsify_output_in_memory: bool = False,
369
+ ) -> None:
370
+ """Add anndata files to an existing collection of sharded anndata zarr datasets.
371
+
372
+ The var space of the source anndata files will be adapted to the target store.
373
+
374
+ Parameters
375
+ ----------
376
+ adata_paths
377
+ Paths to the anndata files to be appended to the collection of output chunks.
378
+ output_path
379
+ Path to the output zarr store.
380
+ load_adata
381
+ Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used.
382
+ If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data.
383
+ The input to the function is a path to an anndata file, and the output is an anndata object.
384
+ If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead.
385
+ zarr_sparse_chunk_size
386
+ Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
387
+ zarr_sparse_shard_size
388
+ Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
389
+ zarr_dense_chunk_size
390
+ Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
391
+ zarr_dense_shard_size
392
+ Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
393
+ zarr_compressor
394
+ Compressors to use to compress the data in the zarr store.
395
+ should_sparsify_output_in_memory
396
+ This option is for testing only appending sparse files to dense stores.
397
+ To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing.
398
+
399
+ Examples
400
+ --------
401
+ >>> import anndata as ad
402
+ >>> from annbatch import add_to_collection
403
+ >>> datasets = [
404
+ ... "path/to/first_adata.h5ad",
405
+ ... "path/to/second_adata.h5ad",
406
+ ... "path/to/third_adata.h5ad",
407
+ ... ]
408
+ >>> add_to_collection(
409
+ ... datasets,
410
+ ... "path/to/output/zarr_store",
411
+ ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory
412
+ ...)
413
+ """
414
+ shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr"))
415
+ if len(shards) == 0:
416
+ raise ValueError(
417
+ "Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first."
418
+ )
419
+ encoding = _get_array_encoding_type(output_path)
420
+ if encoding == "array":
421
+ print("Detected array encoding type. Will convert to dense format before writing.")
422
+ # Check for mismatched keys among the inputs.
423
+ _check_for_mismatched_keys(adata_paths)
424
+
425
+ adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
426
+ # Check for mismatched keys between shards and the inputs.
427
+ _check_for_mismatched_keys([adata_concat] + shards)
428
+ if isinstance(adata_concat.X, DaskArray):
429
+ chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True)
430
+ else:
431
+ chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards))
432
+
433
+ adata_concat.obs_names_make_unique()
434
+ if encoding == "array":
435
+ if not should_sparsify_output_in_memory:
436
+ if isinstance(adata_concat.X, sp.spmatrix):
437
+ adata_concat.X = adata_concat.X.toarray()
438
+ elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix):
439
+ adata_concat.X = adata_concat.X.map_blocks(
440
+ lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype
441
+ )
442
+ elif encoding == "csr_matrix":
443
+ if isinstance(adata_concat.X, np.ndarray):
444
+ adata_concat.X = sp.csr_matrix(adata_concat.X)
445
+ elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray):
446
+ adata_concat.X = adata_concat.X.map_blocks(
447
+ sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype))
448
+ )
449
+
450
+ for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards)):
451
+ if should_sparsify_output_in_memory and encoding == "array":
452
+ adata_shard = _lazy_load_anndatas([shard])
453
+ adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute()
454
+ else:
455
+ adata_shard = ad.read_zarr(shard)
456
+
457
+ adata = ad.concat(
458
+ [adata_shard, adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)]], join="outer"
459
+ )
460
+ idxs_shuffled = np.random.default_rng().permutation(len(adata))
461
+ adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk
462
+ if should_sparsify_output_in_memory and encoding == "array":
463
+ adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute()
464
+
465
+ f = zarr.open_group(shard, mode="w")
466
+ write_sharded(
467
+ f,
468
+ adata,
469
+ sparse_chunk_size=zarr_sparse_chunk_size,
470
+ sparse_shard_size=zarr_sparse_shard_size,
471
+ dense_chunk_size=zarr_dense_chunk_size,
472
+ dense_shard_size=zarr_dense_shard_size,
473
+ compressors=zarr_compressor,
474
+ )