scdataloader 1.9.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -7,12 +7,14 @@
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
+ import os
10
11
  from collections import Counter
11
12
  from functools import reduce
12
13
  from typing import TYPE_CHECKING, Literal
13
14
 
14
15
  import numpy as np
15
16
  import pandas as pd
17
+ import torch
16
18
  from lamindb.core.storage._anndata_accessor import (
17
19
  ArrayType,
18
20
  ArrayTypes,
@@ -24,10 +26,13 @@ from lamindb.core.storage._anndata_accessor import (
24
26
  registry,
25
27
  )
26
28
  from lamindb_setup.core.upath import UPath
29
+ from tqdm import tqdm
27
30
 
28
31
  if TYPE_CHECKING:
29
32
  from lamindb_setup.core.types import UPathStr
30
33
 
34
+ from pandas.api.types import union_categoricals
35
+
31
36
 
32
37
  class _Connect:
33
38
  def __init__(self, storage):
@@ -106,6 +111,8 @@ class MappedCollection:
106
111
  meta_assays: Assays that are already defined as metacells.
107
112
  metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
108
113
  get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
114
+ store_location: Path to a directory where klass_indices can be cached, or full path to the cache file.
115
+ force_recompute_indices: If True, recompute indices even if a cache file exists.
109
116
  """
110
117
 
111
118
  def __init__(
@@ -124,6 +131,8 @@ class MappedCollection:
124
131
  metacell_mode: float = 0.0,
125
132
  get_knn_cells: bool = False,
126
133
  meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
134
+ store_location: str | None = None,
135
+ force_recompute_indices: bool = False,
127
136
  ):
128
137
  if join not in {None, "inner", "outer"}: # pragma: nocover
129
138
  raise ValueError(
@@ -181,14 +190,26 @@ class MappedCollection:
181
190
  self._cache_cats: dict = {}
182
191
  if self.obs_keys is not None:
183
192
  if cache_categories:
184
- self._cache_categories(self.obs_keys)
193
+ if store_location is not None:
194
+ os.makedirs(store_location, exist_ok=True)
195
+ self.store_location = os.path.join(store_location, "categories")
196
+ if (
197
+ not os.path.exists(self.store_location)
198
+ or force_recompute_indices
199
+ ):
200
+ self._cache_categories(self.obs_keys)
201
+ torch.save(self._cache_cats, self.store_location)
202
+ else:
203
+ self._cache_cats = torch.load(self.store_location)
204
+ print(f"Loaded categories from {self.store_location}")
185
205
  self.encoders: dict = {}
186
206
  if self.encode_labels:
187
207
  self._make_encoders(self.encode_labels) # type: ignore
188
-
189
208
  self.n_obs_list = []
190
209
  self.indices_list = []
191
- for i, storage in enumerate(self.storages):
210
+ for i, storage in tqdm(
211
+ enumerate(self.storages), desc="Checking datasets", total=len(self.storages)
212
+ ):
192
213
  with _Connect(storage) as store:
193
214
  X = store["X"]
194
215
  store_path = self.path_list[i]
@@ -263,13 +284,10 @@ class MappedCollection:
263
284
  self._cache_cats = {}
264
285
  for label in obs_keys:
265
286
  self._cache_cats[label] = []
266
- for storage in self.storages:
287
+ for storage in tqdm(self.storages, f"caching categories, {label}"):
267
288
  with _Connect(storage) as store:
268
289
  cats = self._get_categories(store, label)
269
- if cats is not None:
270
- cats = (
271
- _decode(cats) if isinstance(cats[0], bytes) else cats[...]
272
- )
290
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats[...]
273
291
  self._cache_cats[label].append(cats)
274
292
 
275
293
  def _make_encoders(self, encode_labels: list):
@@ -403,8 +421,21 @@ class MappedCollection:
403
421
  cats = None
404
422
  label_idx = self._get_obs_idx(store, obs_idx, label, cats)
405
423
  if label in self.encoders:
406
- label_idx = self.encoders[label][label_idx]
407
- out[label] = label_idx
424
+ try:
425
+ label_idx = self.encoders[label][label_idx]
426
+ except:
427
+ print(self.storages[storage_idx])
428
+ print(label, label_idx)
429
+ print(idx)
430
+ print(cats)
431
+ raise
432
+ try:
433
+ out[label] = label_idx
434
+ except:
435
+ print(self.storages[storage_idx])
436
+ print(label, label_idx)
437
+ print(out)
438
+ raise
408
439
 
409
440
  if self.metacell_mode > 0:
410
441
  if (
@@ -555,21 +586,41 @@ class MappedCollection:
555
586
  weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
556
587
  return weights
557
588
 
558
- def get_merged_labels(self, label_key: str):
589
+ def get_merged_labels(self, label_key: str, is_cat: bool = True):
559
590
  """Get merged labels for `label_key` from all `.obs`."""
560
591
  labels_merge = []
561
- for i, storage in enumerate(self.storages):
592
+ for i, storage in tqdm(
593
+ enumerate(self.storages), label_key, total=len(self.storages)
594
+ ):
562
595
  with _Connect(storage) as store:
563
- labels = self._get_labels(store, label_key, storage_idx=i)
596
+ labels = self._get_labels(
597
+ store, label_key, storage_idx=i, is_cat=is_cat
598
+ )
564
599
  if self.filtered:
565
600
  labels = labels[self.indices_list[i]]
566
601
  labels_merge.append(labels)
567
- return np.hstack(labels_merge)
602
+ if is_cat:
603
+ try:
604
+ return union_categoricals(labels_merge)
605
+ except TypeError:
606
+ typ = type(int)
607
+ for i in range(len(labels_merge)):
608
+ if typ != type(labels_merge[i][0]):
609
+ self.storages[i]
610
+ typ = type(labels_merge[i][0])
611
+ return []
612
+ else:
613
+ print("concatenating labels")
614
+ return np.concatenate(labels_merge)
568
615
 
569
616
  def get_merged_categories(self, label_key: str):
570
617
  """Get merged categories for `label_key` from all `.obs`."""
571
618
  cats_merge = set()
572
- for i, storage in enumerate(self.storages):
619
+ for i, storage in tqdm(
620
+ enumerate(self.storages),
621
+ total=len(self.storages),
622
+ desc="merging all " + label_key + " categories",
623
+ ):
573
624
  with _Connect(storage) as store:
574
625
  if label_key in self._cache_cats:
575
626
  cats = self._cache_cats[label_key][i]
@@ -609,8 +660,8 @@ class MappedCollection:
609
660
  else:
610
661
  if "categories" in labels.attrs:
611
662
  return labels.attrs["categories"]
612
- else:
613
- return None
663
+ elif labels.dtype == "bool":
664
+ return np.array(["True", "False"])
614
665
  return None
615
666
 
616
667
  def _get_codes(self, storage: StorageType, label_key: str):
@@ -626,11 +677,17 @@ class MappedCollection:
626
677
  return label["codes"][...]
627
678
 
628
679
  def _get_labels(
629
- self, storage: StorageType, label_key: str, storage_idx: int | None = None
680
+ self,
681
+ storage: StorageType,
682
+ label_key: str,
683
+ storage_idx: int | None = None,
684
+ is_cat: bool = True,
630
685
  ):
631
686
  """Get labels."""
632
687
  codes = self._get_codes(storage, label_key)
633
688
  labels = _decode(codes) if isinstance(codes[0], bytes) else codes
689
+ if labels.dtype == bool:
690
+ labels = labels.astype(int)
634
691
  if storage_idx is not None and label_key in self._cache_cats:
635
692
  cats = self._cache_cats[label_key][storage_idx]
636
693
  else:
@@ -638,6 +695,8 @@ class MappedCollection:
638
695
  if cats is not None:
639
696
  cats = _decode(cats) if isinstance(cats[0], bytes) else cats
640
697
  labels = cats[labels]
698
+ if is_cat:
699
+ labels = pd.Categorical(labels.astype(str))
641
700
  return labels
642
701
 
643
702
  def close(self):
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import time
2
3
  from typing import Callable, Optional, Union
3
4
  from uuid import uuid4
4
5
 
@@ -8,6 +9,7 @@ import numpy as np
8
9
  import pandas as pd
9
10
  import scanpy as sc
10
11
  from anndata import AnnData, read_h5ad
12
+ from django.db.utils import OperationalError
11
13
  from scipy.sparse import csr_matrix
12
14
  from upath import UPath
13
15
 
@@ -61,6 +63,7 @@ class Preprocessor:
61
63
  organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
62
64
  use_raw: bool = True,
63
65
  keepdata: bool = False,
66
+ drop_non_primary: bool = False,
64
67
  ) -> None:
65
68
  """
66
69
  Initializes the preprocessor and configures the workflow steps.
@@ -108,6 +111,8 @@ class Preprocessor:
108
111
  Defaults to False.
109
112
  keepdata (bool, optional): Determines whether to keep the data in the AnnData object.
110
113
  Defaults to False.
114
+ drop_non_primary (bool, optional): Determines whether to drop non-primary cells.
115
+ Defaults to False.
111
116
  """
112
117
  self.filter_gene_by_counts = filter_gene_by_counts
113
118
  self.filter_cell_by_counts = filter_cell_by_counts
@@ -123,6 +128,7 @@ class Preprocessor:
123
128
  self.min_valid_genes_id = min_valid_genes_id
124
129
  self.min_nnz_genes = min_nnz_genes
125
130
  self.maxdropamount = maxdropamount
131
+ self.drop_non_primary = drop_non_primary
126
132
  self.madoutlier = madoutlier
127
133
  self.n_hvg_for_postp = n_hvg_for_postp
128
134
  self.pct_mt_outlier = pct_mt_outlier
@@ -142,10 +148,6 @@ class Preprocessor:
142
148
  raise ValueError(
143
149
  "organism_ontology_term_id not found in adata.obs, you need to add an ontology term id for the organism of your anndata"
144
150
  )
145
- if not adata[0].var.index.str.contains("ENS").any() and not self.is_symbol:
146
- raise ValueError(
147
- "gene names in the `var.index` field of your anndata should map to the ensembl_gene nomenclature else set `is_symbol` to True if using hugo symbols"
148
- )
149
151
  if adata.obs["organism_ontology_term_id"].iloc[0] not in self.organisms:
150
152
  raise ValueError(
151
153
  "we cannot work with this organism",
@@ -161,8 +163,8 @@ class Preprocessor:
161
163
  if np.abs(adata[:50_000].X.astype(int) - adata[:50_000].X).sum():
162
164
  print("X was not raw counts, using 'counts' layer")
163
165
  adata.X = adata.layers["counts"].copy()
164
- print("Dropping layers: ", adata.layers.keys())
165
166
  if not self.keepdata:
167
+ print("Dropping layers: ", adata.layers.keys())
166
168
  del adata.layers
167
169
  if len(adata.varm.keys()) > 0 and not self.keepdata:
168
170
  del adata.varm
@@ -170,6 +172,8 @@ class Preprocessor:
170
172
  del adata.obsm
171
173
  if len(adata.obsp.keys()) > 0 and not self.keepdata:
172
174
  del adata.obsp
175
+ if len(adata.varp.keys()) > 0 and not self.keepdata:
176
+ del adata.varp
173
177
  # check that it is a count
174
178
 
175
179
  print("checking raw counts")
@@ -188,7 +192,7 @@ class Preprocessor:
188
192
  # if not available count drop
189
193
  prevsize = adata.shape[0]
190
194
  # dropping non primary
191
- if "is_primary_data" in adata.obs.columns:
195
+ if "is_primary_data" in adata.obs.columns and self.drop_non_primary:
192
196
  adata = adata[adata.obs.is_primary_data]
193
197
  if adata.shape[0] < self.min_dataset_size:
194
198
  raise Exception("Dataset dropped due to too many secondary cells")
@@ -213,13 +217,10 @@ class Preprocessor:
213
217
  min_genes=self.min_nnz_genes,
214
218
  )
215
219
  # if lost > 50% of the dataset, drop dataset
216
- # load the genes
217
- genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
218
-
219
- if prevsize / adata.shape[0] > self.maxdropamount:
220
+ if prevsize / (adata.shape[0] + 1) > self.maxdropamount:
220
221
  raise Exception(
221
222
  "Dataset dropped due to low expressed genes and unexpressed cells: factor of "
222
- + str(prevsize / adata.shape[0])
223
+ + str(prevsize / (adata.shape[0] + 1))
223
224
  )
224
225
  if adata.shape[0] < self.min_dataset_size:
225
226
  raise Exception(
@@ -232,60 +233,39 @@ class Preprocessor:
232
233
  )
233
234
  )
234
235
 
235
- # Check if we have a mix of gene names and ensembl IDs
236
- has_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").any()
237
- all_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").all()
238
-
239
- if not has_ens:
240
- print("No ENS genes found, assuming gene symbols...")
241
- elif not all_ens:
242
- print("Mix of ENS and gene symbols found, converting all to ENS IDs...")
243
-
236
+ # load the genes
237
+ genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
244
238
  genesdf["ensembl_gene_id"] = genesdf.index
245
239
 
246
240
  # For genes that are already ENS IDs, use them directly
247
- ens_mask = adata.var.index.str.match(r"ENS.*\d{6,}$")
248
- symbol_mask = ~ens_mask
249
-
241
+ prev_size = adata.shape[1]
250
242
  # Handle symbol genes
251
- if symbol_mask.any():
252
- symbol_var = adata.var[symbol_mask].merge(
243
+ if self.is_symbol:
244
+ new_var = adata.var.merge(
253
245
  genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
254
246
  left_index=True,
255
247
  right_index=True,
256
248
  how="inner",
257
249
  )
258
-
259
- # Handle ENS genes
260
- if ens_mask.any():
261
- ens_var = adata.var[ens_mask].merge(
250
+ new_var["symbol"] = new_var.index
251
+ adata = adata[:, new_var.index]
252
+ new_var.index = new_var["ensembl_gene_id"]
253
+ else:
254
+ new_var = adata.var.merge(
262
255
  genesdf, left_index=True, right_index=True, how="inner"
263
256
  )
257
+ adata = adata[:, new_var.index]
258
+ print(f"Removed {prev_size - adata.shape[1]} genes not known to the ontology")
259
+ prev_size = adata.shape[1]
264
260
 
265
- # Combine and sort
266
- if symbol_mask.any() and ens_mask.any():
267
- var = pd.concat([symbol_var, ens_var])
268
- elif symbol_mask.any():
269
- var = symbol_var
270
- else:
271
- var = ens_var
272
-
273
- adata = adata[:, var.index]
274
- # var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
275
- # Update adata with combined genes
276
- if "ensembl_gene_id" in var.columns:
277
- adata.var = var.set_index("ensembl_gene_id")
278
- else:
279
- adata.var = var
261
+ adata.var = new_var
280
262
  # Drop duplicate genes, keeping first occurrence
281
263
  adata = adata[:, ~adata.var.index.duplicated(keep="first")]
264
+ print(f"Removed {prev_size - adata.shape[1]} duplicate genes")
282
265
 
283
- intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
284
- print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
285
- if len(intersect_genes) < self.min_valid_genes_id:
266
+ if adata.shape[1] < self.min_valid_genes_id:
286
267
  raise Exception("Dataset dropped due to too many genes not mapping to it")
287
- adata = adata[:, list(intersect_genes)]
288
- # marking unseen genes
268
+
289
269
  unseen = set(genesdf.index) - set(adata.var.index)
290
270
  # adding them to adata
291
271
  emptyda = ad.AnnData(
@@ -293,6 +273,9 @@ class Preprocessor:
293
273
  var=pd.DataFrame(index=list(unseen)),
294
274
  obs=pd.DataFrame(index=adata.obs.index),
295
275
  )
276
+ print(
277
+ f"Added {len(unseen)} genes in the ontology but not present in the dataset"
278
+ )
296
279
  adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
297
280
  # do a validation function
298
281
  adata.uns["unseen_genes"] = list(unseen)
@@ -330,7 +313,7 @@ class Preprocessor:
330
313
  # QC
331
314
 
332
315
  adata.var[genesdf.columns] = genesdf.loc[adata.var.index]
333
- print("startin QC")
316
+ print("starting QC")
334
317
  sc.pp.calculate_qc_metrics(
335
318
  adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20]
336
319
  )
@@ -348,7 +331,7 @@ class Preprocessor:
348
331
  )
349
332
  total_outliers = (adata.obs["outlier"] | adata.obs["mt_outlier"]).sum()
350
333
  total_cells = adata.shape[0]
351
- percentage_outliers = (total_outliers / total_cells) * 100
334
+ percentage_outliers = (total_outliers / (total_cells + 1)) * 100
352
335
  print(
353
336
  f"Seeing {total_outliers} outliers ({percentage_outliers:.2f}% of total dataset):"
354
337
  )
@@ -395,7 +378,7 @@ class Preprocessor:
395
378
  subset=False,
396
379
  layer="norm",
397
380
  )
398
-
381
+ print("starting PCA")
399
382
  adata.obsm["X_pca"] = sc.pp.pca(
400
383
  adata.layers["norm"][:, adata.var.highly_variable]
401
384
  if "highly_variable" in adata.var.columns
@@ -464,13 +447,13 @@ class LaminPreprocessor(Preprocessor):
464
447
  *args,
465
448
  cache: bool = True,
466
449
  keep_files: bool = True,
467
- force_preloaded: bool = False,
450
+ force_lamin_cache: bool = False,
468
451
  **kwargs,
469
452
  ):
470
453
  super().__init__(*args, **kwargs)
471
454
  self.cache = cache
472
455
  self.keep_files = keep_files
473
- self.force_preloaded = force_preloaded
456
+ self.force_lamin_cache = force_lamin_cache
474
457
 
475
458
  def __call__(
476
459
  self,
@@ -505,10 +488,13 @@ class LaminPreprocessor(Preprocessor):
505
488
  print(f"{file.stem_uid} is already processed... not preprocessing")
506
489
  continue
507
490
  print(file)
491
+ if self.force_lamin_cache:
492
+ path = cache_path(file)
493
+ backed = read_h5ad(path, backed="r")
494
+ else:
495
+ # file.cache()
496
+ backed = file.open()
508
497
 
509
- _ = cache_path(file) if self.force_preloaded else file.cache()
510
- backed = file.open()
511
- # backed = read_h5ad(path, backed="r")
512
498
  if "is_primary_data" in backed.obs.columns:
513
499
  if backed.obs.is_primary_data.sum() == 0:
514
500
  print(f"{file.key} only contains non primary cells.. dropping")
@@ -556,37 +542,52 @@ class LaminPreprocessor(Preprocessor):
556
542
  block,
557
543
  dataset_id=file.stem_uid + "_p" + str(j),
558
544
  )
559
- myfile = ln.Artifact.from_anndata(
560
- block,
561
- description=description
562
- + " n"
563
- + str(i)
564
- + " p"
565
- + str(j)
566
- + " ( revises file "
567
- + str(file.stem_uid)
568
- + " )",
569
- version=version,
570
- )
571
- myfile.save()
572
-
545
+ saved = False
546
+ while not saved:
547
+ try:
548
+ myfile = ln.Artifact.from_anndata(
549
+ block,
550
+ description=description
551
+ + " n"
552
+ + str(i)
553
+ + " p"
554
+ + str(j)
555
+ + " ( revises file "
556
+ + str(file.stem_uid)
557
+ + " )",
558
+ version=version,
559
+ )
560
+ myfile.save()
561
+ saved = True
562
+ except OperationalError:
563
+ print(
564
+ "Database locked, waiting 30 seconds and retrying..."
565
+ )
566
+ time.sleep(10)
573
567
  if self.keep_files:
574
568
  files.append(myfile)
575
569
  del block
576
570
  else:
577
571
  del myfile
578
572
  del block
579
- gc.collect()
580
-
581
573
  else:
582
574
  adata = super().__call__(adata, dataset_id=file.stem_uid)
583
- myfile = ln.Artifact.from_anndata(
584
- adata,
585
- revises=file,
586
- description=description + " p" + str(i),
587
- version=version,
588
- )
589
- myfile.save()
575
+ saved = False
576
+ while not saved:
577
+ try:
578
+ myfile = ln.Artifact.from_anndata(
579
+ adata,
580
+ revises=file,
581
+ description=description + " p" + str(i),
582
+ version=version,
583
+ )
584
+ myfile.save()
585
+ saved = True
586
+ except OperationalError:
587
+ print(
588
+ "Database locked, waiting 10 seconds and retrying..."
589
+ )
590
+ time.sleep(10)
590
591
  if self.keep_files:
591
592
  files.append(myfile)
592
593
  del adata
@@ -606,7 +607,7 @@ class LaminPreprocessor(Preprocessor):
606
607
  continue
607
608
  else:
608
609
  raise e
609
-
610
+ gc.collect()
610
611
  # issues with KLlggfw6I6lvmbqiZm46
611
612
  if self.keep_files:
612
613
  # Reconstruct collection using keys
@@ -716,7 +717,7 @@ def additional_preprocess(adata):
716
717
  }
717
718
  }
718
719
  ) # multi ethnic will have to get renamed
719
- adata.obs["cell_culture"] = False
720
+ adata.obs["cell_culture"] = "False"
720
721
  # if cell_type contains the word "(cell culture)" then it is a cell culture and we mark it as so and remove this from the cell type
721
722
  loc = adata.obs["cell_type_ontology_term_id"].str.contains(
722
723
  "(cell culture)", regex=False
@@ -725,7 +726,7 @@ def additional_preprocess(adata):
725
726
  adata.obs["cell_type_ontology_term_id"] = adata.obs[
726
727
  "cell_type_ontology_term_id"
727
728
  ].astype(str)
728
- adata.obs.loc[loc, "cell_culture"] = True
729
+ adata.obs.loc[loc, "cell_culture"] = "True"
729
730
  adata.obs.loc[loc, "cell_type_ontology_term_id"] = adata.obs.loc[
730
731
  loc, "cell_type_ontology_term_id"
731
732
  ].str.replace(" (cell culture)", "")
@@ -734,7 +735,7 @@ def additional_preprocess(adata):
734
735
  "(cell culture)", regex=False
735
736
  )
736
737
  if loc.sum() > 0:
737
- adata.obs.loc[loc, "cell_culture"] = True
738
+ adata.obs.loc[loc, "cell_culture"] = "True"
738
739
  adata.obs["tissue_ontology_term_id"] = adata.obs[
739
740
  "tissue_ontology_term_id"
740
741
  ].astype(str)
@@ -744,7 +745,7 @@ def additional_preprocess(adata):
744
745
 
745
746
  loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)", regex=False)
746
747
  if loc.sum() > 0:
747
- adata.obs.loc[loc, "cell_culture"] = True
748
+ adata.obs.loc[loc, "cell_culture"] = "True"
748
749
  adata.obs["tissue_ontology_term_id"] = adata.obs[
749
750
  "tissue_ontology_term_id"
750
751
  ].astype(str)
@@ -773,6 +774,7 @@ def additional_postprocess(adata):
773
774
  # sc.external.pp.harmony_integrate(adata, key="batches")
774
775
  # sc.pp.neighbors(adata, use_rep="X_pca_harmony")
775
776
  # else:
777
+ print("starting post processing")
776
778
  sc.pp.neighbors(adata, use_rep="X_pca")
777
779
  sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
778
780
  sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
@@ -791,8 +793,12 @@ def additional_postprocess(adata):
791
793
  MAXSIM = 0.94
792
794
  from collections import Counter
793
795
 
796
+ import bionty as bt
797
+
794
798
  from .config import MAIN_HUMAN_MOUSE_DEV_STAGE_MAP
795
799
 
800
+ remap_stages = {u: k for k, v in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for u in v}
801
+
796
802
  adata.obs[NEWOBS] = (
797
803
  adata.obs[COL].astype(str) + "_" + adata.obs["leiden_1"].astype(str)
798
804
  )
@@ -860,18 +866,17 @@ def additional_postprocess(adata):
860
866
  num += 1
861
867
  adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
862
868
 
863
- import bionty as bt
864
-
865
869
  stages = adata.obs["development_stage_ontology_term_id"].unique()
866
870
  if adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:9606"]:
867
871
  relabel = {i: i for i in stages}
868
872
  for stage in stages:
873
+ if stage in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.keys():
874
+ continue
869
875
  stage_obj = bt.DevelopmentalStage.filter(ontology_id=stage).first()
870
876
  parents = set([i.ontology_id for i in stage_obj.parents.filter()])
871
877
  parents = parents - set(
872
878
  [
873
879
  "HsapDv:0010000",
874
- "HsapDv:0000204",
875
880
  "HsapDv:0000227",
876
881
  ]
877
882
  )
@@ -879,9 +884,14 @@ def additional_postprocess(adata):
879
884
  for p in parents:
880
885
  if p in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP:
881
886
  relabel[stage] = p
882
- adata.obs["simplified_dev_stage"] = adata.obs[
883
- "development_stage_ontology_term_id"
884
- ].map(relabel)
887
+ adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
888
+ relabel
889
+ )
890
+ for stage in adata.obs["age_group"].unique():
891
+ if stage in remap_stages.keys():
892
+ adata.obs["age_group"] = adata.obs["age_group"].map(
893
+ lambda x: remap_stages[x] if x == stage else x
894
+ )
885
895
  elif adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:10090"]:
886
896
  rename_mapping = {
887
897
  k: v for v, j in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for k in j
@@ -890,11 +900,12 @@ def additional_postprocess(adata):
890
900
  for stage in stages:
891
901
  if stage in rename_mapping:
892
902
  relabel[stage] = rename_mapping[stage]
893
- adata.obs["simplified_dev_stage"] = adata.obs[
894
- "development_stage_ontology_term_id"
895
- ].map(relabel)
903
+ adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
904
+ relabel
905
+ )
896
906
  else:
897
- raise ValueError("organism not supported")
907
+ # raise ValueError("organism not supported")
908
+ print("organism not supported for age labels")
898
909
  # palantir.utils.run_diffusion_maps(adata, n_components=20)
899
910
  # palantir.utils.determine_multiscale_space(adata)
900
911
  # terminal_states = palantir.utils.find_terminal_states(