pertpy 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/data/_dataloader.py +2 -2
- pertpy/data/_datasets.py +62 -62
- pertpy/metadata/_cell_line.py +9 -3
- pertpy/metadata/_drug.py +4 -2
- pertpy/preprocessing/_guide_rna.py +17 -10
- pertpy/preprocessing/_guide_rna_mixture.py +9 -3
- pertpy/tools/__init__.py +12 -2
- pertpy/tools/_augur.py +37 -14
- pertpy/tools/_coda/_sccoda.py +68 -101
- pertpy/tools/_coda/_tasccoda.py +103 -85
- pertpy/tools/_mixscape.py +48 -39
- pertpy/tools/_perturbation_space/_comparison.py +3 -3
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +261 -353
- pertpy/tools/_perturbation_space/_perturbation_space.py +22 -14
- pertpy/tools/_perturbation_space/_simple.py +12 -6
- pertpy/tools/_scgen/_scgenvae.py +2 -1
- pertpy/tools/core.py +18 -0
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/METADATA +14 -2
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/RECORD +22 -21
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/WHEEL +0 -0
- {pertpy-1.0.1.dist-info → pertpy-1.0.3.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_tasccoda.py
CHANGED
@@ -33,24 +33,6 @@ config.update("jax_enable_x64", True)
|
|
33
33
|
class Tasccoda(CompositionalModel2):
|
34
34
|
r"""Statistical model for tree-aggregated differential composition analysis (tascCODA, Ostner et al., 2021).
|
35
35
|
|
36
|
-
The hierarchical formulation of the model for one sample is:
|
37
|
-
|
38
|
-
.. math::
|
39
|
-
\\begin{align*}
|
40
|
-
Y_i &\\sim \\textrm{DirMult}(\\bar{Y}_i, \\textbf{a}(\\textbf{x})_i)\\\\
|
41
|
-
\\log(\\textbf{a}(X))_i &= \\alpha + X_{i, \\cdot} \\beta\\\\
|
42
|
-
\\alpha_j &\\sim \\mathcal{N}(0, 10) & \\forall j\\in[p]\\\\
|
43
|
-
\\beta &= \\hat{\\beta} A^T \\\\
|
44
|
-
\\hat{\\beta}_{l, k} &= 0 & \\forall k \\in \\hat{v}, l \\in [d]\\\\
|
45
|
-
\\hat{\\beta}_{l, k} &= \\theta \\tilde{\\beta}_{1, l, k} + (1- \\theta) \\tilde{\\beta}_{0, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in [d]\\\\
|
46
|
-
\\tilde{\\beta}_{m, l, k} &= \\sigma_{m, l, k} * b_{m, l, k} \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, m \\in \\{0, 1\\}, l \\in [d]\\\\
|
47
|
-
\\sigma_{m, l, k} &\\sim \\textrm{Exp}(\\lambda_{m, l, k}^2/2) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
|
48
|
-
b_{m, l, k} &\\sim N(0,1) \\quad & \\forall k\\in\\{[v] \\smallsetminus \\hat{v}\\}, l \\in \\{0, 1\\}, l \\in [d]\\\\
|
49
|
-
\\theta &\\sim \\textrm{Beta}(1, \\frac{1}{|\\{[v] \\smallsetminus \\hat{v}\\}|})
|
50
|
-
\\end{align*}
|
51
|
-
|
52
|
-
with Y being the cell counts, X the covariates, and v the set of nodes of the underlying tree structure.
|
53
|
-
|
54
36
|
For further information, see `tascCODA: Bayesian Tree-Aggregated Analysis of Compositional Amplicon and Single-Cell Data`
|
55
37
|
(Ostner et al., 2021)
|
56
38
|
"""
|
@@ -75,11 +57,14 @@ class Tasccoda(CompositionalModel2):
|
|
75
57
|
modality_key_1: str = "rna",
|
76
58
|
modality_key_2: str = "coda",
|
77
59
|
) -> MuData:
|
78
|
-
"""Prepare a MuData object for subsequent processing.
|
60
|
+
"""Prepare a MuData object for subsequent processing.
|
61
|
+
|
62
|
+
If type is "cell_level", then create a compositional analysis dataset from the input adata.
|
63
|
+
If type is "sample_level", generate ete tree for tascCODA models from dendrogram information or cell-level observations.
|
79
64
|
|
80
|
-
When using
|
65
|
+
When using `type="cell_level"`, `adata` needs to have a column in `adata.obs` that contains the cell type assignment.
|
81
66
|
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
|
82
|
-
Further covariates (e.g. subject age) can either be specified via addidional column names in
|
67
|
+
Further covariates (e.g. subject age) can either be specified via addidional column names in `adata.obs`, a key in `adata.uns`, or as a separate DataFrame.
|
83
68
|
|
84
69
|
Args:
|
85
70
|
adata: AnnData object.
|
@@ -90,10 +75,13 @@ class Tasccoda(CompositionalModel2):
|
|
90
75
|
covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
|
91
76
|
covariate_df: If type is "cell_level", specify dataFrame with covariates.
|
92
77
|
dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
|
93
|
-
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
|
94
|
-
|
78
|
+
levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels.
|
79
|
+
The list must begin with the root level, and end with the leaf level.
|
80
|
+
levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels.
|
81
|
+
The list must begin with the root level, and end with the leaf level.
|
95
82
|
add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
|
96
|
-
key_added: If not specified, the tree is stored in
|
83
|
+
key_added: If not specified, the tree is stored in `.uns['tree']`.
|
84
|
+
If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
|
97
85
|
modality_key_1: Key to the cell-level AnnData in the MuData object.
|
98
86
|
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
|
99
87
|
|
@@ -120,8 +108,10 @@ class Tasccoda(CompositionalModel2):
|
|
120
108
|
covariate_df=covariate_df,
|
121
109
|
)
|
122
110
|
mdata = MuData({modality_key_1: adata, modality_key_2: adata_coda})
|
123
|
-
|
111
|
+
elif type == "sample_level":
|
124
112
|
mdata = MuData({modality_key_1: AnnData(), modality_key_2: adata})
|
113
|
+
else:
|
114
|
+
raise ValueError(f'{type} is not a supported type, expected "cell_level" or "sample_level".')
|
125
115
|
import_tree(
|
126
116
|
data=mdata,
|
127
117
|
modality_1=modality_key_1,
|
@@ -464,7 +454,7 @@ class Tasccoda(CompositionalModel2):
|
|
464
454
|
self,
|
465
455
|
data: AnnData | MuData,
|
466
456
|
modality_key: str = "coda",
|
467
|
-
rng_key=None,
|
457
|
+
rng_key: int | None = None,
|
468
458
|
num_prior_samples: int = 500,
|
469
459
|
use_posterior_predictive: bool = True,
|
470
460
|
) -> az.InferenceData:
|
@@ -547,6 +537,8 @@ class Tasccoda(CompositionalModel2):
|
|
547
537
|
if rng_key is None:
|
548
538
|
rng = np.random.default_rng()
|
549
539
|
rng_key = random.key(rng.integers(0, 10000))
|
540
|
+
else:
|
541
|
+
rng_key = random.key(rng_key)
|
550
542
|
|
551
543
|
if use_posterior_predictive:
|
552
544
|
posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
|
@@ -557,6 +549,15 @@ class Tasccoda(CompositionalModel2):
|
|
557
549
|
ref_index=ref_index,
|
558
550
|
sample_adata=sample_adata,
|
559
551
|
)
|
552
|
+
# Remove problematic posterior predictive arrays with wrong dimensions
|
553
|
+
if posterior_predictive and "counts" in posterior_predictive:
|
554
|
+
counts_shape = posterior_predictive["counts"].shape
|
555
|
+
expected_dims = 2 # ['sample', 'cell_type']
|
556
|
+
if len(counts_shape) != expected_dims:
|
557
|
+
posterior_predictive = {k: v for k, v in posterior_predictive.items() if k != "counts"}
|
558
|
+
logger.warning(
|
559
|
+
f"Removed 'counts' from posterior_predictive due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
|
560
|
+
)
|
560
561
|
else:
|
561
562
|
posterior_predictive = None
|
562
563
|
|
@@ -569,6 +570,15 @@ class Tasccoda(CompositionalModel2):
|
|
569
570
|
ref_index=ref_index,
|
570
571
|
sample_adata=sample_adata,
|
571
572
|
)
|
573
|
+
# Remove problematic prior arrays with wrong dimensions
|
574
|
+
if prior and "counts" in prior:
|
575
|
+
counts_shape = prior["counts"].shape
|
576
|
+
expected_dims = 2 # ['sample', 'cell_type']
|
577
|
+
if len(counts_shape) != expected_dims:
|
578
|
+
prior = {k: v for k, v in prior.items() if k != "counts"}
|
579
|
+
logger.warning(
|
580
|
+
f"Removed 'counts' from prior due to dimension mismatch: got {len(counts_shape)}D, expected {expected_dims}D"
|
581
|
+
)
|
572
582
|
else:
|
573
583
|
prior = None
|
574
584
|
|
@@ -592,80 +602,88 @@ class Tasccoda(CompositionalModel2):
|
|
592
602
|
*args,
|
593
603
|
**kwargs,
|
594
604
|
):
|
595
|
-
"""
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
605
|
+
"""
|
606
|
+
|
607
|
+
Examples:
|
608
|
+
>>> import pertpy as pt
|
609
|
+
>>> adata = pt.dt.tasccoda_example()
|
610
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
611
|
+
>>> mdata = tasccoda.load(
|
612
|
+
>>> adata, type="sample_level",
|
613
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
614
|
+
>>> key_added="lineage", add_level_name=True
|
615
|
+
>>> )
|
616
|
+
>>> mdata = tasccoda.prepare(
|
617
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
618
|
+
>>> )
|
619
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42).
|
620
|
+
""" # noqa: D205, D212
|
609
621
|
return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
|
610
622
|
|
611
623
|
run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
|
612
624
|
|
613
625
|
def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
|
614
|
-
"""
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
626
|
+
"""
|
627
|
+
|
628
|
+
Examples:
|
629
|
+
>>> import pertpy as pt
|
630
|
+
>>> adata = pt.dt.tasccoda_example()
|
631
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
632
|
+
>>> mdata = tasccoda.load(
|
633
|
+
>>> adata, type="sample_level",
|
634
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
635
|
+
>>> key_added="lineage", add_level_name=True
|
636
|
+
>>> )
|
637
|
+
>>> mdata = tasccoda.prepare(
|
638
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
639
|
+
>>> )
|
640
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
641
|
+
>>> tasccoda.summary(mdata).
|
642
|
+
""" # noqa: D205, D212
|
629
643
|
return super().summary(data, extended, modality_key, *args, **kwargs)
|
630
644
|
|
631
645
|
summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
|
632
646
|
|
633
647
|
def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
|
634
|
-
"""
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
648
|
+
"""
|
649
|
+
|
650
|
+
Examples:
|
651
|
+
>>> import pertpy as pt
|
652
|
+
>>> adata = pt.dt.tasccoda_example()
|
653
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
654
|
+
>>> mdata = tasccoda.load(
|
655
|
+
>>> adata, type="sample_level",
|
656
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
657
|
+
>>> key_added="lineage", add_level_name=True
|
658
|
+
>>> )
|
659
|
+
>>> mdata = tasccoda.prepare(
|
660
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
661
|
+
>>> )
|
662
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
663
|
+
>>> tasccoda.credible_effects(mdata).
|
664
|
+
""" # noqa: D205, D212
|
649
665
|
return super().credible_effects(data, modality_key, est_fdr)
|
650
666
|
|
651
667
|
credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
|
652
668
|
|
653
669
|
def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
|
654
|
-
"""
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
670
|
+
"""
|
671
|
+
|
672
|
+
Examples:
|
673
|
+
>>> import pertpy as pt
|
674
|
+
>>> adata = pt.dt.tasccoda_example()
|
675
|
+
>>> tasccoda = pt.tl.Tasccoda()
|
676
|
+
>>> mdata = tasccoda.load(
|
677
|
+
>>> adata, type="sample_level",
|
678
|
+
>>> levels_agg=["Major_l1", "Major_l2", "Major_l3", "Major_l4", "Cluster"],
|
679
|
+
>>> key_added="lineage", add_level_name=True
|
680
|
+
>>> )
|
681
|
+
>>> mdata = tasccoda.prepare(
|
682
|
+
>>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
|
683
|
+
>>> )
|
684
|
+
>>> tasccoda.run_nuts(mdata, num_samples=1000, num_warmup=100, rng_key=42)
|
685
|
+
>>> tasccoda.set_fdr(mdata, est_fdr=0.4).
|
686
|
+
""" # noqa: D205, D212
|
669
687
|
return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
|
670
688
|
|
671
689
|
set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__
|
pertpy/tools/_mixscape.py
CHANGED
@@ -177,7 +177,7 @@ class Mixscape:
|
|
177
177
|
def mixscape(
|
178
178
|
self,
|
179
179
|
adata: AnnData,
|
180
|
-
|
180
|
+
pert_key: str,
|
181
181
|
control: str,
|
182
182
|
*,
|
183
183
|
new_class_name: str | None = "mixscape_class",
|
@@ -201,12 +201,12 @@ class Mixscape:
|
|
201
201
|
|
202
202
|
Args:
|
203
203
|
adata: The annotated data object.
|
204
|
-
|
204
|
+
pert_key: The column of `.obs` with target gene labels.
|
205
205
|
control: Control category from the `labels` column.
|
206
206
|
new_class_name: Name of mixscape classification to be stored in `.obs`.
|
207
207
|
layer: Key from adata.layers whose value will be used to perform tests on. Default is using `.layers["X_pert"]`.
|
208
208
|
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
|
209
|
-
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells
|
209
|
+
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
|
210
210
|
de_layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
211
211
|
test_method: Method to use for differential expression testing.
|
212
212
|
iter_num: Number of normalmixEM iterations to run if convergence does not occur.
|
@@ -256,7 +256,7 @@ class Mixscape:
|
|
256
256
|
adata=adata,
|
257
257
|
split_masks=split_masks,
|
258
258
|
categories=categories,
|
259
|
-
|
259
|
+
pert_key=pert_key,
|
260
260
|
control=control,
|
261
261
|
layer=de_layer,
|
262
262
|
pval_cutoff=pval_cutoff,
|
@@ -278,7 +278,7 @@ class Mixscape:
|
|
278
278
|
|
279
279
|
# initialize return variables
|
280
280
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
|
281
|
-
adata.obs[new_class_name] = adata.obs[
|
281
|
+
adata.obs[new_class_name] = adata.obs[pert_key].astype(str)
|
282
282
|
adata.obs[f"{new_class_name}_global"] = np.empty(
|
283
283
|
[
|
284
284
|
adata.n_obs,
|
@@ -290,12 +290,12 @@ class Mixscape:
|
|
290
290
|
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0.0
|
291
291
|
for split, split_mask in enumerate(split_masks):
|
292
292
|
category = categories[split]
|
293
|
-
gene_targets = list(set(adata[split_mask].obs[
|
293
|
+
gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
|
294
294
|
for gene in gene_targets:
|
295
295
|
post_prob = 0
|
296
|
-
orig_guide_cells = (adata.obs[
|
296
|
+
orig_guide_cells = (adata.obs[pert_key] == gene) & split_mask
|
297
297
|
orig_guide_cells_index = list(orig_guide_cells.index[orig_guide_cells])
|
298
|
-
nt_cells = (adata.obs[
|
298
|
+
nt_cells = (adata.obs[pert_key] == control) & split_mask
|
299
299
|
all_cells = orig_guide_cells | nt_cells
|
300
300
|
|
301
301
|
if len(perturbation_markers[(category, gene)]) == 0:
|
@@ -307,7 +307,11 @@ class Mixscape:
|
|
307
307
|
|
308
308
|
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
309
309
|
if scale:
|
310
|
-
|
310
|
+
with warnings.catch_warnings():
|
311
|
+
warnings.filterwarnings(
|
312
|
+
"ignore", message="zero-centering a sparse array/matrix densifies it."
|
313
|
+
)
|
314
|
+
dat = sc.pp.scale(dat)
|
311
315
|
|
312
316
|
converged = False
|
313
317
|
n_iter = 0
|
@@ -335,10 +339,10 @@ class Mixscape:
|
|
335
339
|
pvec = pd.Series(np.asarray(pvec).flatten(), index=list(all_cells.index[all_cells]))
|
336
340
|
|
337
341
|
if n_iter == 0:
|
338
|
-
gv = pd.DataFrame(columns=["pvec",
|
342
|
+
gv = pd.DataFrame(columns=["pvec", pert_key])
|
339
343
|
gv["pvec"] = pvec
|
340
|
-
gv[
|
341
|
-
gv.loc[guide_cells,
|
344
|
+
gv[pert_key] = control
|
345
|
+
gv.loc[guide_cells, pert_key] = gene
|
342
346
|
if gene not in gv_list:
|
343
347
|
gv_list[gene] = {}
|
344
348
|
gv_list[gene][category] = gv
|
@@ -389,7 +393,7 @@ class Mixscape:
|
|
389
393
|
def lda(
|
390
394
|
self,
|
391
395
|
adata: AnnData,
|
392
|
-
|
396
|
+
pert_key: str,
|
393
397
|
control: str,
|
394
398
|
*,
|
395
399
|
mixscape_class_global: str | None = "mixscape_class_global",
|
@@ -407,7 +411,7 @@ class Mixscape:
|
|
407
411
|
|
408
412
|
Args:
|
409
413
|
adata: The annotated data object.
|
410
|
-
|
414
|
+
pert_key: The column of `.obs` with target gene labels.
|
411
415
|
control: Control category from the `pert_key` column.
|
412
416
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
413
417
|
layer: Layer to use for identifying differentially expressed genes. If `None`, adata.X is used.
|
@@ -456,7 +460,7 @@ class Mixscape:
|
|
456
460
|
adata=adata,
|
457
461
|
split_masks=split_masks,
|
458
462
|
categories=categories,
|
459
|
-
|
463
|
+
pert_key=pert_key,
|
460
464
|
control=control,
|
461
465
|
layer=layer,
|
462
466
|
pval_cutoff=pval_cutoff,
|
@@ -475,17 +479,19 @@ class Mixscape:
|
|
475
479
|
continue
|
476
480
|
else:
|
477
481
|
gene_subset = adata_subset[
|
478
|
-
(adata_subset.obs[
|
482
|
+
(adata_subset.obs[pert_key] == key[1]) | (adata_subset.obs[pert_key] == control)
|
479
483
|
].copy()
|
480
|
-
|
484
|
+
with warnings.catch_warnings():
|
485
|
+
warnings.simplefilter("ignore", UserWarning)
|
486
|
+
sc.pp.scale(gene_subset)
|
481
487
|
sc.tl.pca(gene_subset, n_comps=n_comps)
|
482
488
|
# project cells into PCA space of gene_subset
|
483
489
|
projected_pcs[key[1]] = np.asarray(np.dot(X, gene_subset.varm["PCs"]))
|
484
490
|
# concatenate all pcs into a single matrix.
|
485
491
|
projected_pcs_array = np.concatenate(list(projected_pcs.values()), axis=1)
|
486
492
|
|
487
|
-
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[
|
488
|
-
clf.fit(projected_pcs_array, adata_subset.obs[
|
493
|
+
clf = LinearDiscriminantAnalysis(n_components=len(np.unique(adata_subset.obs[pert_key])) - 1)
|
494
|
+
clf.fit(projected_pcs_array, adata_subset.obs[pert_key])
|
489
495
|
cell_embeddings = clf.transform(projected_pcs_array)
|
490
496
|
adata.uns["mixscape_lda"] = cell_embeddings
|
491
497
|
|
@@ -495,9 +501,10 @@ class Mixscape:
|
|
495
501
|
def _get_perturbation_markers(
|
496
502
|
self,
|
497
503
|
adata: AnnData,
|
504
|
+
*,
|
498
505
|
split_masks: list[np.ndarray],
|
499
506
|
categories: list[str],
|
500
|
-
|
507
|
+
pert_key: str,
|
501
508
|
control: str,
|
502
509
|
layer: str,
|
503
510
|
pval_cutoff: float,
|
@@ -511,7 +518,7 @@ class Mixscape:
|
|
511
518
|
adata: :class:`~anndata.AnnData` object
|
512
519
|
split_masks: List of boolean masks for each split/group.
|
513
520
|
categories: List of split/group names.
|
514
|
-
|
521
|
+
pert_key: The column of `.obs` with target gene labels.
|
515
522
|
control: Control category from the `labels` column.
|
516
523
|
layer: Key from adata.layers whose value will be used to compare gene expression.
|
517
524
|
pval_cutoff: P-value cut-off for selection of significantly DE genes.
|
@@ -526,7 +533,7 @@ class Mixscape:
|
|
526
533
|
for split, split_mask in enumerate(split_masks):
|
527
534
|
category = categories[split]
|
528
535
|
# get gene sets for each split
|
529
|
-
gene_targets = list(set(adata[split_mask].obs[
|
536
|
+
gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
|
530
537
|
adata_split = adata[split_mask].copy()
|
531
538
|
# find top DE genes between cells with targeting and non-targeting gRNAs
|
532
539
|
with warnings.catch_warnings():
|
@@ -535,7 +542,7 @@ class Mixscape:
|
|
535
542
|
sc.tl.rank_genes_groups(
|
536
543
|
adata_split,
|
537
544
|
layer=layer,
|
538
|
-
groupby=
|
545
|
+
groupby=pert_key,
|
539
546
|
groups=gene_targets,
|
540
547
|
reference=control,
|
541
548
|
method=test_method,
|
@@ -666,7 +673,7 @@ class Mixscape:
|
|
666
673
|
def plot_heatmap( # pragma: no cover # noqa: D417
|
667
674
|
self,
|
668
675
|
adata: AnnData,
|
669
|
-
|
676
|
+
pert_key: str,
|
670
677
|
target_gene: str,
|
671
678
|
control: str,
|
672
679
|
*,
|
@@ -682,7 +689,7 @@ class Mixscape:
|
|
682
689
|
|
683
690
|
Args:
|
684
691
|
adata: The annotated data object.
|
685
|
-
|
692
|
+
pert_key: The column of `.obs` with target gene labels.
|
686
693
|
target_gene: Target gene name to visualize heatmap for.
|
687
694
|
control: Control category from the `pert_key` column.
|
688
695
|
layer: Key from `adata.layers` whose value will be used to perform tests on.
|
@@ -711,12 +718,13 @@ class Mixscape:
|
|
711
718
|
"""
|
712
719
|
if "mixscape_class" not in adata.obs:
|
713
720
|
raise ValueError("Please run `pt.tl.mixscape` first.")
|
714
|
-
adata_subset = adata[(adata.obs[
|
721
|
+
adata_subset = adata[(adata.obs[pert_key] == target_gene) | (adata.obs[pert_key] == control)].copy()
|
715
722
|
with warnings.catch_warnings():
|
716
723
|
warnings.simplefilter("ignore", RuntimeWarning)
|
717
724
|
warnings.simplefilter("ignore", PerformanceWarning)
|
718
|
-
|
719
|
-
|
725
|
+
warnings.simplefilter("ignore", UserWarning)
|
726
|
+
sc.tl.rank_genes_groups(adata_subset, layer=layer, groupby=pert_key, method=method)
|
727
|
+
sc.pp.scale(adata_subset, max_value=vmax)
|
720
728
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
721
729
|
|
722
730
|
fig = sc.pl.rank_genes_groups_heatmap(
|
@@ -739,7 +747,7 @@ class Mixscape:
|
|
739
747
|
def plot_perturbscore( # pragma: no cover # noqa: D417
|
740
748
|
self,
|
741
749
|
adata: AnnData,
|
742
|
-
|
750
|
+
pert_key: str,
|
743
751
|
target_gene: str,
|
744
752
|
*,
|
745
753
|
mixscape_class: str = "mixscape_class",
|
@@ -758,7 +766,7 @@ class Mixscape:
|
|
758
766
|
|
759
767
|
Args:
|
760
768
|
adata: The annotated data object.
|
761
|
-
|
769
|
+
pert_key: The column of `.obs` with target gene labels.
|
762
770
|
target_gene: Target gene name to visualize perturbation scores for.
|
763
771
|
mixscape_class: The column of `.obs` with mixscape classifications.
|
764
772
|
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
|
@@ -797,21 +805,21 @@ class Mixscape:
|
|
797
805
|
else:
|
798
806
|
perturbation_score = pd.concat([perturbation_score, perturbation_score_temp])
|
799
807
|
perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index]
|
800
|
-
gd = list(set(perturbation_score[
|
808
|
+
gd = list(set(perturbation_score[pert_key]).difference({target_gene}))[0]
|
801
809
|
|
802
810
|
# If before_mixscape is True, split densities based on original target gene classification
|
803
811
|
if before_mixscape is True:
|
804
812
|
palette = {gd: "#7d7d7d", target_gene: color}
|
805
|
-
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=
|
813
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
|
806
814
|
top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines())))
|
807
815
|
plt.close()
|
808
816
|
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
809
817
|
rng = np.random.default_rng()
|
810
|
-
perturbation_score.loc[perturbation_score[
|
811
|
-
low=0.001, high=top_r / 10, size=sum(perturbation_score[
|
818
|
+
perturbation_score.loc[perturbation_score[pert_key] == gd, "y_jitter"] = rng.uniform(
|
819
|
+
low=0.001, high=top_r / 10, size=sum(perturbation_score[pert_key] == gd)
|
812
820
|
)
|
813
|
-
perturbation_score.loc[perturbation_score[
|
814
|
-
low=-top_r / 10, high=0, size=sum(perturbation_score[
|
821
|
+
perturbation_score.loc[perturbation_score[pert_key] == target_gene, "y_jitter"] = rng.uniform(
|
822
|
+
low=-top_r / 10, high=0, size=sum(perturbation_score[pert_key] == target_gene)
|
815
823
|
)
|
816
824
|
# If split_by is provided, split densities based on the split_by
|
817
825
|
if split_by is not None:
|
@@ -844,7 +852,7 @@ class Mixscape:
|
|
844
852
|
else:
|
845
853
|
if palette is None:
|
846
854
|
palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color}
|
847
|
-
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=
|
855
|
+
plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=pert_key, fill=False, common_norm=False)
|
848
856
|
top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines())))
|
849
857
|
plt.close()
|
850
858
|
perturbation_score["y_jitter"] = perturbation_score["pvec"]
|
@@ -899,6 +907,7 @@ class Mixscape:
|
|
899
907
|
if return_fig:
|
900
908
|
return plt.gcf()
|
901
909
|
plt.show()
|
910
|
+
|
902
911
|
return None
|
903
912
|
|
904
913
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1058,7 +1067,7 @@ class Mixscape:
|
|
1058
1067
|
data=obs_tidy,
|
1059
1068
|
order=order,
|
1060
1069
|
orient="vertical",
|
1061
|
-
|
1070
|
+
density_norm=scale,
|
1062
1071
|
ax=ax,
|
1063
1072
|
hue=hue,
|
1064
1073
|
**kwargs,
|
@@ -1072,7 +1081,7 @@ class Mixscape:
|
|
1072
1081
|
data=obs_tidy,
|
1073
1082
|
order=order,
|
1074
1083
|
jitter=jitter,
|
1075
|
-
|
1084
|
+
palette="dark:black",
|
1076
1085
|
size=size,
|
1077
1086
|
ax=ax,
|
1078
1087
|
hue=hue,
|
@@ -22,7 +22,7 @@ class PerturbationComparison:
|
|
22
22
|
) -> float:
|
23
23
|
"""Compare classification accuracy between real and simulated perturbations.
|
24
24
|
|
25
|
-
Trains a classifier on the real perturbation data
|
25
|
+
Trains a classifier on the real perturbation data & the control data and reports a normalized
|
26
26
|
classification accuracy on the simulated perturbation.
|
27
27
|
|
28
28
|
Args:
|
@@ -64,8 +64,8 @@ class PerturbationComparison:
|
|
64
64
|
real: Real perturbed data.
|
65
65
|
simulated: Simulated perturbed data.
|
66
66
|
control: Control data
|
67
|
-
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
|
68
|
-
control (`control`) is provided.
|
67
|
+
use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph.
|
68
|
+
Only valid when control (`control`) is provided.
|
69
69
|
n_neighbors: Number of neighbors to use in k-neighbor graph.
|
70
70
|
random_state: Random state used for k-neighbor graph construction.
|
71
71
|
n_jobs: Number of cores to use. Defaults to -1 (all).
|