scdataloader 0.0.4__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -1,20 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import Counter
2
4
  from functools import reduce
3
- from os import PathLike
4
- from typing import List, Literal, Optional, Union
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
5
7
 
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  from lamin_utils import logger
9
- from lamindb.dev._data import _track_run_input
10
- from lamindb.dev.storage._backed_access import (
11
+ from lamindb_setup.core.upath import UPath
12
+ from lamindb.core._data import _track_run_input
13
+
14
+ from lamindb.core.storage._backed_access import (
11
15
  ArrayTypes,
12
16
  GroupTypes,
13
17
  StorageType,
14
18
  _safer_read_index,
15
19
  registry,
16
20
  )
17
- from lamindb_setup.dev.upath import UPath
21
+
22
+ if TYPE_CHECKING:
23
+ from lamindb_setup.core.types import UPathStr
18
24
 
19
25
 
20
26
  class _Connect:
@@ -43,50 +49,118 @@ class _Connect:
43
49
 
44
50
  def mapped(
45
51
  dataset,
52
+ label_keys: str | list[str] | None = None,
53
+ join: Literal["inner", "outer"] | None = "inner",
54
+ encode_labels: bool | list[str] = True,
55
+ unknown_label: str | dict[str, str] | None = None,
56
+ cache_categories: bool = True,
57
+ parallel: bool = False,
58
+ dtype: str | None = None,
46
59
  stream: bool = False,
47
- is_run_input: Optional[bool] = None,
48
- **kwargs,
49
- ) -> "MappedDataset":
50
- _track_run_input(dataset, is_run_input)
60
+ is_run_input: bool | None = None,
61
+ ) -> MappedCollection:
51
62
  path_list = []
52
- for file in dataset.artifacts.all():
53
- if file.suffix not in {".h5ad", ".zrad", ".zarr"}:
54
- logger.warning(f"Ignoring file with suffix {file.suffix}")
63
+ for artifact in dataset.artifacts.all():
64
+ if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
65
+ logger.warning(f"Ignoring artifact with suffix {artifact.suffix}")
55
66
  continue
56
- elif not stream and file.suffix == ".h5ad":
57
- path_list.append(file.stage())
67
+ elif not stream:
68
+ path_list.append(artifact.stage())
58
69
  else:
59
- path_list.append(file.path)
60
- return MappedDataset(path_list, **kwargs)
70
+ path_list.append(artifact.path)
71
+ ds = MappedCollection(
72
+ path_list,
73
+ label_keys,
74
+ join,
75
+ encode_labels,
76
+ unknown_label,
77
+ cache_categories,
78
+ parallel,
79
+ dtype,
80
+ )
81
+ # track only if successful
82
+ _track_run_input(dataset, is_run_input)
83
+ return ds
61
84
 
62
85
 
63
- class MappedDataset:
64
- """Map-style dataset for use in data loaders.
86
+ class MappedCollection:
87
+ """Map-style collection for use in data loaders.
65
88
 
66
- This currently only works for collections of `AnnData` objects.
89
+ This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
90
+ <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
67
91
 
68
- For an example, see :meth:`~lamindb.Dataset.mapped`.
92
+ If your `AnnData` collection is in the cloud, move them into a local cache
93
+ first for faster access.
69
94
 
70
95
  .. note::
71
96
 
