lamindb 0.76.8__py3-none-any.whl → 0.76.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. lamindb/__init__.py +113 -113
  2. lamindb/_artifact.py +1205 -1205
  3. lamindb/_can_validate.py +579 -579
  4. lamindb/_collection.py +389 -387
  5. lamindb/_curate.py +1601 -1601
  6. lamindb/_feature.py +155 -155
  7. lamindb/_feature_set.py +242 -242
  8. lamindb/_filter.py +23 -23
  9. lamindb/_finish.py +256 -256
  10. lamindb/_from_values.py +382 -382
  11. lamindb/_is_versioned.py +40 -40
  12. lamindb/_parents.py +476 -476
  13. lamindb/_query_manager.py +125 -125
  14. lamindb/_query_set.py +362 -362
  15. lamindb/_record.py +649 -649
  16. lamindb/_run.py +57 -57
  17. lamindb/_save.py +308 -308
  18. lamindb/_storage.py +14 -14
  19. lamindb/_transform.py +127 -127
  20. lamindb/_ulabel.py +56 -56
  21. lamindb/_utils.py +9 -9
  22. lamindb/_view.py +72 -72
  23. lamindb/core/__init__.py +94 -94
  24. lamindb/core/_context.py +574 -574
  25. lamindb/core/_data.py +438 -438
  26. lamindb/core/_feature_manager.py +867 -867
  27. lamindb/core/_label_manager.py +253 -253
  28. lamindb/core/_mapped_collection.py +631 -597
  29. lamindb/core/_settings.py +187 -187
  30. lamindb/core/_sync_git.py +138 -138
  31. lamindb/core/_track_environment.py +27 -27
  32. lamindb/core/datasets/__init__.py +59 -59
  33. lamindb/core/datasets/_core.py +581 -571
  34. lamindb/core/datasets/_fake.py +36 -36
  35. lamindb/core/exceptions.py +90 -90
  36. lamindb/core/fields.py +12 -12
  37. lamindb/core/loaders.py +164 -164
  38. lamindb/core/schema.py +56 -56
  39. lamindb/core/storage/__init__.py +25 -25
  40. lamindb/core/storage/_anndata_accessor.py +740 -740
  41. lamindb/core/storage/_anndata_sizes.py +41 -41
  42. lamindb/core/storage/_backed_access.py +98 -98
  43. lamindb/core/storage/_tiledbsoma.py +204 -204
  44. lamindb/core/storage/_valid_suffixes.py +21 -21
  45. lamindb/core/storage/_zarr.py +110 -110
  46. lamindb/core/storage/objects.py +62 -62
  47. lamindb/core/storage/paths.py +172 -172
  48. lamindb/core/subsettings/__init__.py +12 -12
  49. lamindb/core/subsettings/_creation_settings.py +38 -38
  50. lamindb/core/subsettings/_transform_settings.py +21 -21
  51. lamindb/core/types.py +19 -19
  52. lamindb/core/versioning.py +158 -158
  53. lamindb/integrations/__init__.py +12 -12
  54. lamindb/integrations/_vitessce.py +107 -107
  55. lamindb/setup/__init__.py +14 -14
  56. lamindb/setup/core/__init__.py +4 -4
  57. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/LICENSE +201 -201
  58. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/METADATA +4 -4
  59. lamindb-0.76.9.dist-info/RECORD +60 -0
  60. {lamindb-0.76.8.dist-info → lamindb-0.76.9.dist-info}/WHEEL +1 -1
  61. lamindb-0.76.8.dist-info/RECORD +0 -60
