pertpy 0.9.5__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +1 -2
- pertpy/metadata/_cell_line.py +3 -5
- pertpy/preprocessing/_guide_rna.py +98 -10
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/_augur.py +32 -44
- pertpy/tools/_cinemaot.py +1 -3
- pertpy/tools/_coda/_base_coda.py +21 -29
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/_base.py +4 -12
- pertpy/tools/_distances/_distances.py +56 -48
- pertpy/tools/_enrichment.py +1 -3
- pertpy/tools/_milo.py +4 -12
- pertpy/tools/_mixscape.py +215 -127
- pertpy/tools/_perturbation_space/_simple.py +1 -3
- pertpy/tools/_scgen/_scgen.py +1 -3
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/METADATA +2 -2
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/RECORD +20 -19
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/WHEEL +0 -0
- {pertpy-0.9.5.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
pertpy/tools/_coda/_base_coda.py
CHANGED
@@ -850,7 +850,7 @@ class CompositionalModel2(ABC):
|
|
850
850
|
table = Table(title="Compositional Analysis summary", box=box.SQUARE, expand=True, highlight=True)
|
851
851
|
table.add_column("Name", justify="left", style="cyan")
|
852
852
|
table.add_column("Value", justify="left")
|
853
|
-
table.add_row("Data", "Data:
|
853
|
+
table.add_row("Data", f"Data: {data_dims[0]} samples, {data_dims[1]} cell types")
|
854
854
|
table.add_row("Reference cell type", "{}".format(str(sample_adata.uns["scCODA_params"]["reference_cell_type"])))
|
855
855
|
table.add_row("Formula", "{}".format(sample_adata.uns["scCODA_params"]["formula"]))
|
856
856
|
if extended:
|
@@ -1199,7 +1199,6 @@ class CompositionalModel2(ABC):
|
|
1199
1199
|
level_order: list[str] = None,
|
1200
1200
|
figsize: tuple[float, float] | None = None,
|
1201
1201
|
dpi: int | None = 100,
|
1202
|
-
show: bool = True,
|
1203
1202
|
return_fig: bool = False,
|
1204
1203
|
) -> Figure | None:
|
1205
1204
|
"""Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples").
|
@@ -1278,10 +1277,9 @@ class CompositionalModel2(ABC):
|
|
1278
1277
|
show_legend=show_legend,
|
1279
1278
|
)
|
1280
1279
|
|
1281
|
-
if show:
|
1282
|
-
plt.show()
|
1283
1280
|
if return_fig:
|
1284
1281
|
return plt.gcf()
|
1282
|
+
plt.show()
|
1285
1283
|
return None
|
1286
1284
|
|
1287
1285
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1300,7 +1298,6 @@ class CompositionalModel2(ABC):
|
|
1300
1298
|
args_barplot: dict | None = None,
|
1301
1299
|
figsize: tuple[float, float] | None = None,
|
1302
1300
|
dpi: int | None = 100,
|
1303
|
-
show: bool = True,
|
1304
1301
|
return_fig: bool = False,
|
1305
1302
|
) -> Figure | None:
|
1306
1303
|
"""Barplot visualization for effects.
|
@@ -1465,10 +1462,11 @@ class CompositionalModel2(ABC):
|
|
1465
1462
|
cell_types = pd.unique(plot_df["Cell Type"])
|
1466
1463
|
ax.set_xticklabels(cell_types, rotation=90)
|
1467
1464
|
|
1468
|
-
if
|
1469
|
-
|
1470
|
-
if return_fig:
|
1465
|
+
if return_fig and plot_facets:
|
1466
|
+
return g
|
1467
|
+
if return_fig and not plot_facets:
|
1471
1468
|
return plt.gcf()
|
1469
|
+
plt.show()
|
1472
1470
|
return None
|
1473
1471
|
|
1474
1472
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1489,7 +1487,6 @@ class CompositionalModel2(ABC):
|
|
1489
1487
|
level_order: list[str] = None,
|
1490
1488
|
figsize: tuple[float, float] | None = None,
|
1491
1489
|
dpi: int | None = 100,
|
1492
|
-
show: bool = True,
|
1493
1490
|
return_fig: bool = False,
|
1494
1491
|
) -> Figure | None:
|
1495
1492
|
"""Grouped boxplot visualization.
|
@@ -1697,10 +1694,11 @@ class CompositionalModel2(ABC):
|
|
1697
1694
|
title=feature_name,
|
1698
1695
|
)
|
1699
1696
|
|
1700
|
-
if
|
1701
|
-
|
1702
|
-
if return_fig:
|
1697
|
+
if return_fig and plot_facets:
|
1698
|
+
return g
|
1699
|
+
if return_fig and not plot_facets:
|
1703
1700
|
return plt.gcf()
|
1701
|
+
plt.show()
|
1704
1702
|
return None
|
1705
1703
|
|
1706
1704
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1716,7 +1714,6 @@ class CompositionalModel2(ABC):
|
|
1716
1714
|
figsize: tuple[float, float] | None = None,
|
1717
1715
|
dpi: int | None = 100,
|
1718
1716
|
ax: plt.Axes | None = None,
|
1719
|
-
show: bool = True,
|
1720
1717
|
return_fig: bool = False,
|
1721
1718
|
) -> Figure | None:
|
1722
1719
|
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
|
@@ -1820,10 +1817,9 @@ class CompositionalModel2(ABC):
|
|
1820
1817
|
|
1821
1818
|
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1, title="Is abundant")
|
1822
1819
|
|
1823
|
-
if show:
|
1824
|
-
plt.show()
|
1825
1820
|
if return_fig:
|
1826
1821
|
return plt.gcf()
|
1822
|
+
plt.show()
|
1827
1823
|
return None
|
1828
1824
|
|
1829
1825
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1839,7 +1835,6 @@ class CompositionalModel2(ABC):
|
|
1839
1835
|
figsize: tuple[float, float] | None = (None, None),
|
1840
1836
|
dpi: int | None = 100,
|
1841
1837
|
save: str | bool = False,
|
1842
|
-
show: bool = True,
|
1843
1838
|
return_fig: bool = False,
|
1844
1839
|
) -> Tree | None:
|
1845
1840
|
"""Plot a tree using input ete3 tree object.
|
@@ -1903,10 +1898,9 @@ class CompositionalModel2(ABC):
|
|
1903
1898
|
|
1904
1899
|
if save is not None:
|
1905
1900
|
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1906
|
-
if show:
|
1907
|
-
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1908
1901
|
if return_fig:
|
1909
1902
|
return tree, tree_style
|
1903
|
+
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
|
1910
1904
|
return None
|
1911
1905
|
|
1912
1906
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1925,7 +1919,6 @@ class CompositionalModel2(ABC):
|
|
1925
1919
|
figsize: tuple[float, float] | None = (None, None),
|
1926
1920
|
dpi: int | None = 100,
|
1927
1921
|
save: str | bool = False,
|
1928
|
-
show: bool = True,
|
1929
1922
|
return_fig: bool = False,
|
1930
1923
|
) -> Tree | None:
|
1931
1924
|
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
|
@@ -2092,15 +2085,16 @@ class CompositionalModel2(ABC):
|
|
2092
2085
|
|
2093
2086
|
if save:
|
2094
2087
|
plt.savefig(save)
|
2088
|
+
if return_fig:
|
2089
|
+
return plt.gcf()
|
2095
2090
|
|
2096
|
-
|
2097
|
-
|
2098
|
-
|
2099
|
-
if
|
2100
|
-
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2101
|
-
if return_fig:
|
2102
|
-
if not show_leaf_effects:
|
2091
|
+
else:
|
2092
|
+
if save:
|
2093
|
+
tree2.render(save, tree_style=tree_style, units=units)
|
2094
|
+
if return_fig:
|
2103
2095
|
return tree2, tree_style
|
2096
|
+
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
|
2097
|
+
|
2104
2098
|
return None
|
2105
2099
|
|
2106
2100
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -2115,7 +2109,6 @@ class CompositionalModel2(ABC):
|
|
2115
2109
|
color_map: Colormap | str | None = None,
|
2116
2110
|
palette: str | Sequence[str] | None = None,
|
2117
2111
|
ax: Axes = None,
|
2118
|
-
show: bool = True,
|
2119
2112
|
return_fig: bool = False,
|
2120
2113
|
**kwargs,
|
2121
2114
|
) -> Figure | None:
|
@@ -2209,10 +2202,9 @@ class CompositionalModel2(ABC):
|
|
2209
2202
|
**kwargs,
|
2210
2203
|
)
|
2211
2204
|
|
2212
|
-
if show:
|
2213
|
-
plt.show()
|
2214
2205
|
if return_fig:
|
2215
2206
|
return fig
|
2207
|
+
plt.show()
|
2216
2208
|
return None
|
2217
2209
|
|
2218
2210
|
|
pertpy/tools/_dialogue.py
CHANGED
@@ -82,27 +82,27 @@ class Dialogue:
|
|
82
82
|
|
83
83
|
return pseudobulk
|
84
84
|
|
85
|
-
def
|
86
|
-
|
85
|
+
def _pseudobulk_feature_space(
|
86
|
+
self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
|
87
|
+
) -> pd.DataFrame:
|
88
|
+
"""Return Cell-averaged components from a passed feature space.
|
87
89
|
|
88
90
|
TODO: consider merging with `get_pseudobulks`
|
89
91
|
TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
|
90
92
|
|
91
93
|
Args:
|
92
|
-
groupby: The key to groupby for pseudobulks
|
93
|
-
n_components: The number of
|
94
|
+
groupby: The key to groupby for pseudobulks.
|
95
|
+
n_components: The number of components to use.
|
96
|
+
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
|
94
97
|
|
95
98
|
Returns:
|
96
|
-
A pseudobulk of
|
99
|
+
A pseudobulk DataFrame of the averaged components.
|
97
100
|
"""
|
98
101
|
aggr = {}
|
99
|
-
|
100
102
|
for category in adata.obs.loc[:, groupby].cat.categories:
|
101
103
|
temp = adata.obs.loc[:, groupby] == category
|
102
|
-
aggr[category] = adata[temp].obsm[
|
103
|
-
|
104
|
+
aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
|
104
105
|
aggr = pd.DataFrame(aggr)
|
105
|
-
|
106
106
|
return aggr
|
107
107
|
|
108
108
|
def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
|
@@ -558,7 +558,7 @@ class Dialogue:
|
|
558
558
|
self,
|
559
559
|
adata: AnnData,
|
560
560
|
ct_order: list[str],
|
561
|
-
|
561
|
+
agg_feature: bool = True,
|
562
562
|
normalize: bool = True,
|
563
563
|
) -> tuple[list, dict]:
|
564
564
|
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
|
@@ -568,14 +568,14 @@ class Dialogue:
|
|
568
568
|
Args:
|
569
569
|
adata: AnnData object generate celltype objects for
|
570
570
|
ct_order: The order of cell types
|
571
|
-
|
571
|
+
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
|
572
572
|
normalize: Whether to mimic DIALOGUE behavior or not.
|
573
573
|
|
574
574
|
Returns:
|
575
575
|
A celltype_label:array dictionary.
|
576
576
|
"""
|
577
577
|
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}
|
578
|
-
fn = self.
|
578
|
+
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks
|
579
579
|
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore
|
580
580
|
|
581
581
|
# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
|
@@ -593,7 +593,7 @@ class Dialogue:
|
|
593
593
|
adata: AnnData,
|
594
594
|
penalties: list[int] = None,
|
595
595
|
ct_order: list[str] = None,
|
596
|
-
|
596
|
+
agg_feature: bool = True,
|
597
597
|
solver: Literal["lp", "bs"] = "bs",
|
598
598
|
normalize: bool = True,
|
599
599
|
) -> tuple[AnnData, dict[str, np.ndarray], dict[Any, Any], dict[Any, Any]]:
|
@@ -606,7 +606,7 @@ class Dialogue:
|
|
606
606
|
sample_id: Key to use for pseudobulk determination.
|
607
607
|
penalties: PMD penalties.
|
608
608
|
ct_order: The order of cell types.
|
609
|
-
|
609
|
+
agg_features: Whether to calculate cell-averaged principal components.
|
610
610
|
solver: Which solver to use for PMD. Must be one of "lp" (linear programming) or "bs" (binary search).
|
611
611
|
For differences between these to please refer to https://github.com/theislab/sparsecca/blob/main/examples/linear_programming_multicca.ipynb
|
612
612
|
normalize: Whether to mimic DIALOGUE as close as possible
|
@@ -631,7 +631,7 @@ class Dialogue:
|
|
631
631
|
else:
|
632
632
|
ct_order = cell_types = adata.obs[self.celltype_key].astype("category").cat.categories
|
633
633
|
|
634
|
-
mcca_in, ct_subs = self._load(adata, ct_order=cell_types,
|
634
|
+
mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)
|
635
635
|
|
636
636
|
n_samples = mcca_in[0].shape[1]
|
637
637
|
if penalties is None:
|
@@ -1070,7 +1070,6 @@ class Dialogue:
|
|
1070
1070
|
*,
|
1071
1071
|
split_which: tuple[str, str] = None,
|
1072
1072
|
mcp: str = "mcp_0",
|
1073
|
-
show: bool = True,
|
1074
1073
|
return_fig: bool = False,
|
1075
1074
|
) -> Figure | None:
|
1076
1075
|
"""Plots split violin plots for a given MCP and split variable.
|
@@ -1110,10 +1109,9 @@ class Dialogue:
|
|
1110
1109
|
ax = sns.violinplot(data=df, x=celltype_key, y=mcp, hue=split_key, split=True)
|
1111
1110
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
|
1112
1111
|
|
1113
|
-
if show:
|
1114
|
-
plt.show()
|
1115
1112
|
if return_fig:
|
1116
1113
|
return plt.gcf()
|
1114
|
+
plt.show()
|
1117
1115
|
return None
|
1118
1116
|
|
1119
1117
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1125,7 +1123,6 @@ class Dialogue:
|
|
1125
1123
|
sample_id: str,
|
1126
1124
|
*,
|
1127
1125
|
mcp: str = "mcp_0",
|
1128
|
-
show: bool = True,
|
1129
1126
|
return_fig: bool = False,
|
1130
1127
|
) -> Figure | None:
|
1131
1128
|
"""Generate a pairplot visualization for multi-cell perturbation (MCP) data.
|
@@ -1167,8 +1164,7 @@ class Dialogue:
|
|
1167
1164
|
mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1)
|
1168
1165
|
sns.pairplot(mcp_pivot, hue=color, corner=True)
|
1169
1166
|
|
1170
|
-
if show:
|
1171
|
-
plt.show()
|
1172
1167
|
if return_fig:
|
1173
1168
|
return plt.gcf()
|
1169
|
+
plt.show()
|
1174
1170
|
return None
|
@@ -125,7 +125,6 @@ class MethodBase(ABC):
|
|
125
125
|
shape_order: list[str] | None = None,
|
126
126
|
x_label: str | None = None,
|
127
127
|
y_label: str | None = None,
|
128
|
-
show: bool = True,
|
129
128
|
return_fig: bool = False,
|
130
129
|
**kwargs: int,
|
131
130
|
) -> Figure | None:
|
@@ -484,10 +483,9 @@ class MethodBase(ABC):
|
|
484
483
|
|
485
484
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
486
485
|
|
487
|
-
if show:
|
488
|
-
plt.show()
|
489
486
|
if return_fig:
|
490
487
|
return plt.gcf()
|
488
|
+
plt.show()
|
491
489
|
return None
|
492
490
|
|
493
491
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -511,7 +509,6 @@ class MethodBase(ABC):
|
|
511
509
|
pvalue_template=lambda x: f"p={x:.2e}",
|
512
510
|
boxplot_properties=None,
|
513
511
|
palette=None,
|
514
|
-
show: bool = True,
|
515
512
|
return_fig: bool = False,
|
516
513
|
) -> Figure | None:
|
517
514
|
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
@@ -679,10 +676,9 @@ class MethodBase(ABC):
|
|
679
676
|
)
|
680
677
|
|
681
678
|
plt.tight_layout()
|
682
|
-
if show:
|
683
|
-
plt.show()
|
684
679
|
if return_fig:
|
685
680
|
return plt.gcf()
|
681
|
+
plt.show()
|
686
682
|
return None
|
687
683
|
|
688
684
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -696,7 +692,6 @@ class MethodBase(ABC):
|
|
696
692
|
symbol_col: str = "variable",
|
697
693
|
y_label: str = "Log2 fold change",
|
698
694
|
figsize: tuple[int, int] = (10, 5),
|
699
|
-
show: bool = True,
|
700
695
|
return_fig: bool = False,
|
701
696
|
**barplot_kwargs,
|
702
697
|
) -> Figure | None:
|
@@ -762,10 +757,9 @@ class MethodBase(ABC):
|
|
762
757
|
plt.xlabel("")
|
763
758
|
plt.ylabel(y_label)
|
764
759
|
|
765
|
-
if show:
|
766
|
-
plt.show()
|
767
760
|
if return_fig:
|
768
761
|
return plt.gcf()
|
762
|
+
plt.show()
|
769
763
|
return None
|
770
764
|
|
771
765
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -782,7 +776,6 @@ class MethodBase(ABC):
|
|
782
776
|
figsize: tuple[int, int] = (10, 2),
|
783
777
|
x_label: str = "Contrast",
|
784
778
|
y_label: str = "Gene",
|
785
|
-
show: bool = True,
|
786
779
|
return_fig: bool = False,
|
787
780
|
**heatmap_kwargs,
|
788
781
|
) -> Figure | None:
|
@@ -880,10 +873,9 @@ class MethodBase(ABC):
|
|
880
873
|
plt.xlabel(x_label)
|
881
874
|
plt.ylabel(y_label)
|
882
875
|
|
883
|
-
if show:
|
884
|
-
plt.show()
|
885
876
|
if return_fig:
|
886
877
|
return plt.gcf()
|
878
|
+
plt.show()
|
887
879
|
return None
|
888
880
|
|
889
881
|
|
@@ -1117,67 +1117,75 @@ class MeanVarDistributionDistance(AbstractDistance):
|
|
1117
1117
|
super().__init__()
|
1118
1118
|
self.accepts_precomputed = False
|
1119
1119
|
|
1120
|
+
@staticmethod
|
1121
|
+
def _mean_var(x, log: bool = False):
|
1122
|
+
mean = np.mean(x, axis=0)
|
1123
|
+
var = np.var(x, axis=0)
|
1124
|
+
positive = mean > 0
|
1125
|
+
mean = mean[positive]
|
1126
|
+
var = var[positive]
|
1127
|
+
if log:
|
1128
|
+
mean = np.log(mean)
|
1129
|
+
var = np.log(var)
|
1130
|
+
return mean, var
|
1131
|
+
|
1132
|
+
@staticmethod
|
1133
|
+
def _prep_kde_data(x, y):
|
1134
|
+
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1135
|
+
|
1136
|
+
@staticmethod
|
1137
|
+
def _grid_points(d, n_points=100):
|
1138
|
+
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1139
|
+
d_min = d.min()
|
1140
|
+
d_max = d.max()
|
1141
|
+
# Compute bin size
|
1142
|
+
d_bin = (d_max - d_min) / (n_points - 2)
|
1143
|
+
d_min = d_min - d_bin
|
1144
|
+
d_max = d_max + d_bin
|
1145
|
+
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1146
|
+
|
1147
|
+
@staticmethod
|
1148
|
+
def _kde_eval_both(x_kde, y_kde, grid):
|
1149
|
+
n_points = len(grid)
|
1150
|
+
chunk_size = 10000
|
1151
|
+
|
1152
|
+
result_x = np.zeros(n_points)
|
1153
|
+
result_y = np.zeros(n_points)
|
1154
|
+
|
1155
|
+
# Process same chunks for both KDEs
|
1156
|
+
for start in range(0, n_points, chunk_size):
|
1157
|
+
end = min(start + chunk_size, n_points)
|
1158
|
+
chunk = grid[start:end]
|
1159
|
+
result_x[start:end] = x_kde.score_samples(chunk)
|
1160
|
+
result_y[start:end] = y_kde.score_samples(chunk)
|
1161
|
+
|
1162
|
+
return result_x, result_y
|
1163
|
+
|
1120
1164
|
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
1121
1165
|
"""Difference of mean-var distributions in 2 matrices.
|
1122
|
-
|
1123
1166
|
Args:
|
1124
1167
|
X: Normalized and log transformed cells x genes count matrix.
|
1125
1168
|
Y: Normalized and log transformed cells x genes count matrix.
|
1126
1169
|
"""
|
1170
|
+
mean_x, var_x = self._mean_var(X, log=True)
|
1171
|
+
mean_y, var_y = self._mean_var(Y, log=True)
|
1127
1172
|
|
1128
|
-
|
1129
|
-
|
1130
|
-
var = np.var(x, axis=0)
|
1131
|
-
positive = mean > 0
|
1132
|
-
mean = mean[positive]
|
1133
|
-
var = var[positive]
|
1134
|
-
if log:
|
1135
|
-
mean = np.log(mean)
|
1136
|
-
var = np.log(var)
|
1137
|
-
return mean, var
|
1138
|
-
|
1139
|
-
def _prep_kde_data(x, y):
|
1140
|
-
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)
|
1141
|
-
|
1142
|
-
def _grid_points(d, n_points=100):
|
1143
|
-
# Make grid, add 1 bin on lower/upper end to get final n_points
|
1144
|
-
d_min = d.min()
|
1145
|
-
d_max = d.max()
|
1146
|
-
# Compute bin size
|
1147
|
-
d_bin = (d_max - d_min) / (n_points - 2)
|
1148
|
-
d_min = d_min - d_bin
|
1149
|
-
d_max = d_max + d_bin
|
1150
|
-
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)
|
1151
|
-
|
1152
|
-
def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
|
1153
|
-
# the thread_count is determined using the factor 0.875 as recommended here:
|
1154
|
-
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
|
1155
|
-
with multiprocessing.Pool(thread_count) as p:
|
1156
|
-
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))
|
1157
|
-
|
1158
|
-
def _kde_eval(d, grid):
|
1159
|
-
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
|
1160
|
-
# can not be compared well on regions further away from the data as they are -inf
|
1161
|
-
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
|
1162
|
-
return _parallel_score_samples(kde, grid)
|
1163
|
-
|
1164
|
-
mean_x, var_x = _mean_var(X, log=True)
|
1165
|
-
mean_y, var_y = _mean_var(Y, log=True)
|
1166
|
-
|
1167
|
-
x = _prep_kde_data(mean_x, var_x)
|
1168
|
-
y = _prep_kde_data(mean_y, var_y)
|
1173
|
+
x = self._prep_kde_data(mean_x, var_x)
|
1174
|
+
y = self._prep_kde_data(mean_y, var_y)
|
1169
1175
|
|
1170
1176
|
# Gridpoints to eval KDE on
|
1171
|
-
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
|
1172
|
-
var_grid = _grid_points(np.concatenate([var_x, var_y]))
|
1177
|
+
mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
|
1178
|
+
var_grid = self._grid_points(np.concatenate([var_x, var_y]))
|
1173
1179
|
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)
|
1174
1180
|
|
1175
|
-
|
1176
|
-
|
1181
|
+
# Fit both KDEs first
|
1182
|
+
x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
|
1183
|
+
y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)
|
1177
1184
|
|
1178
|
-
|
1185
|
+
# Evaluate both KDEs on same grid chunks
|
1186
|
+
kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)
|
1179
1187
|
|
1180
|
-
return
|
1188
|
+
return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()
|
1181
1189
|
|
1182
1190
|
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
|
1183
1191
|
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
|
pertpy/tools/_enrichment.py
CHANGED
@@ -304,7 +304,6 @@ class Enrichment:
|
|
304
304
|
groupby: str = None,
|
305
305
|
key: str = "pertpy_enrichment",
|
306
306
|
ax: Axes | None = None,
|
307
|
-
show: bool = True,
|
308
307
|
return_fig: bool = False,
|
309
308
|
**kwargs,
|
310
309
|
) -> DotPlot | None:
|
@@ -417,10 +416,9 @@ class Enrichment:
|
|
417
416
|
**kwargs,
|
418
417
|
)
|
419
418
|
|
420
|
-
if show:
|
421
|
-
plt.show()
|
422
419
|
if return_fig:
|
423
420
|
return fig
|
421
|
+
plt.show()
|
424
422
|
return None
|
425
423
|
|
426
424
|
def plot_gsea(
|
pertpy/tools/_milo.py
CHANGED
@@ -727,7 +727,6 @@ class Milo:
|
|
727
727
|
color_map: Colormap | str | None = None,
|
728
728
|
palette: str | Sequence[str] | None = None,
|
729
729
|
ax: Axes | None = None,
|
730
|
-
show: bool = True,
|
731
730
|
return_fig: bool = False,
|
732
731
|
**kwargs,
|
733
732
|
) -> Figure | None:
|
@@ -803,10 +802,9 @@ class Milo:
|
|
803
802
|
**kwargs,
|
804
803
|
)
|
805
804
|
|
806
|
-
if show:
|
807
|
-
plt.show()
|
808
805
|
if return_fig:
|
809
806
|
return fig
|
807
|
+
plt.show()
|
810
808
|
return None
|
811
809
|
|
812
810
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -820,7 +818,6 @@ class Milo:
|
|
820
818
|
color_map: Colormap | str | None = None,
|
821
819
|
palette: str | Sequence[str] | None = None,
|
822
820
|
ax: Axes | None = None,
|
823
|
-
show: bool = True,
|
824
821
|
return_fig: bool = False,
|
825
822
|
**kwargs,
|
826
823
|
) -> Figure | None:
|
@@ -866,10 +863,9 @@ class Milo:
|
|
866
863
|
**kwargs,
|
867
864
|
)
|
868
865
|
|
869
|
-
if show:
|
870
|
-
plt.show()
|
871
866
|
if return_fig:
|
872
867
|
return fig
|
868
|
+
plt.show()
|
873
869
|
return None
|
874
870
|
|
875
871
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -882,7 +878,6 @@ class Milo:
|
|
882
878
|
alpha: float = 0.1,
|
883
879
|
subset_nhoods: list[str] = None,
|
884
880
|
palette: str | Sequence[str] | dict[str, str] | None = None,
|
885
|
-
show: bool = True,
|
886
881
|
return_fig: bool = False,
|
887
882
|
) -> Figure | None:
|
888
883
|
"""Plot beeswarm plot of logFC against nhood labels
|
@@ -994,10 +989,9 @@ class Milo:
|
|
994
989
|
plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False)
|
995
990
|
plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--")
|
996
991
|
|
997
|
-
if show:
|
998
|
-
plt.show()
|
999
992
|
if return_fig:
|
1000
993
|
return plt.gcf()
|
994
|
+
plt.show()
|
1001
995
|
return None
|
1002
996
|
|
1003
997
|
@_doc_params(common_plot_args=doc_common_plot_args)
|
@@ -1008,7 +1002,6 @@ class Milo:
|
|
1008
1002
|
*,
|
1009
1003
|
subset_nhoods: list[str] = None,
|
1010
1004
|
log_counts: bool = False,
|
1011
|
-
show: bool = True,
|
1012
1005
|
return_fig: bool = False,
|
1013
1006
|
) -> Figure | None:
|
1014
1007
|
"""Plot boxplot of cell numbers vs condition of interest.
|
@@ -1050,8 +1043,7 @@ class Milo:
|
|
1050
1043
|
plt.xticks(rotation=90)
|
1051
1044
|
plt.xlabel(test_var)
|
1052
1045
|
|
1053
|
-
if show:
|
1054
|
-
plt.show()
|
1055
1046
|
if return_fig:
|
1056
1047
|
return plt.gcf()
|
1048
|
+
plt.show()
|
1057
1049
|
return None
|