72
- A similar data loader exists `here
73
- <https://github.com/Genentech/scimilarity>`__.
97
+ For a guide, see :doc:`docs:scrna5`.
98
+
99
+ For more convenient use within :class:`~lamindb.core.MappedCollection`,
100
+ see :meth:`~lamindb.Collection.mapped`.
101
+
102
+ This currently only works for collections of `AnnData` objects.
103
+
104
+ The implementation was influenced by the `SCimilarity
105
+ <https://github.com/Genentech/scimilarity>`__ data loader.
106
+
107
+
108
+ Args:
109
+ path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
110
+ label_keys: Columns of the ``.obs`` slot that store labels.
111
+ join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
112
+ does not join.
113
+ encode_labels: Encode labels into integers.
114
+ Can be a list with elements from ``label_keys```.
115
+ unknown_label: Encode this label to -1.
116
+ Can be a dictionary with keys from ``label_keys`` if ``encode_labels=True```
117
+ or from ``encode_labels`` if it is a list.
118
+ cache_categories: Enable caching categories of ``label_keys`` for faster access.
119
+ parallel: Enable sampling with multiple processes.
120
+ dtype: Convert numpy arrays from ``.X`` to this dtype on selection.
74
121
  """
75
122
 
76
123
  def __init__(
77
124
  self,
78
- path_list: List[Union[str, PathLike]],
79
- label_keys: Optional[Union[str, List[str]]] = None,
80
- join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
81
- encode_labels: Optional[Union[bool, List[str]]] = False,
125
+ path_list: list[UPathStr],
126
+ label_keys: str | list[str] | None = None,
127
+ join: Literal["inner", "outer", "auto"] | None = "inner",
128
+ encode_labels: bool | list[str] = True,
129
+ unknown_label: str | dict[str, str] | None = None,
130
+ cache_categories: bool = True,
82
131
  parallel: bool = False,
83
- unknown_class: str = "unknown",
132
+ dtype: str | None = None,
84
133
  ):
85
- self.storages = []
86
- self.conns = []
134
+ assert join in {None, "inner", "outer", "auto"}
135
+
136
+ label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
137
+ self.label_keys = label_keys
138
+
139
+ if isinstance(encode_labels, list):
140
+ if len(encode_labels) == 0:
141
+ encode_labels = False
142
+ elif label_keys is None or not all(
143
+ enc_label in label_keys for enc_label in encode_labels
144
+ ):
145
+ raise ValueError(
146
+ "All elements of `encode_labels` should be in `label_keys`."
147
+ )
148
+ else:
149
+ if encode_labels:
150
+ encode_labels = label_keys if label_keys is not None else False
151
+ self.encode_labels = encode_labels
152
+
153
+ if encode_labels and isinstance(unknown_label, dict):
154
+ if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
155
+ raise ValueError(
156
+ "All keys of `unknown_label` should be in `encode_labels` and `label_keys`."
157
+ )
158
+ self.unknown_label = unknown_label
159
+
160
+ self.storages = [] # type: ignore
161
+ self.conns = [] # type: ignore
87
162
  self.parallel = parallel
88
- self.unknown_class = unknown_class
89
- self.path_list = path_list
163
+ self._path_list = path_list
90
164
  self._make_connections(path_list, parallel)
91
165
 
92
166
  self.n_obs_list = []
@@ -98,7 +172,7 @@ class MappedDataset:
98
172
  if "ensembl_gene_id" in store["var"]
99
173
  else store["var"]["_index"]
100
174
  )
101
- if join_vars == "None":
175
+ if join is None:
102
176
  if not all(
103
177
  [
104
178
  i <= j
@@ -118,27 +192,21 @@ class MappedDataset:
118
192
  self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
119
193
  self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
120
194
 
121
- self.join_vars = join_vars if len(path_list) > 1 else None
195
+ self.join_vars = join
122
196
  self.var_indices = None
123
- if self.join_vars != "None":
197
+ if self.join_vars is not None:
124
198
  self._make_join_vars()
125
199
 
126
- self.encode_labels = encode_labels
127
- self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
128
- if isinstance(encode_labels, bool):
129
- if encode_labels:
130
- encode_labels = label_keys
200
+ if self.label_keys is not None:
201
+ if cache_categories:
202
+ self._cache_categories(self.label_keys)
131
203
  else:
132
- encode_labels = []
133
- if isinstance(encode_labels, list):
134
- self.encoders = {}
135
- for label in encode_labels:
136
- cats = self.get_merged_categories(label)
137
- self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
138
- if unknown_class in self.encoders[label]:
139
- self.encoders[label][unknown_class] = -1
140
- else:
141
- self.encoders = {}
204
+ self._cache_cats: dict = {}
205
+ self.encoders: dict = {}
206
+ if self.encode_labels:
207
+ self._make_encoders(self.encode_labels) # type: ignore
208
+
209
+ self._dtype = dtype
142
210
  self._closed = False
143
211
 
144
212
  def _make_connections(self, path_list: list, parallel: bool):
@@ -154,31 +222,63 @@ class MappedDataset:
154
222
  self.conns.append(conn)
155
223
  self.storages.append(storage)
156
224
 
225
+ def _cache_categories(self, label_keys: list):
226
+ self._cache_cats = {}
227
+ decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
228
+ for label in label_keys:
229
+ self._cache_cats[label] = []
230
+ for storage in self.storages:
231
+ with _Connect(storage) as store:
232
+ cats = self._get_categories(store, label)
233
+ if cats is not None:
234
+ cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
235
+ self._cache_cats[label].append(cats)
236
+
237
+ def _make_encoders(self, encode_labels: list):
238
+ for label in encode_labels:
239
+ cats = self.get_merged_categories(label)
240
+ encoder = {}
241
+ if isinstance(self.unknown_label, dict):
242
+ unknown_label = self.unknown_label.get(label, None)
243
+ else:
244
+ unknown_label = self.unknown_label
245
+ if unknown_label is not None and unknown_label in cats:
246
+ cats.remove(unknown_label)
247
+ encoder[unknown_label] = -1
248
+ cats = list(cats)
249
+ cats.sort()
250
+ encoder.update({cat: i for i, cat in enumerate(cats)})
251
+ self.encoders[label] = encoder
252
+
157
253
  def _make_join_vars(self):
158
254
  var_list = []
159
255
  for storage in self.storages:
160
256
  with _Connect(storage) as store:
161
257
  var_list.append(_safer_read_index(store["var"]))
162
- if self.join_vars == "auto":
163
- vars_eq = all([var_list[0].equals(vrs) for vrs in var_list[1:]])
164
- if vars_eq:
165
- self.join_vars = None
166
- return
167
- else:
168
- self.join_vars = "inner"
258
+
259
+ self.var_joint = None
260
+ vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
261
+ if vars_eq:
262
+ self.join_vars = None
263
+ self.var_joint = var_list[0]
264
+ return
169
265
  if self.join_vars == "inner":
170
266
  self.var_joint = reduce(pd.Index.intersection, var_list)
171
267
  if len(self.var_joint) == 0:
172
268
  raise ValueError(
173
- "The provided AnnData objects don't have shared varibales."
269
+ "The provided AnnData objects don't have shared varibales.\n"
270
+ "Use join='outer'."
174
271
  )
175
272
  self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
273
+ elif self.join_vars == "outer":
274
+ self.var_joint = reduce(pd.Index.union, var_list)
275
+ self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
176
276
 
177
277
  def _check_aligned_vars(self, vars: list):
178
278
  i = 0
179
279
  for storage in self.storages:
180
280
  with _Connect(storage) as store:
181
- if vars == _safer_read_index(store["var"]).tolist():
281
+ if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
182
282
  i += 1
183
283
  print("{}% are aligned".format(i * 100 / len(self.storages)))
184
284
 
@@ -189,46 +289,75 @@ class MappedDataset:
189
289
  obs_idx = self.indices[idx]
190
290
  storage_idx = self.storage_idx[idx]
191
291
  if self.var_indices is not None:
192
- var_idxs = self.var_indices[storage_idx]
292
+ var_idxs_join = self.var_indices[storage_idx]
193
293
  else:
194
- var_idxs = None
294
+ var_idxs_join = None
295
+
195
296
  with _Connect(self.storages[storage_idx]) as store:
196
- out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
297
+ out = {"x": self._get_data_idx(store, obs_idx, var_idxs_join)}
298
+ out["_storage_idx"] = storage_idx
197
299
  if self.label_keys is not None:
198
- for _, label in enumerate(self.label_keys):
199
- label_idx = self.get_label_idx(store, obs_idx, label)
200
- if label in self.encoders:
201
- out.update({label: self.encoders[label][label_idx]})
300
+ for label in self.label_keys:
301
+ if label in self._cache_cats:
302
+ cats = self._cache_cats[label][storage_idx]
303
+ if cats is None:
304
+ cats = []
202
305
  else:
203
- out.update({label: label_idx})
204
- out.update({"dataset": storage_idx})
306
+ cats = None
307
+ label_idx = self._get_label_idx(store, obs_idx, label, cats)
308
+ if label in self.encoders:
309
+ label_idx = self.encoders[label][label_idx]
310
+ out[label] = label_idx
205
311
  return out
206
312
 
207
- def get_data_idx(
313
+ def _get_data_idx(
208
314
  self,
209
- storage: StorageType,
315
+ storage: StorageType, # type: ignore
210
316
  idx: int,
211
- var_idxs: Optional[list] = None,
212
- layer_key: Optional[str] = None, # type: ignore # noqa
317
+ var_idxs_join: list | None = None,
318
+ layer_key: str | None = None,
213
319
  ):
214
320
  """Get the index for the data."""
215
- layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore # noqa
321
+ layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore
216
322
  if isinstance(layer, ArrayTypes): # type: ignore
217
- # todo: better way to select variables
218
-
219
- return layer[idx] if var_idxs is None else layer[idx][var_idxs]
323
+ layer_idx = layer[idx]
324
+ if self.join_vars is None:
325
+ result = layer_idx
326
+ if self._dtype is not None:
327
+ result = result.astype(self._dtype, copy=False)
328
+ elif self.join_vars == "outer":
329
+ dtype = layer_idx.dtype if self._dtype is None else self._dtype
330
+ result = np.zeros(len(self.var_joint), dtype=dtype)
331
+ result[var_idxs_join] = layer_idx
332
+ else: # inner join
333
+ result = layer_idx[var_idxs_join]
334
+ if self._dtype is not None:
335
+ result = result.astype(self._dtype, copy=False)
336
+ return result
220
337
  else: # assume csr_matrix here
221
338
  data = layer["data"]
222
339
  indices = layer["indices"]
223
340
  indptr = layer["indptr"]
224
341
  s = slice(*(indptr[idx : idx + 2]))
225
- # this requires more memory than csr_matrix when var_idxs is not None
226
- # but it is faster
227
- layer_idx = np.zeros(layer.attrs["shape"][1])
228
- layer_idx[indices[s]] = data[s]
229
- return layer_idx if var_idxs is None else layer_idx[var_idxs]
342
+ data_s = data[s]
343
+ dtype = data_s.dtype if self._dtype is None else self._dtype
344
+ if self.join_vars == "outer":
345
+ layer_idx = np.zeros(len(self.var_joint), dtype=dtype)
346
+ layer_idx[var_idxs_join[indices[s]]] = data_s
347
+ else:
348
+ layer_idx = np.zeros(layer.attrs["shape"][1], dtype=dtype)
349
+ layer_idx[indices[s]] = data_s
350
+ if self.join_vars == "inner":
351
+ layer_idx = layer_idx[var_idxs_join]
352
+ return layer_idx
230
353
 
231
- def get_label_idx(self, storage: StorageType, idx: int, label_key: str): # type: ignore # noqa
354
+ def _get_label_idx(
355
+ self,
356
+ storage: StorageType,
357
+ idx: int,
358
+ label_key: str,
359
+ categories: list | None = None,
360
+ ):
232
361
  """Get the index for the label by key."""
233
362
  obs = storage["obs"] # type: ignore
234
363
  # how backwards compatible do we want to be here actually?
@@ -240,25 +369,29 @@ class MappedDataset:
240
369
  label = labels[idx]
241
370
  else:
242
371
  label = labels["codes"][idx]
243
-
244
- cats = self.get_categories(storage, label_key)
245
- if cats is not None:
372
+ if categories is not None:
373
+ cats = categories
374
+ else:
375
+ cats = self._get_categories(storage, label_key)
376
+ if cats is not None and len(cats) > 0:
246
377
  label = cats[label]
247
378
  if isinstance(label, bytes):
248
379
  label = label.decode("utf-8")
249
380
  return label
250
381
 
251
- def get_label_weights(self, label_keys: Union[str, List[str]], scaler=10):
252
- """Get all weights for a given label key."""
253
- if type(label_keys) is not list:
382
+ def get_label_weights(self, label_keys: str | list[str], scaler: int = 10):
383
+ """Get all weights for the given label keys."""
384
+ if isinstance(label_keys, str):
254
385
  label_keys = [label_keys]
255
- for i, val in enumerate(label_keys):
256
- if val not in self.label_keys:
257
- raise ValueError(f"{val} is not a valid label key.")
258
- if i == 0:
259
- labels = self.get_merged_labels(val)
260
- else:
261
- labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
386
+ labels_list = []
387
+ for label_key in label_keys:
388
+ labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
389
+ labels_list.append(labels_to_str)
390
+ if len(labels_list) > 1:
391
+ labels = reduce(lambda a, b: a + b, labels_list)
392
+ else:
393
+ labels = labels_list[0]
394
+
262
395
  counter = Counter(labels) # type: ignore
263
396
  rn = {n: i for i, n in enumerate(counter.keys())}
264
397
  labels = np.array([rn[label] for label in labels])
@@ -267,14 +400,17 @@ class MappedDataset:
267
400
  return weights, labels
268
401
 
269
402
  def get_merged_labels(self, label_key: str):
270
- """Get merged labels."""
403
+ """Get merged labels for `label_key` from all `.obs`."""
271
404
  labels_merge = []
272
405
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
273
- for storage in self.storages:
406
+ for i, storage in enumerate(self.storages):
274
407
  with _Connect(storage) as store:
275
- codes = self.get_codes(store, label_key)
408
+ codes = self._get_codes(store, label_key)
276
409
  labels = decode(codes) if isinstance(codes[0], bytes) else codes
277
- cats = self.get_categories(store, label_key)
410
+ if label_key in self._cache_cats:
411
+ cats = self._cache_cats[label_key][i]
412
+ else:
413
+ cats = self._get_categories(store, label_key)
278
414
  if cats is not None:
279
415
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
280
416
  labels = cats[labels]
@@ -282,22 +418,25 @@ class MappedDataset:
282
418
  return np.hstack(labels_merge)
283
419
 
284
420
  def get_merged_categories(self, label_key: str):
285
- """Get merged categories."""
421
+ """Get merged categories for `label_key` from all `.obs`."""
286
422
  cats_merge = set()
287
423
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
288
- for storage in self.storages:
424
+ for i, storage in enumerate(self.storages):
289
425
  with _Connect(storage) as store:
290
- cats = self.get_categories(store, label_key)
426
+ if label_key in self._cache_cats:
427
+ cats = self._cache_cats[label_key][i]
428
+ else:
429
+ cats = self._get_categories(store, label_key)
291
430
  if cats is not None:
292
431
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
293
432
  cats_merge.update(cats)
294
433
  else:
295
- codes = self.get_codes(store, label_key)
434
+ codes = self._get_codes(store, label_key)
296
435
  codes = decode(codes) if isinstance(codes[0], bytes) else codes
297
436
  cats_merge.update(codes)
298
437
  return cats_merge
299
438
 
300
- def get_categories(self, storage: StorageType, label_key: str): # type: ignore
439
+ def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
301
440
  """Get categories."""
302
441
  obs = storage["obs"] # type: ignore
303
442
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -324,8 +463,9 @@ class MappedDataset:
324
463
  return labels.attrs["categories"]
325
464
  else:
326
465
  return None
466
+ return None
327
467
 
328
- def get_codes(self, storage: StorageType, label_key: str): # type: ignore
468
+ def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
329
469
  """Get codes."""
330
470
  obs = storage["obs"] # type: ignore
331
471
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -338,7 +478,10 @@ class MappedDataset:
338
478
  return label["codes"][...]
339
479
 
340
480
  def close(self):
341
- """Close connection to array streaming backend."""
481
+ """Close connections to array streaming backend.
482
+
483
+ No effect if `parallel=True`.
484
+ """
342
485
  for storage in self.storages:
343
486
  if hasattr(storage, "close"):
344
487
  storage.close()
@@ -349,6 +492,10 @@ class MappedDataset:
349
492
 
350
493
  @property
351
494
  def closed(self):
495
+ """Check if connections to array streaming backend are closed.
496
+
497
+ Does not matter if `parallel=True`.
498
+ """
352
499
  return self._closed
353
500
 
354
501
  def __enter__(self):
@@ -356,3 +503,17 @@ class MappedDataset:
356
503
 
357
504
  def __exit__(self, exc_type, exc_val, exc_tb):
358
505
  self.close()
506
+
507
+ @staticmethod
508
+ def torch_worker_init_fn(worker_id):
509
+ """`worker_init_fn` for `torch.utils.data.DataLoader`.
510
+
511
+ Improves performance for `num_workers > 1`.
512
+ """
513
+ from torch.utils.data import get_worker_info
514
+
515
+ mapped = get_worker_info().dataset
516
+ mapped.parallel = False
517
+ mapped.storages = []
518
+ mapped.conns = []
519
+ mapped._make_connections(mapped._path_list, parallel=False)