@@ -1,597 +1,631 @@
1
- from __future__ import annotations
2
-
3
- from collections import Counter
4
- from functools import reduce
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
7
-
8
- import numpy as np
9
- import pandas as pd
10
- from lamin_utils import logger
11
- from lamindb_setup.core.upath import UPath
12
-
13
- from .storage._anndata_accessor import (
14
- ArrayType,
15
- ArrayTypes,
16
- GroupType,
17
- GroupTypes,
18
- StorageType,
19
- _safer_read_index,
20
- get_spec,
21
- registry,
22
- )
23
-
24
- if TYPE_CHECKING:
25
- from lamindb_setup.core.types import UPathStr
26
-
27
-
28
- class _Connect:
29
- def __init__(self, storage):
30
- if isinstance(storage, UPath):
31
- self.conn, self.store = registry.open("h5py", storage)
32
- self.to_close = True
33
- else:
34
- self.conn, self.store = None, storage
35
- self.to_close = False
36
-
37
- def __enter__(self):
38
- return self.store
39
-
40
- def __exit__(self, exc_type, exc_val, exc_tb):
41
- self.close()
42
-
43
- def close(self):
44
- if not self.to_close:
45
- return
46
- if hasattr(self.store, "close"):
47
- self.store.close()
48
- if hasattr(self.conn, "close"):
49
- self.conn.close()
50
-
51
-
52
- class MappedCollection:
53
- """Map-style collection for use in data loaders.
54
-
55
- This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
56
- <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
57
-
58
- If your `AnnData` collection is in the cloud, move them into a local cache
59
- first for faster access.
60
-
61
- `__getitem__` of the `MappedCollection` object takes a single integer index
62
- and returns a dictionary with the observation data sample for this index from
63
- the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
64
- (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
65
- for the index of the `AnnData` object containing this observation sample.
66
-
67
- .. note::
68
-
69
- For a guide, see :doc:`docs:scrna5`.
70
-
71
- For more convenient use within :class:`~lamindb.core.MappedCollection`,
72
- see :meth:`~lamindb.Collection.mapped`.
73
-
74
- This currently only works for collections of `AnnData` objects.
75
-
76
- The implementation was influenced by the `SCimilarity
77
- <https://github.com/Genentech/scimilarity>`__ data loader.
78
-
79
-
80
- Args:
81
- path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
82
- layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
83
- retrieves ``.X``.
84
- obsm_keys: Keys from the ``.obsm`` slots.
85
- obs_keys: Keys from the ``.obs`` slots.
86
- join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
87
- does not join.
88
- encode_labels: Encode labels into integers.
89
- Can be a list with elements from ``obs_keys``.
90
- unknown_label: Encode this label to -1.
91
- Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
92
- or from ``encode_labels`` if it is a list.
93
- cache_categories: Enable caching categories of ``obs_keys`` for faster access.
94
- parallel: Enable sampling with multiple processes.
95
- dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
96
- """
97
-
98
- def __init__(
99
- self,
100
- path_list: list[UPathStr],
101
- layers_keys: str | list[str] | None = None,
102
- obs_keys: str | list[str] | None = None,
103
- obsm_keys: str | list[str] | None = None,
104
- join: Literal["inner", "outer"] | None = "inner",
105
- encode_labels: bool | list[str] = True,
106
- unknown_label: str | dict[str, str] | None = None,
107
- cache_categories: bool = True,
108
- parallel: bool = False,
109
- dtype: str | None = None,
110
- ):
111
- if join not in {None, "inner", "outer"}: # pragma: nocover
112
- raise ValueError(
113
- f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
114
- )
115
-
116
- if layers_keys is None:
117
- self.layers_keys = ["X"]
118
- else:
119
- self.layers_keys = (
120
- [layers_keys] if isinstance(layers_keys, str) else layers_keys
121
- )
122
-
123
- obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
124
- self.obsm_keys = obsm_keys
125
-
126
- obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
127
- self.obs_keys = obs_keys
128
-
129
- if isinstance(encode_labels, list):
130
- if len(encode_labels) == 0:
131
- encode_labels = False
132
- elif obs_keys is None or not all(
133
- enc_label in obs_keys for enc_label in encode_labels
134
- ):
135
- raise ValueError(
136
- "All elements of `encode_labels` should be in `obs_keys`."
137
- )
138
- else:
139
- if encode_labels:
140
- encode_labels = obs_keys if obs_keys is not None else False
141
- self.encode_labels = encode_labels
142
-
143
- if encode_labels and isinstance(unknown_label, dict):
144
- if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
145
- raise ValueError(
146
- "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
147
- )
148
- self.unknown_label = unknown_label
149
-
150
- self.storages = [] # type: ignore
151
- self.conns = [] # type: ignore
152
- self.parallel = parallel
153
- self.path_list = path_list
154
- self._make_connections(path_list, parallel)
155
-
156
- self.n_obs_list = []
157
- for i, storage in enumerate(self.storages):
158
- with _Connect(storage) as store:
159
- X = store["X"]
160
- store_path = self.path_list[i]
161
- self._check_csc_raise_error(X, "X", store_path)
162
- if isinstance(X, ArrayTypes): # type: ignore
163
- self.n_obs_list.append(X.shape[0])
164
- else:
165
- self.n_obs_list.append(X.attrs["shape"][0])
166
- for layer_key in self.layers_keys:
167
- if layer_key == "X":
168
- continue
169
- self._check_csc_raise_error(
170
- store["layers"][layer_key],
171
- f"layers/{layer_key}",
172
- store_path,
173
- )
174
- if self.obsm_keys is not None:
175
- for obsm_key in self.obsm_keys:
176
- self._check_csc_raise_error(
177
- store["obsm"][obsm_key],
178
- f"obsm/{obsm_key}",
179
- store_path,
180
- )
181
- self.n_obs = sum(self.n_obs_list)
182
-
183
- self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
184
- self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
185
-
186
- self.join_vars: Literal["inner", "outer"] | None = join
187
- self.var_indices: list | None = None
188
- self.var_joint: pd.Index | None = None
189
- self.n_vars_list: list | None = None
190
- self.var_list: list | None = None
191
- self.n_vars: int | None = None
192
- if self.join_vars is not None:
193
- self._make_join_vars()
194
- self.n_vars = len(self.var_joint)
195
-
196
- if self.obs_keys is not None:
197
- if cache_categories:
198
- self._cache_categories(self.obs_keys)
199
- else:
200
- self._cache_cats: dict = {}
201
- self.encoders: dict = {}
202
- if self.encode_labels:
203
- self._make_encoders(self.encode_labels) # type: ignore
204
-
205
- self._dtype = dtype
206
- self._closed = False
207
-
208
- def _make_connections(self, path_list: list, parallel: bool):
209
- for path in path_list:
210
- path = UPath(path)
211
- if path.exists() and path.is_file(): # type: ignore
212
- if parallel:
213
- conn, storage = None, path
214
- else:
215
- conn, storage = registry.open("h5py", path)
216
- else:
217
- conn, storage = registry.open("zarr", path)
218
- self.conns.append(conn)
219
- self.storages.append(storage)
220
-
221
- def _cache_categories(self, obs_keys: list):
222
- self._cache_cats = {}
223
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
224
- for label in obs_keys:
225
- self._cache_cats[label] = []
226
- for storage in self.storages:
227
- with _Connect(storage) as store:
228
- cats = self._get_categories(store, label)
229
- if cats is not None:
230
- cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
231
- self._cache_cats[label].append(cats)
232
-
233
- def _make_encoders(self, encode_labels: list):
234
- for label in encode_labels:
235
- cats = self.get_merged_categories(label)
236
- encoder = {}
237
- if isinstance(self.unknown_label, dict):
238
- unknown_label = self.unknown_label.get(label, None)
239
- else:
240
- unknown_label = self.unknown_label
241
- if unknown_label is not None and unknown_label in cats:
242
- cats.remove(unknown_label)
243
- encoder[unknown_label] = -1
244
- encoder.update({cat: i for i, cat in enumerate(cats)})
245
- self.encoders[label] = encoder
246
-
247
- def _read_vars(self):
248
- self.var_list = []
249
- self.n_vars_list = []
250
- for storage in self.storages:
251
- with _Connect(storage) as store:
252
- vars = _safer_read_index(store["var"])
253
- self.var_list.append(vars)
254
- self.n_vars_list.append(len(vars))
255
-
256
- def _make_join_vars(self):
257
- if self.var_list is None:
258
- self._read_vars()
259
- vars_eq = all(self.var_list[0].equals(vrs) for vrs in self.var_list[1:])
260
- if vars_eq:
261
- self.join_vars = None
262
- self.var_joint = self.var_list[0]
263
- return
264
-
265
- if self.join_vars == "inner":
266
- self.var_joint = reduce(pd.Index.intersection, self.var_list)
267
- if len(self.var_joint) == 0:
268
- raise ValueError(
269
- "The provided AnnData objects don't have shared varibales.\n"
270
- "Use join='outer'."
271
- )
272
- self.var_indices = [
273
- vrs.get_indexer(self.var_joint) for vrs in self.var_list
274
- ]
275
- elif self.join_vars == "outer":
276
- self.var_joint = reduce(pd.Index.union, self.var_list)
277
- self.var_indices = [
278
- self.var_joint.get_indexer(vrs) for vrs in self.var_list
279
- ]
280
-
281
- def check_vars_sorted(self, ascending: bool = True) -> bool:
282
- """Returns `True` if all variables are sorted in all objects."""
283
- if self.var_list is None:
284
- self._read_vars()
285
- if ascending:
286
- vrs_sort_status = (vrs.is_monotonic_increasing for vrs in self.var_list)
287
- else:
288
- vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
289
- return all(vrs_sort_status)
290
-
291
- def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
292
- """Returns indices of objects with non-aligned variables.
293
-
294
- Args:
295
- vars: Check alignment against these variables.
296
- """
297
- if self.var_list is None:
298
- self._read_vars()
299
- vars = pd.Index(vars)
300
- return [i for i, vrs in enumerate(self.var_list) if not vrs.equals(vars)]
301
-
302
- def _check_csc_raise_error(
303
- self, elem: GroupType | ArrayType, key: str, path: UPathStr
304
- ):
305
- if isinstance(elem, ArrayTypes): # type: ignore
306
- return
307
- if get_spec(elem).encoding_type == "csc_matrix":
308
- if not self.parallel:
309
- self.close()
310
- raise ValueError(
311
- f"{key} in {path} is a csc matrix, `MappedCollection` doesn't support this format yet."
312
- )
313
-
314
- def __len__(self):
315
- return self.n_obs
316
-
317
- @property
318
- def shape(self) -> tuple[int, int]:
319
- """Shape of the (virtually aligned) dataset."""
320
- return (self.n_obs, self.n_vars)
321
-
322
- @property
323
- def original_shapes(self) -> list[tuple[int, int]]:
324
- """Shapes of the underlying AnnData objects."""
325
- if self.n_vars_list is None:
326
- n_vars_list = [None] * len(self.n_obs_list)
327
- else:
328
- n_vars_list = self.n_vars_list
329
- return list(zip(self.n_obs_list, n_vars_list))
330
-
331
- def __getitem__(self, idx: int):
332
- obs_idx = self.indices[idx]
333
- storage_idx = self.storage_idx[idx]
334
- if self.var_indices is not None:
335
- var_idxs_join = self.var_indices[storage_idx]
336
- else:
337
- var_idxs_join = None
338
-
339
- with _Connect(self.storages[storage_idx]) as store:
340
- out = {}
341
- for layers_key in self.layers_keys:
342
- lazy_data = (
343
- store["X"] if layers_key == "X" else store["layers"][layers_key]
344
- )
345
- out[layers_key] = self._get_data_idx(
346
- lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
347
- )
348
- if self.obsm_keys is not None:
349
- for obsm_key in self.obsm_keys:
350
- lazy_data = store["obsm"][obsm_key]
351
- out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
352
- out["_store_idx"] = storage_idx
353
- if self.obs_keys is not None:
354
- for label in self.obs_keys:
355
- if label in self._cache_cats:
356
- cats = self._cache_cats[label][storage_idx]
357
- if cats is None:
358
- cats = []
359
- else:
360
- cats = None
361
- label_idx = self._get_obs_idx(store, obs_idx, label, cats)
362
- if label in self.encoders:
363
- label_idx = self.encoders[label][label_idx]
364
- out[label] = label_idx
365
- return out
366
-
367
- def _get_data_idx(
368
- self,
369
- lazy_data: ArrayType | GroupType, # type: ignore
370
- idx: int,
371
- join_vars: Literal["inner", "outer"] | None = None,
372
- var_idxs_join: list | None = None,
373
- n_vars_out: int | None = None,
374
- ):
375
- """Get the index for the data."""
376
- if isinstance(lazy_data, ArrayTypes): # type: ignore
377
- lazy_data_idx = lazy_data[idx] # type: ignore
378
- if join_vars is None:
379
- result = lazy_data_idx
380
- if self._dtype is not None:
381
- result = result.astype(self._dtype, copy=False)
382
- elif join_vars == "outer":
383
- dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
384
- result = np.zeros(n_vars_out, dtype=dtype)
385
- result[var_idxs_join] = lazy_data_idx
386
- else: # inner join
387
- result = lazy_data_idx[var_idxs_join]
388
- if self._dtype is not None:
389
- result = result.astype(self._dtype, copy=False)
390
- return result
391
- else: # assume csr_matrix here
392
- data = lazy_data["data"] # type: ignore
393
- indices = lazy_data["indices"] # type: ignore
394
- indptr = lazy_data["indptr"] # type: ignore
395
- s = slice(*(indptr[idx : idx + 2]))
396
- data_s = data[s]
397
- dtype = data_s.dtype if self._dtype is None else self._dtype
398
- if join_vars == "outer":
399
- lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
400
- lazy_data_idx[var_idxs_join[indices[s]]] = data_s
401
- else:
402
- lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
403
- lazy_data_idx[indices[s]] = data_s
404
- if join_vars == "inner":
405
- lazy_data_idx = lazy_data_idx[var_idxs_join]
406
- return lazy_data_idx
407
-
408
- def _get_obs_idx(
409
- self,
410
- storage: StorageType,
411
- idx: int,
412
- label_key: str,
413
- categories: list | None = None,
414
- ):
415
- """Get the index for the label by key."""
416
- obs = storage["obs"] # type: ignore
417
- # how backwards compatible do we want to be here actually?
418
- if isinstance(obs, ArrayTypes): # type: ignore
419
- label = obs[idx][obs.dtype.names.index(label_key)]
420
- else:
421
- labels = obs[label_key]
422
- if isinstance(labels, ArrayTypes): # type: ignore
423
- label = labels[idx]
424
- else:
425
- label = labels["codes"][idx]
426
- if categories is not None:
427
- cats = categories
428
- else:
429
- cats = self._get_categories(storage, label_key)
430
- if cats is not None and len(cats) > 0:
431
- label = cats[label]
432
- if isinstance(label, bytes):
433
- label = label.decode("utf-8")
434
- return label
435
-
436
- def get_label_weights(
437
- self,
438
- obs_keys: str | list[str],
439
- scaler: float | None = None,
440
- return_categories: bool = False,
441
- ):
442
- """Get all weights for the given label keys.
443
-
444
- This counts the number of labels for each label and returns
445
- weights for each obs label accoding to the formula `1 / num of this label in the data`.
446
- If `scaler` is provided, then `scaler / (scaler + num of this label in the data)`.
447
-
448
- Args:
449
- obs_keys: A key in the ``.obs`` slots or a list of keys. If a list is provided,
450
- the labels from the obs keys will be concatenated with ``"__"`` delimeter
451
- scaler: Use this number to scale the provided weights.
452
- return_categories: If `False`, returns weights for each observation,
453
- can be directly passed to a sampler. If `True`, returns a dictionary with
454
- unique categories for labels (concatenated if `obs_keys` is a list)
455
- and their weights.
456
- """
457
- if isinstance(obs_keys, str):
458
- obs_keys = [obs_keys]
459
- labels_list = []
460
- for label_key in obs_keys:
461
- labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
462
- labels_list.append(labels_to_str)
463
- if len(labels_list) > 1:
464
- labels = ["__".join(labels_obs) for labels_obs in zip(*labels_list)]
465
- else:
466
- labels = labels_list[0]
467
- counter = Counter(labels)
468
- if return_categories:
469
- return {
470
- k: 1.0 / v if scaler is None else scaler / (v + scaler)
471
- for k, v in counter.items()
472
- }
473
- counts = np.array([counter[label] for label in labels])
474
- if scaler is None:
475
- weights = 1.0 / counts
476
- else:
477
- weights = scaler / (counts + scaler)
478
- return weights
479
-
480
- def get_merged_labels(self, label_key: str):
481
- """Get merged labels for `label_key` from all `.obs`."""
482
- labels_merge = []
483
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
484
- for i, storage in enumerate(self.storages):
485
- with _Connect(storage) as store:
486
- codes = self._get_codes(store, label_key)
487
- labels = decode(codes) if isinstance(codes[0], bytes) else codes
488
- if label_key in self._cache_cats:
489
- cats = self._cache_cats[label_key][i]
490
- else:
491
- cats = self._get_categories(store, label_key)
492
- if cats is not None:
493
- cats = decode(cats) if isinstance(cats[0], bytes) else cats
494
- labels = cats[labels]
495
- labels_merge.append(labels)
496
- return np.hstack(labels_merge)
497
-
498
- def get_merged_categories(self, label_key: str):
499
- """Get merged categories for `label_key` from all `.obs`."""
500
- cats_merge = set()
501
- decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
502
- for i, storage in enumerate(self.storages):
503
- with _Connect(storage) as store:
504
- if label_key in self._cache_cats:
505
- cats = self._cache_cats[label_key][i]
506
- else:
507
- cats = self._get_categories(store, label_key)
508
- if cats is not None:
509
- cats = decode(cats) if isinstance(cats[0], bytes) else cats
510
- cats_merge.update(cats)
511
- else:
512
- codes = self._get_codes(store, label_key)
513
- codes = decode(codes) if isinstance(codes[0], bytes) else codes
514
- cats_merge.update(codes)
515
- return sorted(cats_merge)
516
-
517
- def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
518
- """Get categories."""
519
- obs = storage["obs"] # type: ignore
520
- if isinstance(obs, ArrayTypes): # type: ignore
521
- cat_key_uns = f"{label_key}_categories"
522
- if cat_key_uns in storage["uns"]: # type: ignore
523
- return storage["uns"][cat_key_uns] # type: ignore
524
- else:
525
- return None
526
- else:
527
- if "__categories" in obs:
528
- cats = obs["__categories"]
529
- if label_key in cats:
530
- return cats[label_key]
531
- else:
532
- return None
533
- labels = obs[label_key]
534
- if isinstance(labels, GroupTypes): # type: ignore
535
- if "categories" in labels:
536
- return labels["categories"]
537
- else:
538
- return None
539
- else:
540
- if "categories" in labels.attrs:
541
- return labels.attrs["categories"]
542
- else:
543
- return None
544
- return None
545
-
546
- def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
547
- """Get codes."""
548
- obs = storage["obs"] # type: ignore
549
- if isinstance(obs, ArrayTypes): # type: ignore
550
- label = obs[label_key]
551
- else:
552
- label = obs[label_key]
553
- if isinstance(label, ArrayTypes): # type: ignore
554
- return label[...]
555
- else:
556
- return label["codes"][...]
557
-
558
- def close(self):
559
- """Close connections to array streaming backend.
560
-
561
- No effect if `parallel=True`.
562
- """
563
- for storage in self.storages:
564
- if hasattr(storage, "close"):
565
- storage.close()
566
- for conn in self.conns:
567
- if hasattr(conn, "close"):
568
- conn.close()
569
- self._closed = True
570
-
571
- @property
572
- def closed(self) -> bool:
573
- """Check if connections to array streaming backend are closed.
574
-
575
- Does not matter if `parallel=True`.
576
- """
577
- return self._closed
578
-
579
- def __enter__(self):
580
- return self
581
-
582
- def __exit__(self, exc_type, exc_val, exc_tb):
583
- self.close()
584
-
585
- @staticmethod
586
- def torch_worker_init_fn(worker_id):
587
- """`worker_init_fn` for `torch.utils.data.DataLoader`.
588
-
589
- Improves performance for `num_workers > 1`.
590
- """
591
- from torch.utils.data import get_worker_info
592
-
593
- mapped = get_worker_info().dataset
594
- mapped.parallel = False
595
- mapped.storages = []
596
- mapped.conns = []
597
- mapped._make_connections(mapped.path_list, parallel=False)
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from functools import reduce
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from lamin_utils import logger
11
+ from lamindb_setup.core.upath import UPath
12
+
13
+ from .storage._anndata_accessor import (
14
+ ArrayType,
15
+ ArrayTypes,
16
+ GroupType,
17
+ GroupTypes,
18
+ StorageType,
19
+ _safer_read_index,
20
+ get_spec,
21
+ registry,
22
+ )
23
+
24
+ if TYPE_CHECKING:
25
+ from lamindb_setup.core.types import UPathStr
26
+
27
+
28
+ class _Connect:
29
+ def __init__(self, storage):
30
+ if isinstance(storage, UPath):
31
+ self.conn, self.store = registry.open("h5py", storage)
32
+ self.to_close = True
33
+ else:
34
+ self.conn, self.store = None, storage
35
+ self.to_close = False
36
+
37
+ def __enter__(self):
38
+ return self.store
39
+
40
+ def __exit__(self, exc_type, exc_val, exc_tb):
41
+ self.close()
42
+
43
+ def close(self):
44
+ if not self.to_close:
45
+ return
46
+ if hasattr(self.store, "close"):
47
+ self.store.close()
48
+ if hasattr(self.conn, "close"):
49
+ self.conn.close()
50
+
51
+
52
+ _decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
53
+
54
+
55
+ class MappedCollection:
56
+ """Map-style collection for use in data loaders.
57
+
58
+ This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
59
+ <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
60
+
61
+ If your `AnnData` collection is in the cloud, move them into a local cache
62
+ first for faster access.
63
+
64
+ `__getitem__` of the `MappedCollection` object takes a single integer index
65
+ and returns a dictionary with the observation data sample for this index from
66
+ the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
67
+ (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
68
+ for the index of the `AnnData` object containing this observation sample.
69
+
70
+ .. note::
71
+
72
+ For a guide, see :doc:`docs:scrna5`.
73
+
74
+ For more convenient use within :class:`~lamindb.core.MappedCollection`,
75
+ see :meth:`~lamindb.Collection.mapped`.
76
+
77
+ This currently only works for collections of `AnnData` objects.
78
+
79
+ The implementation was influenced by the `SCimilarity
80
+ <https://github.com/Genentech/scimilarity>`__ data loader.
81
+
82
+
83
+ Args:
84
+ path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
85
+ layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
86
+ retrieves ``.X``.
87
+ obsm_keys: Keys from the ``.obsm`` slots.
88
+ obs_keys: Keys from the ``.obs`` slots.
89
+ obs_filter: Select only observations with these values for the given obs column.
90
+ Should be a tuple with an obs column name as the first element
91
+ and filtering values (a string or a tuple of strings) as the second element.
92
+ join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
93
+ does not join.
94
+ encode_labels: Encode labels into integers.
95
+ Can be a list with elements from ``obs_keys``.
96
+ unknown_label: Encode this label to -1.
97
+ Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
98
+ or from ``encode_labels`` if it is a list.
99
+ cache_categories: Enable caching categories of ``obs_keys`` for faster access.
100
+ parallel: Enable sampling with multiple processes.
101
+ dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ path_list: list[UPathStr],
107
+ layers_keys: str | list[str] | None = None,
108
+ obs_keys: str | list[str] | None = None,
109
+ obsm_keys: str | list[str] | None = None,
110
+ obs_filter: tuple[str, str | tuple[str, ...]] | None = None,
111
+ join: Literal["inner", "outer"] | None = "inner",
112
+ encode_labels: bool | list[str] = True,
113
+ unknown_label: str | dict[str, str] | None = None,
114
+ cache_categories: bool = True,
115
+ parallel: bool = False,
116
+ dtype: str | None = None,
117
+ ):
118
+ if join not in {None, "inner", "outer"}: # pragma: nocover
119
+ raise ValueError(
120
+ f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
121
+ )
122
+
123
+ self.filtered = obs_filter is not None
124
+ if self.filtered and len(obs_filter) != 2:
125
+ raise ValueError(
126
+ "obs_filter should be a tuple with obs column name "
127
+ "as the first element and filtering values as the second element"
128
+ )
129
+
130
+ if layers_keys is None:
131
+ self.layers_keys = ["X"]
132
+ else:
133
+ self.layers_keys = (
134
+ [layers_keys] if isinstance(layers_keys, str) else layers_keys
135
+ )
136
+
137
+ obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
138
+ self.obsm_keys = obsm_keys
139
+
140
+ obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
141
+ self.obs_keys = obs_keys
142
+
143
+ if isinstance(encode_labels, list):
144
+ if len(encode_labels) == 0:
145
+ encode_labels = False
146
+ elif obs_keys is None or not all(
147
+ enc_label in obs_keys for enc_label in encode_labels
148
+ ):
149
+ raise ValueError(
150
+ "All elements of `encode_labels` should be in `obs_keys`."
151
+ )
152
+ else:
153
+ if encode_labels:
154
+ encode_labels = obs_keys if obs_keys is not None else False
155
+ self.encode_labels = encode_labels
156
+
157
+ if encode_labels and isinstance(unknown_label, dict):
158
+ if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
159
+ raise ValueError(
160
+ "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
161
+ )
162
+ self.unknown_label = unknown_label
163
+
164
+ self.storages = [] # type: ignore
165
+ self.conns = [] # type: ignore
166
+ self.parallel = parallel
167
+ self.path_list = path_list
168
+ self._make_connections(path_list, parallel)
169
+
170
+ self._cache_cats: dict = {}
171
+ if self.obs_keys is not None:
172
+ if cache_categories:
173
+ self._cache_categories(self.obs_keys)
174
+ self.encoders: dict = {}
175
+ if self.encode_labels:
176
+ self._make_encoders(self.encode_labels) # type: ignore
177
+
178
+ self.n_obs_list = []
179
+ self.indices_list = []
180
+ for i, storage in enumerate(self.storages):
181
+ with _Connect(storage) as store:
182
+ X = store["X"]
183
+ store_path = self.path_list[i]
184
+ self._check_csc_raise_error(X, "X", store_path)
185
+ if self.filtered:
186
+ obs_filter_key, obs_filter_values = obs_filter
187
+ indices_storage = np.where(
188
+ np.isin(
189
+ self._get_labels(store, obs_filter_key), obs_filter_values
190
+ )
191
+ )[0]
192
+ n_obs_storage = len(indices_storage)
193
+ else:
194
+ if isinstance(X, ArrayTypes): # type: ignore
195
+ n_obs_storage = X.shape[0]
196
+ else:
197
+ n_obs_storage = X.attrs["shape"][0]
198
+ indices_storage = np.arange(n_obs_storage)
199
+ self.n_obs_list.append(n_obs_storage)
200
+ self.indices_list.append(indices_storage)
201
+ for layer_key in self.layers_keys:
202
+ if layer_key == "X":
203
+ continue
204
+ self._check_csc_raise_error(
205
+ store["layers"][layer_key],
206
+ f"layers/{layer_key}",
207
+ store_path,
208
+ )
209
+ if self.obsm_keys is not None:
210
+ for obsm_key in self.obsm_keys:
211
+ self._check_csc_raise_error(
212
+ store["obsm"][obsm_key],
213
+ f"obsm/{obsm_key}",
214
+ store_path,
215
+ )
216
+ self.n_obs = sum(self.n_obs_list)
217
+
218
+ self.indices = np.hstack(self.indices_list)
219
+ self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
220
+
221
+ self.join_vars: Literal["inner", "outer"] | None = join
222
+ self.var_indices: list | None = None
223
+ self.var_joint: pd.Index | None = None
224
+ self.n_vars_list: list | None = None
225
+ self.var_list: list | None = None
226
+ self.n_vars: int | None = None
227
+ if self.join_vars is not None:
228
+ self._make_join_vars()
229
+ self.n_vars = len(self.var_joint)
230
+
231
+ self._dtype = dtype
232
+ self._closed = False
233
+
234
+ def _make_connections(self, path_list: list, parallel: bool):
235
+ for path in path_list:
236
+ path = UPath(path)
237
+ if path.exists() and path.is_file(): # type: ignore
238
+ if parallel:
239
+ conn, storage = None, path
240
+ else:
241
+ conn, storage = registry.open("h5py", path)
242
+ else:
243
+ conn, storage = registry.open("zarr", path)
244
+ self.conns.append(conn)
245
+ self.storages.append(storage)
246
+
247
+ def _cache_categories(self, obs_keys: list):
248
+ self._cache_cats = {}
249
+ for label in obs_keys:
250
+ self._cache_cats[label] = []
251
+ for storage in self.storages:
252
+ with _Connect(storage) as store:
253
+ cats = self._get_categories(store, label)
254
+ if cats is not None:
255
+ cats = (
256
+ _decode(cats) if isinstance(cats[0], bytes) else cats[...]
257
+ )
258
+ self._cache_cats[label].append(cats)
259
+
260
+ def _make_encoders(self, encode_labels: list):
261
+ for label in encode_labels:
262
+ cats = self.get_merged_categories(label)
263
+ encoder = {}
264
+ if isinstance(self.unknown_label, dict):
265
+ unknown_label = self.unknown_label.get(label, None)
266
+ else:
267
+ unknown_label = self.unknown_label
268
+ if unknown_label is not None and unknown_label in cats:
269
+ cats.remove(unknown_label)
270
+ encoder[unknown_label] = -1
271
+ encoder.update({cat: i for i, cat in enumerate(cats)})
272
+ self.encoders[label] = encoder
273
+
274
+ def _read_vars(self):
275
+ self.var_list = []
276
+ self.n_vars_list = []
277
+ for storage in self.storages:
278
+ with _Connect(storage) as store:
279
+ vars = _safer_read_index(store["var"])
280
+ self.var_list.append(vars)
281
+ self.n_vars_list.append(len(vars))
282
+
283
+ def _make_join_vars(self):
284
+ if self.var_list is None:
285
+ self._read_vars()
286
+ vars_eq = all(self.var_list[0].equals(vrs) for vrs in self.var_list[1:])
287
+ if vars_eq:
288
+ self.join_vars = None
289
+ self.var_joint = self.var_list[0]
290
+ return
291
+
292
+ if self.join_vars == "inner":
293
+ self.var_joint = reduce(pd.Index.intersection, self.var_list)
294
+ if len(self.var_joint) == 0:
295
+ raise ValueError(
296
+ "The provided AnnData objects don't have shared varibales.\n"
297
+ "Use join='outer'."
298
+ )
299
+ self.var_indices = [
300
+ vrs.get_indexer(self.var_joint) for vrs in self.var_list
301
+ ]
302
+ elif self.join_vars == "outer":
303
+ self.var_joint = reduce(pd.Index.union, self.var_list)
304
+ self.var_indices = [
305
+ self.var_joint.get_indexer(vrs) for vrs in self.var_list
306
+ ]
307
+
308
+ def check_vars_sorted(self, ascending: bool = True) -> bool:
309
+ """Returns `True` if all variables are sorted in all objects."""
310
+ if self.var_list is None:
311
+ self._read_vars()
312
+ if ascending:
313
+ vrs_sort_status = (vrs.is_monotonic_increasing for vrs in self.var_list)
314
+ else:
315
+ vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
316
+ return all(vrs_sort_status)
317
+
318
+ def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
319
+ """Returns indices of objects with non-aligned variables.
320
+
321
+ Args:
322
+ vars: Check alignment against these variables.
323
+ """
324
+ if self.var_list is None:
325
+ self._read_vars()
326
+ vars = pd.Index(vars)
327
+ return [i for i, vrs in enumerate(self.var_list) if not vrs.equals(vars)]
328
+
329
+ def _check_csc_raise_error(
330
+ self, elem: GroupType | ArrayType, key: str, path: UPathStr
331
+ ):
332
+ if isinstance(elem, ArrayTypes): # type: ignore
333
+ return
334
+ if get_spec(elem).encoding_type == "csc_matrix":
335
+ if not self.parallel:
336
+ self.close()
337
+ raise ValueError(
338
+ f"{key} in {path} is a csc matrix, `MappedCollection` doesn't support this format yet."
339
+ )
340
+
341
+ def __len__(self):
342
+ return self.n_obs
343
+
344
+ @property
345
+ def shape(self) -> tuple[int, int]:
346
+ """Shape of the (virtually aligned) dataset."""
347
+ return (self.n_obs, self.n_vars)
348
+
349
+ @property
350
+ def original_shapes(self) -> list[tuple[int, int]]:
351
+ """Shapes of the underlying AnnData objects."""
352
+ if self.n_vars_list is None:
353
+ n_vars_list = [None] * len(self.n_obs_list)
354
+ else:
355
+ n_vars_list = self.n_vars_list
356
+ return list(zip(self.n_obs_list, n_vars_list))
357
+
358
+ def __getitem__(self, idx: int):
359
+ obs_idx = self.indices[idx]
360
+ storage_idx = self.storage_idx[idx]
361
+ if self.var_indices is not None:
362
+ var_idxs_join = self.var_indices[storage_idx]
363
+ else:
364
+ var_idxs_join = None
365
+
366
+ with _Connect(self.storages[storage_idx]) as store:
367
+ out = {}
368
+ for layers_key in self.layers_keys:
369
+ lazy_data = (
370
+ store["X"] if layers_key == "X" else store["layers"][layers_key]
371
+ )
372
+ out[layers_key] = self._get_data_idx(
373
+ lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
374
+ )
375
+ if self.obsm_keys is not None:
376
+ for obsm_key in self.obsm_keys:
377
+ lazy_data = store["obsm"][obsm_key]
378
+ out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
379
+ out["_store_idx"] = storage_idx
380
+ if self.obs_keys is not None:
381
+ for label in self.obs_keys:
382
+ if label in self._cache_cats:
383
+ cats = self._cache_cats[label][storage_idx]
384
+ if cats is None:
385
+ cats = []
386
+ else:
387
+ cats = None
388
+ label_idx = self._get_obs_idx(store, obs_idx, label, cats)
389
+ if label in self.encoders:
390
+ label_idx = self.encoders[label][label_idx]
391
+ out[label] = label_idx
392
+ return out
393
+
394
+ def _get_data_idx(
395
+ self,
396
+ lazy_data: ArrayType | GroupType,
397
+ idx: int,
398
+ join_vars: Literal["inner", "outer"] | None = None,
399
+ var_idxs_join: list | None = None,
400
+ n_vars_out: int | None = None,
401
+ ):
402
+ """Get the index for the data."""
403
+ if isinstance(lazy_data, ArrayTypes): # type: ignore
404
+ lazy_data_idx = lazy_data[idx] # type: ignore
405
+ if join_vars is None:
406
+ result = lazy_data_idx
407
+ if self._dtype is not None:
408
+ result = result.astype(self._dtype, copy=False)
409
+ elif join_vars == "outer":
410
+ dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
411
+ result = np.zeros(n_vars_out, dtype=dtype)
412
+ result[var_idxs_join] = lazy_data_idx
413
+ else: # inner join
414
+ result = lazy_data_idx[var_idxs_join]
415
+ if self._dtype is not None:
416
+ result = result.astype(self._dtype, copy=False)
417
+ return result
418
+ else: # assume csr_matrix here
419
+ data = lazy_data["data"] # type: ignore
420
+ indices = lazy_data["indices"] # type: ignore
421
+ indptr = lazy_data["indptr"] # type: ignore
422
+ s = slice(*(indptr[idx : idx + 2]))
423
+ data_s = data[s]
424
+ dtype = data_s.dtype if self._dtype is None else self._dtype
425
+ if join_vars == "outer":
426
+ lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
427
+ lazy_data_idx[var_idxs_join[indices[s]]] = data_s
428
+ else:
429
+ lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
430
+ lazy_data_idx[indices[s]] = data_s
431
+ if join_vars == "inner":
432
+ lazy_data_idx = lazy_data_idx[var_idxs_join]
433
+ return lazy_data_idx
434
+
435
+ def _get_obs_idx(
436
+ self,
437
+ storage: StorageType,
438
+ idx: int,
439
+ label_key: str,
440
+ categories: list | None = None,
441
+ ):
442
+ """Get the index for the label by key."""
443
+ obs = storage["obs"] # type: ignore
444
+ # how backwards compatible do we want to be here actually?
445
+ if isinstance(obs, ArrayTypes): # type: ignore
446
+ label = obs[idx][obs.dtype.names.index(label_key)]
447
+ else:
448
+ labels = obs[label_key]
449
+ if isinstance(labels, ArrayTypes): # type: ignore
450
+ label = labels[idx]
451
+ else:
452
+ label = labels["codes"][idx]
453
+ if categories is not None:
454
+ cats = categories
455
+ else:
456
+ cats = self._get_categories(storage, label_key)
457
+ if cats is not None and len(cats) > 0:
458
+ label = cats[label]
459
+ if isinstance(label, bytes):
460
+ label = label.decode("utf-8")
461
+ return label
462
+
463
+ def get_label_weights(
464
+ self,
465
+ obs_keys: str | list[str],
466
+ scaler: float | None = None,
467
+ return_categories: bool = False,
468
+ ):
469
+ """Get all weights for the given label keys.
470
+
471
+ This counts the number of labels for each label and returns
472
+ weights for each obs label accoding to the formula `1 / num of this label in the data`.
473
+ If `scaler` is provided, then `scaler / (scaler + num of this label in the data)`.
474
+
475
+ Args:
476
+ obs_keys: A key in the ``.obs`` slots or a list of keys. If a list is provided,
477
+ the labels from the obs keys will be concatenated with ``"__"`` delimeter
478
+ scaler: Use this number to scale the provided weights.
479
+ return_categories: If `False`, returns weights for each observation,
480
+ can be directly passed to a sampler. If `True`, returns a dictionary with
481
+ unique categories for labels (concatenated if `obs_keys` is a list)
482
+ and their weights.
483
+ """
484
+ if isinstance(obs_keys, str):
485
+ obs_keys = [obs_keys]
486
+ labels_list = []
487
+ for label_key in obs_keys:
488
+ labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
489
+ labels_list.append(labels_to_str)
490
+ if len(labels_list) > 1:
491
+ labels = ["__".join(labels_obs) for labels_obs in zip(*labels_list)]
492
+ else:
493
+ labels = labels_list[0]
494
+ counter = Counter(labels)
495
+ if return_categories:
496
+ return {
497
+ k: 1.0 / v if scaler is None else scaler / (v + scaler)
498
+ for k, v in counter.items()
499
+ }
500
+ counts = np.array([counter[label] for label in labels])
501
+ if scaler is None:
502
+ weights = 1.0 / counts
503
+ else:
504
+ weights = scaler / (counts + scaler)
505
+ return weights
506
+
507
+ def get_merged_labels(self, label_key: str):
508
+ """Get merged labels for `label_key` from all `.obs`."""
509
+ labels_merge = []
510
+ for i, storage in enumerate(self.storages):
511
+ with _Connect(storage) as store:
512
+ labels = self._get_labels(store, label_key, storage_idx=i)
513
+ if self.filtered:
514
+ labels = labels[self.indices_list[i]]
515
+ labels_merge.append(labels)
516
+ return np.hstack(labels_merge)
517
+
518
+ def get_merged_categories(self, label_key: str):
519
+ """Get merged categories for `label_key` from all `.obs`."""
520
+ cats_merge = set()
521
+ for i, storage in enumerate(self.storages):
522
+ with _Connect(storage) as store:
523
+ if label_key in self._cache_cats:
524
+ cats = self._cache_cats[label_key][i]
525
+ else:
526
+ cats = self._get_categories(store, label_key)
527
+ if cats is not None:
528
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats
529
+ cats_merge.update(cats)
530
+ else:
531
+ codes = self._get_codes(store, label_key)
532
+ codes = _decode(codes) if isinstance(codes[0], bytes) else codes
533
+ cats_merge.update(codes)
534
+ return sorted(cats_merge)
535
+
536
+ def _get_categories(self, storage: StorageType, label_key: str):
537
+ """Get categories."""
538
+ obs = storage["obs"] # type: ignore
539
+ if isinstance(obs, ArrayTypes): # type: ignore
540
+ cat_key_uns = f"{label_key}_categories"
541
+ if cat_key_uns in storage["uns"]: # type: ignore
542
+ return storage["uns"][cat_key_uns] # type: ignore
543
+ else:
544
+ return None
545
+ else:
546
+ if "__categories" in obs:
547
+ cats = obs["__categories"]
548
+ if label_key in cats:
549
+ return cats[label_key]
550
+ else:
551
+ return None
552
+ labels = obs[label_key]
553
+ if isinstance(labels, GroupTypes): # type: ignore
554
+ if "categories" in labels:
555
+ return labels["categories"]
556
+ else:
557
+ return None
558
+ else:
559
+ if "categories" in labels.attrs:
560
+ return labels.attrs["categories"]
561
+ else:
562
+ return None
563
+ return None
564
+
565
+ def _get_codes(self, storage: StorageType, label_key: str):
566
+ """Get codes."""
567
+ obs = storage["obs"] # type: ignore
568
+ if isinstance(obs, ArrayTypes): # type: ignore
569
+ label = obs[label_key]
570
+ else:
571
+ label = obs[label_key]
572
+ if isinstance(label, ArrayTypes): # type: ignore
573
+ return label[...]
574
+ else:
575
+ return label["codes"][...]
576
+
577
+ def _get_labels(
578
+ self, storage: StorageType, label_key: str, storage_idx: int | None = None
579
+ ):
580
+ """Get labels."""
581
+ codes = self._get_codes(storage, label_key)
582
+ labels = _decode(codes) if isinstance(codes[0], bytes) else codes
583
+ if storage_idx is not None and label_key in self._cache_cats:
584
+ cats = self._cache_cats[label_key][storage_idx]
585
+ else:
586
+ cats = self._get_categories(storage, label_key)
587
+ if cats is not None:
588
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats
589
+ labels = cats[labels]
590
+ return labels
591
+
592
+ def close(self):
593
+ """Close connections to array streaming backend.
594
+
595
+ No effect if `parallel=True`.
596
+ """
597
+ for storage in self.storages:
598
+ if hasattr(storage, "close"):
599
+ storage.close()
600
+ for conn in self.conns:
601
+ if hasattr(conn, "close"):
602
+ conn.close()
603
+ self._closed = True
604
+
605
+ @property
606
+ def closed(self) -> bool:
607
+ """Check if connections to array streaming backend are closed.
608
+
609
+ Does not matter if `parallel=True`.
610
+ """
611
+ return self._closed
612
+
613
+ def __enter__(self):
614
+ return self
615
+
616
+ def __exit__(self, exc_type, exc_val, exc_tb):
617
+ self.close()
618
+
619
+ @staticmethod
620
+ def torch_worker_init_fn(worker_id):
621
+ """`worker_init_fn` for `torch.utils.data.DataLoader`.
622
+
623
+ Improves performance for `num_workers > 1`.
624
+ """
625
+ from torch.utils.data import get_worker_info
626
+
627
+ mapped = get_worker_info().dataset
628
+ mapped.parallel = False
629
+ mapped.storages = []
630
+ mapped.conns = []
631
+ mapped._make_connections(mapped.path_list, parallel=False)