scdataloader 0.0.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -1,20 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import Counter
2
4
  from functools import reduce
3
- from os import PathLike
4
- from typing import List, Literal, Optional, Union
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
5
7
 
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  from lamin_utils import logger
9
- from lamindb.dev._data import _track_run_input
10
- from lamindb.dev.storage._backed_access import (
11
+ from lamindb_setup.core.upath import UPath
12
+ from lamindb.core._data import _track_run_input
13
+
14
+ from lamindb.core.storage._backed_access import (
11
15
  ArrayTypes,
12
16
  GroupTypes,
13
17
  StorageType,
14
18
  _safer_read_index,
15
19
  registry,
16
20
  )
17
- from lamindb_setup.dev.upath import UPath
21
+
22
+ if TYPE_CHECKING:
23
+ from lamindb_setup.core.types import UPathStr
18
24
 
19
25
 
20
26
  class _Connect:
@@ -43,56 +49,140 @@ class _Connect:
43
49
 
44
50
  def mapped(
45
51
  dataset,
52
+ label_keys: str | list[str] | None = None,
53
+ join: Literal["inner", "outer"] | None = "inner",
54
+ encode_labels: bool | list[str] = True,
55
+ unknown_label: str | dict[str, str] | None = None,
56
+ cache_categories: bool = True,
57
+ parallel: bool = False,
58
+ dtype: str | None = None,
46
59
  stream: bool = False,
47
- is_run_input: Optional[bool] = None,
48
- **kwargs,
49
- ) -> "MappedDataset":
50
- _track_run_input(dataset, is_run_input)
60
+ is_run_input: bool | None = None,
61
+ ) -> MappedCollection:
51
62
  path_list = []
52
- for file in dataset.artifacts.all():
53
- if file.suffix not in {".h5ad", ".zrad", ".zarr"}:
54
- logger.warning(f"Ignoring file with suffix {file.suffix}")
63
+ for artifact in dataset.artifacts.all():
64
+ if artifact.suffix not in {".h5ad", ".zrad", ".zarr"}:
65
+ logger.warning(f"Ignoring artifact with suffix {artifact.suffix}")
55
66
  continue
56
- elif not stream and file.suffix == ".h5ad":
57
- path_list.append(file.stage())
67
+ elif not stream:
68
+ path_list.append(artifact.stage())
58
69
  else:
59
- path_list.append(file.path)
60
- return MappedDataset(path_list, **kwargs)
70
+ path_list.append(artifact.path)
71
+ ds = MappedCollection(
72
+ path_list,
73
+ label_keys,
74
+ join,
75
+ encode_labels,
76
+ unknown_label,
77
+ cache_categories,
78
+ parallel,
79
+ dtype,
80
+ )
81
+ # track only if successful
82
+ _track_run_input(dataset, is_run_input)
83
+ return ds
61
84
 
62
85
 
63
- class MappedDataset:
64
- """Map-style dataset for use in data loaders.
86
+ class MappedCollection:
87
+ """Map-style collection for use in data loaders.
65
88
 
66
- This currently only works for collections of `AnnData` objects.
89
+ This class virtually concatenates `AnnData` arrays as a `pytorch map-style dataset
90
+ <https://pytorch.org/docs/stable/data.html#map-style-datasets>`__.
67
91
 
68
- For an example, see :meth:`~lamindb.Dataset.mapped`.
92
+ If your `AnnData` collection is in the cloud, move them into a local cache
93
+ first for faster access.
69
94
 
70
95
  .. note::
71
96
 
