scdataloader 1.1.3__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdataloader/VERSION +1 -1
- scdataloader/__init__.py +1 -1
- scdataloader/__main__.py +16 -7
- scdataloader/collator.py +4 -2
- scdataloader/data.py +41 -17
- scdataloader/datamodule.py +13 -13
- scdataloader/preprocess.py +71 -56
- scdataloader/utils.py +87 -61
- scdataloader-1.2.2.dist-info/METADATA +299 -0
- scdataloader-1.2.2.dist-info/RECORD +14 -0
- {scdataloader-1.1.3.dist-info → scdataloader-1.2.2.dist-info}/WHEEL +1 -1
- scdataloader/mapped.py +0 -540
- scdataloader-1.1.3.dist-info/METADATA +0 -899
- scdataloader-1.1.3.dist-info/RECORD +0 -16
- scdataloader-1.1.3.dist-info/entry_points.txt +0 -3
- {scdataloader-1.1.3.dist-info → scdataloader-1.2.2.dist-info/licenses}/LICENSE +0 -0
scdataloader/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
1.
|
|
1
|
+
1.2.2
|
scdataloader/__init__.py
CHANGED
scdataloader/__main__.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
from typing import Optional, Union
|
|
3
|
+
|
|
4
|
+
import lamindb as ln
|
|
5
|
+
|
|
2
6
|
from scdataloader.preprocess import (
|
|
3
7
|
LaminPreprocessor,
|
|
4
|
-
additional_preprocess,
|
|
5
8
|
additional_postprocess,
|
|
9
|
+
additional_preprocess,
|
|
6
10
|
)
|
|
7
|
-
import lamindb as ln
|
|
8
|
-
from typing import Optional, Union
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
# scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
|
|
@@ -51,14 +53,14 @@ def main():
|
|
|
51
53
|
)
|
|
52
54
|
parser.add_argument(
|
|
53
55
|
"--filter_gene_by_counts",
|
|
54
|
-
type=
|
|
55
|
-
default=
|
|
56
|
+
type=int,
|
|
57
|
+
default=0,
|
|
56
58
|
help="Determines whether to filter genes by counts.",
|
|
57
59
|
)
|
|
58
60
|
parser.add_argument(
|
|
59
61
|
"--filter_cell_by_counts",
|
|
60
|
-
type=
|
|
61
|
-
default=
|
|
62
|
+
type=int,
|
|
63
|
+
default=0,
|
|
62
64
|
help="Determines whether to filter cells by counts.",
|
|
63
65
|
)
|
|
64
66
|
parser.add_argument(
|
|
@@ -151,6 +153,12 @@ def main():
|
|
|
151
153
|
default=False,
|
|
152
154
|
help="Determines whether to do postprocessing.",
|
|
153
155
|
)
|
|
156
|
+
parser.add_argument(
|
|
157
|
+
"--cache",
|
|
158
|
+
type=bool,
|
|
159
|
+
default=True,
|
|
160
|
+
help="Determines whether to cache the dataset.",
|
|
161
|
+
)
|
|
154
162
|
args = parser.parse_args()
|
|
155
163
|
|
|
156
164
|
# Load the collection
|
|
@@ -176,6 +184,7 @@ def main():
|
|
|
176
184
|
normalize_sum=args.normalize_sum,
|
|
177
185
|
subset_hvg=args.subset_hvg,
|
|
178
186
|
hvg_flavor=args.hvg_flavor,
|
|
187
|
+
cache=args.cache,
|
|
179
188
|
binning=args.binning,
|
|
180
189
|
result_binned_key=args.result_binned_key,
|
|
181
190
|
length_normalize=args.length_normalize,
|
scdataloader/collator.py
CHANGED
scdataloader/data.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections import Counter
|
|
1
3
|
from dataclasses import dataclass, field
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
+
from functools import reduce
|
|
5
|
+
from typing import Literal, Optional, Union
|
|
4
6
|
|
|
5
7
|
# ln.connect("scprint")
|
|
6
|
-
|
|
7
8
|
import bionty as bt
|
|
9
|
+
import lamindb as ln
|
|
10
|
+
import numpy as np
|
|
8
11
|
import pandas as pd
|
|
9
|
-
from torch.utils.data import Dataset as torchDataset
|
|
10
|
-
from typing import Union, Optional, Literal
|
|
11
|
-
from scdataloader.mapped import MappedCollection
|
|
12
|
-
import warnings
|
|
13
|
-
|
|
14
12
|
from anndata import AnnData
|
|
13
|
+
from lamindb.core import MappedCollection
|
|
14
|
+
from lamindb.core._mapped_collection import _Connect
|
|
15
|
+
from lamindb.core.storage._anndata_accessor import _safer_read_index
|
|
15
16
|
from scipy.sparse import issparse
|
|
17
|
+
from torch.utils.data import Dataset as torchDataset
|
|
16
18
|
|
|
17
19
|
from scdataloader.utils import get_ancestry_mapping, load_genes
|
|
18
20
|
|
|
@@ -110,7 +112,16 @@ class Dataset(torchDataset):
|
|
|
110
112
|
self.genedf = load_genes(self.organisms)
|
|
111
113
|
|
|
112
114
|
self.genedf.columns = self.genedf.columns.astype(str)
|
|
113
|
-
self.
|
|
115
|
+
self.check_aligned_vars()
|
|
116
|
+
|
|
117
|
+
def check_aligned_vars(self):
|
|
118
|
+
vars = self.genedf.index.tolist()
|
|
119
|
+
i = 0
|
|
120
|
+
for storage in self.mapped_dataset.storages:
|
|
121
|
+
with _Connect(storage) as store:
|
|
122
|
+
if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
|
|
123
|
+
i += 1
|
|
124
|
+
print("{}% are aligned".format(i * 100 / len(self.mapped_dataset.storages)))
|
|
114
125
|
|
|
115
126
|
def __len__(self, **kwargs):
|
|
116
127
|
return self.mapped_dataset.__len__(**kwargs)
|
|
@@ -145,14 +156,27 @@ class Dataset(torchDataset):
|
|
|
145
156
|
)
|
|
146
157
|
)
|
|
147
158
|
|
|
148
|
-
def get_label_weights(self,
|
|
149
|
-
"""
|
|
150
|
-
|
|
159
|
+
def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
|
|
160
|
+
"""Get all weights for the given label keys."""
|
|
161
|
+
if isinstance(obs_keys, str):
|
|
162
|
+
obs_keys = [obs_keys]
|
|
163
|
+
labels_list = []
|
|
164
|
+
for label_key in obs_keys:
|
|
165
|
+
labels_to_str = (
|
|
166
|
+
self.mapped_dataset.get_merged_labels(label_key).astype(str).astype("O")
|
|
167
|
+
)
|
|
168
|
+
labels_list.append(labels_to_str)
|
|
169
|
+
if len(labels_list) > 1:
|
|
170
|
+
labels = reduce(lambda a, b: a + b, labels_list)
|
|
171
|
+
else:
|
|
172
|
+
labels = labels_list[0]
|
|
151
173
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
174
|
+
counter = Counter(labels) # type: ignore
|
|
175
|
+
rn = {n: i for i, n in enumerate(counter.keys())}
|
|
176
|
+
labels = np.array([rn[label] for label in labels])
|
|
177
|
+
counter = np.array(list(counter.values()))
|
|
178
|
+
weights = scaler / (counter + scaler)
|
|
179
|
+
return weights, labels
|
|
156
180
|
|
|
157
181
|
def get_unseen_mapped_dataset_elements(self, idx: int):
|
|
158
182
|
"""
|
|
@@ -236,7 +260,7 @@ class Dataset(torchDataset):
|
|
|
236
260
|
clss
|
|
237
261
|
)
|
|
238
262
|
)
|
|
239
|
-
cats = self.mapped_dataset.get_merged_categories(clss)
|
|
263
|
+
cats = set(self.mapped_dataset.get_merged_categories(clss))
|
|
240
264
|
addition = set(LABELS_TOADD.get(clss, {}).values())
|
|
241
265
|
cats |= addition
|
|
242
266
|
groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
|
scdataloader/datamodule.py
CHANGED
|
@@ -1,21 +1,20 @@
|
|
|
1
|
+
from typing import Optional, Sequence, Union
|
|
2
|
+
|
|
3
|
+
import lamindb as ln
|
|
4
|
+
import lightning as L
|
|
1
5
|
import numpy as np
|
|
2
6
|
import pandas as pd
|
|
3
|
-
import
|
|
4
|
-
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import DataLoader, Sampler
|
|
5
9
|
from torch.utils.data.sampler import (
|
|
6
|
-
WeightedRandomSampler,
|
|
7
|
-
SubsetRandomSampler,
|
|
8
|
-
SequentialSampler,
|
|
9
10
|
RandomSampler,
|
|
11
|
+
SequentialSampler,
|
|
12
|
+
SubsetRandomSampler,
|
|
13
|
+
WeightedRandomSampler,
|
|
10
14
|
)
|
|
11
|
-
import torch
|
|
12
|
-
from torch.utils.data import DataLoader, Sampler
|
|
13
|
-
import lightning as L
|
|
14
|
-
|
|
15
|
-
from typing import Optional, Union, Sequence
|
|
16
15
|
|
|
17
|
-
from .data import Dataset
|
|
18
16
|
from .collator import Collator
|
|
17
|
+
from .data import Dataset
|
|
19
18
|
from .utils import getBiomartTable
|
|
20
19
|
|
|
21
20
|
|
|
@@ -110,7 +109,8 @@ class DataModule(L.LightningDataModule):
|
|
|
110
109
|
"need to provide your own table as this automated function only works for humans for now"
|
|
111
110
|
)
|
|
112
111
|
biomart = getBiomartTable(
|
|
113
|
-
attributes=["start_position", "chromosome_name"]
|
|
112
|
+
attributes=["start_position", "chromosome_name"],
|
|
113
|
+
useCache=True,
|
|
114
114
|
).set_index("ensembl_gene_id")
|
|
115
115
|
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
|
|
116
116
|
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
|
|
@@ -129,7 +129,7 @@ class DataModule(L.LightningDataModule):
|
|
|
129
129
|
prev_chromosome = r["chromosome_name"]
|
|
130
130
|
print(f"reduced the size to {len(set(c))/len(biomart)}")
|
|
131
131
|
biomart["pos"] = c
|
|
132
|
-
mdataset.genedf =
|
|
132
|
+
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
|
|
133
133
|
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
|
|
134
134
|
|
|
135
135
|
if gene_embeddings != "":
|
scdataloader/preprocess.py
CHANGED
|
@@ -177,11 +177,18 @@ class Preprocessor:
|
|
|
177
177
|
# # cleanup and dropping low expressed genes and unexpressed cells
|
|
178
178
|
prevsize = adata.shape[0]
|
|
179
179
|
adata.obs["nnz"] = np.array(np.sum(adata.X != 0, axis=1).flatten())[0]
|
|
180
|
-
adata = adata[(adata.obs["nnz"] > self.min_nnz_genes)]
|
|
181
180
|
if self.filter_gene_by_counts:
|
|
182
181
|
sc.pp.filter_genes(adata, min_counts=self.filter_gene_by_counts)
|
|
183
182
|
if self.filter_cell_by_counts:
|
|
184
|
-
sc.pp.filter_cells(
|
|
183
|
+
sc.pp.filter_cells(
|
|
184
|
+
adata,
|
|
185
|
+
min_counts=self.filter_cell_by_counts,
|
|
186
|
+
)
|
|
187
|
+
if self.min_nnz_genes:
|
|
188
|
+
sc.pp.filter_cells(
|
|
189
|
+
adata,
|
|
190
|
+
min_genes=self.min_nnz_genes,
|
|
191
|
+
)
|
|
185
192
|
# if lost > 50% of the dataset, drop dataset
|
|
186
193
|
# load the genes
|
|
187
194
|
genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
|
|
@@ -297,7 +304,7 @@ class Preprocessor:
|
|
|
297
304
|
# https://rapids-singlecell.readthedocs.io/en/latest/api/generated/rapids_singlecell.pp.pca.html#rapids_singlecell.pp.pca
|
|
298
305
|
if self.do_postp:
|
|
299
306
|
print("normalize")
|
|
300
|
-
adata.layers["
|
|
307
|
+
adata.layers["norm"] = sc.pp.log1p(
|
|
301
308
|
sc.pp.normalize_total(
|
|
302
309
|
adata, target_sum=self.normalize_sum, inplace=False
|
|
303
310
|
)["X"]
|
|
@@ -306,20 +313,34 @@ class Preprocessor:
|
|
|
306
313
|
if self.subset_hvg:
|
|
307
314
|
sc.pp.highly_variable_genes(
|
|
308
315
|
adata,
|
|
309
|
-
layer="clean",
|
|
310
316
|
n_top_genes=self.subset_hvg,
|
|
311
317
|
batch_key=self.batch_key,
|
|
312
318
|
flavor=self.hvg_flavor,
|
|
313
319
|
subset=False,
|
|
314
320
|
)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
321
|
+
sc.pp.log1p(adata, layer="norm")
|
|
322
|
+
sc.pp.pca(
|
|
323
|
+
adata,
|
|
324
|
+
layer="norm",
|
|
325
|
+
n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
|
|
318
326
|
)
|
|
319
|
-
sc.pp.neighbors(adata, use_rep="
|
|
320
|
-
sc.tl.leiden(adata, key_added="leiden_3", resolution=3.0)
|
|
327
|
+
sc.pp.neighbors(adata, use_rep="X_pca")
|
|
321
328
|
sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
|
|
322
329
|
sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
|
|
330
|
+
sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
|
|
331
|
+
batches = [
|
|
332
|
+
"assay_ontology_term_id",
|
|
333
|
+
"self_reported_ethnicity_ontology_term_id",
|
|
334
|
+
"sex_ontology_term_id",
|
|
335
|
+
"development_stage_ontology_term_id",
|
|
336
|
+
]
|
|
337
|
+
if "donor_id" in adata.obs.columns:
|
|
338
|
+
batches.append("donor_id")
|
|
339
|
+
if "suspension_type" in adata.obs.columns:
|
|
340
|
+
batches.append("suspension_type")
|
|
341
|
+
adata.obs["batches"] = adata.obs[batches].apply(
|
|
342
|
+
lambda x: ",".join(x.dropna().astype(str)), axis=1
|
|
343
|
+
)
|
|
323
344
|
sc.tl.umap(adata)
|
|
324
345
|
# additional
|
|
325
346
|
if self.additional_postprocess is not None:
|
|
@@ -379,14 +400,12 @@ class LaminPreprocessor(Preprocessor):
|
|
|
379
400
|
def __init__(
|
|
380
401
|
self,
|
|
381
402
|
*args,
|
|
382
|
-
erase_prev_dataset: bool = False,
|
|
383
403
|
cache: bool = True,
|
|
384
404
|
stream: bool = False,
|
|
385
405
|
keep_files: bool = True,
|
|
386
406
|
**kwargs,
|
|
387
407
|
):
|
|
388
408
|
super().__init__(*args, **kwargs)
|
|
389
|
-
self.erase_prev_dataset = erase_prev_dataset
|
|
390
409
|
self.cache = cache
|
|
391
410
|
self.stream = stream
|
|
392
411
|
self.keep_files = keep_files
|
|
@@ -418,14 +437,17 @@ class LaminPreprocessor(Preprocessor):
|
|
|
418
437
|
elif isinstance(data, ln.Collection):
|
|
419
438
|
for i, file in enumerate(data.artifacts.all()[start_at:]):
|
|
420
439
|
# use the counts matrix
|
|
421
|
-
print(i)
|
|
440
|
+
print(i + start_at)
|
|
422
441
|
if file.stem_uid in all_ready_processed_keys:
|
|
423
442
|
print(f"{file.stem_uid} is already processed... not preprocessing")
|
|
424
443
|
continue
|
|
425
444
|
print(file)
|
|
426
|
-
backed = file.
|
|
445
|
+
backed = file.open()
|
|
427
446
|
if backed.obs.is_primary_data.sum() == 0:
|
|
428
447
|
print(f"{file.key} only contains non primary cells.. dropping")
|
|
448
|
+
# Save the stem_uid to a file to avoid loading it again
|
|
449
|
+
with open("nonprimary.txt", "a") as f:
|
|
450
|
+
f.write(f"{file.stem_uid}\n")
|
|
429
451
|
continue
|
|
430
452
|
if backed.shape[1] < 1000:
|
|
431
453
|
print(
|
|
@@ -449,17 +471,17 @@ class LaminPreprocessor(Preprocessor):
|
|
|
449
471
|
(np.ceil(badata.shape[0] / 30_000) * 30_000) // num_blocks
|
|
450
472
|
)
|
|
451
473
|
print("num blocks ", num_blocks)
|
|
452
|
-
for
|
|
453
|
-
start_index =
|
|
454
|
-
end_index = min((
|
|
474
|
+
for j in range(num_blocks):
|
|
475
|
+
start_index = j * block_size
|
|
476
|
+
end_index = min((j + 1) * block_size, badata.shape[0])
|
|
455
477
|
block = badata[start_index:end_index].to_memory()
|
|
456
478
|
print(block)
|
|
457
479
|
block = super().__call__(block)
|
|
458
|
-
myfile = ln.
|
|
480
|
+
myfile = ln.from_anndata(
|
|
459
481
|
block,
|
|
460
|
-
|
|
482
|
+
revises=file,
|
|
461
483
|
description=description,
|
|
462
|
-
version=str(version) + "_s" + str(
|
|
484
|
+
version=str(version) + "_s" + str(j),
|
|
463
485
|
)
|
|
464
486
|
myfile.save()
|
|
465
487
|
if self.keep_files:
|
|
@@ -470,9 +492,13 @@ class LaminPreprocessor(Preprocessor):
|
|
|
470
492
|
|
|
471
493
|
else:
|
|
472
494
|
adata = super().__call__(adata)
|
|
473
|
-
|
|
495
|
+
try:
|
|
496
|
+
sc.pl.umap(adata, color=["cell_type"])
|
|
497
|
+
except Exception:
|
|
498
|
+
sc.pl.umap(adata, color=["cell_type_ontology_term_id"])
|
|
499
|
+
myfile = ln.from_anndata(
|
|
474
500
|
adata,
|
|
475
|
-
|
|
501
|
+
revises=file,
|
|
476
502
|
description=description,
|
|
477
503
|
version=str(version),
|
|
478
504
|
)
|
|
@@ -646,46 +672,35 @@ def additional_preprocess(adata):
|
|
|
646
672
|
|
|
647
673
|
|
|
648
674
|
def additional_postprocess(adata):
|
|
675
|
+
import palantir
|
|
676
|
+
|
|
649
677
|
# define the "up to" 10 neighbors for each cells and add to obs
|
|
650
678
|
# compute neighbors
|
|
651
679
|
# need to be connectivities and same labels [cell type, assay, dataset, disease]
|
|
652
680
|
# define the "neighbor" up to 10(N) cells and add to obs
|
|
653
681
|
# define the "next time point" up to 5(M) cells and add to obs # step 1: filter genes
|
|
682
|
+
del adata.obsp["connectivities"]
|
|
683
|
+
del adata.obsp["distances"]
|
|
684
|
+
sc.external.pp.harmony_integrate(adata, key="batches")
|
|
685
|
+
sc.pp.neighbors(adata, use_rep="X_pca_harmony")
|
|
686
|
+
sc.tl.umap(adata)
|
|
687
|
+
sc.pl.umap(
|
|
688
|
+
adata,
|
|
689
|
+
color=["cell_type", "batches"],
|
|
690
|
+
)
|
|
691
|
+
palantir.utils.run_diffusion_maps(adata, n_components=20)
|
|
692
|
+
palantir.utils.determine_multiscale_space(adata)
|
|
693
|
+
terminal_states = palantir.utils.find_terminal_states(
|
|
694
|
+
adata,
|
|
695
|
+
celltypes=adata.obs.cell_type_ontology_term_id.unique(),
|
|
696
|
+
celltype_column="cell_type_ontology_term_id",
|
|
697
|
+
)
|
|
654
698
|
sc.tl.diffmap(adata)
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
adata.
|
|
658
|
-
+ "_"
|
|
659
|
-
+ adata.obs["disease_ontology_term_id"].astype(str)
|
|
660
|
-
+ "_"
|
|
661
|
-
+ adata.obs["cell_type_ontology_term_id"].astype(str)
|
|
662
|
-
+ "_"
|
|
663
|
-
+ adata.obs["tissue_ontology_term_id"].astype(str)
|
|
664
|
-
) # + "_" + adata.obs['dataset_id'].astype(str)
|
|
665
|
-
|
|
666
|
-
# if group is too small
|
|
667
|
-
okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j >= 10]
|
|
668
|
-
not_okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j < 3]
|
|
669
|
-
# set the group to empty
|
|
670
|
-
adata.obs.loc[adata.obs["dpt_group"].isin(not_okgroup), "dpt_group"] = ""
|
|
671
|
-
adata.obs["heat_diff"] = np.nan
|
|
672
|
-
# for each group
|
|
673
|
-
for val in set(okgroup):
|
|
674
|
-
if val == "":
|
|
675
|
-
continue
|
|
676
|
-
# get the best root cell
|
|
677
|
-
eq = adata.obs.dpt_group == val
|
|
678
|
-
loc = np.where(eq)[0]
|
|
679
|
-
|
|
680
|
-
root_ixs = loc[adata.obsm["X_diffmap"][eq, 0].argmin()]
|
|
681
|
-
adata.uns["iroot"] = root_ixs
|
|
682
|
-
# compute the diffusion pseudo time from it
|
|
699
|
+
adata.obs["heat_diff"] = 1
|
|
700
|
+
for terminal_state in terminal_states.index.tolist():
|
|
701
|
+
adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
|
|
683
702
|
sc.tl.dpt(adata)
|
|
684
|
-
adata.obs
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
# sort so that the next time points are aligned for all groups
|
|
688
|
-
adata = adata[adata.obs.sort_values(["dpt_group", "heat_diff"]).index]
|
|
689
|
-
# to query N next time points we just get the N elements below and check they are in the group
|
|
690
|
-
# to query the N nearest neighbors we just get the N elements above and N below and check they are in the group
|
|
703
|
+
adata.obs["heat_diff"] = np.minimum(
|
|
704
|
+
adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
|
|
705
|
+
)
|
|
691
706
|
return adata
|