pertpy 0.9.4__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +23 -17
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.4.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -5
- {pertpy-0.9.4.dist-info → pertpy-0.9.5.dist-info}/RECORD +28 -28
- {pertpy-0.9.4.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_mixscape.py
CHANGED
@@ -18,6 +18,7 @@ from scipy.sparse import csr_matrix, issparse, spmatrix
|
|
18
18
|
from sklearn.mixture import GaussianMixture
|
19
19
|
|
20
20
|
import pertpy as pt
|
21
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
21
22
|
|
22
23
|
if TYPE_CHECKING:
|
23
24
|
from collections.abc import Sequence
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
25
26
|
from anndata import AnnData
|
26
27
|
from matplotlib.axes import Axes
|
27
28
|
from matplotlib.colors import Colormap
|
29
|
+
from matplotlib.pyplot import Figure
|
28
30
|
from scipy import sparse
|
29
31
|
|
30
32
|
|
@@ -102,7 +104,7 @@ class Mixscape:
|
|
102
104
|
control_mask_split = control_mask & split_mask
|
103
105
|
|
104
106
|
R_split = representation[split_mask]
|
105
|
-
R_control = representation[control_mask_split]
|
107
|
+
R_control = representation[np.asarray(control_mask_split)]
|
106
108
|
|
107
109
|
from pynndescent import NNDescent
|
108
110
|
|
@@ -110,7 +112,7 @@ class Mixscape:
|
|
110
112
|
nn_index = NNDescent(R_control, **kwargs)
|
111
113
|
indices, _ = nn_index.query(R_split, k=n_neighbors, epsilon=eps)
|
112
114
|
|
113
|
-
X_control = np.expm1(adata.X[control_mask_split])
|
115
|
+
X_control = np.expm1(adata.X[np.asarray(control_mask_split)])
|
114
116
|
|
115
117
|
n_split = split_mask.sum()
|
116
118
|
n_control = X_control.shape[0]
|
@@ -254,7 +256,7 @@ class Mixscape:
|
|
254
256
|
else:
|
255
257
|
de_genes = perturbation_markers[(category, gene)]
|
256
258
|
de_genes_indices = self._get_column_indices(adata, list(de_genes))
|
257
|
-
dat = X[all_cells][:, de_genes_indices]
|
259
|
+
dat = X[np.asarray(all_cells)][:, de_genes_indices]
|
258
260
|
converged = False
|
259
261
|
n_iter = 0
|
260
262
|
old_classes = adata.obs[labels][all_cells]
|
@@ -264,8 +266,8 @@ class Mixscape:
|
|
264
266
|
# get average value for each gene over all selected cells
|
265
267
|
# all cells in current split&Gene minus all NT cells in current split
|
266
268
|
# Each row is for each cell, each column is for each gene, get mean for each column
|
267
|
-
vec = np.mean(X[guide_cells][:, de_genes_indices], axis=0) - np.mean(
|
268
|
-
X[nt_cells][:, de_genes_indices], axis=0
|
269
|
+
vec = np.mean(X[np.asarray(guide_cells)][:, de_genes_indices], axis=0) - np.mean(
|
270
|
+
X[np.asarray(nt_cells)][:, de_genes_indices], axis=0
|
269
271
|
)
|
270
272
|
# project cells onto the perturbation vector
|
271
273
|
if isinstance(dat, spmatrix):
|
@@ -506,21 +508,23 @@ class Mixscape:
|
|
506
508
|
|
507
509
|
return [mu, sd]
|
508
510
|
|
511
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
509
512
|
def plot_barplot( # pragma: no cover
|
510
513
|
self,
|
511
514
|
adata: AnnData,
|
512
515
|
guide_rna_column: str,
|
516
|
+
*,
|
513
517
|
mixscape_class_global: str = "mixscape_class_global",
|
514
518
|
axis_text_x_size: int = 8,
|
515
519
|
axis_text_y_size: int = 6,
|
516
520
|
axis_title_size: int = 8,
|
517
521
|
legend_title_size: int = 8,
|
518
522
|
legend_text_size: int = 8,
|
519
|
-
|
520
|
-
|
521
|
-
show: bool
|
522
|
-
|
523
|
-
):
|
523
|
+
legend_bbox_to_anchor: tuple[float, float] = None,
|
524
|
+
figsize: tuple[float, float] = (25, 25),
|
525
|
+
show: bool = True,
|
526
|
+
return_fig: bool = False,
|
527
|
+
) -> Figure | None:
|
524
528
|
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
|
525
529
|
|
526
530
|
Args:
|
@@ -528,12 +532,17 @@ class Mixscape:
|
|
528
532
|
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
|
529
533
|
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
|
530
534
|
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
|
531
|
-
|
532
|
-
|
533
|
-
|
535
|
+
axis_text_x_size: Size of the x-axis text.
|
536
|
+
axis_text_y_size: Size of the y-axis text.
|
537
|
+
axis_title_size: Size of the axis title.
|
538
|
+
legend_title_size: Size of the legend title.
|
539
|
+
legend_text_size: Size of the legend text.
|
540
|
+
legend_bbox_to_anchor: The bbox that the legend will be anchored.
|
541
|
+
figsize: The size of the figure.
|
542
|
+
{common_plot_args}
|
534
543
|
|
535
544
|
Returns:
|
536
|
-
If `
|
545
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
537
546
|
|
538
547
|
Examples:
|
539
548
|
>>> import pertpy as pt
|
@@ -565,63 +574,66 @@ class Mixscape:
|
|
565
574
|
all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"]
|
566
575
|
NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"]
|
567
576
|
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
593
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
594
|
-
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
595
|
-
fig.subplots_adjust(right=0.8)
|
596
|
-
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
577
|
+
color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"}
|
578
|
+
unique_genes = NP_KO_cells["gene"].unique()
|
579
|
+
fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=figsize, sharey=True)
|
580
|
+
for i, gene in enumerate(unique_genes):
|
581
|
+
ax = axs[int(i / 5), i % 5]
|
582
|
+
grouped_df = (
|
583
|
+
NP_KO_cells[NP_KO_cells["gene"] == gene]
|
584
|
+
.groupby(["guide_number", "mixscape_class_global"], observed=False)["value"]
|
585
|
+
.sum()
|
586
|
+
.unstack()
|
587
|
+
)
|
588
|
+
grouped_df.plot(
|
589
|
+
kind="bar",
|
590
|
+
stacked=True,
|
591
|
+
color=[color_mapping[col] for col in grouped_df.columns],
|
592
|
+
ax=ax,
|
593
|
+
width=0.8,
|
594
|
+
legend=False,
|
595
|
+
)
|
596
|
+
ax.set_title(gene, bbox={"facecolor": "white", "edgecolor": "black", "pad": 1}, fontsize=axis_title_size)
|
597
|
+
ax.set(xlabel="sgRNA", ylabel="% of cells")
|
598
|
+
sns.despine(ax=ax, top=True, right=True, left=False, bottom=False)
|
599
|
+
ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=0, ha="right", fontsize=axis_text_x_size)
|
600
|
+
ax.set_yticks(ax.get_yticks(), ax.get_yticklabels(), rotation=0, fontsize=axis_text_y_size)
|
597
601
|
ax.legend(
|
598
|
-
title="
|
602
|
+
title="Mixscape Class",
|
599
603
|
loc="center right",
|
600
|
-
bbox_to_anchor=
|
604
|
+
bbox_to_anchor=legend_bbox_to_anchor,
|
601
605
|
frameon=True,
|
602
606
|
fontsize=legend_text_size,
|
603
607
|
title_fontsize=legend_title_size,
|
604
608
|
)
|
605
609
|
|
610
|
+
fig.subplots_adjust(right=0.8)
|
611
|
+
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
606
612
|
plt.tight_layout()
|
607
|
-
_utils.savefig_or_show("mixscape_barplot", show=show, save=save)
|
608
613
|
|
614
|
+
if show:
|
615
|
+
plt.show()
|
616
|
+
if return_fig:
|
617
|
+
return fig
|
618
|
+
return None
|
619
|
+
|
620
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
609
621
|
def plot_heatmap( # pragma: no cover
|
610
622
|
self,
|
611
623
|
adata: AnnData,
|
612
624
|
labels: str,
|
613
625
|
target_gene: str,
|
614
626
|
control: str,
|
627
|
+
*,
|
615
628
|
layer: str | None = None,
|
616
629
|
method: str | None = "wilcoxon",
|
617
630
|
subsample_number: int | None = 900,
|
618
631
|
vmin: float | None = -2,
|
619
632
|
vmax: float | None = 2,
|
620
|
-
|
621
|
-
|
622
|
-
save: bool | str | None = None,
|
633
|
+
show: bool = True,
|
634
|
+
return_fig: bool = False,
|
623
635
|
**kwds,
|
624
|
-
) ->
|
636
|
+
) -> Figure | None:
|
625
637
|
"""Heatmap plot using mixscape results. Requires `pt.tl.mixscape()` to be run first.
|
626
638
|
|
627
639
|
Args:
|
@@ -634,14 +646,11 @@ class Mixscape:
|
|
634
646
|
subsample_number: Subsample to this number of observations.
|
635
647
|
vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
|
636
648
|
vmax: The value representing the upper limit of the color scale. Values larger than vmax are plotted with the same color as vmax.
|
637
|
-
|
638
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
639
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
640
|
-
ax: A matplotlib axes object. Only works if plotting a single component.
|
649
|
+
{common_plot_args}
|
641
650
|
**kwds: Additional arguments to `scanpy.pl.rank_genes_groups_heatmap`.
|
642
651
|
|
643
652
|
Returns:
|
644
|
-
If `
|
653
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
645
654
|
|
646
655
|
Examples:
|
647
656
|
>>> import pertpy as pt
|
@@ -663,35 +672,39 @@ class Mixscape:
|
|
663
672
|
sc.pp.scale(adata_subset, max_value=vmax)
|
664
673
|
sc.pp.subsample(adata_subset, n_obs=subsample_number)
|
665
674
|
|
666
|
-
|
675
|
+
fig = sc.pl.rank_genes_groups_heatmap(
|
667
676
|
adata_subset,
|
668
677
|
groupby="mixscape_class",
|
669
678
|
vmin=vmin,
|
670
679
|
vmax=vmax,
|
671
680
|
n_genes=20,
|
672
681
|
groups=["NT"],
|
673
|
-
|
674
|
-
show=show,
|
675
|
-
save=save,
|
682
|
+
show=False,
|
676
683
|
**kwds,
|
677
684
|
)
|
678
685
|
|
686
|
+
if show:
|
687
|
+
plt.show()
|
688
|
+
if return_fig:
|
689
|
+
return fig
|
690
|
+
return None
|
691
|
+
|
692
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
679
693
|
def plot_perturbscore( # pragma: no cover
|
680
694
|
self,
|
681
695
|
adata: AnnData,
|
682
696
|
labels: str,
|
683
697
|
target_gene: str,
|
698
|
+
*,
|
684
699
|
mixscape_class: str = "mixscape_class",
|
685
700
|
color: str = "orange",
|
686
701
|
palette: dict[str, str] = None,
|
687
702
|
split_by: str = None,
|
688
703
|
before_mixscape: bool = False,
|
689
704
|
perturbation_type: str = "KO",
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
save: bool | str | None = None,
|
694
|
-
) -> None:
|
705
|
+
show: bool = True,
|
706
|
+
return_fig: bool = False,
|
707
|
+
) -> Figure | None:
|
695
708
|
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
|
696
709
|
|
697
710
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -710,6 +723,10 @@ class Mixscape:
|
|
710
723
|
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
|
711
724
|
Default is set to NULL and plots cells by original class ID.
|
712
725
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
726
|
+
{common_plot_args}
|
727
|
+
|
728
|
+
Returns:
|
729
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
713
730
|
|
714
731
|
Examples:
|
715
732
|
Visualizing the perturbation scores for the cells in a dataset:
|
@@ -778,15 +795,6 @@ class Mixscape:
|
|
778
795
|
plt.legend(title="gene_target", title_fontsize=14, fontsize=12)
|
779
796
|
sns.despine()
|
780
797
|
|
781
|
-
if save:
|
782
|
-
plt.savefig(save, bbox_inches="tight")
|
783
|
-
if show:
|
784
|
-
plt.show()
|
785
|
-
if return_fig:
|
786
|
-
return plt.gcf()
|
787
|
-
if not (show or save):
|
788
|
-
return plt.gca()
|
789
|
-
|
790
798
|
# If before_mixscape is False, split densities based on mixscape classifications
|
791
799
|
else:
|
792
800
|
if palette is None:
|
@@ -843,19 +851,18 @@ class Mixscape:
|
|
843
851
|
plt.legend(title="mixscape class", title_fontsize=14, fontsize=12)
|
844
852
|
sns.despine()
|
845
853
|
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
return plt.gcf()
|
852
|
-
if not (show or save):
|
853
|
-
return plt.gca()
|
854
|
+
if show:
|
855
|
+
plt.show()
|
856
|
+
if return_fig:
|
857
|
+
return plt.gcf()
|
858
|
+
return None
|
854
859
|
|
860
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
855
861
|
def plot_violin( # pragma: no cover
|
856
862
|
self,
|
857
863
|
adata: AnnData,
|
858
864
|
target_gene_idents: str | list[str],
|
865
|
+
*,
|
859
866
|
keys: str | Sequence[str] = "mixscape_class_p_ko",
|
860
867
|
groupby: str | None = "mixscape_class",
|
861
868
|
log: bool = False,
|
@@ -872,10 +879,10 @@ class Mixscape:
|
|
872
879
|
ylabel: str | Sequence[str] | None = None,
|
873
880
|
rotation: float | None = None,
|
874
881
|
ax: Axes | None = None,
|
875
|
-
show: bool
|
876
|
-
|
882
|
+
show: bool = True,
|
883
|
+
return_fig: bool = False,
|
877
884
|
**kwargs,
|
878
|
-
):
|
885
|
+
) -> Axes | Figure | None:
|
879
886
|
"""Violin plot using mixscape results.
|
880
887
|
|
881
888
|
Requires `pt.tl.mixscape` to be run first.
|
@@ -892,14 +899,12 @@ class Mixscape:
|
|
892
899
|
xlabel: Label of the x-axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
|
893
900
|
ylabel: Label of the y-axis. If `None` and `groupby` is `None`, defaults to `'value'`.
|
894
901
|
If `None` and `groubpy` is not `None`, defaults to `keys`.
|
895
|
-
show: Show the plot, do not return axis.
|
896
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
897
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
898
902
|
ax: A matplotlib axes object. Only works if plotting a single component.
|
903
|
+
{common_plot_args}
|
899
904
|
**kwargs: Additional arguments to `seaborn.violinplot`.
|
900
905
|
|
901
906
|
Returns:
|
902
|
-
|
907
|
+
If `return_fig` is `True`, returns the figure (as Axes list if it's a multi-panel plot), otherwise `None`.
|
903
908
|
|
904
909
|
Examples:
|
905
910
|
>>> import pertpy as pt
|
@@ -1045,20 +1050,24 @@ class Mixscape:
|
|
1045
1050
|
show = settings.autoshow if show is None else show
|
1046
1051
|
if hue is not None and stripplot is True:
|
1047
1052
|
plt.legend(handles, labels)
|
1048
|
-
_utils.savefig_or_show("mixscape_violin", show=show, save=save)
|
1049
1053
|
|
1050
|
-
if
|
1054
|
+
if show:
|
1055
|
+
plt.show()
|
1056
|
+
if return_fig:
|
1051
1057
|
if multi_panel and groupby is None and len(ys) == 1:
|
1052
1058
|
return g
|
1053
1059
|
elif len(axs) == 1:
|
1054
1060
|
return axs[0]
|
1055
1061
|
else:
|
1056
1062
|
return axs
|
1063
|
+
return None
|
1057
1064
|
|
1065
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
1058
1066
|
def plot_lda( # pragma: no cover
|
1059
1067
|
self,
|
1060
1068
|
adata: AnnData,
|
1061
1069
|
control: str,
|
1070
|
+
*,
|
1062
1071
|
mixscape_class: str = "mixscape_class",
|
1063
1072
|
mixscape_class_global: str = "mixscape_class_global",
|
1064
1073
|
perturbation_type: str | None = "KO",
|
@@ -1066,12 +1075,11 @@ class Mixscape:
|
|
1066
1075
|
n_components: int | None = None,
|
1067
1076
|
color_map: Colormap | str | None = None,
|
1068
1077
|
palette: str | Sequence[str] | None = None,
|
1069
|
-
return_fig: bool | None = None,
|
1070
1078
|
ax: Axes | None = None,
|
1071
|
-
show: bool
|
1072
|
-
|
1079
|
+
show: bool = True,
|
1080
|
+
return_fig: bool = False,
|
1073
1081
|
**kwds,
|
1074
|
-
) -> None:
|
1082
|
+
) -> Figure | None:
|
1075
1083
|
"""Visualizing perturbation responses with Linear Discriminant Analysis. Requires `pt.tl.mixscape()` to be run first.
|
1076
1084
|
|
1077
1085
|
Args:
|
@@ -1082,9 +1090,7 @@ class Mixscape:
|
|
1082
1090
|
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
|
1083
1091
|
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
|
1084
1092
|
n_components: The number of dimensions of the embedding.
|
1085
|
-
|
1086
|
-
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
|
1087
|
-
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
|
1093
|
+
{common_plot_args}
|
1088
1094
|
**kwds: Additional arguments to `scanpy.pl.umap`.
|
1089
1095
|
|
1090
1096
|
Examples:
|
@@ -1112,14 +1118,19 @@ class Mixscape:
|
|
1112
1118
|
n_components = adata_subset.uns[lda_key].shape[1]
|
1113
1119
|
sc.pp.neighbors(adata_subset, use_rep=lda_key)
|
1114
1120
|
sc.tl.umap(adata_subset, n_components=n_components)
|
1115
|
-
sc.pl.umap(
|
1121
|
+
fig = sc.pl.umap(
|
1116
1122
|
adata_subset,
|
1117
1123
|
color=mixscape_class,
|
1118
1124
|
palette=palette,
|
1119
1125
|
color_map=color_map,
|
1120
1126
|
return_fig=return_fig,
|
1121
|
-
show=
|
1122
|
-
save=save,
|
1127
|
+
show=False,
|
1123
1128
|
ax=ax,
|
1124
1129
|
**kwds,
|
1125
1130
|
)
|
1131
|
+
|
1132
|
+
if show:
|
1133
|
+
plt.show()
|
1134
|
+
if return_fig:
|
1135
|
+
return fig
|
1136
|
+
return None
|
@@ -7,6 +7,7 @@ import pandas as pd
|
|
7
7
|
from anndata import AnnData
|
8
8
|
from lamin_utils import logger
|
9
9
|
from rich import print
|
10
|
+
from scipy.stats import entropy
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
12
13
|
from collections.abc import Iterable
|
@@ -41,7 +42,7 @@ class PerturbationSpace:
|
|
41
42
|
Args:
|
42
43
|
adata: Anndata object of size cells x genes.
|
43
44
|
target_col: .obs column name that stores the label of the perturbation applied to each cell.
|
44
|
-
group_col: .obs column name that stores the label of the group of
|
45
|
+
group_col: .obs column name that stores the label of the group of each cell. If None, ignore groups.
|
45
46
|
reference_key: The key of the control values.
|
46
47
|
layer_key: Key of the AnnData layer to use for computation.
|
47
48
|
new_layer_key: the results are stored in the given layer.
|
@@ -364,50 +365,58 @@ class PerturbationSpace:
|
|
364
365
|
self,
|
365
366
|
adata: AnnData,
|
366
367
|
column: str = "perturbation",
|
368
|
+
column_uncertainty_score_key: str = "perturbation_transfer_uncertainty",
|
367
369
|
target_val: str = "unknown",
|
368
|
-
|
369
|
-
use_rep: str = "X_umap",
|
370
|
+
neighbors_key: str = "neighbors",
|
370
371
|
) -> None:
|
371
372
|
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
|
372
373
|
|
374
|
+
Uncertainty is calculated as the entropy of the label distribution in the neighborhood of the target cell.
|
375
|
+
In other words, a cell where all neighbors have the same set of labels will have an uncertainty of 0, whereas a cell
|
376
|
+
where all neighbors have many different labels will have high uncertainty.
|
377
|
+
|
373
378
|
Args:
|
374
379
|
adata: The AnnData object containing single-cell data.
|
375
|
-
column: The column name in
|
380
|
+
column: The column name in adata.obs to perform imputation on.
|
381
|
+
column_uncertainty_score_key: The column name in adata.obs to store the uncertainty score of the label transfer.
|
376
382
|
target_val: The target value to impute.
|
377
|
-
|
378
|
-
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
|
383
|
+
neighbors_key: The key in adata.uns where the neighbors are stored.
|
379
384
|
|
380
385
|
Examples:
|
381
386
|
>>> import pertpy as pt
|
382
387
|
>>> import scanpy as sc
|
383
388
|
>>> import numpy as np
|
384
389
|
>>> adata = sc.datasets.pbmc68k_reduced()
|
385
|
-
>>>
|
386
|
-
>>> adata.obs["perturbation"] =
|
387
|
-
|
388
|
-
|
390
|
+
>>> # randomly dropout 10% of the data annotations
|
391
|
+
>>> adata.obs["perturbation"] = adata.obs["louvain"].astype(str).copy()
|
392
|
+
>>> random_cells = np.random.choice(adata.obs.index, int(adata.obs.shape[0] * 0.1), replace=False)
|
393
|
+
>>> adata.obs.loc[random_cells, "perturbation"] = "unknown"
|
389
394
|
>>> sc.pp.neighbors(adata)
|
390
395
|
>>> sc.tl.umap(adata)
|
391
396
|
>>> ps = pt.tl.PseudobulkSpace()
|
392
|
-
>>> ps.label_transfer(adata
|
397
|
+
>>> ps.label_transfer(adata)
|
393
398
|
"""
|
394
|
-
if
|
395
|
-
raise ValueError(f"
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
399
|
+
if neighbors_key not in adata.uns:
|
400
|
+
raise ValueError(f"Key {neighbors_key} not found in adata.uns. Please run `sc.pp.neighbors` first.")
|
401
|
+
|
402
|
+
labels = adata.obs[column].astype(str)
|
403
|
+
target_cells = labels == target_val
|
404
|
+
|
405
|
+
connectivities = adata.obsp[adata.uns[neighbors_key]["connectivities_key"]]
|
406
|
+
# convert labels to an incidence matrix
|
407
|
+
one_hot_encoded_labels = adata.obs[column].astype(str).str.get_dummies()
|
408
|
+
# convert to distance-weighted neighborhood incidence matrix
|
409
|
+
weighted_label_occurence = pd.DataFrame(
|
410
|
+
(one_hot_encoded_labels.values.T * connectivities).T,
|
411
|
+
index=adata.obs_names,
|
412
|
+
columns=one_hot_encoded_labels.columns,
|
413
|
+
)
|
414
|
+
# choose best label for each target cell
|
415
|
+
best_labels = weighted_label_occurence.drop(target_val, axis=1)[target_cells].idxmax(axis=1)
|
416
|
+
adata.obs[column] = labels
|
417
|
+
adata.obs.loc[target_cells, column] = best_labels
|
418
|
+
|
419
|
+
# calculate uncertainty
|
420
|
+
uncertainty = np.zeros(adata.n_obs)
|
421
|
+
uncertainty[target_cells] = entropy(weighted_label_occurence.drop(target_val, axis=1)[target_cells], axis=1)
|
422
|
+
adata.obs[column_uncertainty_score_key] = uncertainty
|
@@ -1,13 +1,20 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
3
5
|
import decoupler as dc
|
6
|
+
import matplotlib.pyplot as plt
|
4
7
|
import numpy as np
|
5
8
|
from anndata import AnnData
|
6
9
|
from sklearn.cluster import DBSCAN, KMeans
|
7
10
|
|
11
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
8
12
|
from pertpy.tools._perturbation_space._clustering import ClusteringSpace
|
9
13
|
from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace
|
10
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from matplotlib.pyplot import Figure
|
17
|
+
|
11
18
|
|
12
19
|
class CentroidSpace(PerturbationSpace):
|
13
20
|
"""Computes the centroids per perturbation of a pre-computed embedding."""
|
@@ -168,6 +175,49 @@ class PseudobulkSpace(PerturbationSpace):
|
|
168
175
|
|
169
176
|
return ps_adata
|
170
177
|
|
178
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
179
|
+
def plot_psbulk_samples(
|
180
|
+
self,
|
181
|
+
adata: AnnData,
|
182
|
+
groupby: str,
|
183
|
+
*,
|
184
|
+
show: bool = True,
|
185
|
+
return_fig: bool = False,
|
186
|
+
**kwargs,
|
187
|
+
) -> Figure | None:
|
188
|
+
"""Plot the pseudobulk samples of an AnnData object.
|
189
|
+
|
190
|
+
Plot the count number vs. the number of cells per pseudobulk sample.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
adata: Anndata containing pseudobulk samples.
|
194
|
+
groupby: `.obs` column to color the samples by.
|
195
|
+
{common_plot_args}
|
196
|
+
**kwargs: Are passed to decoupler's plot_psbulk_samples.
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
200
|
+
|
201
|
+
Examples:
|
202
|
+
>>> import pertpy as pt
|
203
|
+
>>> adata = pt.dt.zhang_2021()
|
204
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
205
|
+
>>> pdata = ps.compute(
|
206
|
+
... adata, target_col="Patient", groups_col="Cluster", mode="sum", min_cells=10, min_counts=1000
|
207
|
+
... )
|
208
|
+
>>> ps.plot_psbulk_samples(pdata, groupby=["Patient", "Major celltype"], figsize=(12, 4))
|
209
|
+
|
210
|
+
Preview:
|
211
|
+
.. image:: /_static/docstring_previews/pseudobulk_samples.png
|
212
|
+
"""
|
213
|
+
fig = dc.plot_psbulk_samples(adata, groupby, return_fig=True, **kwargs)
|
214
|
+
|
215
|
+
if show:
|
216
|
+
plt.show()
|
217
|
+
if return_fig:
|
218
|
+
return fig
|
219
|
+
return None
|
220
|
+
|
171
221
|
|
172
222
|
class KMeansSpace(ClusteringSpace):
|
173
223
|
"""Computes K-Means clustering of the expression values."""
|