72
- A similar data loader exists `here
73
- <https://github.com/Genentech/scimilarity>`__.
97
+ For a guide, see :doc:`docs:scrna5`.
98
+
99
+ For more convenient use within :class:`~lamindb.core.MappedCollection`,
100
+ see :meth:`~lamindb.Collection.mapped`.
101
+
102
+ This currently only works for collections of `AnnData` objects.
103
+
104
+ The implementation was influenced by the `SCimilarity
105
+ <https://github.com/Genentech/scimilarity>`__ data loader.
106
+
107
+
108
+ Args:
109
+ path_list: A list of paths to `AnnData` objects stored in `.h5ad` or `.zarr` formats.
110
+ label_keys: Columns of the ``.obs`` slot that store labels.
111
+ join: `"inner"` or `"outer"` virtual joins. If ``None`` is passed,
112
+ does not join.
113
+ encode_labels: Encode labels into integers.
114
+ Can be a list with elements from ``label_keys```.
115
+ unknown_label: Encode this label to -1.
116
+ Can be a dictionary with keys from ``label_keys`` if ``encode_labels=True```
117
+ or from ``encode_labels`` if it is a list.
118
+ cache_categories: Enable caching categories of ``label_keys`` for faster access.
119
+ parallel: Enable sampling with multiple processes.
120
+ dtype: Convert numpy arrays from ``.X`` to this dtype on selection.
74
121
  """
75
122
 
76
123
  def __init__(
77
124
  self,
78
- path_list: List[Union[str, PathLike]],
79
- label_keys: Optional[Union[str, List[str]]] = None,
80
- join_vars: Optional[Literal["auto", "inner", "None"]] = "auto",
81
- encode_labels: Optional[Union[bool, List[str]]] = False,
125
+ path_list: list[UPathStr],
126
+ label_keys: str | list[str] | None = None,
127
+ join: Literal["inner", "outer", "auto"] | None = "inner",
128
+ encode_labels: bool | list[str] = True,
129
+ unknown_label: str | dict[str, str] | None = None,
130
+ cache_categories: bool = True,
82
131
  parallel: bool = False,
83
- unknown_class: str = "unknown",
132
+ dtype: str | None = None,
84
133
  ):
85
- self.storages = []
86
- self.conns = []
134
+ assert join in {None, "inner", "outer", "auto"}
135
+
136
+ label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
137
+ self.label_keys = label_keys
138
+
139
+ if isinstance(encode_labels, list):
140
+ if len(encode_labels) == 0:
141
+ encode_labels = False
142
+ elif label_keys is None or not all(
143
+ enc_label in label_keys for enc_label in encode_labels
144
+ ):
145
+ raise ValueError(
146
+ "All elements of `encode_labels` should be in `label_keys`."
147
+ )
148
+ else:
149
+ if encode_labels:
150
+ encode_labels = label_keys if label_keys is not None else False
151
+ self.encode_labels = encode_labels
152
+
153
+ if encode_labels and isinstance(unknown_label, dict):
154
+ if not all(unkey in encode_labels for unkey in unknown_label): # type: ignore
155
+ raise ValueError(
156
+ "All keys of `unknown_label` should be in `encode_labels` and `label_keys`."
157
+ )
158
+ self.unknown_label = unknown_label
159
+
160
+ self.storages = [] # type: ignore
161
+ self.conns = [] # type: ignore
87
162
  self.parallel = parallel
88
- self.unknown_class = unknown_class
89
- self.path_list = path_list
163
+ self._path_list = path_list
90
164
  self._make_connections(path_list, parallel)
91
165
 
92
166
  self.n_obs_list = []
93
167
  for storage in self.storages:
94
168
  with _Connect(storage) as store:
95
169
  X = store["X"]
170
+ index = (
171
+ store["var"]["ensembl_gene_id"]
172
+ if "ensembl_gene_id" in store["var"]
173
+ else store["var"]["_index"]
174
+ )
175
+ if join is None:
176
+ if not all(
177
+ [
178
+ i <= j
179
+ for i, j in zip(
180
+ index[:99],
181
+ index[1:100],
182
+ )
183
+ ]
184
+ ):
185
+ raise ValueError("The variables are not sorted.")
96
186
  if isinstance(X, ArrayTypes): # type: ignore
97
187
  self.n_obs_list.append(X.shape[0])
98
188
  else:
@@ -102,27 +192,21 @@ class MappedDataset:
102
192
  self.indices = np.hstack([np.arange(n_obs) for n_obs in self.n_obs_list])
103
193
  self.storage_idx = np.repeat(np.arange(len(self.storages)), self.n_obs_list)
104
194
 
105
- self.join_vars = join_vars if len(path_list) > 1 else None
195
+ self.join_vars = join
106
196
  self.var_indices = None
107
- if self.join_vars != "None":
197
+ if self.join_vars is not None:
108
198
  self._make_join_vars()
109
199
 
110
- self.encode_labels = encode_labels
111
- self.label_keys = [label_keys] if isinstance(label_keys, str) else label_keys
112
- if isinstance(encode_labels, bool):
113
- if encode_labels:
114
- encode_labels = label_keys
200
+ if self.label_keys is not None:
201
+ if cache_categories:
202
+ self._cache_categories(self.label_keys)
115
203
  else:
116
- encode_labels = []
117
- if isinstance(encode_labels, list):
118
- self.encoders = {}
119
- for label in encode_labels:
120
- cats = self.get_merged_categories(label)
121
- self.encoders[label] = {cat: i for i, cat in enumerate(cats)}
122
- if unknown_class in self.encoders[label]:
123
- self.encoders[label][unknown_class] = -1
124
- else:
125
- self.encoders = {}
204
+ self._cache_cats: dict = {}
205
+ self.encoders: dict = {}
206
+ if self.encode_labels:
207
+ self._make_encoders(self.encode_labels) # type: ignore
208
+
209
+ self._dtype = dtype
126
210
  self._closed = False
127
211
 
128
212
  def _make_connections(self, path_list: list, parallel: bool):
@@ -138,31 +222,63 @@ class MappedDataset:
138
222
  self.conns.append(conn)
139
223
  self.storages.append(storage)
140
224
 
225
+ def _cache_categories(self, label_keys: list):
226
+ self._cache_cats = {}
227
+ decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
228
+ for label in label_keys:
229
+ self._cache_cats[label] = []
230
+ for storage in self.storages:
231
+ with _Connect(storage) as store:
232
+ cats = self._get_categories(store, label)
233
+ if cats is not None:
234
+ cats = decode(cats) if isinstance(cats[0], bytes) else cats[...]
235
+ self._cache_cats[label].append(cats)
236
+
237
+ def _make_encoders(self, encode_labels: list):
238
+ for label in encode_labels:
239
+ cats = self.get_merged_categories(label)
240
+ encoder = {}
241
+ if isinstance(self.unknown_label, dict):
242
+ unknown_label = self.unknown_label.get(label, None)
243
+ else:
244
+ unknown_label = self.unknown_label
245
+ if unknown_label is not None and unknown_label in cats:
246
+ cats.remove(unknown_label)
247
+ encoder[unknown_label] = -1
248
+ cats = list(cats)
249
+ cats.sort()
250
+ encoder.update({cat: i for i, cat in enumerate(cats)})
251
+ self.encoders[label] = encoder
252
+
141
253
  def _make_join_vars(self):
142
254
  var_list = []
143
255
  for storage in self.storages:
144
256
  with _Connect(storage) as store:
145
257
  var_list.append(_safer_read_index(store["var"]))
146
- if self.join_vars == "auto":
147
- vars_eq = all([var_list[0].equals(vrs) for vrs in var_list[1:]])
148
- if vars_eq:
149
- self.join_vars = None
150
- return
151
- else:
152
- self.join_vars = "inner"
258
+
259
+ self.var_joint = None
260
+ vars_eq = all(var_list[0].equals(vrs) for vrs in var_list[1:])
261
+ if vars_eq:
262
+ self.join_vars = None
263
+ self.var_joint = var_list[0]
264
+ return
153
265
  if self.join_vars == "inner":
154
266
  self.var_joint = reduce(pd.Index.intersection, var_list)
155
267
  if len(self.var_joint) == 0:
156
268
  raise ValueError(
157
- "The provided AnnData objects don't have shared varibales."
269
+ "The provided AnnData objects don't have shared varibales.\n"
270
+ "Use join='outer'."
158
271
  )
159
272
  self.var_indices = [vrs.get_indexer(self.var_joint) for vrs in var_list]
273
+ elif self.join_vars == "outer":
274
+ self.var_joint = reduce(pd.Index.union, var_list)
275
+ self.var_indices = [self.var_joint.get_indexer(vrs) for vrs in var_list]
160
276
 
161
277
  def _check_aligned_vars(self, vars: list):
162
278
  i = 0
163
279
  for storage in self.storages:
164
280
  with _Connect(storage) as store:
165
- if vars == _safer_read_index(store["var"]).tolist():
281
+ if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
166
282
  i += 1
167
283
  print("{}% are aligned".format(i * 100 / len(self.storages)))
168
284
 
@@ -173,49 +289,75 @@ class MappedDataset:
173
289
  obs_idx = self.indices[idx]
174
290
  storage_idx = self.storage_idx[idx]
175
291
  if self.var_indices is not None:
176
- var_idxs = self.var_indices[storage_idx]
292
+ var_idxs_join = self.var_indices[storage_idx]
177
293
  else:
178
- var_idxs = None
294
+ var_idxs_join = None
295
+
179
296
  with _Connect(self.storages[storage_idx]) as store:
180
- out = {"x": self.get_data_idx(store, obs_idx, var_idxs)}
297
+ out = {"x": self._get_data_idx(store, obs_idx, var_idxs_join)}
298
+ out["_storage_idx"] = storage_idx
181
299
  if self.label_keys is not None:
182
- for i, label in enumerate(self.label_keys):
183
- label_idx = self.get_label_idx(store, obs_idx, label)
184
- if label in self.encoders:
185
- out.update({label: self.encoders[label][label_idx]})
300
+ for label in self.label_keys:
301
+ if label in self._cache_cats:
302
+ cats = self._cache_cats[label][storage_idx]
303
+ if cats is None:
304
+ cats = []
186
305
  else:
187
- out.update({label: label_idx})
306
+ cats = None
307
+ label_idx = self._get_label_idx(store, obs_idx, label, cats)
308
+ if label in self.encoders:
309
+ label_idx = self.encoders[label][label_idx]
310
+ out[label] = label_idx
188
311
  return out
189
312
 
190
- def uns(self, idx, key):
191
- storage = self.storages[self.storage_idx[idx]]
192
- return storage["uns"][key]
193
-
194
- def get_data_idx(
313
+ def _get_data_idx(
195
314
  self,
196
- storage: StorageType,
315
+ storage: StorageType, # type: ignore
197
316
  idx: int,
198
- var_idxs: Optional[list] = None,
199
- layer_key: Optional[str] = None, # type: ignore # noqa
317
+ var_idxs_join: list | None = None,
318
+ layer_key: str | None = None,
200
319
  ):
201
320
  """Get the index for the data."""
202
- layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore # noqa
321
+ layer = storage["X"] if layer_key is None else storage["layers"][layer_key] # type: ignore
203
322
  if isinstance(layer, ArrayTypes): # type: ignore
204
- # todo: better way to select variables
205
-
206
- return layer[idx] if var_idxs is None else layer[idx][var_idxs]
323
+ layer_idx = layer[idx]
324
+ if self.join_vars is None:
325
+ result = layer_idx
326
+ if self._dtype is not None:
327
+ result = result.astype(self._dtype, copy=False)
328
+ elif self.join_vars == "outer":
329
+ dtype = layer_idx.dtype if self._dtype is None else self._dtype
330
+ result = np.zeros(len(self.var_joint), dtype=dtype)
331
+ result[var_idxs_join] = layer_idx
332
+ else: # inner join
333
+ result = layer_idx[var_idxs_join]
334
+ if self._dtype is not None:
335
+ result = result.astype(self._dtype, copy=False)
336
+ return result
207
337
  else: # assume csr_matrix here
208
338
  data = layer["data"]
209
339
  indices = layer["indices"]
210
340
  indptr = layer["indptr"]
211
341
  s = slice(*(indptr[idx : idx + 2]))
212
- # this requires more memory than csr_matrix when var_idxs is not None
213
- # but it is faster
214
- layer_idx = np.zeros(layer.attrs["shape"][1])
215
- layer_idx[indices[s]] = data[s]
216
- return layer_idx if var_idxs is None else layer_idx[var_idxs]
342
+ data_s = data[s]
343
+ dtype = data_s.dtype if self._dtype is None else self._dtype
344
+ if self.join_vars == "outer":
345
+ layer_idx = np.zeros(len(self.var_joint), dtype=dtype)
346
+ layer_idx[var_idxs_join[indices[s]]] = data_s
347
+ else:
348
+ layer_idx = np.zeros(layer.attrs["shape"][1], dtype=dtype)
349
+ layer_idx[indices[s]] = data_s
350
+ if self.join_vars == "inner":
351
+ layer_idx = layer_idx[var_idxs_join]
352
+ return layer_idx
217
353
 
218
- def get_label_idx(self, storage: StorageType, idx: int, label_key: str): # type: ignore # noqa
354
+ def _get_label_idx(
355
+ self,
356
+ storage: StorageType,
357
+ idx: int,
358
+ label_key: str,
359
+ categories: list | None = None,
360
+ ):
219
361
  """Get the index for the label by key."""
220
362
  obs = storage["obs"] # type: ignore
221
363
  # how backwards compatible do we want to be here actually?
@@ -227,39 +369,48 @@ class MappedDataset:
227
369
  label = labels[idx]
228
370
  else:
229
371
  label = labels["codes"][idx]
230
-
231
- cats = self.get_categories(storage, label_key)
232
- if cats is not None:
372
+ if categories is not None:
373
+ cats = categories
374
+ else:
375
+ cats = self._get_categories(storage, label_key)
376
+ if cats is not None and len(cats) > 0:
233
377
  label = cats[label]
234
378
  if isinstance(label, bytes):
235
379
  label = label.decode("utf-8")
236
380
  return label
237
381
 
238
- def get_label_weights(self, label_keys: Union[str, List[str]], scaler=10):
239
- """Get all weights for a given label key."""
240
- if type(label_keys) is not list:
382
+ def get_label_weights(self, label_keys: str | list[str], scaler: int = 10):
383
+ """Get all weights for the given label keys."""
384
+ if isinstance(label_keys, str):
241
385
  label_keys = [label_keys]
242
- for i, val in enumerate(label_keys):
243
- if val not in self.label_keys:
244
- raise ValueError(f"{val} is not a valid label key.")
245
- if i == 0:
246
- labels = self.get_merged_labels(val)
247
- else:
248
- labels += "_" + self.get_merged_labels(val).astype(str).astype("O")
386
+ labels_list = []
387
+ for label_key in label_keys:
388
+ labels_to_str = self.get_merged_labels(label_key).astype(str).astype("O")
389
+ labels_list.append(labels_to_str)
390
+ if len(labels_list) > 1:
391
+ labels = reduce(lambda a, b: a + b, labels_list)
392
+ else:
393
+ labels = labels_list[0]
394
+
249
395
  counter = Counter(labels) # type: ignore
250
- counter = np.array([counter[label] for label in labels])
396
+ rn = {n: i for i, n in enumerate(counter.keys())}
397
+ labels = np.array([rn[label] for label in labels])
398
+ counter = np.array(list(counter.values()))
251
399
  weights = scaler / (counter + scaler)
252
- return weights
400
+ return weights, labels
253
401
 
254
402
  def get_merged_labels(self, label_key: str):
255
- """Get merged labels."""
403
+ """Get merged labels for `label_key` from all `.obs`."""
256
404
  labels_merge = []
257
405
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
258
- for storage in self.storages:
406
+ for i, storage in enumerate(self.storages):
259
407
  with _Connect(storage) as store:
260
- codes = self.get_codes(store, label_key)
408
+ codes = self._get_codes(store, label_key)
261
409
  labels = decode(codes) if isinstance(codes[0], bytes) else codes
262
- cats = self.get_categories(store, label_key)
410
+ if label_key in self._cache_cats:
411
+ cats = self._cache_cats[label_key][i]
412
+ else:
413
+ cats = self._get_categories(store, label_key)
263
414
  if cats is not None:
264
415
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
265
416
  labels = cats[labels]
@@ -267,22 +418,25 @@ class MappedDataset:
267
418
  return np.hstack(labels_merge)
268
419
 
269
420
  def get_merged_categories(self, label_key: str):
270
- """Get merged categories."""
421
+ """Get merged categories for `label_key` from all `.obs`."""
271
422
  cats_merge = set()
272
423
  decode = np.frompyfunc(lambda x: x.decode("utf-8"), 1, 1)
273
- for storage in self.storages:
424
+ for i, storage in enumerate(self.storages):
274
425
  with _Connect(storage) as store:
275
- cats = self.get_categories(store, label_key)
426
+ if label_key in self._cache_cats:
427
+ cats = self._cache_cats[label_key][i]
428
+ else:
429
+ cats = self._get_categories(store, label_key)
276
430
  if cats is not None:
277
431
  cats = decode(cats) if isinstance(cats[0], bytes) else cats
278
432
  cats_merge.update(cats)
279
433
  else:
280
- codes = self.get_codes(store, label_key)
434
+ codes = self._get_codes(store, label_key)
281
435
  codes = decode(codes) if isinstance(codes[0], bytes) else codes
282
436
  cats_merge.update(codes)
283
437
  return cats_merge
284
438
 
285
- def get_categories(self, storage: StorageType, label_key: str): # type: ignore
439
+ def _get_categories(self, storage: StorageType, label_key: str): # type: ignore
286
440
  """Get categories."""
287
441
  obs = storage["obs"] # type: ignore
288
442
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -309,8 +463,9 @@ class MappedDataset:
309
463
  return labels.attrs["categories"]
310
464
  else:
311
465
  return None
466
+ return None
312
467
 
313
- def get_codes(self, storage: StorageType, label_key: str): # type: ignore
468
+ def _get_codes(self, storage: StorageType, label_key: str): # type: ignore
314
469
  """Get codes."""
315
470
  obs = storage["obs"] # type: ignore
316
471
  if isinstance(obs, ArrayTypes): # type: ignore
@@ -323,7 +478,10 @@ class MappedDataset:
323
478
  return label["codes"][...]
324
479
 
325
480
  def close(self):
326
- """Close connection to array streaming backend."""
481
+ """Close connections to array streaming backend.
482
+
483
+ No effect if `parallel=True`.
484
+ """
327
485
  for storage in self.storages:
328
486
  if hasattr(storage, "close"):
329
487
  storage.close()
@@ -334,6 +492,10 @@ class MappedDataset:
334
492
 
335
493
  @property
336
494
  def closed(self):
495
+ """Check if connections to array streaming backend are closed.
496
+
497
+ Does not matter if `parallel=True`.
498
+ """
337
499
  return self._closed
338
500
 
339
501
  def __enter__(self):
@@ -341,3 +503,17 @@ class MappedDataset:
341
503
 
342
504
  def __exit__(self, exc_type, exc_val, exc_tb):
343
505
  self.close()
506
+
507
+ @staticmethod
508
+ def torch_worker_init_fn(worker_id):
509
+ """`worker_init_fn` for `torch.utils.data.DataLoader`.
510
+
511
+ Improves performance for `num_workers > 1`.
512
+ """
513
+ from torch.utils.data import get_worker_info
514
+
515
+ mapped = get_worker_info().dataset
516
+ mapped.parallel = False
517
+ mapped.storages = []
518
+ mapped.conns = []
519
+ mapped._make_connections(mapped._path_list, parallel=False)