scdataloader 1.9.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/data.py CHANGED
@@ -16,7 +16,6 @@ from torch.utils.data import Dataset as torchDataset
16
16
 
17
17
  from scdataloader.utils import get_ancestry_mapping, load_genes
18
18
 
19
- from .config import LABELS_TOADD
20
19
  from .mapped import MappedCollection, _Connect
21
20
 
22
21
 
@@ -39,19 +38,18 @@ class Dataset(torchDataset):
39
38
  ----
40
39
  lamin_dataset (lamindb.Dataset): lamin dataset to load
41
40
  genedf (pd.Dataframe): dataframe containing the genes to load
42
- organisms (list[str]): list of organisms to load
43
- (for now only validates the the genes map to this organism)
44
41
  obs (list[str]): list of observations to load from the Collection
45
42
  clss_to_predict (list[str]): list of observations to encode
46
43
  join_vars (flag): join variables @see :meth:`~lamindb.Dataset.mapped`.
47
44
  hierarchical_clss: list of observations to map to a hierarchy using lamin's bionty
45
+ metacell_mode (float, optional): The mode to use for metacell sampling. Defaults to 0.0.
46
+ get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each cell. Defaults to False.
47
+ store_location (str, optional): The location to store the sampler indices. Defaults to None.
48
+ force_recompute_indices (bool, optional): Whether to force recompute the sampler indices. Defaults to False.
48
49
  """
49
50
 
50
51
  lamin_dataset: ln.Collection
51
52
  genedf: Optional[pd.DataFrame] = None
52
- organisms: Optional[Union[list[str], str]] = field(
53
- default_factory=["NCBITaxon:9606", "NCBITaxon:10090"]
54
- )
55
53
  # set of obs to prepare for prediction (encode)
56
54
  clss_to_predict: Optional[list[str]] = field(default_factory=list)
57
55
  # set of obs that need to be hierarchically prepared
@@ -59,6 +57,8 @@ class Dataset(torchDataset):
59
57
  join_vars: Literal["inner", "outer"] | None = None
60
58
  metacell_mode: float = 0.0
61
59
  get_knn_cells: bool = False
60
+ store_location: str | None = None
61
+ force_recompute_indices: bool = False
62
62
 
63
63
  def __post_init__(self):
64
64
  self.mapped_dataset = mapped(
@@ -71,6 +71,8 @@ class Dataset(torchDataset):
71
71
  parallel=True,
72
72
  metacell_mode=self.metacell_mode,
73
73
  get_knn_cells=self.get_knn_cells,
74
+ store_location=self.store_location,
75
+ force_recompute_indices=self.force_recompute_indices,
74
76
  )
75
77
  print(
76
78
  "won't do any check but we recommend to have your dataset coming from local storage"
@@ -85,7 +87,7 @@ class Dataset(torchDataset):
85
87
  if clss not in self.hierarchical_clss:
86
88
  # otherwise it's already been done
87
89
  self.class_topred[clss] = set(
88
- self.mapped_dataset.get_merged_categories(clss)
90
+ self.mapped_dataset.encoders[clss].keys()
89
91
  )
90
92
  if (
91
93
  self.mapped_dataset.unknown_label
@@ -94,12 +96,19 @@ class Dataset(torchDataset):
94
96
  self.class_topred[clss] -= set(
95
97
  [self.mapped_dataset.unknown_label]
96
98
  )
97
-
98
99
  if self.genedf is None:
100
+ if "organism_ontology_term_id" not in self.clss_to_predict:
101
+ raise ValueError(
102
+ "need 'organism_ontology_term_id' in the set of classes if you don't provide a genedf"
103
+ )
104
+ self.organisms = list(self.class_topred["organism_ontology_term_id"])
105
+ self.organisms.sort()
99
106
  self.genedf = load_genes(self.organisms)
107
+ else:
108
+ self.organisms = None
100
109
 
101
110
  self.genedf.columns = self.genedf.columns.astype(str)
102
- self.check_aligned_vars()
111
+ # self.check_aligned_vars()
103
112
 
104
113
  def check_aligned_vars(self):
105
114
  vars = self.genedf.index.tolist()
@@ -117,6 +126,10 @@ class Dataset(torchDataset):
117
126
  def encoder(self):
118
127
  return self.mapped_dataset.encoders
119
128
 
129
+ @encoder.setter
130
+ def encoder(self, encoder):
131
+ self.mapped_dataset.encoders = encoder
132
+
120
133
  def __getitem__(self, *args, **kwargs):
121
134
  item = self.mapped_dataset.__getitem__(*args, **kwargs)
122
135
  return item
@@ -132,7 +145,11 @@ class Dataset(torchDataset):
132
145
  + " {} genes\n".format(self.genedf.shape[0])
133
146
  + " {} clss_to_predict\n".format(len(self.clss_to_predict))
134
147
  + " {} hierarchical_clss\n".format(len(self.hierarchical_clss))
135
- + " {} organisms\n".format(len(self.organisms))
148
+ + (
149
+ " {} organisms\n".format(len(self.organisms))
150
+ if self.organisms is not None
151
+ else ""
152
+ )
136
153
  + (
137
154
  "dataset contains {} classes to predict\n".format(
138
155
  sum([len(self.class_topred[i]) for i in self.class_topred])
@@ -148,31 +165,24 @@ class Dataset(torchDataset):
148
165
  obs_keys: str | list[str],
149
166
  scaler: int = 10,
150
167
  return_categories=False,
151
- bypass_label=["neuron"],
152
168
  ):
153
169
  """Get all weights for the given label keys."""
154
170
  if isinstance(obs_keys, str):
155
171
  obs_keys = [obs_keys]
156
- labels_list = []
172
+ labels = None
157
173
  for label_key in obs_keys:
158
- labels_to_str = (
159
- self.mapped_dataset.get_merged_labels(label_key).astype(str).astype("O")
160
- )
161
- labels_list.append(labels_to_str)
162
- if len(labels_list) > 1:
163
- labels = ["___".join(labels_obs) for labels_obs in zip(*labels_list)]
164
- else:
165
- labels = labels_list[0]
166
-
167
- counter = Counter(labels) # type: ignore
174
+ labels_to_str = self.mapped_dataset.get_merged_labels(label_key)
175
+ if labels is None:
176
+ labels = labels_to_str
177
+ else:
178
+ labels = concat_categorical_codes([labels, labels_to_str])
179
+ counter = Counter(labels.codes) # type: ignore
168
180
  if return_categories:
169
- rn = {n: i for i, n in enumerate(counter.keys())}
170
- labels = np.array([rn[label] for label in labels])
171
181
  counter = np.array(list(counter.values()))
172
182
  weights = scaler / (counter + scaler)
173
- return weights, labels
183
+ return weights, np.array(labels.codes)
174
184
  else:
175
- counts = np.array([counter[label] for label in labels])
185
+ counts = np.array([counter[label] for label in labels.codes])
176
186
  if scaler is None:
177
187
  weights = 1.0 / counts
178
188
  else:
@@ -267,12 +277,14 @@ class Dataset(torchDataset):
267
277
  clss
268
278
  )
269
279
  )
270
- cats = set(self.mapped_dataset.get_merged_categories(clss))
271
- addition = set(LABELS_TOADD.get(clss, {}).values())
272
- cats |= addition
280
+ cats = set(self.mapped_dataset.encoders[clss].keys())
273
281
  groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
274
282
  for i, j in groupings.items():
275
283
  if len(j) == 0:
284
+ # that should not happen
285
+ import pdb
286
+
287
+ pdb.set_trace()
276
288
  groupings.pop(i)
277
289
  self.labels_groupings[clss] = groupings
278
290
  if clss in self.clss_to_predict:
@@ -287,11 +299,12 @@ class Dataset(torchDataset):
287
299
  )
288
300
 
289
301
  for i, v in enumerate(
290
- addition - set(self.mapped_dataset.encoders[clss].keys())
302
+ set(groupings.keys())
303
+ - set(self.mapped_dataset.encoders[clss].keys())
291
304
  ):
292
305
  self.mapped_dataset.encoders[clss].update({v: mlength + i})
293
- # we need to change the ordering so that the things that can't be predicted appear afterward
294
306
 
307
+ # we need to change the ordering so that the things that can't be predicted appear afterward
295
308
  self.class_topred[clss] = leaf_labels
296
309
  c = 0
297
310
  update = {}
@@ -320,6 +333,7 @@ class SimpleAnnDataset(torchDataset):
320
333
  adata: AnnData,
321
334
  obs_to_output: Optional[list[str]] = [],
322
335
  layer: Optional[str] = None,
336
+ get_knn_cells: bool = False,
323
337
  ):
324
338
  """
