scdataloader 1.9.2__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -2,7 +2,7 @@ import warnings
2
2
  from collections import Counter
3
3
  from dataclasses import dataclass, field
4
4
  from functools import reduce
5
- from typing import Literal, Optional, Union
5
+ from typing import List, Literal, Optional, Union
6
6
 
7
7
  # ln.connect("scprint")
8
8
  import bionty as bt
@@ -16,7 +16,6 @@ from torch.utils.data import Dataset as torchDataset
16
16
 
17
17
  from scdataloader.utils import get_ancestry_mapping, load_genes
18
18
 
19
- from .config import LABELS_TOADD
20
19
  from .mapped import MappedCollection, _Connect
21
20
 
22
21
 
@@ -39,28 +38,30 @@ class Dataset(torchDataset):
39
38
  ----
40
39
  lamin_dataset (lamindb.Dataset): lamin dataset to load
41
40
  genedf (pd.Dataframe): dataframe containing the genes to load
42
- organisms (list[str]): list of organisms to load
43
- (for now only validates the the genes map to this organism)
44
- obs (list[str]): list of observations to load from the Collection
45
- clss_to_predict (list[str]): list of observations to encode
41
+ obs (List[str]): list of observations to load from the Collection
42
+ clss_to_predict (List[str]): list of observations to encode
46
43
  join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
47
44
  hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
45
+ metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
46
+ get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
47
+ store_location (str, optional): The location to store the sampler indices. Defaults to None.
48
+ force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
48
49
  """
49
50
 
50
51
  lamin_dataset: ln.Collection
51
52
  genedf: Optional[pd.DataFrame] = None
52
- organisms: Optional[Union[list[str], str]] = field(
53
- default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
54
- )
55
53
  # set of obs to prepare for prediction (encode)
56
- clss_to_predict: Optional[list[str]] = field(default_factory=list)
54
+ clss_to_predict: Optional[List[str]] = field(default_factory=list)
57
55
  # set of obs that need to be hierarchically prepared
58
- hierarchical_clss: Optional[list[str]] = field(default_factory=list)
56
+ hierarchical_clss: Optional[List[str]] = field(default_factory=list)
59
57
  join_vars: Literal["inner", "outer"] | None = None
60
58
  metacell_mode: float = 0.0
61
59
  get_knn_cells: bool = False
60
+ store_location: str | None = None
61
+ force_recompute_indices: bool = False
62
62
 
63
63
  def __post_init__(self):
64
+ # see at the end of the file for the mapped function
64
65
  self.mapped_dataset = mapped(
65
66
  self.lamin_dataset,
66
67
  obs_keys=list(set(self.hierarchical_clss + self.clss_to_predict)),
@@ -71,6 +72,8 @@ class Dataset(torchDataset):
71
72
  parallel=True,
72
73
  metacell_mode=self.metacell_mode,
73
74
  get_knn_cells=self.get_knn_cells,
75
+ store_location=self.store_location,
76
+ force_recompute_indices=self.force_recompute_indices,
74
77
  )
75
78
  print(
76
79
  "won't do any check but we recommend to have your dataset coming from local storage"
@@ -85,7 +88,7 @@ class Dataset(torchDataset):
85
88
  if clss not in self.hierarchical_clss:
86
89
  # otherwise it's already been done
87
90
  self.class_topred[clss] = set(
88
- self.mapped_dataset.get_merged_categories(clss)
91
+ self.mapped_dataset.encoders[clss].keys()
89
92
  )
90
93
  if (
91
94
  self.mapped_dataset.unknown_label
@@ -94,12 +97,19 @@ class Dataset(torchDataset):
94
97
  self.class_topred[clss] -= set(
95
98
  [self.mapped_dataset.unknown_label]
96
99
  )
97
-
98
100
  if self.genedf is None:
101
+ if "organism_ontology_term_id" not in self.clss_to_predict:
102
+ raise ValueError(
103
+ "need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
104
+ )
105
+ self.organisms = list(self.class_topred["organism_ontology_term_id"])
99
106
  self.genedf = load_genes(self.organisms)
107
+ else:
108
+ self.organisms = self.genedf["organism"].unique().tolist()
109
+ self.organisms.sort()
100
110
 
101
111
  self.genedf.columns = self.genedf.columns.astype(str)
102
- self.check_aligned_vars()
112
+ # self.check_aligned_vars()
103
113
 
104
114
  def check_aligned_vars(self):
105
115
  vars = self.genedf.index.tolist()
@@ -117,6 +127,10 @@ class Dataset(torchDataset):
117
127
  def encoder(self):
118
128
  return self.mapped_dataset.encoders
119
129
 
130
+ @encoder.setter
131
+ def encoder(self, encoder):
132
+ self.mapped_dataset.encoders = encoder
133
+
120
134
  def __getitem__(self, *args, **kwargs):
121
135
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
122
136
  return item
@@ -132,7 +146,11 @@ class Dataset(torchDataset):
132
146
  + " {} genes\n".format(self.genedf.shape[0])
133
147
  + " {} clss_to_predict\n".format(len(self.clss_to_predict))
134
148
  + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
135
- + " {} organisms\n".format(len(self.organisms))
149
+ + (
150
+ " {} organisms\n".format(len(self.organisms))
151
+ if self.organisms is not None
152
+ else ""
153
+ )
136
154
  + (
137
155
  "dataset contains {} classes to predict\n".format(
138
156
  sum([len(self.class_topred[i]) for i in self.class_topred])
@@ -143,41 +161,21 @@ class Dataset(torchDataset):
143
161
  + " {} metacell_mode\n".format(self.metacell_mode)
144
162
  )
145
163
 
146
- def get_label_weights(
164
+ def get_label_cats(
147
165
  self,
148
- obs_keys: str | list[str],
149
- scaler: int = 10,
150
- return_categories=False,
151
- bypass_label=["neuron"],
166
+ obs_keys: Union[str, List[str]],
152
167
  ):
153
- """Get all weights for the given label keys."""
168
+ """Get all categories for the given label keys."""
154
169
  if isinstance(obs_keys, str):
155
170
  obs_keys = [obs_keys]
156
- labels_list = []
171
+ labels = None
157
172
  for label_key in obs_keys:
158
- labels_to_str = (
159
- self.mapped_dataset.get_merged_labels(label_key).astype(str).astype("O")
160
- )
161
- labels_list.append(labels_to_str)
162
- if len(labels_list) > 1:
163
- labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
164
- else:
165
- labels = labels_list[0]
166
-
167
- counter = Counter(labels) # type: ignore
168
- if return_categories:
169
- rn = {n: i for i, n in enumerate(counter.keys())}
170
- labels = np.array([rn[label] for label in labels])
171
- counter = np.array(list(counter.values()))
172
- weights = scaler / (counter + scaler)
173
- return weights, labels
174
- else:
175
- counts = np.array([counter[label] for label in labels])
176
- if scaler is None:
177
- weights = 1.0 / counts
173
+ labels_to_str = self.mapped_dataset.get_merged_labels(label_key)
174
+ if labels is None:
175
+ labels = labels_to_str
178
176
  else:
179
- weights = scaler / (counts + scaler)
180
- return weights
177
+ labels = concat_categorical_codes([labels, labels_to_str])
178
+ return np.array(labels.codes)
181
179
 
182
180
  def get_unseen_mapped_dataset_elements(self, idx: int):
183
181
  """
