scdataloader 0.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -1,20 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import Counter
2
4
  from functools import reduce
3
- from os import PathLike
4
- from typing import List, Literal, Optional, Union
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
5
7
 
6
8
  import numpy as np
7
9
  import pandas as pd
8
- from lamin_utils import logger
9
- from lamindb.dev._data import _track_run_input
10
- from lamindb.dev.storage._backed_access import (
10
+ from lamindb_setup.core.upath import UPath
11
+
12
+ from lamindb.core.storage._anndata_accessor import (
13
+ ArrayType,
11
14
  ArrayTypes,
15
+ GroupType,
12
16
  GroupTypes,
13
17
  StorageType,
14
18
  _safer_read_index,
15
19
  registry,
16
20
  )
17
- from lamindb_setup.dev.upath import UPath
21
+
22
+ if TYPE_CHECKING:
23
+ from lamindb_setup.core.types import UPathStr
18
24
 
19
25
 
20
26
  class _Connect:
@@ -41,52 +47,108 @@ class _Connect:
41
47
  self.conn.close()
42
48
 
43
49
 
44
- def mapped(
45
- dataset,
46
- stream: bool = False,
47
- is_run_input: Optional[bool] = None,
48
- **kwargs,
49
- ) -> "MappedDataset":
50
- _track_run_input(dataset, is_run_input)
51
- path_list = []
52
- for file in dataset.artifacts.all():
53
- if file.suffix not in {".h5ad", ".zrad", ".zarr"}:
54
- logger.warning(f"Ignoring file with suffix {file.suffix}")
55
- continue
56
- elif not stream and file.suffix == ".h5ad":
57
- path_list.append(file.stage())
58
- else:
59
- path_list.append(file.path)
60
- return MappedDataset(path_list, **kwargs)
61
-
50
+ class MappedCollection:
51
+ """Map-style collection for use in data loaders.
62
52
 
63
- class MappedDataset:
64
- """Map-style dataset for use in data loaders.
53
+ This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
54
+ <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
65
55
 
66
- This currently only works for collections of `AnnData` objects.
56
+ If your `AnnData` collection is in the cloud, move them into a local cache
57
+ first for faster access.
67
58
 
68
- For an example, see :meth:`~lamindb.Dataset.mapped`.
59
+ `__getitem__` of the `MappedCollection` object takes a single integer index
60
+ and returns a dictionary with the observation data sample for this index from
61
+ the `AnnData` objects in `path_list`. The dictionary has keys for `layers_keys`
62
+ (`.X` is in `"X"`), `obs_keys`, `obsm_keys` (under `f"obsm_{key}"`) and also `"_store_idx"`
63
+ for the index of the `AnnData` object containing this observation sample.
69
64
 
70
65
  .. note::
71
66
 
72
- A similar data loader exists `here
73
- <https://github.com/Genentech/scimilarity>`__.
67
+ For a guide, see :doc:`docs:scrna5`.
68
+
69
+ For more convenient use within :class:`~lamindb.core.MappedCollection`,
70
+ see :meth:`~lamindb.Collection.mapped`.
71
+
72
+ This currently only works for collections of `AnnData` objects.
73
+
74
+ The implementation was influenced by the `SCimilarity
75
+ <https://github.com/Genentech/scimilarity>`__ data loader.
76
+
77
+
78
+ Args:
79
+ path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
80
+ layers_keys: Keys from the ``.layers`` slot. ``layers_keys=None`` or ``"X"`` in the list
81
+ retrieves ``.X``.
82
+ obsm_keys: Keys from the ``.obsm`` slots.
83
+ obs_keys: Keys from the ``.obs`` slots.
84
+ join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
85
+ does not join.
86
+ encode_labels: Encode labels into integers.
87
+ Can be a list with elements from ``obs_keys``.
88
+ unknown_label: Encode this label to -1.
89
+ Can be a dictionary with keys from ``obs_keys`` if ``encode_labels=True``
90
+ or from ``encode_labels`` if it is a list.
91
+ cache_categories: Enable caching categories of ``obs_keys`` for faster access.
92
+ parallel: Enable sampling with multiple processes.
93
+ dtype: Convert numpy arrays from ``.X``, ``.layers`` and ``.obsm``
74
94
  """
75
95
 
76
96
  def __init__(
77
97
  self,
78
- path_list: List[Union[str, PathLike]],
79
- label_keys: Optional[Union[str, List[str]]] = None,
80
- join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
81
- encode_labels: Optional[Union[bool, List[str]]] = False,
98
+ path_list: list[UPathStr],
99
+ layers_keys: str | list[str] | None = None,
100
+ obs_keys: str | list[str] | None = None,
101
+ obsm_keys: str | list[str] | None = None,
102
+ join: Literal["inner", "outer"] | None = "inner",
103
+ encode_labels: bool | list[str] = True,
104
+ unknown_label: str | dict[str, str] | None = None,
105
+ cache_categories: bool = True,
82
106
  parallel: bool = False,
83
- unknown_class: str = "unknown",
107
+ dtype: str | None = None,
84
108
  ):
85
- self.storages = []
86
- self.conns = []
109
+ if join not in {None, "inner", "outer"}: # pragma: nocover
110
+ raise ValueError(
111
+ f"join must be one of None, 'inner, or 'outer' but was {type(join)}"
112
+ )
113
+
114
+ if layers_keys is None:
115
+ self.layers_keys = ["X"]
116
+ else:
117
+ self.layers_keys = (
118
+ [layers_keys] if isinstance(layers_keys, str) else layers_keys
119
+ )
120
+
121
+ obsm_keys = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
122
+ self.obsm_keys = obsm_keys
123
+
124
+ obs_keys = [obs_keys] if isinstance(obs_keys, str) else obs_keys
125
+ self.obs_keys = obs_keys
126
+
127
+ if isinstance(encode_labels, list):
128
+ if len(encode_labels) == 0:
129
+ encode_labels = False
130
+ elif obs_keys is None or not all(
131
+ enc_label in obs_keys for enc_label in encode_labels
132
+ ):
133
+ raise ValueError(
134
+ "All elements of `encode_labels` should be in `obs_keys`."
135
+ )
136
+ else:
137
+ if encode_labels:
138
+ encode_labels = obs_keys if obs_keys is not None else False
139
+ self.encode_labels = encode_labels
140
+
141
+ if encode_labels and isinstance(unknown_label, dict):
142
+ if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
143
+ raise ValueError(
144
+ "All keys of `unknown_label` should be in `encode_labels` and `obs_keys`."
145
+ )
146
+ self.unknown_label = unknown_label
147
+
148
+ self.storages = [] # type: ignore
149
+ self.conns = [] # type: ignore
87
150
  self.parallel = parallel
88
- self.unknown_class = unknown_class
89
- self.path_list = path_list
151
+ self._path_list = path_list
90
152
  self._make_connections(path_list, parallel)
91
153
 
92
154
  self.n_obs_list = []
@@ -98,7 +160,7 @@ class MappedDataset:
98
160
  if "ensembl_gene_id" in store["var"]
99
161
  else store["var"]["_index"]
100
162
  )
101
- if join_vars == "None":
163
+ if join is None:
102
164
  if not all(
103
165
  [
104
166
  i <= j
@@ -118,27 +180,25 @@ class MappedDataset:
118
180
  self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
119
181
  self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
120
182
 
121
- self.join_vars = join_vars if len(path_list) > 1 else None
183
+ self.join_vars = join
122
184
  self.var_indices = None
123
- if self.join_vars != "None":
185
+ self.var_joint = None
186
+ self.n_vars_list = None
187
+ self.n_vars = None
188
+ if self.join_vars is not None:
124
189
  self._make_join_vars()
190
+ self.n_vars = len(self.var_joint)
125
191
 
126
- self.encode_labels = encode_labels
127
- self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
128
- if isinstance(encode_labels, bool):
129
- if encode_labels:
130
- encode_labels = label_keys
192
+ if self.obs_keys is not None:
193
+ if cache_categories:
194
+ self._cache_categories(self.obs_keys)
131
195
  else:
132
- encode_labels = []
133
- if isinstance(encode_labels, list):
134
- self.encoders = {}
135
- for label in encode_labels:
136
- cats = self.get_merged_categories(label)
137
- self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
138
- if unknown_class in self.encoders[label]:
139
- self.encoders[label][unknown_class] = -1
140
- else:
141
- self.encoders = {}
196
+ self._cache_cats: dict = {}
197
+ self.encoders: dict = {}
198
+ if self.encode_labels:
199
+ self._make_encoders(self.encode_labels) # type: ignore
200
+
201
+ self._dtype = dtype
142
202
  self._closed = False
143
203
 
144
204
  def _make_connections(self, path_list: list, parallel: bool):
@@ -154,81 +214,171 @@ class MappedDataset:
154
214
  self.conns.append(conn)
155
215
  self.storages.append(storage)
156
216
 
217
+ def _cache_categories(self, obs_keys: list):
218
+ self._cache_cats = {}
219
+ decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
220
+ for label in obs_keys:
221
+ self._cache_cats[label] = []
222
+ for storage in self.storages:
223
+ with _Connect(storage) as store:
224
+ cats = self._get_categories(store, label)
225
+ if cats is not None:
226
+ cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
227
+ self._cache_cats[label].append(cats)
228
+
229
+ def _make_encoders(self, encode_labels: list):
230
+ for label in encode_labels:
231
+ cats = self.get_merged_categories(label)
232
+ encoder = {}
233
+ if isinstance(self.unknown_label, dict):
234
+ unknown_label = self.unknown_label.get(label, None)
235
+ else:
236
+ unknown_label = self.unknown_label
237
+ if unknown_label is not None and unknown_label in cats:
238
+ cats.remove(unknown_label)
239
+ encoder[unknown_label] = -1
240
+ cats = list(cats)
241
+ cats.sort()
242
+ encoder.update({cat: i for i, cat in enumerate(cats)})
243
+ self.encoders[label] = encoder
244
+
157
245
  def _make_join_vars(self):
158
246
  var_list = []
247
+ self.n_vars_list = []
159
248
  for storage in self.storages:
160
249
  with _Connect(storage) as store:
161
- var_list.append(_safer_read_index(store["var"]))
162
- if self.join_vars == "auto":
163
- vars_eq = all([var_list[0].equals(vrs) for vrs in var_list[1:]])
164
- if vars_eq:
165
- self.join_vars = None
166
- return
167
- else:
168
- self.join_vars = "inner"
250
+ vars = _safer_read_index(store["var"])
251
+ var_list.append(vars)
252
+ self.n_vars_list.append(len(vars))
253
+
254
+ self.var_joint = None
255
+ vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
256
+ if vars_eq:
257
+ self.join_vars = None
258
+ self.var_joint = var_list[0]
259
+ return
260
+
169
261
  if self.join_vars == "inner":
170
262
  self.var_joint = reduce(pd.Index.intersection, var_list)
171
263
  if len(self.var_joint) == 0:
172
264
  raise ValueError(
173
- "The provided AnnData objects don't have shared varibales."
265
+ "The provided AnnData objects don't have shared varibales.\n"
266
+ "Use join='outer'."
174
267
  )
175
268
  self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
269
+ elif self.join_vars == "outer":
270
+ self.var_joint = reduce(pd.Index.union, var_list)
271
+ self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
176
272
 
177
273
  def _check_aligned_vars(self, vars: list):
178
274
  i = 0
179
275
  for storage in self.storages:
180
276
  with _Connect(storage) as store:
181
- if vars == _safer_read_index(store["var"]).tolist():
277
+ if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
182
278
  i += 1
183
279
  print("{}% are aligned".format(i * 100 / len(self.storages)))
184
280
 
185
281
  def __len__(self):
186
282
  return self.n_obs
187
283
 
284
+ @property
285
+ def shape(self):
286
+ """Shape of the (virtually aligned) dataset."""
287
+ return (self.n_obs, self.n_vars)
288
+
289
+ @property
290
+ def original_shapes(self):
291
+ """Shapes of the underlying AnnData objects."""
292
+ if self.n_vars_list is None:
293
+ n_vars_list = [None] * len(self.n_obs_list)
294
+ else:
295
+ n_vars_list = self.n_vars_list
296
+ return list(zip(self.n_obs_list, n_vars_list))
297
+
188
298
  def __getitem__(self, idx: int):
189
299
  obs_idx = self.indices[idx]
190
300
  storage_idx = self.storage_idx[idx]
191
301
  if self.var_indices is not None:
192
- var_idxs = self.var_indices[storage_idx]
302
+ var_idxs_join = self.var_indices[storage_idx]
193
303
  else:
194
- var_idxs = None
304
+ var_idxs_join = None
305
+
195
306
  with _Connect(self.storages[storage_idx]) as store:
196
- out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
197
- if self.label_keys is not None:
198
- for _, label in enumerate(self.label_keys):
199
- label_idx = self.get_label_idx(store, obs_idx, label)
200
- if label in self.encoders:
201
- out.update({label: self.encoders[label][label_idx]})
307
+ out = {}
308
+ for layers_key in self.layers_keys:
309
+ lazy_data = (
310
+ store["X"] if layers_key == "X" else store["layers"][layers_key]
311
+ )
312
+ out[layers_key] = self._get_data_idx(
313
+ lazy_data, obs_idx, self.join_vars, var_idxs_join, self.n_vars
314
+ )
315
+ if self.obsm_keys is not None:
316
+ for obsm_key in self.obsm_keys:
317
+ lazy_data = store["obsm"][obsm_key]
318
+ out[f"obsm_{obsm_key}"] = self._get_data_idx(lazy_data, obs_idx)
319
+ out["_store_idx"] = storage_idx
320
+ if self.obs_keys is not None:
321
+ for label in self.obs_keys:
322
+ if label in self._cache_cats:
323
+ cats = self._cache_cats[label][storage_idx]
324
+ if cats is None:
325
+ cats = []
202
326
  else:
203
- out.update({label: label_idx})
204
- out.update({"dataset": storage_idx})
327
+ cats = None
328
+ label_idx = self._get_obs_idx(store, obs_idx, label, cats)
329
+ if label in self.encoders:
330
+ label_idx = self.encoders[label][label_idx]
331
+ out[label] = label_idx
205
332
  return out
206
333
 
207
- def get_data_idx(
334
+ def _get_data_idx(
208
335
  self,
209
- storage: StorageType,
336
+ lazy_data: ArrayType | GroupType, # type: ignore
210
337
  idx: int,
211
- var_idxs: Optional[list] = None,
212
- layer_key: Optional[str] = None, # type: ignore # noqa
338
+ join_vars: Literal["inner", "outer"] | None = None,
339
+ var_idxs_join: list | None = None,
340
+ n_vars_out: int | None = None,
213
341
  ):
214
342
  """Get the index for the data."""
215
- layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore # noqa
216
- if isinstance(layer, ArrayTypes): # type: ignore
217
- # todo: better way to select variables
218
-
219
- return layer[idx] if var_idxs is None else layer[idx][var_idxs]
343
+ if isinstance(lazy_data, ArrayTypes): # type: ignore
344
+ lazy_data_idx = lazy_data[idx] # type: ignore
345
+ if join_vars is None:
346
+ result = lazy_data_idx
347
+ if self._dtype is not None:
348
+ result = result.astype(self._dtype, copy=False)
349
+ elif join_vars == "outer":
350
+ dtype = lazy_data_idx.dtype if self._dtype is None else self._dtype
351
+ result = np.zeros(n_vars_out, dtype=dtype)
352
+ result[var_idxs_join] = lazy_data_idx
353
+ else: # inner join
354
+ result = lazy_data_idx[var_idxs_join]
355
+ if self._dtype is not None:
356
+ result = result.astype(self._dtype, copy=False)
357
+ return result
220
358
  else: # assume csr_matrix here
221
- data = layer["data"]
222
- indices = layer["indices"]
223
- indptr = layer["indptr"]
359
+ data = lazy_data["data"] # type: ignore
360
+ indices = lazy_data["indices"] # type: ignore
361
+ indptr = lazy_data["indptr"] # type: ignore
224
362
  s = slice(*(indptr[idx : idx + 2]))
225
- # this requires more memory than csr_matrix when var_idxs is not None
226
- # but it is faster
227
- layer_idx = np.zeros(layer.attrs["shape"][1])
228
- layer_idx[indices[s]] = data[s]
229
- return layer_idx if var_idxs is None else layer_idx[var_idxs]
363
+ data_s = data[s]
364
+ dtype = data_s.dtype if self._dtype is None else self._dtype
365
+ if join_vars == "outer":
366
+ lazy_data_idx = np.zeros(n_vars_out, dtype=dtype)
367
+ lazy_data_idx[var_idxs_join[indices[s]]] = data_s
368
+ else:
369
+ lazy_data_idx = np.zeros(lazy_data.attrs["shape"][1], dtype=dtype) # type: ignore
370
+ lazy_data_idx[indices[s]] = data_s
371
+ if join_vars == "inner":
372
+ lazy_data_idx = lazy_data_idx[var_idxs_join]
373
+ return lazy_data_idx
230
374
 
231
- def get_label_idx(self, storage: StorageType, idx: int, label_key: str): # type: ignore # noqa
375
+ def _get_obs_idx(
376
+ self,
377
+ storage: StorageType,
378
+ idx: int,
379
+ label_key: str,
380
+ categories: list | None = None,
381
+ ):
232
382
  """Get the index for the label by key."""
233
383
  obs = storage["obs"] # type: ignore
234
384
  # how backwards compatible do we want to be here actually?
@@ -240,25 +390,29 @@ class MappedDataset:
240
390
  label = labels[idx]
241
391
  else:
242
392
  label = labels["codes"][idx]
243
-
244
- cats = self.get_categories(storage, label_key)
245
- if cats is not None:
393
+ if categories is not None:
394
+ cats = categories
395
+ else:
396
+ cats = self._get_categories(storage, label_key)
397
+ if cats is not None and len(cats) > 0:
246
398
  label = cats[label]
247
399
  if isinstance(label, bytes):
248
400
  label = label.decode("utf-8")
249
401
  return label
250
402
 
251
- def get_label_weights(self, label_keys: Union[str, List[str]], scaler=10):
252
- """Get all weights for a given label key."""
253
- if type(label_keys) is not list:
254
- label_keys = [label_keys]
255
- for i, val in enumerate(label_keys):
256
- if val not in self.label_keys:
257
- raise ValueError(f"{val} is not a valid label key.")
258
- if i == 0:
259
- labels = self.get_merged_labels(val)
260
- else:
261
- labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
403
+ def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
404
+ """Get all weights for the given label keys."""
405
+ if isinstance(obs_keys, str):
406
+ obs_keys = [obs_keys]
407
+ labels_list = []
408
+ for label_key in obs_keys:
409
+ labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
410
+ labels_list.append(labels_to_str)
411
+ if len(labels_list) > 1:
412
+ labels = reduce(lambda a, b: a + b, labels_list)
413
+ else:
414
+ labels = labels_list[0]
415
+
262
416
  counter = Counter(labels) # type: ignore
263
417
  rn = {n: i for i, n in enumerate(counter.keys())}
264
418
  labels = np.array([rn[label] for label in labels])
@@ -267,14 +421,17 @@ class MappedDataset:
267
421
  return weights, labels
268
422
 
269
423
  def get_merged_labels(self, label_key: str):
270
- """Get merged labels."""
424
+ """Get merged labels for `label_key` from all `.obs`."""
271
425
  labels_merge = []
272
426
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
273
- for storage in self.storages:
427
+ for i, storage in enumerate(self.storages):
274
428
  with _Connect(storage) as store:
275
- codes = self.get_codes(store, label_key)
429
+ codes = self._get_codes(store, label_key)
276
430
  labels = decode(codes) if isinstance(codes[0], bytes) else codes
277
- cats = self.get_categories(store, label_key)
431
+ if label_key in self._cache_cats:
432
+ cats = self._cache_cats[label_key][i]
433
+ else:
434
+ cats = self._get_categories(store, label_key)
278
435
  if cats is not None:
279
436
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
280
437
  labels = cats[labels]
@@ -282,22 +439,25 @@ class MappedDataset:
282
439
  return np.hstack(labels_merge)
283
440
 
284
441
  def get_merged_categories(self, label_key: str):
285
- """Get merged categories."""
442
+ """Get merged categories for `label_key` from all `.obs`."""
286
443
  cats_merge = set()
287
444
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
288
- for storage in self.storages:
445
+ for i, storage in enumerate(self.storages):
289
446
  with _Connect(storage) as store:
290
- cats = self.get_categories(store, label_key)
447
+ if label_key in self._cache_cats:
448
+ cats = self._cache_cats[label_key][i]
449
+ else:
450
+ cats = self._get_categories(store, label_key)
291
451
  if cats is not None:
292
452
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
293
453
  cats_merge.update(cats)
294
454
  else:
295
- codes = self.get_codes(store, label_key)
455
+ codes = self._get_codes(store, label_key)
296
456
  codes = decode(codes) if isinstance(codes[0], bytes) else codes
297
457
  cats_merge.update(codes)
298
458
  return cats_merge
299
459
 
300
- def get_categories(self, storage: StorageType, label_key: str): # type: ignore
460
+ def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
301
461
  """Get categories."""
302
462
  obs = storage["obs"] # type: ignore
303
463
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -324,8 +484,9 @@ class MappedDataset:
324
484
  return labels.attrs["categories"]
325
485
  else:
326
486
  return None
487
+ return None
327
488
 
328
- def get_codes(self, storage: StorageType, label_key: str): # type: ignore
489
+ def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
329
490
  """Get codes."""
330
491
  obs = storage["obs"] # type: ignore
331
492
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -338,7 +499,10 @@ class MappedDataset:
338
499
  return label["codes"][...]
339
500
 
340
501
  def close(self):
341
- """Close connection to array streaming backend."""
502
+ """Close connections to array streaming backend.
503
+
504
+ No effect if `parallel=True`.
505
+ """
342
506
  for storage in self.storages:
343
507
  if hasattr(storage, "close"):
344
508
  storage.close()
@@ -349,6 +513,10 @@ class MappedDataset:
349
513
 
350
514
  @property
351
515
  def closed(self):
516
+ """Check if connections to array streaming backend are closed.
517
+
518
+ Does not matter if `parallel=True`.
519
+ """
352
520
  return self._closed
353
521
 
354
522
  def __enter__(self):
@@ -356,3 +524,17 @@ class MappedDataset:
356
524
 
357
525
  def __exit__(self, exc_type, exc_val, exc_tb):
358
526
  self.close()
527
+
528
+ @staticmethod
529
+ def torch_worker_init_fn(worker_id):
530
+ """`worker_init_fn` for `torch.utils.data.DataLoader`.
531
+
532
+ Improves performance for `num_workers > 1`.
533
+ """
534
+ from torch.utils.data import get_worker_info
535
+
536
+ mapped = get_worker_info().dataset
537
+ mapped.parallel = False
538
+ mapped.storages = []
539
+ mapped.conns = []
540
+ mapped._make_connections(mapped._path_list, parallel=False)