325
339
  SimpleAnnDataset is a simple dataloader for an AnnData dataset. this is to interface nicely with the rest of
@@ -330,31 +344,48 @@ class SimpleAnnDataset(torchDataset):
330
344
  adata (anndata.AnnData): anndata object to use
331
345
  obs_to_output (list[str]): list of observations to output from anndata.obs
332
346
  layer (str): layer of the anndata to use
347
+ get_knn_cells (bool): whether to get the knn cells
333
348
  """
334
349
  self.adataX = adata.layers[layer] if layer is not None else adata.X
335
350
  self.adataX = self.adataX.toarray() if issparse(self.adataX) else self.adataX
351
+
336
352
  self.obs_to_output = adata.obs[obs_to_output]
353
+ self.get_knn_cells = get_knn_cells
354
+ if get_knn_cells and "connectivities" not in adata.obsp:
355
+ raise ValueError("neighbors key not found in adata.obsm")
356
+ if get_knn_cells:
357
+ self.distances = adata.obsp["distances"]
337
358
 
338
359
  def __len__(self):
339
360
  return self.adataX.shape[0]
340
361
 
341
362
  def __iter__(self):
342
- for idx, obs in enumerate(self.adata.obs.itertuples(index=False)):
343
- with warnings.catch_warnings():
344
- warnings.filterwarnings("ignore", category=DeprecationWarning)
345
- out = {"X": self.adataX[idx].reshape(-1)}
346
- out.update(
347
- {name: val for name, val in self.obs_to_output.iloc[idx].items()}
348
- )
349
- yield out
350
-
351
- def __getitem__(self, idx):
352
- with warnings.catch_warnings():
353
- warnings.filterwarnings("ignore", category=DeprecationWarning)
363
+ for idx in range(self.adataX.shape[0]):
354
364
  out = {"X": self.adataX[idx].reshape(-1)}
355
365
  out.update(
356
366
  {name: val for name, val in self.obs_to_output.iloc[idx].items()}
357
367
  )
368
+ if self.get_knn_cells:
369
+ distances = self.distances[idx].toarray()[0]
370
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
371
+ out["knn_cells"] = np.array(
372
+ [self.adataX[i].reshape(-1) for i in nn_idx],
373
+ dtype=int,
374
+ )
375
+ out["distances"] = distances[nn_idx]
376
+ yield out
377
+
378
+ def __getitem__(self, idx):
379
+ out = {"X": self.adataX[idx].reshape(-1)}
380
+ out.update({name: val for name, val in self.obs_to_output.iloc[idx].items()})
381
+ if self.get_knn_cells:
382
+ distances = self.distances[idx].toarray()[0]
383
+ nn_idx = np.argsort(-1 / (distances - 1e-6))[:6]
384
+ out["knn_cells"] = np.array(
385
+ [self.adataX[i].reshape(-1) for i in nn_idx],
386
+ dtype=int,
387
+ )
388
+ out["distances"] = distances[nn_idx]
358
389
  return out
359
390
 
360
391
 
@@ -374,6 +405,8 @@ def mapped(
374
405
  metacell_mode: bool = False,
375
406
  meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
376
407
  get_knn_cells: bool = False,
408
+ store_location: str | None = None,
409
+ force_recompute_indices: bool = False,
377
410
  ) -> MappedCollection:
378
411
  path_list = []
379
412
  for artifact in dataset.artifacts.all():
@@ -401,5 +434,45 @@ def mapped(
401
434
  meta_assays=meta_assays,
402
435
  metacell_mode=metacell_mode,
403
436
  get_knn_cells=get_knn_cells,
437
+ store_location=store_location,
438
+ force_recompute_indices=force_recompute_indices,
404
439
  )
405
440
  return ds
441
+
442
+
443
+ def concat_categorical_codes(series_list: list[pd.Categorical]) -> pd.Categorical:
444
+ """Efficiently combine multiple categorical data using their codes,
445
+ only creating categories for combinations that exist in the data.
446
+
447
+ Args:
448
+ series_list: List of pandas Categorical data
449
+
450
+ Returns:
451
+ Combined Categorical with only existing combinations
452
+ """
453
+ # Get the codes for each categorical
454
+ codes_list = [s.codes.astype(np.int32) for s in series_list]
455
+ n_cats = [len(s.categories) for s in series_list]
456
+
457
+ # Calculate combined codes
458
+ combined_codes = codes_list[0]
459
+ multiplier = n_cats[0]
460
+ for codes, n_cat in zip(codes_list[1:], n_cats[1:]):
461
+ combined_codes = (combined_codes * n_cat) + codes
462
+ multiplier *= n_cat
463
+
464
+ # Find unique combinations that actually exist in the data
465
+ unique_existing_codes = np.unique(combined_codes)
466
+
467
+ # Create a mapping from old codes to new compressed codes
468
+ code_mapping = {old: new for new, old in enumerate(unique_existing_codes)}
469
+
470
+ # Map the combined codes to their new compressed values
471
+ combined_codes = np.array([code_mapping[code] for code in combined_codes])
472
+
473
+ # Create final categorical with only existing combinations
474
+ return pd.Categorical.from_codes(
475
+ codes=combined_codes,
476
+ categories=np.arange(len(unique_existing_codes)),
477
+ ordered=False,
478
+ )