@@ -187,16 +185,16 @@ class Dataset(torchDataset):
187
185
  idx (int): index of the element to get
188
186
 
189
187
  Returns:
190
- list[str]: list of unseen genes
188
+ List[str]: list of unseen genes
191
189
  """
192
190
  return [str(i)[2:-1] for i in self.mapped_dataset.uns(idx, "unseen_genes")]
193
191
 
194
- def define_hierarchies(self, clsses: list[str]):
192
+ def define_hierarchies(self, clsses: List[str]):
195
193
  """
196
194
  define_hierarchies is a method to define the hierarchies for the classes to predict
197
195
 
198
196
  Args:
199
- clsses (list[str]): list of classes to predict
197
+ clsses (List[str]): list of classes to predict
200
198
 
201
199
  Raises:
202
200
  ValueError: if the class is not in the accepted classes
@@ -223,19 +221,19 @@ class Dataset(torchDataset):
223
221
  elif clss == "cell_type_ontology_term_id":
224
222
  parentdf = (
225
223
  bt.CellType.filter()
226
- .df(include=["parents__ontology_id"])
224
+ .df(include=["parents__ontology_id", "ontology_id"])
227
225
  .set_index("ontology_id")
228
226
  )
229
227
  elif clss == "tissue_ontology_term_id":
230
228
  parentdf = (
231
229
  bt.Tissue.filter()
232
- .df(include=["parents__ontology_id"])
230
+ .df(include=["parents__ontology_id", "ontology_id"])
233
231
  .set_index("ontology_id")
234
232
  )
235
233
  elif clss == "disease_ontology_term_id":
