scdataloader 1.6.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py ADDED
@@ -0,0 +1,656 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from functools import reduce
5
+ from typing import TYPE_CHECKING, Literal
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from lamindb.core.storage._anndata_accessor import (
10
+ ArrayType,
11
+ ArrayTypes,
12
+ GroupType,
13
+ GroupTypes,
14
+ StorageType,
15
+ _safer_read_index,
16
+ get_spec,
17
+ registry,
18
+ )
19
+ from lamindb_setup.core.upath import UPath
20
+
21
+ if TYPE_CHECKING:
22
+ from lamindb_setup.core.types import UPathStr
23
+
24
+
25
+ class _Connect:
26
+ def __init__(self, storage):
27
+ if isinstance(storage, UPath):
28
+ self.conn, self.store = registry.open("h5py", storage)
29
+ self.to_close = True
30
+ else:
31
+ self.conn, self.store = None, storage
32
+ self.to_close = False
33
+
34
+ def __enter__(self):
35
+ return self.store
36
+
37
+ def __exit__(self, exc_type, exc_val, exc_tb):
38
+ self.close()
39
+
40
+ def close(self):
41
+ if not self.to_close:
42
+ return
43
+ if hasattr(self.store, "close"):
44
+ self.store.close()
45
+ if hasattr(self.conn, "close"):
46
+ self.conn.close()
47
+
48
+
49
+ _decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
50
+
51
+
52
+ class MappedCollection:
53
+ """Map-style collection for use in data loaders.
54
+
55
+ This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
56
+ <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
57
+
58
+ If your `AnnData` collection is in the cloud, move them into a local cache
59
+ first for faster access.
60
+
61
+ `__getitem__` of the `MappedCollection` object takes a single integer index
62
+ and returns a dictionary with the observation data sample for this index from
63
+ the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
64
+ (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
65
+ for the index of the `AnnData` object containing this observation sample.
66
+
67
+ .. note::
68
+
69
+ For a guide, see :doc:`docs:scrna-mappedcollection`.
70
+
71
+ For more convenient use within :class:`~lamindb.core.MappedCollection`,
72
+ see :meth:`~lamindb.Collection.mapped`.
73
+
74
+ This currently only works for collections of `AnnData` objects.
75
+
76
+ The implementation was influenced by the `SCimilarity
77
+ <https://github.com/Genentech/scimilarity>`__ data loader.
78
+
79
+
80
+ Args:
81
+ path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
82
+ layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
83
+ retrieves ``.X``.
84
+ obsm_keys: Keys from the ``.obsm`` slots.
85
+ obs_keys: Keys from the ``.obs`` slots.
86
+ obs_filter: Select only observations with these values for the given obs column.
87
+ Should be a dictionary with an obs column name as the key
88
+ and filtering values (a string or a tuple of strings) as the value.
89
+ join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
90
+ does not join.
91
+ encode_labels: Encode labels into integers.
92
+ Can be a list with elements from ``obs_keys``.
93
+ unknown_label: Encode this label to -1.
94
+ Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
95
+ or from ``encode_labels`` if it is a list.
96
+ cache_categories: Enable caching categories of ``obs_keys`` for faster access.
97
+ parallel: Enable sampling with multiple processes.
98
+ dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
99
+ meta_assays: Assays to check for metacells.
100
+ metacell_mode: Mode for metacells.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ path_list: list[UPathStr],
106
+ layers_keys: str | list[str] | None = None,
107
+ obs_keys: str | list[str] | None = None,
108
+ obsm_keys: str | list[str] | None = None,
109
+ obs_filter: dict[str, str | tuple[str, ...]] | None = None,
110
+ join: Literal["inner", "outer"] | None = "inner",
111
+ encode_labels: bool | list[str] = True,
112
+ unknown_label: str | dict[str, str] | None = None,
113
+ cache_categories: bool = True,
114
+ parallel: bool = False,
115
+ dtype: str | None = None,
116
+ metacell_mode: float = 0.0,
117
+ meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
118
+ ):
119
+ if join not in {None, "inner", "outer"}: # pragma: nocover
120
+ raise ValueError(
121
+ f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
122
+ )
123
+
124
+ self.filtered = obs_filter is not None
125
+ if self.filtered and not isinstance(obs_filter, dict):
126
+ print("Passing a tuple to `obs_filter` is deprecated, use a dictionary")
127
+ obs_filter = {obs_filter[0]: obs_filter[1]}
128
+
129
+ if layers_keys is None:
130
+ self.layers_keys = ["X"]
131
+ else:
132
+ self.layers_keys = (
133
+ [layers_keys] if isinstance(layers_keys, str) else layers_keys
134
+ )
135
+
136
+ obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
137
+ self.obsm_keys = obsm_keys
138
+
139
+ obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
140
+ self.obs_keys = obs_keys
141
+
142
+ if isinstance(encode_labels, list):
143
+ if len(encode_labels) == 0:
144
+ encode_labels = False
145
+ elif obs_keys is None or not all(
146
+ enc_label in obs_keys for enc_label in encode_labels
147
+ ):
148
+ raise ValueError(
149
+ "All elements of `encode_labels` should be in `obs_keys`."
150
+ )
151
+ else:
152
+ if encode_labels:
153
+ encode_labels = obs_keys if obs_keys is not None else False
154
+ self.encode_labels = encode_labels
155
+
156
+ if encode_labels and isinstance(unknown_label, dict):
157
+ if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
158
+ raise ValueError(
159
+ "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
160
+ )
161
+ self.unknown_label = unknown_label
162
+
163
+ self.storages = [] # type: ignore
164
+ self.conns = [] # type: ignore
165
+ self.parallel = parallel
166
+ self.metacell_mode = metacell_mode
167
+ self.path_list = path_list
168
+ self.meta_assays = meta_assays
169
+ self._make_connections(path_list, parallel)
170
+
171
+ self._cache_cats: dict = {}
172
+ if self.obs_keys is not None:
173
+ if cache_categories:
174
+ self._cache_categories(self.obs_keys)
175
+ self.encoders: dict = {}
176
+ if self.encode_labels:
177
+ self._make_encoders(self.encode_labels) # type: ignore
178
+
179
+ self.n_obs_list = []
180
+ self.indices_list = []
181
+ for i, storage in enumerate(self.storages):
182
+ with _Connect(storage) as store:
183
+ X = store["X"]
184
+ store_path = self.path_list[i]
185
+ self._check_csc_raise_error(X, "X", store_path)
186
+ if self.filtered:
187
+ indices_storage_mask = None
188
+ for obs_filter_key, obs_filter_values in obs_filter.items():
189
+ obs_filter_mask = np.isin(
190
+ self._get_labels(store, obs_filter_key), obs_filter_values
191
+ )
192
+ if indices_storage_mask is None:
193
+ indices_storage_mask = obs_filter_mask
194
+ else:
195
+ indices_storage_mask &= obs_filter_mask
196
+ indices_storage = np.where(indices_storage_mask)[0]
197
+ n_obs_storage = len(indices_storage)
198
+ else:
199
+ if isinstance(X, ArrayTypes): # type: ignore
200
+ n_obs_storage = X.shape[0]
201
+ else:
202
+ n_obs_storage = X.attrs["shape"][0]
203
+ indices_storage = np.arange(n_obs_storage)
204
+ self.n_obs_list.append(n_obs_storage)
205
+ self.indices_list.append(indices_storage)
206
+ for layer_key in self.layers_keys:
207
+ if layer_key == "X":
208
+ continue
209
+ self._check_csc_raise_error(
210
+ store["layers"][layer_key],
211
+ f"layers/{layer_key}",
212
+ store_path,
213
+ )
214
+ if self.obsm_keys is not None:
215
+ for obsm_key in self.obsm_keys:
216
+ self._check_csc_raise_error(
217
+ store["obsm"][obsm_key],
218
+ f"obsm/{obsm_key}",
219
+ store_path,
220
+ )
221
+ self.n_obs = sum(self.n_obs_list)
222
+
223
+ self.indices = np.hstack(self.indices_list)
224
+ self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
225
+
226
+ self.join_vars: Literal["inner", "outer"] | None = join
227
+ self.var_indices: list | None = None
228
+ self.var_joint: pd.Index | None = None
229
+ self.n_vars_list: list | None = None
230
+ self.var_list: list | None = None
231
+ self.n_vars: int | None = None
232
+ if self.join_vars is not None:
233
+ self._make_join_vars()
234
+ self.n_vars = len(self.var_joint)
235
+
236
+ self._dtype = dtype
237
+ self._closed = False
238
+
239
+ def _make_connections(self, path_list: list, parallel: bool):
240
+ for path in path_list:
241
+ path = UPath(path)
242
+ if path.exists() and path.is_file(): # type: ignore
243
+ if parallel:
244
+ conn, storage = None, path
245
+ else:
246
+ conn, storage = registry.open("h5py", path)
247
+ else:
248
+ conn, storage = registry.open("zarr", path)
249
+ self.conns.append(conn)
250
+ self.storages.append(storage)
251
+
252
+ def _cache_categories(self, obs_keys: list):
253
+ self._cache_cats = {}
254
+ for label in obs_keys:
255
+ self._cache_cats[label] = []
256
+ for storage in self.storages:
257
+ with _Connect(storage) as store:
258
+ cats = self._get_categories(store, label)
259
+ if cats is not None:
260
+ cats = (
261
+ _decode(cats) if isinstance(cats[0], bytes) else cats[...]
262
+ )
263
+ self._cache_cats[label].append(cats)
264
+
265
+ def _make_encoders(self, encode_labels: list):
266
+ for label in encode_labels:
267
+ cats = self.get_merged_categories(label)
268
+ encoder = {}
269
+ if isinstance(self.unknown_label, dict):
270
+ unknown_label = self.unknown_label.get(label, None)
271
+ else:
272
+ unknown_label = self.unknown_label
273
+ if unknown_label is not None and unknown_label in cats:
274
+ cats.remove(unknown_label)
275
+ encoder[unknown_label] = -1
276
+ encoder.update({cat: i for i, cat in enumerate(cats)})
277
+ self.encoders[label] = encoder
278
+
279
+ def _read_vars(self):
280
+ self.var_list = []
281
+ self.n_vars_list = []
282
+ for storage in self.storages:
283
+ with _Connect(storage) as store:
284
+ vars = _safer_read_index(store["var"])
285
+ self.var_list.append(vars)
286
+ self.n_vars_list.append(len(vars))
287
+
288
+ def _make_join_vars(self):
289
+ if self.var_list is None:
290
+ self._read_vars()
291
+ vars_eq = all(self.var_list[0].equals(vrs) for vrs in self.var_list[1:])
292
+ if vars_eq:
293
+ self.join_vars = None
294
+ self.var_joint = self.var_list[0]
295
+ return
296
+
297
+ if self.join_vars == "inner":
298
+ self.var_joint = reduce(pd.Index.intersection, self.var_list)
299
+ if len(self.var_joint) == 0:
300
+ raise ValueError(
301
+ "The provided AnnData objects don't have shared varibales.\n"
302
+ "Use join='outer'."
303
+ )
304
+ self.var_indices = [
305
+ vrs.get_indexer(self.var_joint) for vrs in self.var_list
306
+ ]
307
+ elif self.join_vars == "outer":
308
+ self.var_joint = reduce(pd.Index.union, self.var_list)
309
+ self.var_indices = [
310
+ self.var_joint.get_indexer(vrs) for vrs in self.var_list
311
+ ]
312
+
313
+ def check_vars_sorted(self, ascending: bool = True) -> bool:
314
+ """Returns `True` if all variables are sorted in all objects."""
315
+ if self.var_list is None:
316
+ self._read_vars()
317
+ if ascending:
318
+ vrs_sort_status = (vrs.is_monotonic_increasing for vrs in self.var_list)
319
+ else:
320
+ vrs_sort_status = (vrs.is_monotonic_decreasing for vrs in self.var_list)
321
+ return all(vrs_sort_status)
322
+
323
+ def check_vars_non_aligned(self, vars: pd.Index | list) -> list[int]:
324
+ """Returns indices of objects with non-aligned variables.
325
+
326
+ Args:
327
+ vars: Check alignment against these variables.
328
+ """
329
+ if self.var_list is None:
330
+ self._read_vars()
331
+ vars = pd.Index(vars)
332
+ return [i for i, vrs in enumerate(self.var_list) if not vrs.equals(vars)]
333
+
334
+ def _check_csc_raise_error(
335
+ self, elem: GroupType | ArrayType, key: str, path: UPathStr
336
+ ):
337
+ if isinstance(elem, ArrayTypes): # type: ignore
338
+ return
339
+ if get_spec(elem).encoding_type == "csc_matrix":
340
+ if not self.parallel:
341
+ self.close()
342
+ raise ValueError(
343
+ f"{key} in {path} is a csc matrix, `MappedCollection` doesn't support this format yet."
344
+ )
345
+
346
+ def __len__(self):
347
+ return self.n_obs
348
+
349
+ @property
350
+ def shape(self) -> tuple[int, int]:
351
+ """Shape of the (virtually aligned) dataset."""
352
+ return (self.n_obs, self.n_vars)
353
+
354
+ @property
355
+ def original_shapes(self) -> list[tuple[int, int]]:
356
+ """Shapes of the underlying AnnData objects (with `obs_filter` applied)."""
357
+ if self.n_vars_list is None:
358
+ n_vars_list = [None] * len(self.n_obs_list)
359
+ else:
360
+ n_vars_list = self.n_vars_list
361
+ return list(zip(self.n_obs_list, n_vars_list))
362
+
363
+ def __getitem__(self, idx: int):
364
+ obs_idx = self.indices[idx]
365
+ storage_idx = self.storage_idx[idx]
366
+ if self.var_indices is not None:
367
+ var_idxs_join = self.var_indices[storage_idx]
368
+ else:
369
+ var_idxs_join = None
370
+
371
+ with _Connect(self.storages[storage_idx]) as store:
372
+ out = {}
373
+ for layers_key in self.layers_keys:
374
+ lazy_data = (
375
+ store["X"] if layers_key == "X" else store["layers"][layers_key]
376
+ )
377
+ out[layers_key] = self._get_data_idx(
378
+ lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
379
+ )
380
+ # out[layers_key]
381
+ if self.obsm_keys is not None:
382
+ for obsm_key in self.obsm_keys:
383
+ lazy_data = store["obsm"][obsm_key]
384
+ out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
385
+ out["_store_idx"] = storage_idx
386
+ if self.obs_keys is not None:
387
+ for label in self.obs_keys:
388
+ if label in self._cache_cats:
389
+ cats = self._cache_cats[label][storage_idx]
390
+ if cats is None:
391
+ cats = []
392
+ else:
393
+ cats = None
394
+ label_idx = self._get_obs_idx(store, obs_idx, label, cats)
395
+ if label in self.encoders:
396
+ label_idx = self.encoders[label][label_idx]
397
+ out[label] = label_idx
398
+
399
+ out["is_meta"] = False
400
+ if len(self.meta_assays) > 0 and "assay_ontology_term_id" in self.obs_keys:
401
+ if out["assay_ontology_term_id"] in self.meta_assays:
402
+ out["is_meta"] = True
403
+ return out
404
+ if self.metacell_mode > 0:
405
+ if np.random.random() < self.metacell_mode:
406
+ out["is_meta"] = True
407
+ distances = self._get_data_idx(store["obsp"]["distances"], obs_idx)
408
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
409
+ for i in nn_idx:
410
+ out[layers_key] += self._get_data_idx(
411
+ lazy_data, i, self.join_vars, var_idxs_join, self.n_vars
412
+ )
413
+
414
+ return out
415
+
416
+ def _get_data_idx(
417
+ self,
418
+ lazy_data: ArrayType | GroupType,
419
+ idx: int,
420
+ join_vars: Literal["inner", "outer"] | None = None,
421
+ var_idxs_join: list | None = None,
422
+ n_vars_out: int | None = None,
423
+ ):
424
+ """Get the index for the data."""
425
+ if isinstance(lazy_data, ArrayTypes): # type: ignore
426
+ lazy_data_idx = lazy_data[idx] # type: ignore
427
+ if join_vars is None:
428
+ result = lazy_data_idx
429
+ if self._dtype is not None:
430
+ result = result.astype(self._dtype, copy=False)
431
+ elif join_vars == "outer":
432
+ dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
433
+ result = np.zeros(n_vars_out, dtype=dtype)
434
+ result[var_idxs_join] = lazy_data_idx
435
+ else: # inner join
436
+ result = lazy_data_idx[var_idxs_join]
437
+ if self._dtype is not None:
438
+ result = result.astype(self._dtype, copy=False)
439
+ return result
440
+ else: # assume csr_matrix here
441
+ data = lazy_data["data"] # type: ignore
442
+ indices = lazy_data["indices"] # type: ignore
443
+ indptr = lazy_data["indptr"] # type: ignore
444
+ s = slice(*(indptr[idx : idx + 2]))
445
+ data_s = data[s]
446
+ dtype = data_s.dtype if self._dtype is None else self._dtype
447
+ if join_vars == "outer":
448
+ lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
449
+ lazy_data_idx[var_idxs_join[indices[s]]] = data_s
450
+ else:
451
+ lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
452
+ lazy_data_idx[indices[s]] = data_s
453
+ if join_vars == "inner":
454
+ lazy_data_idx = lazy_data_idx[var_idxs_join]
455
+ return lazy_data_idx
456
+
457
+ def _get_obs_idx(
458
+ self,
459
+ storage: StorageType,
460
+ idx: int,
461
+ label_key: str,
462
+ categories: list | None = None,
463
+ ):
464
+ """Get the index for the label by key."""
465
+ obs = storage["obs"] # type: ignore
466
+ # how backwards compatible do we want to be here actually?
467
+ if isinstance(obs, ArrayTypes): # type: ignore
468
+ label = obs[idx][obs.dtype.names.index(label_key)]
469
+ else:
470
+ labels = obs[label_key]
471
+ if isinstance(labels, ArrayTypes): # type: ignore
472
+ label = labels[idx]
473
+ else:
474
+ label = labels["codes"][idx]
475
+ if categories is not None:
476
+ cats = categories
477
+ else:
478
+ cats = self._get_categories(storage, label_key)
479
+ if cats is not None and len(cats) > 0:
480
+ label = cats[label]
481
+ if isinstance(label, bytes):
482
+ label = label.decode("utf-8")
483
+ return label
484
+
485
+ def get_label_weights(
486
+ self,
487
+ obs_keys: str | list[str],
488
+ scaler: float | None = None,
489
+ return_categories: bool = False,
490
+ ):
491
+ """Get all weights for the given label keys.
492
+
493
+ This counts the number of labels for each label and returns
494
+ weights for each obs label accoding to the formula `1 / num of this label in the data`.
495
+ If `scaler` is provided, then `scaler / (scaler + num of this label in the data)`.
496
+
497
+ Args:
498
+ obs_keys: A key in the ``.obs`` slots or a list of keys. If a list is provided,
499
+ the labels from the obs keys will be concatenated with ``"__"`` delimeter
500
+ scaler: Use this number to scale the provided weights.
501
+ return_categories: If `False`, returns weights for each observation,
502
+ can be directly passed to a sampler. If `True`, returns a dictionary with
503
+ unique categories for labels (concatenated if `obs_keys` is a list)
504
+ and their weights.
505
+ """
506
+ if isinstance(obs_keys, str):
507
+ obs_keys = [obs_keys]
508
+ labels_list = []
509
+ for label_key in obs_keys:
510
+ labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
511
+ labels_list.append(labels_to_str)
512
+ if len(labels_list) > 1:
513
+ labels = ["__".join(labels_obs) for labels_obs in zip(*labels_list)]
514
+ else:
515
+ labels = labels_list[0]
516
+ counter = Counter(labels)
517
+ MIN, MAX = counter.values().min(), counter.values().max()
518
+ if return_categories:
519
+ return {
520
+ k: 1.0 / v
521
+ if scaler is None
522
+ else (MAX / scaler) / ((1 + v - MIN) + MAX / scaler)
523
+ for k, v in counter.items()
524
+ }
525
+ counts = np.array([counter[label] for label in labels])
526
+ if scaler is None:
527
+ weights = 1.0 / counts
528
+ else:
529
+ weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
530
+ return weights
531
+
532
+ def get_merged_labels(self, label_key: str):
533
+ """Get merged labels for `label_key` from all `.obs`."""
534
+ labels_merge = []
535
+ for i, storage in enumerate(self.storages):
536
+ with _Connect(storage) as store:
537
+ labels = self._get_labels(store, label_key, storage_idx=i)
538
+ if self.filtered:
539
+ labels = labels[self.indices_list[i]]
540
+ labels_merge.append(labels)
541
+ return np.hstack(labels_merge)
542
+
543
+ def get_merged_categories(self, label_key: str):
544
+ """Get merged categories for `label_key` from all `.obs`."""
545
+ cats_merge = set()
546
+ for i, storage in enumerate(self.storages):
547
+ with _Connect(storage) as store:
548
+ if label_key in self._cache_cats:
549
+ cats = self._cache_cats[label_key][i]
550
+ else:
551
+ cats = self._get_categories(store, label_key)
552
+ if cats is not None:
553
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats
554
+ cats_merge.update(cats)
555
+ else:
556
+ codes = self._get_codes(store, label_key)
557
+ codes = _decode(codes) if isinstance(codes[0], bytes) else codes
558
+ cats_merge.update(codes)
559
+ return sorted(cats_merge)
560
+
561
+ def _get_categories(self, storage: StorageType, label_key: str):
562
+ """Get categories."""
563
+ obs = storage["obs"] # type: ignore
564
+ if isinstance(obs, ArrayTypes): # type: ignore
565
+ cat_key_uns = f"{label_key}_categories"
566
+ if cat_key_uns in storage["uns"]: # type: ignore
567
+ return storage["uns"][cat_key_uns] # type: ignore
568
+ else:
569
+ return None
570
+ else:
571
+ if "__categories" in obs:
572
+ cats = obs["__categories"]
573
+ if label_key in cats:
574
+ return cats[label_key]
575
+ else:
576
+ return None
577
+ labels = obs[label_key]
578
+ if isinstance(labels, GroupTypes): # type: ignore
579
+ if "categories" in labels:
580
+ return labels["categories"]
581
+ else:
582
+ return None
583
+ else:
584
+ if "categories" in labels.attrs:
585
+ return labels.attrs["categories"]
586
+ else:
587
+ return None
588
+ return None
589
+
590
+ def _get_codes(self, storage: StorageType, label_key: str):
591
+ """Get codes."""
592
+ obs = storage["obs"] # type: ignore
593
+ if isinstance(obs, ArrayTypes): # type: ignore
594
+ label = obs[label_key]
595
+ else:
596
+ label = obs[label_key]
597
+ if isinstance(label, ArrayTypes): # type: ignore
598
+ return label[...]
599
+ else:
600
+ return label["codes"][...]
601
+
602
+ def _get_labels(
603
+ self, storage: StorageType, label_key: str, storage_idx: int | None = None
604
+ ):
605
+ """Get labels."""
606
+ codes = self._get_codes(storage, label_key)
607
+ labels = _decode(codes) if isinstance(codes[0], bytes) else codes
608
+ if storage_idx is not None and label_key in self._cache_cats:
609
+ cats = self._cache_cats[label_key][storage_idx]
610
+ else:
611
+ cats = self._get_categories(storage, label_key)
612
+ if cats is not None:
613
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats
614
+ labels = cats[labels]
615
+ return labels
616
+
617
+ def close(self):
618
+ """Close connections to array streaming backend.
619
+
620
+ No effect if `parallel=True`.
621
+ """
622
+ for storage in self.storages:
623
+ if hasattr(storage, "close"):
624
+ storage.close()
625
+ for conn in self.conns:
626
+ if hasattr(conn, "close"):
627
+ conn.close()
628
+ self._closed = True
629
+
630
+ @property
631
+ def closed(self) -> bool:
632
+ """Check if connections to array streaming backend are closed.
633
+
634
+ Does not matter if `parallel=True`.
635
+ """
636
+ return self._closed
637
+
638
+ def __enter__(self):
639
+ return self
640
+
641
+ def __exit__(self, exc_type, exc_val, exc_tb):
642
+ self.close()
643
+
644
+ @staticmethod
645
+ def torch_worker_init_fn(worker_id):
646
+ """`worker_init_fn` for `torch.utils.data.DataLoader`.
647
+
648
+ Improves performance for `num_workers > 1`.
649
+ """
650
+ from torch.utils.data import get_worker_info
651
+
652
+ mapped = get_worker_info().dataset
653
+ mapped.parallel = False
654
+ mapped.storages = []
655
+ mapped.conns = []
656
+ mapped._make_connections(mapped.path_list, parallel=False)