pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -850,7 +850,7 @@ class CompositionalModel2(ABC):
|
|
850
850
|
table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
|
851
851
|
table.add_column("Name", justify="left", style="cyan")
|
852
852
|
table.add_column("Value", justify="left")
|
853
|
-
table.add_row("Data", "Data:
|
853
|
+
table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
|
854
854
|
table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
|
855
855
|
table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
|
856
856
|
if extended:
|
@@ -1199,7 +1199,6 @@ class CompositionalModel2(ABC):
|
|
1199
1199
|
level_order: list[str] = None,
|
1200
1200
|
figsize: tuple[float, float] | None = None,
|
1201
1201
|
dpi: int | None = 100,
|
1202
|
-
show: bool = True,
|
1203
1202
|
return_fig: bool = False,
|
1204
1203
|
) -> Figure | None:
|
1205
1204
|
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
@@ -1278,10 +1277,9 @@ class CompositionalModel2(ABC):
|
|
1278
1277
|
show_legend=show_legend,
|
1279
1278
|
)
|
1280
1279
|
|
1281
|
-
if show:
|
1282
|
-
plt.show()
|
1283
1280
|
if return_fig:
|
1284
1281
|
return plt.gcf()
|
1282
|
+
plt.show()
|
1285
1283
|
return None
|
1286
1284
|
|
1287
1285
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1300,7 +1298,6 @@ class CompositionalModel2(ABC):
|
|
1300
1298
|
args_barplot: dict | None = None,
|
1301
1299
|
figsize: tuple[float, float] | None = None,
|
1302
1300
|
dpi: int | None = 100,
|
1303
|
-
show: bool = True,
|
1304
1301
|
return_fig: bool = False,
|
1305
1302
|
) -> Figure | None:
|
1306
1303
|
"""Barplot visualization for effects.
|
@@ -1465,10 +1462,11 @@ class CompositionalModel2(ABC):
|
|
1465
1462
|
cell_types = pd.unique(plot_df["Cell Type"])
|
1466
1463
|
ax.set_xticklabels(cell_types, rotation=90)
|
1467
1464
|
|
1468
|
-
if
|
1469
|
-
|
1470
|
-
if return_fig:
|
1465
|
+
if return_fig and plot_facets:
|
1466
|
+
return g
|
1467
|
+
if return_fig and not plot_facets:
|
1471
1468
|
return plt.gcf()
|
1469
|
+
plt.show()
|
1472
1470
|
return None
|
1473
1471
|
|
1474
1472
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1489,7 +1487,6 @@ class CompositionalModel2(ABC):
|
|
1489
1487
|
level_order: list[str] = None,
|
1490
1488
|
figsize: tuple[float, float] | None = None,
|
1491
1489
|
dpi: int | None = 100,
|
1492
|
-
show: bool = True,
|
1493
1490
|
return_fig: bool = False,
|
1494
1491
|
) -> Figure | None:
|
1495
1492
|
"""Grouped boxplot visualization.
|
@@ -1697,10 +1694,11 @@ class CompositionalModel2(ABC):
|
|
1697
1694
|
title=feature_name,
|
1698
1695
|
)
|
1699
1696
|
|
1700
|
-
if
|
1701
|
-
|
1702
|
-
if return_fig:
|
1697
|
+
if return_fig and plot_facets:
|
1698
|
+
return g
|
1699
|
+
if return_fig and not plot_facets:
|
1703
1700
|
return plt.gcf()
|
1701
|
+
plt.show()
|
1704
1702
|
return None
|
1705
1703
|
|
1706
1704
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1716,7 +1714,6 @@ class CompositionalModel2(ABC):
|
|
1716
1714
|
figsize: tuple[float, float] | None = None,
|
1717
1715
|
dpi: int | None = 100,
|
1718
1716
|
ax: plt.Axes | None = None,
|
1719
|
-
show: bool = True,
|
1720
1717
|
return_fig: bool = False,
|
1721
1718
|
) -> Figure | None:
|
1722
1719
|
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
@@ -1820,10 +1817,9 @@ class CompositionalModel2(ABC):
|
|
1820
1817
|
|
1821
1818
|
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1822
1819
|
|
1823
|
-
if show:
|
1824
|
-
plt.show()
|
1825
1820
|
if return_fig:
|
1826
1821
|
return plt.gcf()
|
1822
|
+
plt.show()
|
1827
1823
|
return None
|
1828
1824
|
|
1829
1825
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1839,7 +1835,6 @@ class CompositionalModel2(ABC):
|
|
1839
1835
|
figsize: tuple[float, float] | None = (None, None),
|
1840
1836
|
dpi: int | None = 100,
|
1841
1837
|
save: str | bool = False,
|
1842
|
-
show: bool = True,
|
1843
1838
|
return_fig: bool = False,
|
1844
1839
|
) -> Tree | None:
|
1845
1840
|
"""Plot a tree using input ete3 tree object.
|
@@ -1903,10 +1898,9 @@ class CompositionalModel2(ABC):
|
|
1903
1898
|
|
1904
1899
|
if save is not None:
|
1905
1900
|
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1906
|
-
if show:
|
1907
|
-
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1908
1901
|
if return_fig:
|
1909
1902
|
return tree, tree_style
|
1903
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1910
1904
|
return None
|
1911
1905
|
|
1912
1906
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1925,7 +1919,6 @@ class CompositionalModel2(ABC):
|
|
1925
1919
|
figsize: tuple[float, float] | None = (None, None),
|
1926
1920
|
dpi: int | None = 100,
|
1927
1921
|
save: str | bool = False,
|
1928
|
-
show: bool = True,
|
1929
1922
|
return_fig: bool = False,
|
1930
1923
|
) -> Tree | None:
|
1931
1924
|
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
@@ -2092,15 +2085,16 @@ class CompositionalModel2(ABC):
|
|
2092
2085
|
|
2093
2086
|
if save:
|
2094
2087
|
plt.savefig(save)
|
2088
|
+
if return_fig:
|
2089
|
+
return plt.gcf()
|
2095
2090
|
|
2096
|
-
|
2097
|
-
|
2098
|
-
|
2099
|
-
if
|
2100
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2101
|
-
if return_fig:
|
2102
|
-
if not show_leaf_effects:
|
2091
|
+
else:
|
2092
|
+
if save:
|
2093
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2094
|
+
if return_fig:
|
2103
2095
|
return tree2, tree_style
|
2096
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2097
|
+
|
2104
2098
|
return None
|
2105
2099
|
|
2106
2100
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -2115,7 +2109,6 @@ class CompositionalModel2(ABC):
|
|
2115
2109
|
color_map: Colormap | str | None = None,
|
2116
2110
|
palette: str | Sequence[str] | None = None,
|
2117
2111
|
ax: Axes = None,
|
2118
|
-
show: bool = True,
|
2119
2112
|
return_fig: bool = False,
|
2120
2113
|
**kwargs,
|
2121
2114
|
) -> Figure | None:
|
@@ -2209,10 +2202,9 @@ class CompositionalModel2(ABC):
|
|
2209
2202
|
**kwargs,
|
2210
2203
|
)
|
2211
2204
|
|
2212
|
-
if show:
|
2213
|
-
plt.show()
|
2214
2205
|
if return_fig:
|
2215
2206
|
return fig
|
2207
|
+
plt.show()
|
2216
2208
|
return None
|
2217
2209
|
|
2218
2210
|
|
pertpy/tools/_dialogue.py
CHANGED
@@ -82,27 +82,27 @@ class Dialogue:
|
|
82
82
|
|
83
83
|
return pseudobulk
|
84
84
|
|
85
|
-
def
|
86
|
-
|
85
|
+
def _pseudobulk_feature_space(
|
86
|
+
self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
|
87
|
+
) -> pd.DataFrame:
|
88
|
+
"""Return Cell-averaged components from a passed feature space.
|
87
89
|
|
88
90
|
TODO: consider merging with `get_pseudobulks`
|
89
91
|
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
|
90
92
|
|
91
93
|
Args:
|
92
|
-
groupby: The key to groupby for pseudobulks
|
93
|
-
n_components: The number of
|
94
|
+
groupby: The key to groupby for pseudobulks.
|
95
|
+
n_components: The number of components to use.
|
96
|
+
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
|
94
97
|
|
95
98
|
Returns:
|
96
|
-
A pseudobulk of
|
99
|
+
A pseudobulk DataFrame of the averaged components.
|
97
100
|
"""
|
98
101
|
aggr = {}
|
99
|
-
|
100
102
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
101
103
|
temp = adata.obs.loc[:, groupby] == category
|
102
|
-
aggr[category] = adata[temp].obsm[
|
103
|
-
|
104
|
+
aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
|
104
105
|
aggr = pd.DataFrame(aggr)
|
105
|
-
|
106
106
|
return aggr
|
107
107
|
|
108
108
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
@@ -558,7 +558,7 @@ class Dialogue:
|
|
558
558
|
self,
|
559
559
|
adata: AnnData,
|
560
560
|
ct_order: list[str],
|
561
|
-
|
561
|
+
agg_feature: bool = True,
|
562
562
|
normalize: bool = True,
|
563
563
|
) -> tuple[list, dict]:
|
564
564
|
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
|
@@ -568,14 +568,14 @@ class Dialogue:
|
|
568
568
|
Args:
|
569
569
|
adata: AnnData object generate celltype objects for
|
570
570
|
ct_order: The order of cell types
|
571
|
-
|
571
|
+
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
|
572
572
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
573
573
|
|
574
574
|
Returns:
|
575
575
|
A celltype_label:array dictionary.
|
576
576
|
"""
|
577
577
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
578
|
-
fn = self.
|
578
|
+
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
|
579
579
|
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
|
580
580
|
|
581
581
|
# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
|
@@ -593,7 +593,7 @@ class Dialogue:
|
|
593
593
|
adata: AnnData,
|
594
594
|
penalties: list[int] = None,
|
595
595
|
ct_order: list[str] = None,
|
596
|
-
|
596
|
+
agg_feature: bool = True,
|
597
597
|
solver: Literal["lp", "bs"] = "bs",
|
598
598
|
normalize: bool = True,
|
599
599
|
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
|
@@ -606,7 +606,7 @@ class Dialogue:
|
|
606
606
|
sample_id: Key to use for pseudobulk determination.
|
607
607
|
penalties: PMD penalties.
|
608
608
|
ct_order: The order of cell types.
|
609
|
-
|
609
|
+
agg_features: Whether to calculate cell-averaged principal components.
|
610
610
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
611
611
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
612
612
|
normalize: Whether to mimic DIALOGUE as close as possible
|
@@ -631,7 +631,7 @@ class Dialogue:
|
|
631
631
|
else:
|
632
632
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
633
633
|
|
634
|
-
mcca_in, ct_subs = self._load(adata, ct_order=cell_types,
|
634
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
|
635
635
|
|
636
636
|
n_samples = mcca_in[0].shape[1]
|
637
637
|
if penalties is None:
|
@@ -1070,7 +1070,6 @@ class Dialogue:
|
|
1070
1070
|
*,
|
1071
1071
|
split_which: tuple[str, str] = None,
|
1072
1072
|
mcp: str = "mcp_0",
|
1073
|
-
show: bool = True,
|
1074
1073
|
return_fig: bool = False,
|
1075
1074
|
) -> Figure | None:
|
1076
1075
|
"""Plots split violin plots for a given MCP and split variable.
|
@@ -1110,10 +1109,9 @@ class Dialogue:
|
|
1110
1109
|
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1111
1110
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1112
1111
|
|
1113
|
-
if show:
|
1114
|
-
plt.show()
|
1115
1112
|
if return_fig:
|
1116
1113
|
return plt.gcf()
|
1114
|
+
plt.show()
|
1117
1115
|
return None
|
1118
1116
|
|
1119
1117
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1125,7 +1123,6 @@ class Dialogue:
|
|
1125
1123
|
sample_id: str,
|
1126
1124
|
*,
|
1127
1125
|
mcp: str = "mcp_0",
|
1128
|
-
show: bool = True,
|
1129
1126
|
return_fig: bool = False,
|
1130
1127
|
) -> Figure | None:
|
1131
1128
|
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
@@ -1167,8 +1164,7 @@ class Dialogue:
|
|
1167
1164
|
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
1165
|
sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
1166
|
|
1170
|
-
if show:
|
1171
|
-
plt.show()
|
1172
1167
|
if return_fig:
|
1173
1168
|
return plt.gcf()
|
1169
|
+
plt.show()
|
1174
1170
|
return None
|
@@ -125,7 +125,6 @@ class MethodBase(ABC):
|
|
125
125
|
shape_order: list[str] | None = None,
|
126
126
|
x_label: str | None = None,
|
127
127
|
y_label: str | None = None,
|
128
|
-
show: bool = True,
|
129
128
|
return_fig: bool = False,
|
130
129
|
**kwargs: int,
|
131
130
|
) -> Figure | None:
|
@@ -484,10 +483,9 @@ class MethodBase(ABC):
|
|
484
483
|
|
485
484
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
486
485
|
|
487
|
-
if show:
|
488
|
-
plt.show()
|
489
486
|
if return_fig:
|
490
487
|
return plt.gcf()
|
488
|
+
plt.show()
|
491
489
|
return None
|
492
490
|
|
493
491
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -511,7 +509,6 @@ class MethodBase(ABC):
|
|
511
509
|
pvalue_template=lambda x: f"p={x:.2e}",
|
512
510
|
boxplot_properties=None,
|
513
511
|
palette=None,
|
514
|
-
show: bool = True,
|
515
512
|
return_fig: bool = False,
|
516
513
|
) -> Figure | None:
|
517
514
|
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
@@ -679,10 +676,9 @@ class MethodBase(ABC):
|
|
679
676
|
)
|
680
677
|
|
681
678
|
plt.tight_layout()
|
682
|
-
if show:
|
683
|
-
plt.show()
|
684
679
|
if return_fig:
|
685
680
|
return plt.gcf()
|
681
|
+
plt.show()
|
686
682
|
return None
|
687
683
|
|
688
684
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -696,7 +692,6 @@ class MethodBase(ABC):
|
|
696
692
|
symbol_col: str = "variable",
|
697
693
|
y_label: str = "Log2 fold change",
|
698
694
|
figsize: tuple[int, int] = (10, 5),
|
699
|
-
show: bool = True,
|
700
695
|
return_fig: bool = False,
|
701
696
|
**barplot_kwargs,
|
702
697
|
) -> Figure | None:
|
@@ -762,10 +757,9 @@ class MethodBase(ABC):
|
|
762
757
|
plt.xlabel("")
|
763
758
|
plt.ylabel(y_label)
|
764
759
|
|
765
|
-
if show:
|
766
|
-
plt.show()
|
767
760
|
if return_fig:
|
768
761
|
return plt.gcf()
|
762
|
+
plt.show()
|
769
763
|
return None
|
770
764
|
|
771
765
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -782,7 +776,6 @@ class MethodBase(ABC):
|
|
782
776
|
figsize: tuple[int, int] = (10, 2),
|
783
777
|
x_label: str = "Contrast",
|
784
778
|
y_label: str = "Gene",
|
785
|
-
show: bool = True,
|
786
779
|
return_fig: bool = False,
|
787
780
|
**heatmap_kwargs,
|
788
781
|
) -> Figure | None:
|
@@ -880,10 +873,9 @@ class MethodBase(ABC):
|
|
880
873
|
plt.xlabel(x_label)
|
881
874
|
plt.ylabel(y_label)
|
882
875
|
|
883
|
-
if show:
|
884
|
-
plt.show()
|
885
876
|
if return_fig:
|
886
877
|
return plt.gcf()
|
878
|
+
plt.show()
|
887
879
|
return None
|
888
880
|
|
889
881
|
|
@@ -1117,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
|
|
1117
1117
|
super().__init__()
|
1118
1118
|
self.accepts_precomputed = False
|
1119
1119
|
|
1120
|
+
@staticmethod
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
@staticmethod
|
1133
|
+
def _prep_kde_data(x, y):
|
1134
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1135
|
+
|
1136
|
+
@staticmethod
|
1137
|
+
def _grid_points(d, n_points=100):
|
1138
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1139
|
+
d_min = d.min()
|
1140
|
+
d_max = d.max()
|
1141
|
+
# Compute bin size
|
1142
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1143
|
+
d_min = d_min - d_bin
|
1144
|
+
d_max = d_max + d_bin
|
1145
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1146
|
+
|
1147
|
+
@staticmethod
|
1148
|
+
def _kde_eval_both(x_kde, y_kde, grid):
|
1149
|
+
n_points = len(grid)
|
1150
|
+
chunk_size = 10000
|
1151
|
+
|
1152
|
+
result_x = np.zeros(n_points)
|
1153
|
+
result_y = np.zeros(n_points)
|
1154
|
+
|
1155
|
+
# Process same chunks for both KDEs
|
1156
|
+
for start in range(0, n_points, chunk_size):
|
1157
|
+
end = min(start + chunk_size, n_points)
|
1158
|
+
chunk = grid[start:end]
|
1159
|
+
result_x[start:end] = x_kde.score_samples(chunk)
|
1160
|
+
result_y[start:end] = y_kde.score_samples(chunk)
|
1161
|
+
|
1162
|
+
return result_x, result_y
|
1163
|
+
|
1120
1164
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1121
1165
|
"""Difference of mean-var distributions in 2 matrices.
|
1122
|
-
|
1123
1166
|
Args:
|
1124
1167
|
X: Normalized and log transformed cells x genes count matrix.
|
1125
1168
|
Y: Normalized and log transformed cells x genes count matrix.
|
1126
1169
|
"""
|
1170
|
+
mean_x, var_x = self._mean_var(X, log=True)
|
1171
|
+
mean_y, var_y = self._mean_var(Y, log=True)
|
1127
1172
|
|
1128
|
-
|
1129
|
-
|
1130
|
-
var = np.var(x, axis=0)
|
1131
|
-
positive = mean > 0
|
1132
|
-
mean = mean[positive]
|
1133
|
-
var = var[positive]
|
1134
|
-
if log:
|
1135
|
-
mean = np.log(mean)
|
1136
|
-
var = np.log(var)
|
1137
|
-
return mean, var
|
1138
|
-
|
1139
|
-
def _prep_kde_data(x, y):
|
1140
|
-
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1141
|
-
|
1142
|
-
def _grid_points(d, n_points=100):
|
1143
|
-
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1144
|
-
d_min = d.min()
|
1145
|
-
d_max = d.max()
|
1146
|
-
# Compute bin size
|
1147
|
-
d_bin = (d_max - d_min) / (n_points - 2)
|
1148
|
-
d_min = d_min - d_bin
|
1149
|
-
d_max = d_max + d_bin
|
1150
|
-
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1151
|
-
|
1152
|
-
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1153
|
-
# the thread_count is determined using the factor 0.875 as recommended here:
|
1154
|
-
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1155
|
-
with multiprocessing.Pool(thread_count) as p:
|
1156
|
-
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1157
|
-
|
1158
|
-
def _kde_eval(d, grid):
|
1159
|
-
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1160
|
-
# can not be compared well on regions further away from the data as they are -inf
|
1161
|
-
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1162
|
-
return _parallel_score_samples(kde, grid)
|
1163
|
-
|
1164
|
-
mean_x, var_x = _mean_var(X, log=True)
|
1165
|
-
mean_y, var_y = _mean_var(Y, log=True)
|
1166
|
-
|
1167
|
-
x = _prep_kde_data(mean_x, var_x)
|
1168
|
-
y = _prep_kde_data(mean_y, var_y)
|
1173
|
+
x = self._prep_kde_data(mean_x, var_x)
|
1174
|
+
y = self._prep_kde_data(mean_y, var_y)
|
1169
1175
|
|
1170
1176
|
# Gridpoints to eval KDE on
|
1171
|
-
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1172
|
-
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1177
|
+
mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
|
1178
|
+
var_grid = self._grid_points(np.concatenate([var_x, var_y]))
|
1173
1179
|
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1174
1180
|
|
1175
|
-
|
1176
|
-
|
1181
|
+
# Fit both KDEs first
|
1182
|
+
x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
|
1183
|
+
y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
|
1177
1184
|
|
1178
|
-
|
1185
|
+
# Evaluate both KDEs on same grid chunks
|
1186
|
+
kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
|
1179
1187
|
|
1180
|
-
return
|
1188
|
+
return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
|
1181
1189
|
|
1182
1190
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1183
1191
|
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
pertpy/tools/_enrichment.py
CHANGED
@@ -304,7 +304,6 @@ class Enrichment:
|
|
304
304
|
groupby: str = None,
|
305
305
|
key: str = "pertpy_enrichment",
|
306
306
|
ax: Axes | None = None,
|
307
|
-
show: bool = True,
|
308
307
|
return_fig: bool = False,
|
309
308
|
**kwargs,
|
310
309
|
) -> DotPlot | None:
|
@@ -417,10 +416,9 @@ class Enrichment:
|
|
417
416
|
**kwargs,
|
418
417
|
)
|
419
418
|
|
420
|
-
if show:
|
421
|
-
plt.show()
|
422
419
|
if return_fig:
|
423
420
|
return fig
|
421
|
+
plt.show()
|
424
422
|
return None
|
425
423
|
|
426
424
|
def plot_gsea(
|
pertpy/tools/_milo.py
CHANGED
@@ -727,7 +727,6 @@ class Milo:
|
|
727
727
|
color_map: Colormap | str | None = None,
|
728
728
|
palette: str | Sequence[str] | None = None,
|
729
729
|
ax: Axes | None = None,
|
730
|
-
show: bool = True,
|
731
730
|
return_fig: bool = False,
|
732
731
|
**kwargs,
|
733
732
|
) -> Figure | None:
|
@@ -803,10 +802,9 @@ class Milo:
|
|
803
802
|
**kwargs,
|
804
803
|
)
|
805
804
|
|
806
|
-
if show:
|
807
|
-
plt.show()
|
808
805
|
if return_fig:
|
809
806
|
return fig
|
807
|
+
plt.show()
|
810
808
|
return None
|
811
809
|
|
812
810
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -820,7 +818,6 @@ class Milo:
|
|
820
818
|
color_map: Colormap | str | None = None,
|
821
819
|
palette: str | Sequence[str] | None = None,
|
822
820
|
ax: Axes | None = None,
|
823
|
-
show: bool = True,
|
824
821
|
return_fig: bool = False,
|
825
822
|
**kwargs,
|
826
823
|
) -> Figure | None:
|
@@ -866,10 +863,9 @@ class Milo:
|
|
866
863
|
**kwargs,
|
867
864
|
)
|
868
865
|
|
869
|
-
if show:
|
870
|
-
plt.show()
|
871
866
|
if return_fig:
|
872
867
|
return fig
|
868
|
+
plt.show()
|
873
869
|
return None
|
874
870
|
|
875
871
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -882,7 +878,6 @@ class Milo:
|
|
882
878
|
alpha: float = 0.1,
|
883
879
|
subset_nhoods: list[str] = None,
|
884
880
|
palette: str | Sequence[str] | dict[str, str] | None = None,
|
885
|
-
show: bool = True,
|
886
881
|
return_fig: bool = False,
|
887
882
|
) -> Figure | None:
|
888
883
|
"""Plot beeswarm plot of logFC against nhood labels
|
@@ -994,10 +989,9 @@ class Milo:
|
|
994
989
|
plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
|
995
990
|
plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
|
996
991
|
|
997
|
-
if show:
|
998
|
-
plt.show()
|
999
992
|
if return_fig:
|
1000
993
|
return plt.gcf()
|
994
|
+
plt.show()
|
1001
995
|
return None
|
1002
996
|
|
1003
997
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1008,7 +1002,6 @@ class Milo:
|
|
1008
1002
|
*,
|
1009
1003
|
subset_nhoods: list[str] = None,
|
1010
1004
|
log_counts: bool = False,
|
1011
|
-
show: bool = True,
|
1012
1005
|
return_fig: bool = False,
|
1013
1006
|
) -> Figure | None:
|
1014
1007
|
"""Plot boxplot of cell numbers vs condition of interest.
|
@@ -1050,8 +1043,7 @@ class Milo:
|
|
1050
1043
|
plt.xticks(rotation=90)
|
1051
1044
|
plt.xlabel(test_var)
|
1052
1045
|
|
1053
|
-
if show:
|
1054
|
-
plt.show()
|
1055
1046
|
if return_fig:
|
1056
1047
|
return plt.gcf()
|
1048
|
+
plt.show()
|
1057
1049
|
return None
|