236
234
  parentdf = (
237
235
  bt.Disease.filter()
238
- .df(include=["parents__ontology_id"])
236
+ .df(include=["parents__ontology_id", "ontology_id"])
239
237
  .set_index("ontology_id")
240
238
  )
241
239
  elif clss in [
@@ -245,19 +243,19 @@ class Dataset(torchDataset):
245
243
  ]:
246
244
  parentdf = (
247
245
  bt.DevelopmentalStage.filter()
248
- .df(include=["parents__ontology_id"])
246
+ .df(include=["parents__ontology_id", "ontology_id"])
249
247
  .set_index("ontology_id")
250
248
  )
251
249
  elif clss == "assay_ontology_term_id":
252
250
  parentdf = (
253
251
  bt.ExperimentalFactor.filter()
254
- .df(include=["parents__ontology_id"])
252
+ .df(include=["parents__ontology_id", "ontology_id"])
255
253
  .set_index("ontology_id")
256
254
  )
257
255
  elif clss == "self_reported_ethnicity_ontology_term_id":
258
256
  parentdf = (
259
257
  bt.Ethnicity.filter()
260
- .df(include=["parents__ontology_id"])
258
+ .df(include=["parents__ontology_id", "ontology_id"])
261
259
  .set_index("ontology_id")
262
260
  )
263
261
 
@@ -267,13 +265,17 @@ class Dataset(torchDataset):
267
265
  clss
268
266
  )
269
267
  )
270
- cats = set(self.mapped_dataset.get_merged_categories(clss))
271
- addition = set(LABELS_TOADD.get(clss, {}).values())
272
- cats |= addition
268
+ cats = set(self.mapped_dataset.encoders[clss].keys())
273
269
  groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
270
+ groupings.pop(None, None)
274
271
  for i, j in groupings.items():
275
272
  if len(j) == 0:
273
+ # that should not happen
274
+ import pdb
275
+
276
+ pdb.set_trace()
276
277
  groupings.pop(i)
278
+
277
279
  self.labels_groupings[clss] = groupings
278
280
  if clss in self.clss_to_predict:
279
281
  # if we have added new clss, we need to update the encoder with them too.
@@ -287,11 +289,12 @@ class Dataset(torchDataset):
287
289
  )
288
290
 
289
291
  for i, v in enumerate(
290
- addition - set(self.mapped_dataset.encoders[clss].keys())
292
+ set(groupings.keys())
293
+ - set(self.mapped_dataset.encoders[clss].keys())
291
294
  ):
292
295
  self.mapped_dataset.encoders[clss].update({v: mlength + i})
293
- # we need to change the ordering so that the things that can't be predicted appear afterward
294
296
 
297
+ # we need to change the ordering so that the things that can't be predicted appear afterward
295
298
  self.class_topred[clss] = leaf_labels
296
299
  c = 0
297
300
  update = {}
@@ -318,8 +321,10 @@ class SimpleAnnDataset(torchDataset):
318
321
  def __init__(
319
322
  self,
320
323
  adata: AnnData,
321
- obs_to_output: Optional[list[str]] = [],
324
+ obs_to_output: Optional[List[str]] = [],
322
325
  layer: Optional[str] = None,
326
+ get_knn_cells: bool = False,
327
+ encoder: Optional[dict[str, dict]] = None,
323
328
  ):
324
329
  """
325
330
  SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
@@ -328,43 +333,53 @@ class SimpleAnnDataset(torchDataset):
328
333
  Args:
329
334
  ----
330
335
  adata (anndata.AnnData): anndata object to use
331
- obs_to_output (list[str]): list of observations to output from anndata.obs
336
+ obs_to_output (List[str]): list of observations to output from anndata.obs
332
337
  layer (str): layer of the anndata to use
338
+ get_knn_cells (bool): whether to get the knn cells
339
+ encoder (dict[str, dict]): dictionary of encoders for the observations.
333
340
  """
334
341
  self.adataX = adata.layers[layer] if layer is not None else adata.X
335
342
  self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
343
+ self.encoder = encoder if encoder is not None else {}
344
+
336
345
  self.obs_to_output = adata.obs[obs_to_output]
346
+ self.get_knn_cells = get_knn_cells
347
+ if get_knn_cells and "connectivities" not in adata.obsp:
348
+ raise ValueError("neighbors key not found in adata.obsm")
349
+ if get_knn_cells:
350
+ self.distances = adata.obsp["distances"]
337
351
 
338
352
  def __len__(self):
339
353
  return self.adataX.shape[0]
340
354
 
341
355
  def __iter__(self):
