scdataloader 1.9.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/collator.py +13 -24
- scdataloader/config.py +25 -9
- scdataloader/data.json +384 -0
- scdataloader/data.py +116 -43
- scdataloader/datamodule.py +551 -199
- scdataloader/mapped.py +77 -18
- scdataloader/preprocess.py +106 -95
- scdataloader/utils.py +39 -33
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/METADATA +3 -4
- scdataloader-2.0.0.dist-info/RECORD +16 -0
- scdataloader-2.0.0.dist-info/licenses/LICENSE +21 -0
- scdataloader/VERSION +0 -1
- scdataloader-1.9.2.dist-info/RECORD +0 -16
- scdataloader-1.9.2.dist-info/licenses/LICENSE +0 -674
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/WHEEL +0 -0
- {scdataloader-1.9.2.dist-info → scdataloader-2.0.0.dist-info}/entry_points.txt +0 -0
scdataloader/mapped.py
CHANGED
|
@@ -7,12 +7,14 @@
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
+
import os
|
|
10
11
|
from collections import Counter
|
|
11
12
|
from functools import reduce
|
|
12
13
|
from typing import TYPE_CHECKING, Literal
|
|
13
14
|
|
|
14
15
|
import numpy as np
|
|
15
16
|
import pandas as pd
|
|
17
|
+
import torch
|
|
16
18
|
from lamindb.core.storage._anndata_accessor import (
|
|
17
19
|
ArrayType,
|
|
18
20
|
ArrayTypes,
|
|
@@ -24,10 +26,13 @@ from lamindb.core.storage._anndata_accessor import (
|
|
|
24
26
|
registry,
|
|
25
27
|
)
|
|
26
28
|
from lamindb_setup.core.upath import UPath
|
|
29
|
+
from tqdm import tqdm
|
|
27
30
|
|
|
28
31
|
if TYPE_CHECKING:
|
|
29
32
|
from lamindb_setup.core.types import UPathStr
|
|
30
33
|
|
|
34
|
+
from pandas.api.types import union_categoricals
|
|
35
|
+
|
|
31
36
|
|
|
32
37
|
class _Connect:
|
|
33
38
|
def __init__(self, storage):
|
|
@@ -106,6 +111,8 @@ class MappedCollection:
|
|
|
106
111
|
meta_assays: Assays that are already defined as metacells.
|
|
107
112
|
metacell_mode: frequency at which to sample a metacell (an average of k-nearest neighbors).
|
|
108
113
|
get_knn_cells: Whether to also dataload the k-nearest neighbors of each queried cells.
|
|
114
|
+
store_location: Path to a directory where klass_indices can be cached, or full path to the cache file.
|
|
115
|
+
force_recompute_indices: If True, recompute indices even if a cache file exists.
|
|
109
116
|
"""
|
|
110
117
|
|
|
111
118
|
def __init__(
|
|
@@ -124,6 +131,8 @@ class MappedCollection:
|
|
|
124
131
|
metacell_mode: float = 0.0,
|
|
125
132
|
get_knn_cells: bool = False,
|
|
126
133
|
meta_assays: list[str] = ["EFO:0022857", "EFO:0010961"],
|
|
134
|
+
store_location: str | None = None,
|
|
135
|
+
force_recompute_indices: bool = False,
|
|
127
136
|
):
|
|
128
137
|
if join not in {None, "inner", "outer"}: # pragma: nocover
|
|
129
138
|
raise ValueError(
|
|
@@ -181,14 +190,26 @@ class MappedCollection:
|
|
|
181
190
|
self._cache_cats: dict = {}
|
|
182
191
|
if self.obs_keys is not None:
|
|
183
192
|
if cache_categories:
|
|
184
|
-
|
|
193
|
+
if store_location is not None:
|
|
194
|
+
os.makedirs(store_location, exist_ok=True)
|
|
195
|
+
self.store_location = os.path.join(store_location, "categories")
|
|
196
|
+
if (
|
|
197
|
+
not os.path.exists(self.store_location)
|
|
198
|
+
or force_recompute_indices
|
|
199
|
+
):
|
|
200
|
+
self._cache_categories(self.obs_keys)
|
|
201
|
+
torch.save(self._cache_cats, self.store_location)
|
|
202
|
+
else:
|
|
203
|
+
self._cache_cats = torch.load(self.store_location)
|
|
204
|
+
print(f"Loaded categories from {self.store_location}")
|
|
185
205
|
self.encoders: dict = {}
|
|
186
206
|
if self.encode_labels:
|
|
187
207
|
self._make_encoders(self.encode_labels) # type: ignore
|
|
188
|
-
|
|
189
208
|
self.n_obs_list = []
|
|
190
209
|
self.indices_list = []
|
|
191
|
-
for i, storage in
|
|
210
|
+
for i, storage in tqdm(
|
|
211
|
+
enumerate(self.storages), desc="Checking datasets", total=len(self.storages)
|
|
212
|
+
):
|
|
192
213
|
with _Connect(storage) as store:
|
|
193
214
|
X = store["X"]
|
|
194
215
|
store_path = self.path_list[i]
|
|
@@ -263,13 +284,10 @@ class MappedCollection:
|
|
|
263
284
|
self._cache_cats = {}
|
|
264
285
|
for label in obs_keys:
|
|
265
286
|
self._cache_cats[label] = []
|
|
266
|
-
for storage in self.storages:
|
|
287
|
+
for storage in tqdm(self.storages, f"caching categories, {label}"):
|
|
267
288
|
with _Connect(storage) as store:
|
|
268
289
|
cats = self._get_categories(store, label)
|
|
269
|
-
if cats
|
|
270
|
-
cats = (
|
|
271
|
-
_decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
272
|
-
)
|
|
290
|
+
cats = _decode(cats) if isinstance(cats[0], bytes) else cats[...]
|
|
273
291
|
self._cache_cats[label].append(cats)
|
|
274
292
|
|
|
275
293
|
def _make_encoders(self, encode_labels: list):
|
|
@@ -403,8 +421,21 @@ class MappedCollection:
|
|
|
403
421
|
cats = None
|
|
404
422
|
label_idx = self._get_obs_idx(store, obs_idx, label, cats)
|
|
405
423
|
if label in self.encoders:
|
|
406
|
-
|
|
407
|
-
|
|
424
|
+
try:
|
|
425
|
+
label_idx = self.encoders[label][label_idx]
|
|
426
|
+
except:
|
|
427
|
+
print(self.storages[storage_idx])
|
|
428
|
+
print(label, label_idx)
|
|
429
|
+
print(idx)
|
|
430
|
+
print(cats)
|
|
431
|
+
raise
|
|
432
|
+
try:
|
|
433
|
+
out[label] = label_idx
|
|
434
|
+
except:
|
|
435
|
+
print(self.storages[storage_idx])
|
|
436
|
+
print(label, label_idx)
|
|
437
|
+
print(out)
|
|
438
|
+
raise
|
|
408
439
|
|
|
409
440
|
if self.metacell_mode > 0:
|
|
410
441
|
if (
|
|
@@ -555,21 +586,41 @@ class MappedCollection:
|
|
|
555
586
|
weights = (MAX / scaler) / ((1 + counts - MIN) + MAX / scaler)
|
|
556
587
|
return weights
|
|
557
588
|
|
|
558
|
-
def get_merged_labels(self, label_key: str):
|
|
589
|
+
def get_merged_labels(self, label_key: str, is_cat: bool = True):
|
|
559
590
|
"""Get merged labels for `label_key` from all `.obs`."""
|
|
560
591
|
labels_merge = []
|
|
561
|
-
for i, storage in
|
|
592
|
+
for i, storage in tqdm(
|
|
593
|
+
enumerate(self.storages), label_key, total=len(self.storages)
|
|
594
|
+
):
|
|
562
595
|
with _Connect(storage) as store:
|
|
563
|
-
labels = self._get_labels(
|
|
596
|
+
labels = self._get_labels(
|
|
597
|
+
store, label_key, storage_idx=i, is_cat=is_cat
|
|
598
|
+
)
|
|
564
599
|
if self.filtered:
|
|
565
600
|
labels = labels[self.indices_list[i]]
|
|
566
601
|
labels_merge.append(labels)
|
|
567
|
-
|
|
602
|
+
if is_cat:
|
|
603
|
+
try:
|
|
604
|
+
return union_categoricals(labels_merge)
|
|
605
|
+
except TypeError:
|
|
606
|
+
typ = type(int)
|
|
607
|
+
for i in range(len(labels_merge)):
|
|
608
|
+
if typ != type(labels_merge[i][0]):
|
|
609
|
+
self.storages[i]
|
|
610
|
+
typ = type(labels_merge[i][0])
|
|
611
|
+
return []
|
|
612
|
+
else:
|
|
613
|
+
print("concatenating labels")
|
|
614
|
+
return np.concatenate(labels_merge)
|
|
568
615
|
|
|
569
616
|
def get_merged_categories(self, label_key: str):
|
|
570
617
|
"""Get merged categories for `label_key` from all `.obs`."""
|
|
571
618
|
cats_merge = set()
|
|
572
|
-
for i, storage in
|
|
619
|
+
for i, storage in tqdm(
|
|
620
|
+
enumerate(self.storages),
|
|
621
|
+
total=len(self.storages),
|
|
622
|
+
desc="merging all " + label_key + " categories",
|
|
623
|
+
):
|
|
573
624
|
with _Connect(storage) as store:
|
|
574
625
|
if label_key in self._cache_cats:
|
|
575
626
|
cats = self._cache_cats[label_key][i]
|
|
@@ -609,8 +660,8 @@ class MappedCollection:
|
|
|
609
660
|
else:
|
|
610
661
|
if "categories" in labels.attrs:
|
|
611
662
|
return labels.attrs["categories"]
|
|
612
|
-
|
|
613
|
-
return
|
|
663
|
+
elif labels.dtype == "bool":
|
|
664
|
+
return np.array(["True", "False"])
|
|
614
665
|
return None
|
|
615
666
|
|
|
616
667
|
def _get_codes(self, storage: StorageType, label_key: str):
|
|
@@ -626,11 +677,17 @@ class MappedCollection:
|
|
|
626
677
|
return label["codes"][...]
|
|
627
678
|
|
|
628
679
|
def _get_labels(
|
|
629
|
-
self,
|
|
680
|
+
self,
|
|
681
|
+
storage: StorageType,
|
|
682
|
+
label_key: str,
|
|
683
|
+
storage_idx: int | None = None,
|
|
684
|
+
is_cat: bool = True,
|
|
630
685
|
):
|
|
631
686
|
"""Get labels."""
|
|
632
687
|
codes = self._get_codes(storage, label_key)
|
|
633
688
|
labels = _decode(codes) if isinstance(codes[0], bytes) else codes
|
|
689
|
+
if labels.dtype == bool:
|
|
690
|
+
labels = labels.astype(int)
|
|
634
691
|
if storage_idx is not None and label_key in self._cache_cats:
|
|
635
692
|
cats = self._cache_cats[label_key][storage_idx]
|
|
636
693
|
else:
|
|
@@ -638,6 +695,8 @@ class MappedCollection:
|
|
|
638
695
|
if cats is not None:
|
|
639
696
|
cats = _decode(cats) if isinstance(cats[0], bytes) else cats
|
|
640
697
|
labels = cats[labels]
|
|
698
|
+
if is_cat:
|
|
699
|
+
labels = pd.Categorical(labels.astype(str))
|
|
641
700
|
return labels
|
|
642
701
|
|
|
643
702
|
def close(self):
|
scdataloader/preprocess.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import time
|
|
2
3
|
from typing import Callable, Optional, Union
|
|
3
4
|
from uuid import uuid4
|
|
4
5
|
|
|
@@ -8,6 +9,7 @@ import numpy as np
|
|
|
8
9
|
import pandas as pd
|
|
9
10
|
import scanpy as sc
|
|
10
11
|
from anndata import AnnData, read_h5ad
|
|
12
|
+
from django.db.utils import OperationalError
|
|
11
13
|
from scipy.sparse import csr_matrix
|
|
12
14
|
from upath import UPath
|
|
13
15
|
|
|
@@ -61,6 +63,7 @@ class Preprocessor:
|
|
|
61
63
|
organisms: list[str] = ["NCBITaxon:9606", "NCBITaxon:10090"],
|
|
62
64
|
use_raw: bool = True,
|
|
63
65
|
keepdata: bool = False,
|
|
66
|
+
drop_non_primary: bool = False,
|
|
64
67
|
) -> None:
|
|
65
68
|
"""
|
|
66
69
|
Initializes the preprocessor and configures the workflow steps.
|
|
@@ -108,6 +111,8 @@ class Preprocessor:
|
|
|
108
111
|
Defaults to False.
|
|
109
112
|
keepdata (bool, optional): Determines whether to keep the data in the AnnData object.
|
|
110
113
|
Defaults to False.
|
|
114
|
+
drop_non_primary (bool, optional): Determines whether to drop non-primary cells.
|
|
115
|
+
Defaults to False.
|
|
111
116
|
"""
|
|
112
117
|
self.filter_gene_by_counts = filter_gene_by_counts
|
|
113
118
|
self.filter_cell_by_counts = filter_cell_by_counts
|
|
@@ -123,6 +128,7 @@ class Preprocessor:
|
|
|
123
128
|
self.min_valid_genes_id = min_valid_genes_id
|
|
124
129
|
self.min_nnz_genes = min_nnz_genes
|
|
125
130
|
self.maxdropamount = maxdropamount
|
|
131
|
+
self.drop_non_primary = drop_non_primary
|
|
126
132
|
self.madoutlier = madoutlier
|
|
127
133
|
self.n_hvg_for_postp = n_hvg_for_postp
|
|
128
134
|
self.pct_mt_outlier = pct_mt_outlier
|
|
@@ -142,10 +148,6 @@ class Preprocessor:
|
|
|
142
148
|
raise ValueError(
|
|
143
149
|
"organism_ontology_term_id not found in adata.obs, you need to add an ontology term id for the organism of your anndata"
|
|
144
150
|
)
|
|
145
|
-
if not adata[0].var.index.str.contains("ENS").any() and not self.is_symbol:
|
|
146
|
-
raise ValueError(
|
|
147
|
-
"gene names in the `var.index` field of your anndata should map to the ensembl_gene nomenclature else set `is_symbol` to True if using hugo symbols"
|
|
148
|
-
)
|
|
149
151
|
if adata.obs["organism_ontology_term_id"].iloc[0] not in self.organisms:
|
|
150
152
|
raise ValueError(
|
|
151
153
|
"we cannot work with this organism",
|
|
@@ -161,8 +163,8 @@ class Preprocessor:
|
|
|
161
163
|
if np.abs(adata[:50_000].X.astype(int) - adata[:50_000].X).sum():
|
|
162
164
|
print("X was not raw counts, using 'counts' layer")
|
|
163
165
|
adata.X = adata.layers["counts"].copy()
|
|
164
|
-
print("Dropping layers: ", adata.layers.keys())
|
|
165
166
|
if not self.keepdata:
|
|
167
|
+
print("Dropping layers: ", adata.layers.keys())
|
|
166
168
|
del adata.layers
|
|
167
169
|
if len(adata.varm.keys()) > 0 and not self.keepdata:
|
|
168
170
|
del adata.varm
|
|
@@ -170,6 +172,8 @@ class Preprocessor:
|
|
|
170
172
|
del adata.obsm
|
|
171
173
|
if len(adata.obsp.keys()) > 0 and not self.keepdata:
|
|
172
174
|
del adata.obsp
|
|
175
|
+
if len(adata.varp.keys()) > 0 and not self.keepdata:
|
|
176
|
+
del adata.varp
|
|
173
177
|
# check that it is a count
|
|
174
178
|
|
|
175
179
|
print("checking raw counts")
|
|
@@ -188,7 +192,7 @@ class Preprocessor:
|
|
|
188
192
|
# if not available count drop
|
|
189
193
|
prevsize = adata.shape[0]
|
|
190
194
|
# dropping non primary
|
|
191
|
-
if "is_primary_data" in adata.obs.columns:
|
|
195
|
+
if "is_primary_data" in adata.obs.columns and self.drop_non_primary:
|
|
192
196
|
adata = adata[adata.obs.is_primary_data]
|
|
193
197
|
if adata.shape[0] < self.min_dataset_size:
|
|
194
198
|
raise Exception("Dataset dropped due to too many secondary cells")
|
|
@@ -213,13 +217,10 @@ class Preprocessor:
|
|
|
213
217
|
min_genes=self.min_nnz_genes,
|
|
214
218
|
)
|
|
215
219
|
# if lost > 50% of the dataset, drop dataset
|
|
216
|
-
|
|
217
|
-
genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
|
|
218
|
-
|
|
219
|
-
if prevsize / adata.shape[0] > self.maxdropamount:
|
|
220
|
+
if prevsize / (adata.shape[0] + 1) > self.maxdropamount:
|
|
220
221
|
raise Exception(
|
|
221
222
|
"Dataset dropped due to low expressed genes and unexpressed cells: factor of "
|
|
222
|
-
+ str(prevsize / adata.shape[0])
|
|
223
|
+
+ str(prevsize / (adata.shape[0] + 1))
|
|
223
224
|
)
|
|
224
225
|
if adata.shape[0] < self.min_dataset_size:
|
|
225
226
|
raise Exception(
|
|
@@ -232,60 +233,39 @@ class Preprocessor:
|
|
|
232
233
|
)
|
|
233
234
|
)
|
|
234
235
|
|
|
235
|
-
#
|
|
236
|
-
|
|
237
|
-
all_ens = adata.var.index.str.match(r"ENS.*\d{6,}$").all()
|
|
238
|
-
|
|
239
|
-
if not has_ens:
|
|
240
|
-
print("No ENS genes found, assuming gene symbols...")
|
|
241
|
-
elif not all_ens:
|
|
242
|
-
print("Mix of ENS and gene symbols found, converting all to ENS IDs...")
|
|
243
|
-
|
|
236
|
+
# load the genes
|
|
237
|
+
genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
|
|
244
238
|
genesdf["ensembl_gene_id"] = genesdf.index
|
|
245
239
|
|
|
246
240
|
# For genes that are already ENS IDs, use them directly
|
|
247
|
-
|
|
248
|
-
symbol_mask = ~ens_mask
|
|
249
|
-
|
|
241
|
+
prev_size = adata.shape[1]
|
|
250
242
|
# Handle symbol genes
|
|
251
|
-
if
|
|
252
|
-
|
|
243
|
+
if self.is_symbol:
|
|
244
|
+
new_var = adata.var.merge(
|
|
253
245
|
genesdf.drop_duplicates("symbol").set_index("symbol", drop=False),
|
|
254
246
|
left_index=True,
|
|
255
247
|
right_index=True,
|
|
256
248
|
how="inner",
|
|
257
249
|
)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
250
|
+
new_var["symbol"] = new_var.index
|
|
251
|
+
adata = adata[:, new_var.index]
|
|
252
|
+
new_var.index = new_var["ensembl_gene_id"]
|
|
253
|
+
else:
|
|
254
|
+
new_var = adata.var.merge(
|
|
262
255
|
genesdf, left_index=True, right_index=True, how="inner"
|
|
263
256
|
)
|
|
257
|
+
adata = adata[:, new_var.index]
|
|
258
|
+
print(f"Removed {prev_size - adata.shape[1]} genes not known to the ontology")
|
|
259
|
+
prev_size = adata.shape[1]
|
|
264
260
|
|
|
265
|
-
|
|
266
|
-
if symbol_mask.any() and ens_mask.any():
|
|
267
|
-
var = pd.concat([symbol_var, ens_var])
|
|
268
|
-
elif symbol_mask.any():
|
|
269
|
-
var = symbol_var
|
|
270
|
-
else:
|
|
271
|
-
var = ens_var
|
|
272
|
-
|
|
273
|
-
adata = adata[:, var.index]
|
|
274
|
-
# var = var.sort_values(by="ensembl_gene_id").set_index("ensembl_gene_id")
|
|
275
|
-
# Update adata with combined genes
|
|
276
|
-
if "ensembl_gene_id" in var.columns:
|
|
277
|
-
adata.var = var.set_index("ensembl_gene_id")
|
|
278
|
-
else:
|
|
279
|
-
adata.var = var
|
|
261
|
+
adata.var = new_var
|
|
280
262
|
# Drop duplicate genes, keeping first occurrence
|
|
281
263
|
adata = adata[:, ~adata.var.index.duplicated(keep="first")]
|
|
264
|
+
print(f"Removed {prev_size - adata.shape[1]} duplicate genes")
|
|
282
265
|
|
|
283
|
-
|
|
284
|
-
print(f"Removed {len(adata.var.index) - len(intersect_genes)} genes.")
|
|
285
|
-
if len(intersect_genes) < self.min_valid_genes_id:
|
|
266
|
+
if adata.shape[1] < self.min_valid_genes_id:
|
|
286
267
|
raise Exception("Dataset dropped due to too many genes not mapping to it")
|
|
287
|
-
|
|
288
|
-
# marking unseen genes
|
|
268
|
+
|
|
289
269
|
unseen = set(genesdf.index) - set(adata.var.index)
|
|
290
270
|
# adding them to adata
|
|
291
271
|
emptyda = ad.AnnData(
|
|
@@ -293,6 +273,9 @@ class Preprocessor:
|
|
|
293
273
|
var=pd.DataFrame(index=list(unseen)),
|
|
294
274
|
obs=pd.DataFrame(index=adata.obs.index),
|
|
295
275
|
)
|
|
276
|
+
print(
|
|
277
|
+
f"Added {len(unseen)} genes in the ontology but not present in the dataset"
|
|
278
|
+
)
|
|
296
279
|
adata = ad.concat([adata, emptyda], axis=1, join="outer", merge="only")
|
|
297
280
|
# do a validation function
|
|
298
281
|
adata.uns["unseen_genes"] = list(unseen)
|
|
@@ -330,7 +313,7 @@ class Preprocessor:
|
|
|
330
313
|
# QC
|
|
331
314
|
|
|
332
315
|
adata.var[genesdf.columns] = genesdf.loc[adata.var.index]
|
|
333
|
-
print("
|
|
316
|
+
print("starting QC")
|
|
334
317
|
sc.pp.calculate_qc_metrics(
|
|
335
318
|
adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20]
|
|
336
319
|
)
|
|
@@ -348,7 +331,7 @@ class Preprocessor:
|
|
|
348
331
|
)
|
|
349
332
|
total_outliers = (adata.obs["outlier"] | adata.obs["mt_outlier"]).sum()
|
|
350
333
|
total_cells = adata.shape[0]
|
|
351
|
-
percentage_outliers = (total_outliers / total_cells) * 100
|
|
334
|
+
percentage_outliers = (total_outliers / (total_cells + 1)) * 100
|
|
352
335
|
print(
|
|
353
336
|
f"Seeing {total_outliers} outliers ({percentage_outliers:.2f}% of total dataset):"
|
|
354
337
|
)
|
|
@@ -395,7 +378,7 @@ class Preprocessor:
|
|
|
395
378
|
subset=False,
|
|
396
379
|
layer="norm",
|
|
397
380
|
)
|
|
398
|
-
|
|
381
|
+
print("starting PCA")
|
|
399
382
|
adata.obsm["X_pca"] = sc.pp.pca(
|
|
400
383
|
adata.layers["norm"][:, adata.var.highly_variable]
|
|
401
384
|
if "highly_variable" in adata.var.columns
|
|
@@ -464,13 +447,13 @@ class LaminPreprocessor(Preprocessor):
|
|
|
464
447
|
*args,
|
|
465
448
|
cache: bool = True,
|
|
466
449
|
keep_files: bool = True,
|
|
467
|
-
|
|
450
|
+
force_lamin_cache: bool = False,
|
|
468
451
|
**kwargs,
|
|
469
452
|
):
|
|
470
453
|
super().__init__(*args, **kwargs)
|
|
471
454
|
self.cache = cache
|
|
472
455
|
self.keep_files = keep_files
|
|
473
|
-
self.
|
|
456
|
+
self.force_lamin_cache = force_lamin_cache
|
|
474
457
|
|
|
475
458
|
def __call__(
|
|
476
459
|
self,
|
|
@@ -505,10 +488,13 @@ class LaminPreprocessor(Preprocessor):
|
|
|
505
488
|
print(f"{file.stem_uid} is already processed... not preprocessing")
|
|
506
489
|
continue
|
|
507
490
|
print(file)
|
|
491
|
+
if self.force_lamin_cache:
|
|
492
|
+
path = cache_path(file)
|
|
493
|
+
backed = read_h5ad(path, backed="r")
|
|
494
|
+
else:
|
|
495
|
+
# file.cache()
|
|
496
|
+
backed = file.open()
|
|
508
497
|
|
|
509
|
-
_ = cache_path(file) if self.force_preloaded else file.cache()
|
|
510
|
-
backed = file.open()
|
|
511
|
-
# backed = read_h5ad(path, backed="r")
|
|
512
498
|
if "is_primary_data" in backed.obs.columns:
|
|
513
499
|
if backed.obs.is_primary_data.sum() == 0:
|
|
514
500
|
print(f"{file.key} only contains non primary cells.. dropping")
|
|
@@ -556,37 +542,52 @@ class LaminPreprocessor(Preprocessor):
|
|
|
556
542
|
block,
|
|
557
543
|
dataset_id=file.stem_uid + "_p" + str(j),
|
|
558
544
|
)
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
545
|
+
saved = False
|
|
546
|
+
while not saved:
|
|
547
|
+
try:
|
|
548
|
+
myfile = ln.Artifact.from_anndata(
|
|
549
|
+
block,
|
|
550
|
+
description=description
|
|
551
|
+
+ " n"
|
|
552
|
+
+ str(i)
|
|
553
|
+
+ " p"
|
|
554
|
+
+ str(j)
|
|
555
|
+
+ " ( revises file "
|
|
556
|
+
+ str(file.stem_uid)
|
|
557
|
+
+ " )",
|
|
558
|
+
version=version,
|
|
559
|
+
)
|
|
560
|
+
myfile.save()
|
|
561
|
+
saved = True
|
|
562
|
+
except OperationalError:
|
|
563
|
+
print(
|
|
564
|
+
"Database locked, waiting 30 seconds and retrying..."
|
|
565
|
+
)
|
|
566
|
+
time.sleep(10)
|
|
573
567
|
if self.keep_files:
|
|
574
568
|
files.append(myfile)
|
|
575
569
|
del block
|
|
576
570
|
else:
|
|
577
571
|
del myfile
|
|
578
572
|
del block
|
|
579
|
-
gc.collect()
|
|
580
|
-
|
|
581
573
|
else:
|
|
582
574
|
adata = super().__call__(adata, dataset_id=file.stem_uid)
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
575
|
+
saved = False
|
|
576
|
+
while not saved:
|
|
577
|
+
try:
|
|
578
|
+
myfile = ln.Artifact.from_anndata(
|
|
579
|
+
adata,
|
|
580
|
+
revises=file,
|
|
581
|
+
description=description + " p" + str(i),
|
|
582
|
+
version=version,
|
|
583
|
+
)
|
|
584
|
+
myfile.save()
|
|
585
|
+
saved = True
|
|
586
|
+
except OperationalError:
|
|
587
|
+
print(
|
|
588
|
+
"Database locked, waiting 10 seconds and retrying..."
|
|
589
|
+
)
|
|
590
|
+
time.sleep(10)
|
|
590
591
|
if self.keep_files:
|
|
591
592
|
files.append(myfile)
|
|
592
593
|
del adata
|
|
@@ -606,7 +607,7 @@ class LaminPreprocessor(Preprocessor):
|
|
|
606
607
|
continue
|
|
607
608
|
else:
|
|
608
609
|
raise e
|
|
609
|
-
|
|
610
|
+
gc.collect()
|
|
610
611
|
# issues with KLlggfw6I6lvmbqiZm46
|
|
611
612
|
if self.keep_files:
|
|
612
613
|
# Reconstruct collection using keys
|
|
@@ -716,7 +717,7 @@ def additional_preprocess(adata):
|
|
|
716
717
|
}
|
|
717
718
|
}
|
|
718
719
|
) # multi ethnic will have to get renamed
|
|
719
|
-
adata.obs["cell_culture"] = False
|
|
720
|
+
adata.obs["cell_culture"] = "False"
|
|
720
721
|
# if cell_type contains the word "(cell culture)" then it is a cell culture and we mark it as so and remove this from the cell type
|
|
721
722
|
loc = adata.obs["cell_type_ontology_term_id"].str.contains(
|
|
722
723
|
"(cell culture)", regex=False
|
|
@@ -725,7 +726,7 @@ def additional_preprocess(adata):
|
|
|
725
726
|
adata.obs["cell_type_ontology_term_id"] = adata.obs[
|
|
726
727
|
"cell_type_ontology_term_id"
|
|
727
728
|
].astype(str)
|
|
728
|
-
adata.obs.loc[loc, "cell_culture"] = True
|
|
729
|
+
adata.obs.loc[loc, "cell_culture"] = "True"
|
|
729
730
|
adata.obs.loc[loc, "cell_type_ontology_term_id"] = adata.obs.loc[
|
|
730
731
|
loc, "cell_type_ontology_term_id"
|
|
731
732
|
].str.replace(" (cell culture)", "")
|
|
@@ -734,7 +735,7 @@ def additional_preprocess(adata):
|
|
|
734
735
|
"(cell culture)", regex=False
|
|
735
736
|
)
|
|
736
737
|
if loc.sum() > 0:
|
|
737
|
-
adata.obs.loc[loc, "cell_culture"] = True
|
|
738
|
+
adata.obs.loc[loc, "cell_culture"] = "True"
|
|
738
739
|
adata.obs["tissue_ontology_term_id"] = adata.obs[
|
|
739
740
|
"tissue_ontology_term_id"
|
|
740
741
|
].astype(str)
|
|
@@ -744,7 +745,7 @@ def additional_preprocess(adata):
|
|
|
744
745
|
|
|
745
746
|
loc = adata.obs["tissue_ontology_term_id"].str.contains("(organoid)", regex=False)
|
|
746
747
|
if loc.sum() > 0:
|
|
747
|
-
adata.obs.loc[loc, "cell_culture"] = True
|
|
748
|
+
adata.obs.loc[loc, "cell_culture"] = "True"
|
|
748
749
|
adata.obs["tissue_ontology_term_id"] = adata.obs[
|
|
749
750
|
"tissue_ontology_term_id"
|
|
750
751
|
].astype(str)
|
|
@@ -773,6 +774,7 @@ def additional_postprocess(adata):
|
|
|
773
774
|
# sc.external.pp.harmony_integrate(adata, key="batches")
|
|
774
775
|
# sc.pp.neighbors(adata, use_rep="X_pca_harmony")
|
|
775
776
|
# else:
|
|
777
|
+
print("starting post processing")
|
|
776
778
|
sc.pp.neighbors(adata, use_rep="X_pca")
|
|
777
779
|
sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
|
|
778
780
|
sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
|
|
@@ -791,8 +793,12 @@ def additional_postprocess(adata):
|
|
|
791
793
|
MAXSIM = 0.94
|
|
792
794
|
from collections import Counter
|
|
793
795
|
|
|
796
|
+
import bionty as bt
|
|
797
|
+
|
|
794
798
|
from .config import MAIN_HUMAN_MOUSE_DEV_STAGE_MAP
|
|
795
799
|
|
|
800
|
+
remap_stages = {u: k for k, v in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for u in v}
|
|
801
|
+
|
|
796
802
|
adata.obs[NEWOBS] = (
|
|
797
803
|
adata.obs[COL].astype(str) + "_" + adata.obs["leiden_1"].astype(str)
|
|
798
804
|
)
|
|
@@ -860,18 +866,17 @@ def additional_postprocess(adata):
|
|
|
860
866
|
num += 1
|
|
861
867
|
adata.obs[NEWOBS] = adata.obs[NEWOBS].map(merge_mapping).fillna(adata.obs[NEWOBS])
|
|
862
868
|
|
|
863
|
-
import bionty as bt
|
|
864
|
-
|
|
865
869
|
stages = adata.obs["development_stage_ontology_term_id"].unique()
|
|
866
870
|
if adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:9606"]:
|
|
867
871
|
relabel = {i: i for i in stages}
|
|
868
872
|
for stage in stages:
|
|
873
|
+
if stage in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.keys():
|
|
874
|
+
continue
|
|
869
875
|
stage_obj = bt.DevelopmentalStage.filter(ontology_id=stage).first()
|
|
870
876
|
parents = set([i.ontology_id for i in stage_obj.parents.filter()])
|
|
871
877
|
parents = parents - set(
|
|
872
878
|
[
|
|
873
879
|
"HsapDv:0010000",
|
|
874
|
-
"HsapDv:0000204",
|
|
875
880
|
"HsapDv:0000227",
|
|
876
881
|
]
|
|
877
882
|
)
|
|
@@ -879,9 +884,14 @@ def additional_postprocess(adata):
|
|
|
879
884
|
for p in parents:
|
|
880
885
|
if p in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP:
|
|
881
886
|
relabel[stage] = p
|
|
882
|
-
adata.obs["
|
|
883
|
-
|
|
884
|
-
|
|
887
|
+
adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
|
|
888
|
+
relabel
|
|
889
|
+
)
|
|
890
|
+
for stage in adata.obs["age_group"].unique():
|
|
891
|
+
if stage in remap_stages.keys():
|
|
892
|
+
adata.obs["age_group"] = adata.obs["age_group"].map(
|
|
893
|
+
lambda x: remap_stages[x] if x == stage else x
|
|
894
|
+
)
|
|
885
895
|
elif adata.obs.organism_ontology_term_id.unique() == ["NCBITaxon:10090"]:
|
|
886
896
|
rename_mapping = {
|
|
887
897
|
k: v for v, j in MAIN_HUMAN_MOUSE_DEV_STAGE_MAP.items() for k in j
|
|
@@ -890,11 +900,12 @@ def additional_postprocess(adata):
|
|
|
890
900
|
for stage in stages:
|
|
891
901
|
if stage in rename_mapping:
|
|
892
902
|
relabel[stage] = rename_mapping[stage]
|
|
893
|
-
adata.obs["
|
|
894
|
-
|
|
895
|
-
|
|
903
|
+
adata.obs["age_group"] = adata.obs["development_stage_ontology_term_id"].map(
|
|
904
|
+
relabel
|
|
905
|
+
)
|
|
896
906
|
else:
|
|
897
|
-
raise ValueError("organism not supported")
|
|
907
|
+
# raise ValueError("organism not supported")
|
|
908
|
+
print("organism not supported for age labels")
|
|
898
909
|
# palantir.utils.run_diffusion_maps(adata, n_components=20)
|
|
899
910
|
# palantir.utils.determine_multiscale_space(adata)
|
|
900
911
|
# terminal_states = palantir.utils.find_terminal_states(
|