scdataloader 1.1.3__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py DELETED
@@ -1,540 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections import Counter
4
- from functools import reduce
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
7
-
8
- import numpy as np
9
- import pandas as pd
10
- from lamindb_setup.core.upath import UPath
11
-
12
- from lamindb.core.storage._anndata_accessor import (
13
- ArrayType,
14
- ArrayTypes,
15
- GroupType,
16
- GroupTypes,
17
- StorageType,
18
- _safer_read_index,
19
- registry,
20
- )
21
-
22
- if TYPE_CHECKING:
23
- from lamindb_setup.core.types import UPathStr
24
-
25
-
26
- class _Connect:
27
- def __init__(self, storage):
28
- if isinstance(storage, UPath):
29
- self.conn, self.store = registry.open("h5py", storage)
30
- self.to_close = True
31
- else:
32
- self.conn, self.store = None, storage
33
- self.to_close = False
34
-
35
- def __enter__(self):
36
- return self.store
37
-
38
- def __exit__(self, exc_type, exc_val, exc_tb):
39
- self.close()
40
-
41
- def close(self):
42
- if not self.to_close:
43
- return
44
- if hasattr(self.store, "close"):
45
- self.store.close()
46
- if hasattr(self.conn, "close"):
47
- self.conn.close()
48
-
49
-
50
- class MappedCollection:
51
- """Map-style collection for use in data loaders.
52
-
53
- This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
54
- <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
55
-
56
- If your `AnnData` collection is in the cloud, move them into a local cache
57
- first for faster access.
58
-
59
- `__getitem__` of the `MappedCollection` object takes a single integer index
60
- and returns a dictionary with the observation data sample for this index from
61
- the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
62
- (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
63
- for the index of the `AnnData` object containing this observation sample.
64
-
65
- .. note::
66
-
67
- For a guide, see :doc:`docs:scrna5`.
68
-
69
- For more convenient use within :class:`~lamindb.core.MappedCollection`,
70
- see :meth:`~lamindb.Collection.mapped`.
71
-
72
- This currently only works for collections of `AnnData` objects.
73
-
74
- The implementation was influenced by the `SCimilarity
75
- <https://github.com/Genentech/scimilarity>`__ data loader.
76
-
77
-
78
- Args:
79
- path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
80
- layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
81
- retrieves ``.X``.
82
- obsm_keys: Keys from the ``.obsm`` slots.
83
- obs_keys: Keys from the ``.obs`` slots.
84
- join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
85
- does not join.
86
- encode_labels: Encode labels into integers.
87
- Can be a list with elements from ``obs_keys``.
88
- unknown_label: Encode this label to -1.
89
- Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
90
- or from ``encode_labels`` if it is a list.
91
- cache_categories: Enable caching categories of ``obs_keys`` for faster access.
92
- parallel: Enable sampling with multiple processes.
93
- dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
94
- """
95
-
96
- def __init__(
97
- self,
98
- path_list: list[UPathStr],
99
- layers_keys: str | list[str] | None = None,
100
- obs_keys: str | list[str] | None = None,
101
- obsm_keys: str | list[str] | None = None,
102
- join: Literal["inner", "outer"] | None = "inner",
103
- encode_labels: bool | list[str] = True,
104
- unknown_label: str | dict[str, str] | None = None,
105
- cache_categories: bool = True,
106
- parallel: bool = False,
107
- dtype: str | None = None,
108
- ):
109
- if join not in {None, "inner", "outer"}: # pragma: nocover
110
- raise ValueError(
111
- f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
112
- )
113
-
114
- if layers_keys is None:
115
- self.layers_keys = ["X"]
116
- else:
117
- self.layers_keys = (
118
- [layers_keys] if isinstance(layers_keys, str) else layers_keys
119
- )
120
-
121
- obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
122
- self.obsm_keys = obsm_keys
123
-
124
- obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
125
- self.obs_keys = obs_keys
126
-
127
- if isinstance(encode_labels, list):
128
- if len(encode_labels) == 0:
129
- encode_labels = False
130
- elif obs_keys is None or not all(
131
- enc_label in obs_keys for enc_label in encode_labels
132
- ):
133
- raise ValueError(
134
- "All elements of `encode_labels` should be in `obs_keys`."
135
- )
136
- else:
137
- if encode_labels:
138
- encode_labels = obs_keys if obs_keys is not None else False
139
- self.encode_labels = encode_labels
140
-
141
- if encode_labels and isinstance(unknown_label, dict):
142
- if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
143
- raise ValueError(
144
- "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
145
- )
146
- self.unknown_label = unknown_label
147
-
148
- self.storages = [] # type: ignore
149
- self.conns = [] # type: ignore
150
- self.parallel = parallel
151
- self._path_list = path_list
152
- self._make_connections(path_list, parallel)
153
-
154
- self.n_obs_list = []
155
- for storage in self.storages:
156
- with _Connect(storage) as store:
157
- X = store["X"]
158
- index = (
159
- store["var"]["ensembl_gene_id"]
160
- if "ensembl_gene_id" in store["var"]
161
- else store["var"]["_index"]
162
- )
163
- if join is None:
164
- if not all(
165
- [
166
- i <= j
167
- for i, j in zip(
168
- index[:99],
169
- index[1:100],
170
- )
171
- ]
172
- ):
173
- raise ValueError("The variables are not sorted.")
174
- if isinstance(X, ArrayTypes): # type: ignore
175
- self.n_obs_list.append(X.shape[0])
176
- else:
177
- self.n_obs_list.append(X.attrs["shape"][0])
178
- self.n_obs = sum(self.n_obs_list)
179
-
180
- self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
181
- self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
182
-
183
- self.join_vars = join
184
- self.var_indices = None
185
- self.var_joint = None
186
- self.n_vars_list = None
187
- self.n_vars = None
188
- if self.join_vars is not None:
189
- self._make_join_vars()
190
- self.n_vars = len(self.var_joint)
191
-
192
- if self.obs_keys is not None:
193
- if cache_categories:
194
- self._cache_categories(self.obs_keys)
195
- else:
196
- self._cache_cats: dict = {}
197
- self.encoders: dict = {}
198
- if self.encode_labels:
199
- self._make_encoders(self.encode_labels) # type: ignore
200
-
201
- self._dtype = dtype
202
- self._closed = False
203
-
204
- def _make_connections(self, path_list: list, parallel: bool):
205
- for path in path_list:
206
- path = UPath(path)
207
- if path.exists() and path.is_file(): # type: ignore
208
- if parallel:
209
- conn, storage = None, path
210
- else:
211
- conn, storage = registry.open("h5py", path)
212
- else:
213
- conn, storage = registry.open("zarr", path)
214
- self.conns.append(conn)
215
- self.storages.append(storage)
216
-
217
- def _cache_categories(self, obs_keys: list):
218
- self._cache_cats = {}
219
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
220
- for label in obs_keys:
221
- self._cache_cats[label] = []
222
- for storage in self.storages:
223
- with _Connect(storage) as store:
224
- cats = self._get_categories(store, label)
225
- if cats is not None:
226
- cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
227
- self._cache_cats[label].append(cats)
228
-
229
- def _make_encoders(self, encode_labels: list):
230
- for label in encode_labels:
231
- cats = self.get_merged_categories(label)
232
- encoder = {}
233
- if isinstance(self.unknown_label, dict):
234
- unknown_label = self.unknown_label.get(label, None)
235
- else:
236
- unknown_label = self.unknown_label
237
- if unknown_label is not None and unknown_label in cats:
238
- cats.remove(unknown_label)
239
- encoder[unknown_label] = -1
240
- cats = list(cats)
241
- cats.sort()
242
- encoder.update({cat: i for i, cat in enumerate(cats)})
243
- self.encoders[label] = encoder
244
-
245
- def _make_join_vars(self):
246
- var_list = []
247
- self.n_vars_list = []
248
- for storage in self.storages:
249
- with _Connect(storage) as store:
250
- vars = _safer_read_index(store["var"])
251
- var_list.append(vars)
252
- self.n_vars_list.append(len(vars))
253
-
254
- self.var_joint = None
255
- vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
256
- if vars_eq:
257
- self.join_vars = None
258
- self.var_joint = var_list[0]
259
- return
260
-
261
- if self.join_vars == "inner":
262
- self.var_joint = reduce(pd.Index.intersection, var_list)
263
- if len(self.var_joint) == 0:
264
- raise ValueError(
265
- "The provided AnnData objects don't have shared varibales.\n"
266
- "Use join='outer'."
267
- )
268
- self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
269
- elif self.join_vars == "outer":
270
- self.var_joint = reduce(pd.Index.union, var_list)
271
- self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
272
-
273
- def _check_aligned_vars(self, vars: list):
274
- i = 0
275
- for storage in self.storages:
276
- with _Connect(storage) as store:
277
- if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
278
- i += 1
279
- print("{}% are aligned".format(i * 100 / len(self.storages)))
280
-
281
- def __len__(self):
282
- return self.n_obs
283
-
284
- @property
285
- def shape(self):
286
- """Shape of the (virtually aligned) dataset."""
287
- return (self.n_obs, self.n_vars)
288
-
289
- @property
290
- def original_shapes(self):
291
- """Shapes of the underlying AnnData objects."""
292
- if self.n_vars_list is None:
293
- n_vars_list = [None] * len(self.n_obs_list)
294
- else:
295
- n_vars_list = self.n_vars_list
296
- return list(zip(self.n_obs_list, n_vars_list))
297
-
298
- def __getitem__(self, idx: int):
299
- obs_idx = self.indices[idx]
300
- storage_idx = self.storage_idx[idx]
301
- if self.var_indices is not None:
302
- var_idxs_join = self.var_indices[storage_idx]
303
- else:
304
- var_idxs_join = None
305
-
306
- with _Connect(self.storages[storage_idx]) as store:
307
- out = {}
308
- for layers_key in self.layers_keys:
309
- lazy_data = (
310
- store["X"] if layers_key == "X" else store["layers"][layers_key]
311
- )
312
- out[layers_key] = self._get_data_idx(
313
- lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
314
- )
315
- if self.obsm_keys is not None:
316
- for obsm_key in self.obsm_keys:
317
- lazy_data = store["obsm"][obsm_key]
318
- out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
319
- out["_store_idx"] = storage_idx
320
- if self.obs_keys is not None:
321
- for label in self.obs_keys:
322
- if label in self._cache_cats:
323
- cats = self._cache_cats[label][storage_idx]
324
- if cats is None:
325
- cats = []
326
- else:
327
- cats = None
328
- label_idx = self._get_obs_idx(store, obs_idx, label, cats)
329
- if label in self.encoders:
330
- label_idx = self.encoders[label][label_idx]
331
- out[label] = label_idx
332
- return out
333
-
334
- def _get_data_idx(
335
- self,
336
- lazy_data: ArrayType | GroupType, # type: ignore
337
- idx: int,
338
- join_vars: Literal["inner", "outer"] | None = None,
339
- var_idxs_join: list | None = None,
340
- n_vars_out: int | None = None,
341
- ):
342
- """Get the index for the data."""
343
- if isinstance(lazy_data, ArrayTypes): # type: ignore
344
- lazy_data_idx = lazy_data[idx] # type: ignore
345
- if join_vars is None:
346
- result = lazy_data_idx
347
- if self._dtype is not None:
348
- result = result.astype(self._dtype, copy=False)
349
- elif join_vars == "outer":
350
- dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
351
- result = np.zeros(n_vars_out, dtype=dtype)
352
- result[var_idxs_join] = lazy_data_idx
353
- else: # inner join
354
- result = lazy_data_idx[var_idxs_join]
355
- if self._dtype is not None:
356
- result = result.astype(self._dtype, copy=False)
357
- return result
358
- else: # assume csr_matrix here
359
- data = lazy_data["data"] # type: ignore
360
- indices = lazy_data["indices"] # type: ignore
361
- indptr = lazy_data["indptr"] # type: ignore
362
- s = slice(*(indptr[idx : idx + 2]))
363
- data_s = data[s]
364
- dtype = data_s.dtype if self._dtype is None else self._dtype
365
- if join_vars == "outer":
366
- lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
367
- lazy_data_idx[var_idxs_join[indices[s]]] = data_s
368
- else:
369
- lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
370
- lazy_data_idx[indices[s]] = data_s
371
- if join_vars == "inner":
372
- lazy_data_idx = lazy_data_idx[var_idxs_join]
373
- return lazy_data_idx
374
-
375
- def _get_obs_idx(
376
- self,
377
- storage: StorageType,
378
- idx: int,
379
- label_key: str,
380
- categories: list | None = None,
381
- ):
382
- """Get the index for the label by key."""
383
- obs = storage["obs"] # type: ignore
384
- # how backwards compatible do we want to be here actually?
385
- if isinstance(obs, ArrayTypes): # type: ignore
386
- label = obs[idx][obs.dtype.names.index(label_key)]
387
- else:
388
- labels = obs[label_key]
389
- if isinstance(labels, ArrayTypes): # type: ignore
390
- label = labels[idx]
391
- else:
392
- label = labels["codes"][idx]
393
- if categories is not None:
394
- cats = categories
395
- else:
396
- cats = self._get_categories(storage, label_key)
397
- if cats is not None and len(cats) > 0:
398
- label = cats[label]
399
- if isinstance(label, bytes):
400
- label = label.decode("utf-8")
401
- return label
402
-
403
- def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
404
- """Get all weights for the given label keys."""
405
- if isinstance(obs_keys, str):
406
- obs_keys = [obs_keys]
407
- labels_list = []
408
- for label_key in obs_keys:
409
- labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
410
- labels_list.append(labels_to_str)
411
- if len(labels_list) > 1:
412
- labels = reduce(lambda a, b: a + b, labels_list)
413
- else:
414
- labels = labels_list[0]
415
-
416
- counter = Counter(labels) # type: ignore
417
- rn = {n: i for i, n in enumerate(counter.keys())}
418
- labels = np.array([rn[label] for label in labels])
419
- counter = np.array(list(counter.values()))
420
- weights = scaler / (counter + scaler)
421
- return weights, labels
422
-
423
- def get_merged_labels(self, label_key: str):
424
- """Get merged labels for `label_key` from all `.obs`."""
425
- labels_merge = []
426
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
427
- for i, storage in enumerate(self.storages):
428
- with _Connect(storage) as store:
429
- codes = self._get_codes(store, label_key)
430
- labels = decode(codes) if isinstance(codes[0], bytes) else codes
431
- if label_key in self._cache_cats:
432
- cats = self._cache_cats[label_key][i]
433
- else:
434
- cats = self._get_categories(store, label_key)
435
- if cats is not None:
436
- cats = decode(cats) if isinstance(cats[0], bytes) else cats
437
- labels = cats[labels]
438
- labels_merge.append(labels)
439
- return np.hstack(labels_merge)
440
-
441
- def get_merged_categories(self, label_key: str):
442
- """Get merged categories for `label_key` from all `.obs`."""
443
- cats_merge = set()
444
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
445
- for i, storage in enumerate(self.storages):
446
- with _Connect(storage) as store:
447
- if label_key in self._cache_cats:
448
- cats = self._cache_cats[label_key][i]
449
- else:
450
- cats = self._get_categories(store, label_key)
451
- if cats is not None:
452
- cats = decode(cats) if isinstance(cats[0], bytes) else cats
453
- cats_merge.update(cats)
454
- else:
455
- codes = self._get_codes(store, label_key)
456
- codes = decode(codes) if isinstance(codes[0], bytes) else codes
457
- cats_merge.update(codes)
458
- return cats_merge
459
-
460
- def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
461
- """Get categories."""
462
- obs = storage["obs"] # type: ignore
463
- if isinstance(obs, ArrayTypes): # type: ignore
464
- cat_key_uns = f"{label_key}_categories"
465
- if cat_key_uns in storage["uns"]: # type: ignore
466
- return storage["uns"][cat_key_uns] # type: ignore
467
- else:
468
- return None
469
- else:
470
- if "__categories" in obs:
471
- cats = obs["__categories"]
472
- if label_key in cats:
473
- return cats[label_key]
474
- else:
475
- return None
476
- labels = obs[label_key]
477
- if isinstance(labels, GroupTypes): # type: ignore
478
- if "categories" in labels:
479
- return labels["categories"]
480
- else:
481
- return None
482
- else:
483
- if "categories" in labels.attrs:
484
- return labels.attrs["categories"]
485
- else:
486
- return None
487
- return None
488
-
489
- def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
490
- """Get codes."""
491
- obs = storage["obs"] # type: ignore
492
- if isinstance(obs, ArrayTypes): # type: ignore
493
- label = obs[label_key]
494
- else:
495
- label = obs[label_key]
496
- if isinstance(label, ArrayTypes): # type: ignore
497
- return label[...]
498
- else:
499
- return label["codes"][...]
500
-
501
- def close(self):
502
- """Close connections to array streaming backend.
503
-
504
- No effect if `parallel=True`.
505
- """
506
- for storage in self.storages:
507
- if hasattr(storage, "close"):
508
- storage.close()
509
- for conn in self.conns:
510
- if hasattr(conn, "close"):
511
- conn.close()
512
- self._closed = True
513
-
514
- @property
515
- def closed(self):
516
- """Check if connections to array streaming backend are closed.
517
-
518
- Does not matter if `parallel=True`.
519
- """
520
- return self._closed
521
-
522
- def __enter__(self):
523
- return self
524
-
525
- def __exit__(self, exc_type, exc_val, exc_tb):
526
- self.close()
527
-
528
- @staticmethod
529
- def torch_worker_init_fn(worker_id):
530
- """`worker_init_fn` for `torch.utils.data.DataLoader`.
531
-
532
- Improves performance for `num_workers > 1`.
533
- """
534
- from torch.utils.data import get_worker_info
535
-
536
- mapped = get_worker_info().dataset
537
- mapped.parallel = False
538
- mapped.storages = []
539
- mapped.conns = []
540
- mapped._make_connections(mapped._path_list, parallel=False)