342
- for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
343
- with warnings.catch_warnings():
344
- warnings.filterwarnings("ignore", category=DeprecationWarning)
345
- out = {"X": self.adataX[idx].reshape(-1)}
346
- out.update(
347
- {name: val for name, val in self.obs_to_output.iloc[idx].items()}
348
- )
349
- yield out
356
+ for idx in range(self.adataX.shape[0]):
357
+ out = self.__getitem__(idx)
358
+ yield out
350
359
 
351
360
  def __getitem__(self, idx):
352
- with warnings.catch_warnings():
353
- warnings.filterwarnings("ignore", category=DeprecationWarning)
354
- out = {"X": self.adataX[idx].reshape(-1)}
355
- out.update(
356
- {name: val for name, val in self.obs_to_output.iloc[idx].items()}
361
+ out = {"X": self.adataX[idx].reshape(-1)}
362
+ # put the observation into the output and encode if needed
363
+ for name, val in self.obs_to_output.iloc[idx].items():
364
+ out.update({name: self.encoder[name][val] if name in self.encoder else val})
365
+ if self.get_knn_cells:
366
+ distances = self.distances[idx].toarray()[0]
367
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
368
+ out["knn_cells"] = np.array(
369
+ [self.adataX[i].reshape(-1) for i in nn_idx],
370
+ dtype=int,
357
371
  )
372
+ out["knn_cells_info"] = distances[nn_idx]
358
373
  return out
359
374
 
360
375
 
361
376
  def mapped(
362
377
  dataset,
363
- obs_keys: list[str] | None = None,
364
- obsm_keys: list[str] | None = None,
378
+ obs_keys: List[str] | None = None,
379
+ obsm_keys: List[str] | None = None,
365
380
  obs_filter: dict[str, str | tuple[str, ...]] | None = None,
366
381
  join: Literal["inner", "outer"] | None = "inner",
367
- encode_labels: bool | list[str] = True,
382
+ encode_labels: bool | List[str] = True,
368
383
  unknown_label: str | dict[str, str] | None = None,
369
384
  cache_categories: bool = True,
370
385
  parallel: bool = False,
@@ -372,8 +387,10 @@ def mapped(
372
387
  stream: bool = False,
373
388
  is_run_input: bool | None = None,
374
389
  metacell_mode: bool = False,
375
- meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
390
+ meta_assays: List[str] = ["EFO:0022857", "EFO:0010961"],
376
391
  get_knn_cells: bool = False,
392
+ store_location: str | None = None,
393
+ force_recompute_indices: bool = False,
377
394
  ) -> MappedCollection:
378
395
  path_list = []
379
396
  for artifact in dataset.artifacts.all():
@@ -401,5 +418,45 @@ def mapped(
401
418
  meta_assays=meta_assays,
402
419
  metacell_mode=metacell_mode,
403
420
  get_knn_cells=get_knn_cells,
421
+ store_location=store_location,
422
+ force_recompute_indices=force_recompute_indices,
404
423
  )
405
424
  return ds
425
+
426
+
427
+ def concat_categorical_codes(series_list: List[pd.Categorical]) -> pd.Categorical:
428
+ """Efficiently combine multiple categorical data using their codes,
429
+ only creating categories for combinations that exist in the data.
430
+
431
+ Args:
432
+ series_list: List of pandas Categorical data
433
+
434
+ Returns:
435
+ Combined Categorical with only existing combinations
436
+ """
437
+ # Get the codes for each categorical
438
+ codes_list = [s.codes.astype(np.int32) for s in series_list]
439
+ n_cats = [len(s.categories) for s in series_list]
440
+
441
+ # Calculate combined codes
442
+ combined_codes = codes_list[0]
443
+ multiplier = n_cats[0]
444
+ for codes, n_cat in zip(codes_list[1:], n_cats[1:]):
445
+ combined_codes = (combined_codes * n_cat) + codes
446
+ multiplier *= n_cat
447
+
448
+ # Find unique combinations that actually exist in the data
449
+ unique_existing_codes = np.unique(combined_codes)
450
+
451
+ # Create a mapping from old codes to new compressed codes
452
+ code_mapping = {old: new for new, old in enumerate(unique_existing_codes)}
453
+
454
+ # Map the combined codes to their new compressed values
455
+ combined_codes = np.array([code_mapping[code] for code in combined_codes])
456
+
457
+ # Create final categorical with only existing combinations
458
+ return pd.Categorical.from_codes(
459
+ codes=combined_codes,
460
+ categories=np.arange(len(unique_existing_codes)),
461
+ ordered=False,
462
+ )