scdataloader 1.9.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/mapped.py CHANGED
@@ -1,11 +1,20 @@
1
+ # Portions of this file are derived from Lamin Labs
2
+ # Copyright 2024 Lamin Labs
3
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # The rest of this file is licensed under MIT
5
+ # Please see https://github.com/laminlabs/lamindb/blob/main/lamindb/core/_mapped_collection.py
6
+ # for the original implementation
7
+
1
8
  from __future__ import annotations
2
9
 
10
+ import os
3
11
  from collections import Counter
4
12
  from functools import reduce
5
13
  from typing import TYPE_CHECKING, Literal
6
14
 
7
15
  import numpy as np
8
16
  import pandas as pd
17
+ import torch
9
18
  from lamindb.core.storage._anndata_accessor import (
10
19
  ArrayType,
11
20
  ArrayTypes,
@@ -17,10 +26,13 @@ from lamindb.core.storage._anndata_accessor import (
17
26
  registry,
18
27
  )
19
28
  from lamindb_setup.core.upath import UPath
29
+ from tqdm import tqdm
20
30
 
21
31
  if TYPE_CHECKING:
22
32
  from lamindb_setup.core.types import UPathStr
23
33
 
34
+ from pandas.api.types import union_categoricals
35
+
24
36
 
25
37
  class _Connect:
26
38
  def __init__(self, storage):
@@ -99,6 +111,8 @@ class MappedCollection:
99
111
  meta_assays: Assays that are already defined as metacells.
100
112
  metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
101
113
  get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
114
+ store_location: Path to a directory where klass_indices can be cached, or full path to the cache file.
115
+ force_recompute_indices: If True, recompute indices even if a cache file exists.
102
116
  """
103
117
 
104
118
  def __init__(
@@ -117,6 +131,8 @@ class MappedCollection:
117
131
  metacell_mode: float = 0.0,
118
132
  get_knn_cells: bool = False,
119
133
  meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
134
+ store_location: str | None = None,
135
+ force_recompute_indices: bool = False,
120
136
  ):
121
137
  if join not in {None, "inner", "outer"}: # pragma: nocover
122
138
  raise ValueError(
@@ -174,14 +190,26 @@ class MappedCollection:
174
190
  self._cache_cats: dict = {}
175
191
  if self.obs_keys is not None:
176
192
  if cache_categories:
177
- self._cache_categories(self.obs_keys)
193
+ if store_location is not None:
194
+ os.makedirs(store_location, exist_ok=True)
195
+ self.store_location = os.path.join(store_location, "categories")
196
+ if (
197
+ not os.path.exists(self.store_location)
198
+ or force_recompute_indices
199
+ ):
200
+ self._cache_categories(self.obs_keys)
201
+ torch.save(self._cache_cats, self.store_location)
202
+ else:
203
+ self._cache_cats = torch.load(self.store_location)
204
+ print(f"Loaded categories from {self.store_location}")
178
205
  self.encoders: dict = {}
179
206
  if self.encode_labels:
180
207
  self._make_encoders(self.encode_labels) # type: ignore
181
-
182
208
  self.n_obs_list = []
183
209
  self.indices_list = []
184
- for i, storage in enumerate(self.storages):
210
+ for i, storage in tqdm(
211
+ enumerate(self.storages), desc="Checking datasets", total=len(self.storages)
212
+ ):
185
213
  with _Connect(storage) as store:
186
214
  X = store["X"]
187
215
  store_path = self.path_list[i]
@@ -256,13 +284,10 @@ class MappedCollection:
256
284
  self._cache_cats = {}
257
285
  for label in obs_keys:
258
286
  self._cache_cats[label] = []
259
- for storage in self.storages:
287
+ for storage in tqdm(self.storages, f"caching categories, {label}"):
260
288
  with _Connect(storage) as store:
261
289
  cats = self._get_categories(store, label)
262
- if cats is not None:
263
- cats = (
264
- _decode(cats) if isinstance(cats[0], bytes) else cats[...]
265
- )
290
+ cats = _decode(cats) if isinstance(cats[0], bytes) else cats[...]
266
291
  self._cache_cats[label].append(cats)
267
292
 
268
293
  def _make_encoders(self, encode_labels: list):
@@ -396,8 +421,21 @@ class MappedCollection:
396
421
  cats = None
397
422
  label_idx = self._get_obs_idx(store, obs_idx, label, cats)
398
423
  if label in self.encoders:
399
- label_idx = self.encoders[label][label_idx]
400
- out[label] = label_idx
424
+ try:
425
+ label_idx = self.encoders[label][label_idx]
426
+ except:
427
+ print(self.storages[storage_idx])
428
+ print(label, label_idx)
429
+ print(idx)
430
+ print(cats)
431
+ raise
432
+ try:
433
+ out[label] = label_idx
434
+ except:
435
+ print(self.storages[storage_idx])
436
+ print(label, label_idx)
437
+ print(out)
438
+ raise
401
439
 
402
440
  if self.metacell_mode > 0:
403
441
  if (
@@ -548,21 +586,41 @@ class MappedCollection:
548
586
  weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
549
587
  return weights
550
588
 
551
- def get_merged_labels(self, label_key: str):
589
+ def get_merged_labels(self, label_key: str, is_cat: bool = True):
552
590
  """Get merged labels for `label_key` from all `.obs`."""
553
591
  labels_merge = []
554
- for i, storage in enumerate(self.storages):
592
+ for i, storage in tqdm(
593
+ enumerate(self.storages), label_key, total=len(self.storages)
594
+ ):
555
595
  with _Connect(storage) as store:
556
- labels = self._get_labels(store, label_key, storage_idx=i)
596
+ labels = self._get_labels(
597
+ store, label_key, storage_idx=i, is_cat=is_cat
598
+ )
557
599
  if self.filtered:
558
600
  labels = labels[self.indices_list[i]]
559
601
  labels_merge.append(labels)
560
- return np.hstack(labels_merge)
602
+ if is_cat:
603
+ try:
604
+ return union_categoricals(labels_merge)
605
+ except TypeError:
606
+ typ = type(int)
607
+ for i in range(len(labels_merge)):
608
+ if typ != type(labels_merge[i][0]):
609
+ self.storages[i]
610
+ typ = type(labels_merge[i][0])
611
+ return []
612
+ else:
613
+ print("concatenating labels")
614
+ return np.concatenate(labels_merge)
561
615
 
562
616
  def get_merged_categories(self, label_key: str):
563
617
  """Get merged categories for `label_key` from all `.obs`."""
564
618
  cats_merge = set()
565
- for i, storage in enumerate(self.storages):
619
+ for i, storage in tqdm(
620
+ enumerate(self.storages),
621
+ total=len(self.storages),
622
+ desc="merging all " + label_key + " categories",
623
+ ):
566
624
  with _Connect(storage) as store:
567
625
  if label_key in self._cache_cats:
568
626
  cats = self._cache_cats[label_key][i]
@@ -602,8 +660,8 @@ class MappedCollection:
602
660
  else:
603
661
  if "categories" in labels.attrs:
604
662
  return labels.attrs["categories"]
605
- else:
606
- return None
663
+ elif labels.dtype == "bool":
664
+ return np.array(["True", "False"])
607
665
  return None
608
666
 
609
667
  def _get_codes(self, storage: StorageType, label_key: str):
@@ -619,11 +677,17 @@ class MappedCollection:
619
677
  return label["codes"][...]
620
678
 
621
679
  def _get_labels(
622
- self, storage: StorageType, label_key: str, storage_idx: int | None = None
680
+ self,
681
+ storage: StorageType,
682
+ label_key: str,
683
+ storage_idx: int | None = None,
684
+ is_cat: bool = True,
623
685
  ):
624
686
  """Get labels."""
625
687
  codes = self._get_codes(storage, label_key)
626
688
  labels = _decode(codes) if isinstance(codes[0], bytes) else codes
689
+ if labels.dtype == bool:
690
+ labels = labels.astype(int)
627
691
  if storage_idx is not None and label_key in self._cache_cats:
628
692
  cats = self._cache_cats[label_key][storage_idx]
629
693
  else:
@@ -631,6 +695,8 @@ class MappedCollection:
631
695
  if cats is not None:
632
696
  cats = _decode(cats) if isinstance(cats[0], bytes) else cats
633
697
  labels = cats[labels]
698
+ if is_cat:
699
+ labels = pd.Categorical(labels.astype(str))
634
700
  return labels
635
701
 
636
702
  def close(self):
@@ -1,3 +1,5 @@
1
+ import gc
2
+ import time
1
3
  from typing import Callable, Optional, Union
2
4
  from uuid import uuid4
3
5
 
@@ -7,9 +9,10 @@ import numpy as np
7
9
  import pandas as pd
8
10
  import scanpy as sc
9
11
  from anndata import AnnData, read_h5ad
12
+ from django.db.utils import OperationalError
10
13
  from scipy.sparse import csr_matrix
11
14
  from upath import UPath
12
- import gc
15
+
13
16
  from scdataloader import utils as data_utils
14
17
 
15
18
  FULL_LENGTH_ASSAYS = [
@@ -60,6 +63,7 @@ class Preprocessor:
60
63
  organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
61
64
  use_raw: bool = True,
62
65
  keepdata: bool = False,
66
+ drop_non_primary: bool = False,
63
67
  ) -> None:
64
68
  """
65
69
  Initializes the preprocessor and configures the workflow steps.
@@ -107,6 +111,8 @@ class Preprocessor:
107
111
  Defaults to False.
108
112
  keepdata (bool, optional): Determines whether to keep the data in the AnnData object.
109
113
  Defaults to False.
114
+ drop_non_primary (bool, optional): Determines whether to drop non-primary cells.
115
+ Defaults to False.
110
116
  """
111
117
  self.filter_gene_by_counts = filter_gene_by_counts
112
118
  self.filter_cell_by_counts = filter_cell_by_counts
@@ -122,6 +128,7 @@ class Preprocessor:
122
128
  self.min_valid_genes_id = min_valid_genes_id
123
129
  self.min_nnz_genes = min_nnz_genes
124
130
  self.maxdropamount = maxdropamount
131
+ self.drop_non_primary = drop_non_primary
125
132
  self.madoutlier = madoutlier
126
133
  self.n_hvg_for_postp = n_hvg_for_postp
127
134
  self.pct_mt_outlier = pct_mt_outlier
@@ -141,10 +148,6 @@ class Preprocessor:
141
148
  raise ValueError(
142
149
  "organism_ontology_term_id not found in adata.obs, you need to add an ontology term id for the organism of your anndata"
143
150
  )
144
- if not adata[0].var.index.str.contains("ENS").any() and not self.is_symbol:
145
- raise ValueError(
146
- "gene names in the `var.index` field of your anndata should map to the ensembl_gene nomenclature else set `is_symbol` to True if using hugo symbols"
147
- )
148
151
  if adata.obs["organism_ontology_term_id"].iloc[0] not in self.organisms:
149
152
  raise ValueError(
150
153
  "we cannot work with this organism",
@@ -160,8 +163,8 @@ class Preprocessor:
160
163
  if np.abs(adata[:50_000].X.astype(int) - adata[:50_000].X).sum():
161
164
  print("X was not raw counts, using 'counts' layer")
162
165
  adata.X = adata.layers["counts"].copy()
163
- print("Dropping layers: ", adata.layers.keys())
164
166
  if not self.keepdata:
167
+ print("Dropping layers: ", adata.layers.keys())
165
168
  del adata.layers
166
169
  if len(adata.varm.keys()) > 0 and not self.keepdata:
167
170
  del adata.varm
@@ -169,6 +172,8 @@ class Preprocessor:
169
172
  del adata.obsm
170
173
  if len(adata.obsp.keys()) > 0 and not self.keepdata:
171
174
  del adata.obsp
175
+ if len(adata.varp.keys()) > 0 and not self.keepdata:
176
+ del adata.varp
172
177
  # check that it is a count
173
178
 
174
179
  print("checking raw counts")
@@ -187,7 +192,7 @@ class Preprocessor:
187
192
  # if not available count drop
188
193
  prevsize = adata.shape[0]
189
194
  # dropping non primary
190
- if "is_primary_data" in adata.obs.columns:
195
+ if "is_primary_data" in adata.obs.columns and self.drop_non_primary:
191
196
  adata = adata[adata.obs.is_primary_data]
192
197
  if adata.shape[0] < self.min_dataset_size:
193
198
  raise Exception("Dataset dropped due to too many secondary cells")
@@ -212,13 +217,10 @@ class Preprocessor:
212
217
  min_genes=self.min_nnz_genes,
213
218
  )
214
219
  # if lost > 50% of the dataset, drop dataset
215
- # load the genes
216
- genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
217
-
218
- if prevsize / adata.shape[0] > self.maxdropamount:
220
+ if prevsize / (adata.shape[0] + 1) > self.maxdropamount:
219
221
  raise Exception(
220
222
  "Dataset dropped due to low expressed genes and unexpressed cells: factor of "
221
- + str(prevsize / adata.shape[0])
223
+ + str(prevsize / (adata.shape[0] + 1))
222
224
  )
223
225
  if adata.shape[0] < self.min_dataset_size:
224
226
  raise Exception(
@@ -231,58 +233,39 @@ class Preprocessor:
231
233
  )
232
234
  )
233
235
 
234
- # Check if we have a mix of gene names and ensembl IDs
235
- has_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").any()
236
- all_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").all()
237
-
238
- if not has_ens:
239
- print("No ENS genes found, assuming gene symbols...")
240
- elif not all_ens:
241
- print("Mix of ENS and gene symbols found, converting all to ENS IDs...")
242
-
236
+ # load the genes
237
+ genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
243
238
  genesdf["ensembl_gene_id"] = genesdf.index
244
239
 
245
240
  # For genes that are already ENS IDs, use them directly
246
- ens_mask = adata.var.index.str.match(r"ENS.*\d{6,}$")
247
- symbol_mask = ~ens_mask
248
-
241
+ prev_size = adata.shape[1]
249
242
  # Handle symbol genes
250
- if symbol_mask.any():
251
- symbol_var = adata.var[symbol_mask].merge(
243
+ if self.is_symbol:
244
+ new_var = adata.var.merge(
252
245
  genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
253
246
  left_index=True,
254
247
  right_index=True,
255
248
  how="inner",
256
249
  )
257
-
258
- # Handle ENS genes
259
- if ens_mask.any():
260
- ens_var = adata.var[ens_mask].merge(
250
+ new_var["symbol"] = new_var.index
251
+ adata = adata[:, new_var.index]
252
+ new_var.index = new_var["ensembl_gene_id"]
253
+ else:
254
+ new_var = adata.var.merge(
261
255
  genesdf, left_index=True, right_index=True, how="inner"
262
256
  )
257
+ adata = adata[:, new_var.index]
258
+ print(f"Removed {prev_size - adata.shape[1]} genes not known to the ontology")
259
+ prev_size = adata.shape[1]
263
260
 
264
- # Combine and sort
265
- if symbol_mask.any() and ens_mask.any():
266
- var = pd.concat([symbol_var, ens_var])
267
- elif symbol_mask.any():
268
- var = symbol_var
269
- else:
270
- var = ens_var
271
-
272
- adata = adata[:, var.index]
273
- var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
274
- # Update adata with combined genes
275
- adata.var = var
276
- genesdf = genesdf.set_index("ensembl_gene_id")
261
+ adata.var = new_var
277
262
  # Drop duplicate genes, keeping first occurrence
278
263
  adata = adata[:, ~adata.var.index.duplicated(keep="first")]
264
+ print(f"Removed {prev_size - adata.shape[1]} duplicate genes")
279
265
 
280
- intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
281
- print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
282
- if len(intersect_genes) < self.min_valid_genes_id:
266
+ if adata.shape[1] < self.min_valid_genes_id:
283
267
  raise Exception("Dataset dropped due to too many genes not mapping to it")
284
- adata = adata[:, list(intersect_genes)]
285
- # marking unseen genes
268
+
286
269
  unseen = set(genesdf.index) - set(adata.var.index)
287
270
  # adding them to adata
288
271
  emptyda = ad.AnnData(
@@ -290,6 +273,9 @@ class Preprocessor:
290
273
  var=pd.DataFrame(index=list(unseen)),
291
274
  obs=pd.DataFrame(index=adata.obs.index),
292
275
  )
276
+ print(
277
+ f"Added {len(unseen)} genes in the ontology but not present in the dataset"
278
+ )
293
279
  adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
294
280
  # do a validation function
295
281
  adata.uns["unseen_genes"] = list(unseen)
@@ -327,7 +313,7 @@ class Preprocessor:
327
313
  # QC
328
314
 
329
315
  adata.var[genesdf.columns] = genesdf.loc[adata.var.index]
330
- print("startin QC")
316
+ print("starting QC")
331
317
  sc.pp.calculate_qc_metrics(
332
318
  adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20]
333
319
  )
@@ -345,7 +331,7 @@ class Preprocessor:
345
331
  )
346
332
  total_outliers = (adata.obs["outlier"] | adata.obs["mt_outlier"]).sum()
347
333
  total_cells = adata.shape[0]
348
- percentage_outliers = (total_outliers / total_cells) * 100
334
+ percentage_outliers = (total_outliers / (total_cells + 1)) * 100
349
335
  print(
350
336
  f"Seeing {total_outliers} outliers ({percentage_outliers:.2f}% of total dataset):"
351
337
  )
@@ -392,7 +378,7 @@ class Preprocessor:
392
378
  subset=False,
393
379
  layer="norm",
394
380
  )
395
-
381
+ print("starting PCA")
396
382
  adata.obsm["X_pca"] = sc.pp.pca(
397
383
  adata.layers["norm"][:, adata.var.highly_variable]
398
384
  if "highly_variable" in adata.var.columns
@@ -461,13 +447,13 @@ class LaminPreprocessor(Preprocessor):
461
447
  *args,
462
448
  cache: bool = True,
463
449
  keep_files: bool = True,
464
- force_preloaded: bool = False,
450
+ force_lamin_cache: bool = False,
465
451
  **kwargs,
466
452
  ):
467
453
  super().__init__(*args, **kwargs)
468
454
  self.cache = cache
469
455
  self.keep_files = keep_files
470
- self.force_preloaded = force_preloaded
456
+ self.force_lamin_cache = force_lamin_cache
471
457
 
472
458
  def __call__(
473
459
  self,
@@ -502,10 +488,13 @@ class LaminPreprocessor(Preprocessor):
502
488
  print(f"{file.stem_uid} is already processed... not preprocessing")
503
489
  continue
504
490
  print(file)
491
+ if self.force_lamin_cache:
492
+ path = cache_path(file)
493
+ backed = read_h5ad(path, backed="r")
494
+ else:
495
+ # file.cache()
496
+ backed = file.open()
505
497
 
506
- path = cache_path(file) if self.force_preloaded else file.cache()
507
- backed = file.open()
508
- # backed = read_h5ad(path, backed="r")
509
498
  if "is_primary_data" in backed.obs.columns:
510
499
  if backed.obs.is_primary_data.sum() == 0:
511
500
  print(f"{file.key} only contains non primary cells.. dropping")
@@ -553,37 +542,52 @@ class LaminPreprocessor(Preprocessor):
553
542
  block,
554
543
  dataset_id=file.stem_uid + "_p" + str(j),
555
544
  )
556
- myfile = ln.Artifact.from_anndata(
557
- block,
558
- description=description
559
- + " n"
560
- + str(i)
561
- + " p"
562
- + str(j)
563
- + " ( revises file "
564
- + str(file.stem_uid)
565
- + " )",
566
- version=version,
567
- )
568
- myfile.save()
569
-
545
+ saved = False
546
+ while not saved:
547
+ try:
548
+ myfile = ln.Artifact.from_anndata(
549
+ block,
550
+ description=description
551
+ + " n"
552
+ + str(i)
553
+ + " p"
554
+ + str(j)
555
+ + " ( revises file "
556
+ + str(file.stem_uid)
557
+ + " )",
558
+ version=version,
559
+ )
560
+ myfile.save()
561
+ saved = True
562
+ except OperationalError:
563
+ print(
564
+ "Database locked, waiting 30 seconds and retrying..."
565
+ )
566
+ time.sleep(10)
570
567
  if self.keep_files:
571
568
  files.append(myfile)
572
569
  del block
573
570
  else:
574
571
  del myfile
575
572
  del block
576
- gc.collect()
577
-
578
573
  else:
579
574
  adata = super().__call__(adata, dataset_id=file.stem_uid)
580
- myfile = ln.Artifact.from_anndata(
581
- adata,
582
- revises=file,
583
- description=description + " p" + str(i),
584
- version=version,
585
- )
586
- myfile.save()
575
+ saved = False
576
+ while not saved:
577
+ try:
578
+ myfile = ln.Artifact.from_anndata(
579
+ adata,
580
+ revises=file,
581
+ description=description + " p" + str(i),
582
+ version=version,
583
+ )
584
+ myfile.save()
585
+ saved = True
586
+ except OperationalError:
587
+ print(
588
+ "Database locked, waiting 10 seconds and retrying..."
589
+ )
590
+ time.sleep(10)
587
591
  if self.keep_files:
588
592
  files.append(myfile)
589
593
  del adata
@@ -603,7 +607,7 @@ class LaminPreprocessor(Preprocessor):
603
607
  continue
604
608
  else:
605
609
  raise e
606
-
610
+ gc.collect()
607
611
  # issues with KLlggfw6I6lvmbqiZm46
608
612
  if self.keep_files:
609
613
  # Reconstruct collection using keys
@@ -713,7 +717,7 @@ def additional_preprocess(adata):
713
717
  }
714
718
  }
715
719
  ) # multi ethnic will have to get renamed
716
- adata.obs["cell_culture"] = False
720
+ adata.obs["cell_culture"] = "False"
717
721
  # if cell_type contains the word "(cell culture)" then it is a cell culture and we mark it as so and remove this from the cell type
718
722
  loc = adata.obs["cell_type_ontology_term_id"].str.contains(
719
723
  "(cell culture)", regex=False
@@ -722,7 +726,7 @@ def additional_preprocess(adata):
722
726
  adata.obs["cell_type_ontology_term_id"] = adata.obs[
723
727
  "cell_type_ontology_term_id"
724
728
  ].astype(str)
725
- adata.obs.loc[loc, "cell_culture"] = True
729
+ adata.obs.loc[loc, "cell_culture"] = "True"
726
730
  adata.obs.loc[loc, "cell_type_ontology_term_id"] = adata.obs.loc[
727
731
  loc, "cell_type_ontology_term_id"
728
732
  ].str.replace(" (cell culture)", "")
@@ -731,7 +735,7 @@ def additional_preprocess(adata):
731
735
  "(cell culture)", regex=False
732
736
  )
733
737
  if loc.sum() > 0:
734
- adata.obs.loc[loc, "cell_culture"] = True
738
+ adata.obs.loc[loc, "cell_culture"] = "True"
735
739
  adata.obs["tissue_ontology_term_id"] = adata.obs[
736
740
  "tissue_ontology_term_id"
737
741
  ].astype(str)
@@ -741,7 +745,7 @@ def additional_preprocess(adata):
741
745
 
742
746
  loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)", regex=False)
743
747
  if loc.sum() > 0:
744
- adata.obs.loc[loc, "cell_culture"] = True
748
+ adata.obs.loc[loc, "cell_culture"] = "True"
745
749
  adata.obs["tissue_ontology_term_id"] = adata.obs[
746
750
  "tissue_ontology_term_id"
747
751
  ].astype(str)
@@ -770,6 +774,7 @@ def additional_postprocess(adata):
770
774
  # sc.external.pp.harmony_integrate(adata, key="batches")
771
775
  # sc.pp.neighbors(adata, use_rep="X_pca_harmony")
772
776
  # else:
777
+ print("starting post processing")
773
778
  sc.pp.neighbors(adata, use_rep="X_pca")
774
779
  sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
775
780
  sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
@@ -788,8 +793,12 @@ def additional_postprocess(adata):
788
793
  MAXSIM = 0.94
789
794
  from collections import Counter
790
795
 
796
+ import bionty as bt
797
+
791
798
  from .config import MAIN_HUMAN_MOUSE_DEV_STAGE_MAP
792
799
 
800
+ remap_stages = {u: k for k, v in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for u in v}
801
+
793
802
  adata.obs[NEWOBS] = (
794
803
  adata.obs[COL].astype(str) + "_" + adata.obs["leiden_1"].astype(str)
795
804
  )
@@ -857,18 +866,17 @@ def additional_postprocess(adata):
857
866
  num += 1
858
867
  adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
859
868
 
860
- import bionty as bt
861
-
862
869
  stages = adata.obs["development_stage_ontology_term_id"].unique()
863
870
  if adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:9606"]:
864
871
  relabel = {i: i for i in stages}
865
872
  for stage in stages:
873
+ if stage in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.keys():
874
+ continue
866
875
  stage_obj = bt.DevelopmentalStage.filter(ontology_id=stage).first()
867
876
  parents = set([i.ontology_id for i in stage_obj.parents.filter()])
868
877
  parents = parents - set(
869
878
  [
870
879
  "HsapDv:0010000",
871
- "HsapDv:0000204",
872
880
  "HsapDv:0000227",
873
881
  ]
874
882
  )
@@ -876,9 +884,14 @@ def additional_postprocess(adata):
876
884
  for p in parents:
877
885
  if p in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP:
878
886
  relabel[stage] = p
879
- adata.obs["simplified_dev_stage"] = adata.obs[
880
- "development_stage_ontology_term_id"
881
- ].map(relabel)
887
+ adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
888
+ relabel
889
+ )
890
+ for stage in adata.obs["age_group"].unique():
891
+ if stage in remap_stages.keys():
892
+ adata.obs["age_group"] = adata.obs["age_group"].map(
893
+ lambda x: remap_stages[x] if x == stage else x
894
+ )
882
895
  elif adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:10090"]:
883
896
  rename_mapping = {
884
897
  k: v for v, j in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for k in j
@@ -887,11 +900,12 @@ def additional_postprocess(adata):
887
900
  for stage in stages:
888
901
  if stage in rename_mapping:
889
902
  relabel[stage] = rename_mapping[stage]
890
- adata.obs["simplified_dev_stage"] = adata.obs[
891
- "development_stage_ontology_term_id"
892
- ].map(relabel)
903
+ adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
904
+ relabel
905
+ )
893
906
  else:
894
- raise ValueError("organism not supported")
907
+ # raise ValueError("organism not supported")
908
+ print("organism not supported for age labels")
895
909
  # palantir.utils.run_diffusion_maps(adata, n_components=20)
896
910
  # palantir.utils.determine_multiscale_space(adata)
897
911
  # terminal_states = palantir.utils.find_terminal_states(