pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
-
import
|
1
|
+
import math
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from
|
4
|
-
from itertools import
|
3
|
+
from collections.abc import Iterable, Mapping, Sequence
|
4
|
+
from itertools import zip_longest
|
5
5
|
from types import MappingProxyType
|
6
6
|
|
7
7
|
import adjustText
|
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
import seaborn as sns
|
14
|
+
from formulaic_contrasts import FormulaicContrasts
|
15
|
+
from lamin_utils import logger
|
16
|
+
from matplotlib.pyplot import Figure
|
14
17
|
from matplotlib.ticker import MaxNLocator
|
15
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
20
|
+
from pertpy.tools import PseudobulkSpace
|
16
21
|
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
|
17
|
-
from pertpy.tools._differential_gene_expression._formulaic import (
|
18
|
-
AmbiguousAttributeError,
|
19
|
-
Factor,
|
20
|
-
get_factor_storage_and_materializer,
|
21
|
-
resolve_ambiguous,
|
22
|
-
)
|
23
|
-
|
24
|
-
|
25
|
-
@dataclass
|
26
|
-
class Contrast:
|
27
|
-
"""Simple contrast for comparison between groups"""
|
28
|
-
|
29
|
-
column: str
|
30
|
-
baseline: str
|
31
|
-
group_to_compare: str
|
32
|
-
|
33
|
-
|
34
|
-
ContrastType = Contrast | tuple[str, str, str]
|
35
22
|
|
36
23
|
|
37
24
|
class MethodBase(ABC):
|
@@ -58,7 +45,7 @@ class MethodBase(ABC):
|
|
58
45
|
if self.layer is None:
|
59
46
|
return self.adata.X
|
60
47
|
else:
|
61
|
-
return self.adata.
|
48
|
+
return self.adata.layers[self.layer]
|
62
49
|
|
63
50
|
@classmethod
|
64
51
|
@abstractmethod
|
@@ -91,9 +78,28 @@ class MethodBase(ABC):
|
|
91
78
|
|
92
79
|
Returns:
|
93
80
|
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
|
81
|
+
|
82
|
+
Examples:
|
83
|
+
>>> # Example with EdgeR
|
84
|
+
>>> import pertpy as pt
|
85
|
+
>>> adata = pt.dt.zhang_2021()
|
86
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
87
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
88
|
+
>>> pdata = ps.compute(
|
89
|
+
... adata,
|
90
|
+
... target_col="Patient",
|
91
|
+
... groups_col="Cluster",
|
92
|
+
... layer_key="counts",
|
93
|
+
... mode="sum",
|
94
|
+
... min_cells=10,
|
95
|
+
... min_counts=1000,
|
96
|
+
... )
|
97
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
98
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
94
99
|
"""
|
95
100
|
...
|
96
101
|
|
102
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
97
103
|
def plot_volcano(
|
98
104
|
self,
|
99
105
|
data: pd.DataFrame | ad.AnnData,
|
@@ -115,13 +121,14 @@ class MethodBase(ABC):
|
|
115
121
|
figsize: tuple[int, int] = (5, 5),
|
116
122
|
legend_pos: tuple[float, float] = (1.6, 1),
|
117
123
|
point_sizes: tuple[int, int] = (15, 150),
|
118
|
-
save: bool | str | None = None,
|
119
124
|
shapes: list[str] | None = None,
|
120
125
|
shape_order: list[str] | None = None,
|
121
126
|
x_label: str | None = None,
|
122
127
|
y_label: str | None = None,
|
128
|
+
show: bool = True,
|
129
|
+
return_fig: bool = False,
|
123
130
|
**kwargs: int,
|
124
|
-
) -> None:
|
131
|
+
) -> Figure | None:
|
125
132
|
"""Creates a volcano plot from a pandas DataFrame or Anndata.
|
126
133
|
|
127
134
|
Args:
|
@@ -143,12 +150,40 @@ class MethodBase(ABC):
|
|
143
150
|
top_right_frame: Whether to show the top and right frame of the plot.
|
144
151
|
figsize: Size of the figure.
|
145
152
|
legend_pos: Position of the legend as determined by matplotlib.
|
146
|
-
save: Saves the plot if True or to the path provided.
|
147
153
|
shapes: List of matplotlib marker ids.
|
148
154
|
shape_order: Order of categories for shapes.
|
149
155
|
x_label: Label for the x-axis.
|
150
156
|
y_label: Label for the y-axis.
|
157
|
+
{common_plot_args}
|
151
158
|
**kwargs: Additional arguments for seaborn.scatterplot.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
162
|
+
|
163
|
+
Examples:
|
164
|
+
>>> # Example with EdgeR
|
165
|
+
>>> import pertpy as pt
|
166
|
+
>>> adata = pt.dt.zhang_2021()
|
167
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
168
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
169
|
+
>>> pdata = ps.compute(
|
170
|
+
... adata,
|
171
|
+
... target_col="Patient",
|
172
|
+
... groups_col="Cluster",
|
173
|
+
... layer_key="counts",
|
174
|
+
... mode="sum",
|
175
|
+
... min_cells=10,
|
176
|
+
... min_counts=1000,
|
177
|
+
... )
|
178
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
179
|
+
>>> edgr.fit()
|
180
|
+
>>> res_df = edgr.test_contrasts(
|
181
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
182
|
+
... )
|
183
|
+
>>> edgr.plot_volcano(res_df, log2fc_thresh=0)
|
184
|
+
|
185
|
+
Preview:
|
186
|
+
.. image:: /_static/docstring_previews/de_volcano.png
|
152
187
|
"""
|
153
188
|
if colors is None:
|
154
189
|
colors = ["gray", "#D62728", "#1F77B4"]
|
@@ -243,7 +278,7 @@ class MethodBase(ABC):
|
|
243
278
|
if varm_key is None:
|
244
279
|
raise ValueError("Please pass a .varm key to use for plotting")
|
245
280
|
|
246
|
-
raise NotImplementedError("Anndata not implemented yet")
|
281
|
+
raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
|
247
282
|
df = data.varm[varm_key].copy()
|
248
283
|
|
249
284
|
df = data.copy(deep=True)
|
@@ -449,26 +484,412 @@ class MethodBase(ABC):
|
|
449
484
|
|
450
485
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
451
486
|
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
487
|
+
if show:
|
488
|
+
plt.show()
|
489
|
+
if return_fig:
|
490
|
+
return plt.gcf()
|
491
|
+
return None
|
492
|
+
|
493
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
494
|
+
def plot_paired(
|
495
|
+
self,
|
496
|
+
adata: ad.AnnData,
|
497
|
+
results_df: pd.DataFrame,
|
498
|
+
groupby: str,
|
499
|
+
pairedby: str,
|
500
|
+
*,
|
501
|
+
var_names: Sequence[str] = None,
|
502
|
+
n_top_vars: int = 15,
|
503
|
+
layer: str = None,
|
504
|
+
pvalue_col: str = "adj_p_value",
|
505
|
+
symbol_col: str = "variable",
|
506
|
+
n_cols: int = 4,
|
507
|
+
panel_size: tuple[int, int] = (5, 5),
|
508
|
+
show_legend: bool = True,
|
509
|
+
size: int = 10,
|
510
|
+
y_label: str = "expression",
|
511
|
+
pvalue_template=lambda x: f"p={x:.2e}",
|
512
|
+
boxplot_properties=None,
|
513
|
+
palette=None,
|
514
|
+
show: bool = True,
|
515
|
+
return_fig: bool = False,
|
516
|
+
) -> Figure | None:
|
517
|
+
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
518
|
+
|
519
|
+
Visualizes a panel of paired scatterplots per variable.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
adata: AnnData object, can be pseudobulked.
|
523
|
+
results_df: DataFrame with results from a differential expression test.
|
524
|
+
groupby: .obs column containing the grouping. Must contain exactly two different values.
|
525
|
+
pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
|
526
|
+
var_names: Variables to plot.
|
527
|
+
n_top_vars: Number of top variables to plot.
|
528
|
+
layer: Layer to use for plotting.
|
529
|
+
pvalue_col: Column name of the p values.
|
530
|
+
symbol_col: Column name of gene IDs.
|
531
|
+
n_cols: Number of columns in the plot.
|
532
|
+
panel_size: Size of each panel.
|
533
|
+
show_legend: Whether to show the legend.
|
534
|
+
size: Size of the points.
|
535
|
+
y_label: Label for the y-axis.
|
536
|
+
pvalue_template: Template for the p-value string displayed in the title of each panel.
|
537
|
+
boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
|
538
|
+
palette: Color palette for the line- and stripplot.
|
539
|
+
{common_plot_args}
|
540
|
+
|
541
|
+
Returns:
|
542
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
543
|
+
|
544
|
+
Examples:
|
545
|
+
>>> # Example with EdgeR
|
546
|
+
>>> import pertpy as pt
|
547
|
+
>>> adata = pt.dt.zhang_2021()
|
548
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
549
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
550
|
+
>>> pdata = ps.compute(
|
551
|
+
... adata,
|
552
|
+
... target_col="Patient",
|
553
|
+
... groups_col="Cluster",
|
554
|
+
... layer_key="counts",
|
555
|
+
... mode="sum",
|
556
|
+
... min_cells=10,
|
557
|
+
... min_counts=1000,
|
558
|
+
... )
|
559
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
560
|
+
>>> edgr.fit()
|
561
|
+
>>> res_df = edgr.test_contrasts(
|
562
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
563
|
+
... )
|
564
|
+
>>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
|
565
|
+
|
566
|
+
Preview:
|
567
|
+
.. image:: /_static/docstring_previews/de_paired_expression.png
|
568
|
+
"""
|
569
|
+
if boxplot_properties is None:
|
570
|
+
boxplot_properties = {}
|
571
|
+
groups = adata.obs[groupby].unique()
|
572
|
+
if len(groups) != 2:
|
573
|
+
raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
|
574
|
+
|
575
|
+
if var_names is None:
|
576
|
+
var_names = results_df.head(n_top_vars)[symbol_col].tolist()
|
577
|
+
|
578
|
+
adata = adata[:, var_names]
|
579
|
+
|
580
|
+
if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
|
581
|
+
logger.info("Performing pseudobulk for paired samples")
|
582
|
+
ps = PseudobulkSpace()
|
583
|
+
adata = ps.compute(
|
584
|
+
adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
|
585
|
+
)
|
586
|
+
|
587
|
+
if layer is not None:
|
588
|
+
X = adata.layers[layer]
|
589
|
+
else:
|
590
|
+
X = adata.X
|
591
|
+
try:
|
592
|
+
X = X.toarray()
|
593
|
+
except AttributeError:
|
594
|
+
pass
|
595
|
+
|
596
|
+
groupby_cols = [pairedby, groupby]
|
597
|
+
df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
|
598
|
+
|
599
|
+
# remove unpaired samples
|
600
|
+
paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
|
601
|
+
df = df[df[pairedby].isin(paired_samples)]
|
602
|
+
removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
|
603
|
+
if removed_samples > 0:
|
604
|
+
logger.warning(f"{removed_samples} unpaired samples removed")
|
605
|
+
|
606
|
+
pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
|
607
|
+
df.reset_index(drop=False, inplace=True)
|
608
|
+
|
609
|
+
# transform data for seaborn
|
610
|
+
df_melt = df.melt(
|
611
|
+
id_vars=groupby_cols,
|
612
|
+
var_name="var",
|
613
|
+
value_name="val",
|
614
|
+
)
|
615
|
+
|
616
|
+
n_panels = len(var_names)
|
617
|
+
nrows = math.ceil(n_panels / n_cols)
|
618
|
+
ncols = min(n_cols, n_panels)
|
619
|
+
|
620
|
+
fig, axes = plt.subplots(
|
621
|
+
nrows,
|
622
|
+
ncols,
|
623
|
+
figsize=(ncols * panel_size[0], nrows * panel_size[1]),
|
624
|
+
tight_layout=True,
|
625
|
+
squeeze=False,
|
626
|
+
)
|
627
|
+
axes = axes.flatten()
|
628
|
+
for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
|
629
|
+
if var is not None:
|
630
|
+
sns.boxplot(
|
631
|
+
x=groupby,
|
632
|
+
data=df_melt.loc[df_melt["var"] == var],
|
633
|
+
y="val",
|
634
|
+
ax=ax,
|
635
|
+
color="white",
|
636
|
+
fliersize=0,
|
637
|
+
**boxplot_properties,
|
638
|
+
)
|
639
|
+
if pairedby is not None:
|
640
|
+
sns.lineplot(
|
641
|
+
x=groupby,
|
642
|
+
data=df_melt.loc[df_melt["var"] == var],
|
643
|
+
y="val",
|
644
|
+
ax=ax,
|
645
|
+
hue=pairedby,
|
646
|
+
legend=False,
|
647
|
+
errorbar=None,
|
648
|
+
palette=palette,
|
649
|
+
)
|
650
|
+
jitter = 0 if pairedby else True
|
651
|
+
sns.stripplot(
|
652
|
+
x=groupby,
|
653
|
+
data=df_melt.loc[df_melt["var"] == var],
|
654
|
+
y="val",
|
655
|
+
ax=ax,
|
656
|
+
hue=pairedby,
|
657
|
+
jitter=jitter,
|
658
|
+
size=size,
|
659
|
+
linewidth=1,
|
660
|
+
palette=palette,
|
661
|
+
)
|
662
|
+
|
663
|
+
ax.set_xlabel("")
|
664
|
+
ax.tick_params(
|
665
|
+
axis="x",
|
666
|
+
labelsize=15,
|
667
|
+
)
|
668
|
+
ax.legend().set_visible(False)
|
669
|
+
ax.set_ylabel(y_label)
|
670
|
+
ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
|
671
|
+
else:
|
672
|
+
ax.set_visible(False)
|
673
|
+
fig.tight_layout()
|
674
|
+
|
675
|
+
if show_legend is True:
|
676
|
+
axes[n_panels - 1].legend().set_visible(True)
|
677
|
+
axes[n_panels - 1].legend(
|
678
|
+
bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
|
679
|
+
)
|
680
|
+
|
681
|
+
plt.tight_layout()
|
682
|
+
if show:
|
683
|
+
plt.show()
|
684
|
+
if return_fig:
|
685
|
+
return plt.gcf()
|
686
|
+
return None
|
687
|
+
|
688
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
689
|
+
def plot_fold_change(
|
690
|
+
self,
|
691
|
+
results_df: pd.DataFrame,
|
692
|
+
*,
|
693
|
+
var_names: Sequence[str] = None,
|
694
|
+
n_top_vars: int = 15,
|
695
|
+
log2fc_col: str = "log_fc",
|
696
|
+
symbol_col: str = "variable",
|
697
|
+
y_label: str = "Log2 fold change",
|
698
|
+
figsize: tuple[int, int] = (10, 5),
|
699
|
+
show: bool = True,
|
700
|
+
return_fig: bool = False,
|
701
|
+
**barplot_kwargs,
|
702
|
+
) -> Figure | None:
|
703
|
+
"""Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
|
704
|
+
|
705
|
+
Args:
|
706
|
+
results_df: DataFrame with results from DE analysis.
|
707
|
+
var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
|
708
|
+
n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
|
709
|
+
log2fc_col: Column name of log2 Fold-Change values.
|
710
|
+
symbol_col: Column name of gene IDs.
|
711
|
+
y_label: Label for the y-axis.
|
712
|
+
figsize: Size of the figure.
|
713
|
+
{common_plot_args}
|
714
|
+
**barplot_kwargs: Additional arguments for seaborn.barplot.
|
715
|
+
|
716
|
+
Returns:
|
717
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
718
|
+
|
719
|
+
Examples:
|
720
|
+
>>> # Example with EdgeR
|
721
|
+
>>> import pertpy as pt
|
722
|
+
>>> adata = pt.dt.zhang_2021()
|
723
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
724
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
725
|
+
>>> pdata = ps.compute(
|
726
|
+
... adata,
|
727
|
+
... target_col="Patient",
|
728
|
+
... groups_col="Cluster",
|
729
|
+
... layer_key="counts",
|
730
|
+
... mode="sum",
|
731
|
+
... min_cells=10,
|
732
|
+
... min_counts=1000,
|
733
|
+
... )
|
734
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
735
|
+
>>> edgr.fit()
|
736
|
+
>>> res_df = edgr.test_contrasts(
|
737
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
738
|
+
... )
|
739
|
+
>>> edgr.plot_fold_change(res_df)
|
740
|
+
|
741
|
+
Preview:
|
742
|
+
.. image:: /_static/docstring_previews/de_fold_change.png
|
743
|
+
"""
|
744
|
+
if var_names is None:
|
745
|
+
var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
|
746
|
+
var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
|
747
|
+
assert len(var_names) == 2 * n_top_vars
|
748
|
+
|
749
|
+
df = results_df[results_df[symbol_col].isin(var_names)]
|
750
|
+
df.sort_values(log2fc_col, ascending=False, inplace=True)
|
751
|
+
|
752
|
+
plt.figure(figsize=figsize)
|
753
|
+
sns.barplot(
|
754
|
+
x=symbol_col,
|
755
|
+
y=log2fc_col,
|
756
|
+
data=df,
|
757
|
+
palette="RdBu",
|
758
|
+
legend=False,
|
759
|
+
**barplot_kwargs,
|
760
|
+
)
|
761
|
+
plt.xticks(rotation=90)
|
762
|
+
plt.xlabel("")
|
763
|
+
plt.ylabel(y_label)
|
764
|
+
|
765
|
+
if show:
|
766
|
+
plt.show()
|
767
|
+
if return_fig:
|
768
|
+
return plt.gcf()
|
769
|
+
return None
|
770
|
+
|
771
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
772
|
+
def plot_multicomparison_fc(
|
773
|
+
self,
|
774
|
+
results_df: pd.DataFrame,
|
775
|
+
*,
|
776
|
+
n_top_vars=15,
|
777
|
+
contrast_col: str = "contrast",
|
778
|
+
log2fc_col: str = "log_fc",
|
779
|
+
pvalue_col: str = "adj_p_value",
|
780
|
+
symbol_col: str = "variable",
|
781
|
+
marker_size: int = 100,
|
782
|
+
figsize: tuple[int, int] = (10, 2),
|
783
|
+
x_label: str = "Contrast",
|
784
|
+
y_label: str = "Gene",
|
785
|
+
show: bool = True,
|
786
|
+
return_fig: bool = False,
|
787
|
+
**heatmap_kwargs,
|
788
|
+
) -> Figure | None:
|
789
|
+
"""Plot a matrix of log2 fold changes from the results.
|
790
|
+
|
791
|
+
Args:
|
792
|
+
results_df: DataFrame with results from DE analysis.
|
793
|
+
n_top_vars: Number of top variables to plot per group.
|
794
|
+
contrast_col: Column in results_df containing information about the contrast.
|
795
|
+
log2fc_col: Column in results_df containing the log2 fold change.
|
796
|
+
pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
|
797
|
+
symbol_col: Column in results_df containing the gene symbol.
|
798
|
+
marker_size: Size of the biggest marker for significant variables.
|
799
|
+
figsize: Size of the figure.
|
800
|
+
x_label: Label for the x-axis.
|
801
|
+
y_label: Label for the y-axis.
|
802
|
+
{common_plot_args}
|
803
|
+
**heatmap_kwargs: Additional arguments for seaborn.heatmap.
|
804
|
+
|
805
|
+
Returns:
|
806
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
807
|
+
|
808
|
+
Examples:
|
809
|
+
>>> # Example with EdgeR
|
810
|
+
>>> import pertpy as pt
|
811
|
+
>>> adata = pt.dt.zhang_2021()
|
812
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
813
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
814
|
+
>>> pdata = ps.compute(
|
815
|
+
... adata,
|
816
|
+
... target_col="Patient",
|
817
|
+
... groups_col="Cluster",
|
818
|
+
... layer_key="counts",
|
819
|
+
... mode="sum",
|
820
|
+
... min_cells=10,
|
821
|
+
... min_counts=1000,
|
822
|
+
... )
|
823
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
824
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
825
|
+
>>> edgr.plot_multicomparison_fc(res_df)
|
826
|
+
|
827
|
+
Preview:
|
828
|
+
.. image:: /_static/docstring_previews/de_multicomparison_fc.png
|
829
|
+
"""
|
830
|
+
groups = results_df[contrast_col].unique().tolist()
|
831
|
+
|
832
|
+
results_df["abs_log_fc"] = results_df[log2fc_col].abs()
|
833
|
+
|
834
|
+
def _get_significance(p_val):
|
835
|
+
if p_val < 0.001:
|
836
|
+
return "< 0.001"
|
837
|
+
elif p_val < 0.01:
|
838
|
+
return "< 0.01"
|
839
|
+
elif p_val < 0.1:
|
840
|
+
return "< 0.1"
|
841
|
+
else:
|
842
|
+
return "n.s."
|
843
|
+
|
844
|
+
results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
|
845
|
+
|
846
|
+
var_names = []
|
847
|
+
for group in groups:
|
848
|
+
var_names += (
|
849
|
+
results_df[results_df[contrast_col] == group]
|
850
|
+
.sort_values("abs_log_fc", ascending=False)
|
851
|
+
.head(n_top_vars)[symbol_col]
|
852
|
+
.tolist()
|
853
|
+
)
|
854
|
+
|
855
|
+
results_df = results_df[results_df[symbol_col].isin(var_names)]
|
856
|
+
df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
|
464
857
|
|
465
|
-
plt.
|
858
|
+
plt.figure(figsize=figsize)
|
859
|
+
sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
|
860
|
+
|
861
|
+
_size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
|
862
|
+
x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
|
863
|
+
y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
|
864
|
+
|
865
|
+
for _i, row in results_df.iterrows():
|
866
|
+
if row["significance"] != "n.s.":
|
867
|
+
plt.scatter(
|
868
|
+
x=x_locs[x_labels.index(row[symbol_col])],
|
869
|
+
y=y_locs[y_labels.index(row[contrast_col])],
|
870
|
+
s=_size[row["significance"]],
|
871
|
+
marker="*",
|
872
|
+
c="white",
|
873
|
+
)
|
874
|
+
|
875
|
+
plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
|
876
|
+
plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
|
877
|
+
plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
|
878
|
+
plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
|
879
|
+
|
880
|
+
plt.xlabel(x_label)
|
881
|
+
plt.ylabel(y_label)
|
882
|
+
|
883
|
+
if show:
|
884
|
+
plt.show()
|
885
|
+
if return_fig:
|
886
|
+
return plt.gcf()
|
887
|
+
return None
|
466
888
|
|
467
889
|
|
468
890
|
class LinearModelBase(MethodBase):
|
469
891
|
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
|
470
|
-
"""
|
471
|
-
Initialize the method.
|
892
|
+
"""Initialize the method.
|
472
893
|
|
473
894
|
Args:
|
474
895
|
adata: AnnData object, usually pseudobulked.
|
@@ -480,26 +901,24 @@ class LinearModelBase(MethodBase):
|
|
480
901
|
super().__init__(adata, mask=mask, layer=layer)
|
481
902
|
self._check_counts()
|
482
903
|
|
483
|
-
self.
|
484
|
-
self.variable_to_factors = None
|
485
|
-
|
904
|
+
self.formulaic_contrasts = None
|
486
905
|
if isinstance(design, str):
|
487
|
-
self.
|
488
|
-
self.design =
|
906
|
+
self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
|
907
|
+
self.design = self.formulaic_contrasts.design_matrix
|
489
908
|
else:
|
490
909
|
self.design = design
|
491
910
|
|
492
911
|
@classmethod
|
493
912
|
def compare_groups(
|
494
913
|
cls,
|
495
|
-
adata,
|
496
|
-
column,
|
497
|
-
baseline,
|
498
|
-
groups_to_compare,
|
914
|
+
adata: ad.AnnData,
|
915
|
+
column: str,
|
916
|
+
baseline: str,
|
917
|
+
groups_to_compare: str | Iterable[str],
|
499
918
|
*,
|
500
|
-
paired_by=None,
|
501
|
-
mask=None,
|
502
|
-
layer=None,
|
919
|
+
paired_by: str | None = None,
|
920
|
+
mask: pd.Series | None = None,
|
921
|
+
layer: str | None = None,
|
503
922
|
fit_kwargs=MappingProxyType({}),
|
504
923
|
test_kwargs=MappingProxyType({}),
|
505
924
|
):
|
@@ -525,17 +944,16 @@ class LinearModelBase(MethodBase):
|
|
525
944
|
@property
|
526
945
|
def variables(self):
|
527
946
|
"""Get the names of the variables used in the model definition."""
|
528
|
-
|
529
|
-
return self.design.model_spec.variables_by_source["data"]
|
530
|
-
except AttributeError:
|
947
|
+
if self.formulaic_contrasts is None:
|
531
948
|
raise ValueError(
|
532
949
|
"Retrieving variables is only possible if the model was initialized using a formula."
|
533
950
|
) from None
|
951
|
+
else:
|
952
|
+
return self.formulaic_contrasts.variables
|
534
953
|
|
535
954
|
@abstractmethod
|
536
955
|
def _check_counts(self):
|
537
|
-
"""
|
538
|
-
Check that counts are valid for the specific method.
|
956
|
+
"""Check that counts are valid for the specific method.
|
539
957
|
|
540
958
|
Raises:
|
541
959
|
ValueError: if the data matrix does not comply with the expectations.
|
@@ -544,8 +962,7 @@ class LinearModelBase(MethodBase):
|
|
544
962
|
|
545
963
|
@abstractmethod
|
546
964
|
def fit(self, **kwargs):
|
547
|
-
"""
|
548
|
-
Fit the model.
|
965
|
+
"""Fit the model.
|
549
966
|
|
550
967
|
Args:
|
551
968
|
**kwargs: Additional arguments for fitting the specific method.
|
@@ -555,9 +972,8 @@ class LinearModelBase(MethodBase):
|
|
555
972
|
@abstractmethod
|
556
973
|
def _test_single_contrast(self, contrast, **kwargs): ...
|
557
974
|
|
558
|
-
def test_contrasts(self, contrasts, **kwargs):
|
559
|
-
"""
|
560
|
-
Perform a comparison as specified in a contrast vector.
|
975
|
+
def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
|
976
|
+
"""Perform a comparison as specified in a contrast vector.
|
561
977
|
|
562
978
|
Args:
|
563
979
|
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
|
@@ -573,25 +989,25 @@ class LinearModelBase(MethodBase):
|
|
573
989
|
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
|
574
990
|
|
575
991
|
results_df = pd.concat(results)
|
992
|
+
|
576
993
|
return results_df
|
577
994
|
|
578
995
|
def test_reduced(self, modelB):
|
579
|
-
"""
|
580
|
-
Test against a reduced model.
|
996
|
+
"""Test against a reduced model.
|
581
997
|
|
582
998
|
Args:
|
583
999
|
modelB: the reduced model against which to test.
|
584
1000
|
|
585
1001
|
Example:
|
586
|
-
|
587
|
-
|
588
|
-
|
1002
|
+
>>> import pertpy as pt
|
1003
|
+
>>> modelA = Model().fit()
|
1004
|
+
>>> modelB = Model().fit()
|
1005
|
+
>>> modelA.test_reduced(modelB)
|
589
1006
|
"""
|
590
1007
|
raise NotImplementedError
|
591
1008
|
|
592
1009
|
def cond(self, **kwargs):
|
593
|
-
"""
|
594
|
-
Get a contrast vector representing a specific condition.
|
1010
|
+
"""Get a contrast vector representing a specific condition.
|
595
1011
|
|
596
1012
|
Args:
|
597
1013
|
**kwargs: column/value pairs.
|
@@ -599,52 +1015,14 @@ class LinearModelBase(MethodBase):
|
|
599
1015
|
Returns:
|
600
1016
|
A contrast vector that aligns to the columns of the design matrix.
|
601
1017
|
"""
|
602
|
-
if self.
|
1018
|
+
if self.formulaic_contrasts is None:
|
603
1019
|
raise RuntimeError(
|
604
1020
|
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
605
1021
|
)
|
606
|
-
|
607
|
-
if not set(cond_dict.keys()).issubset(self.variables):
|
608
|
-
raise ValueError(
|
609
|
-
"You specified a variable that is not part of the model. Available variables: "
|
610
|
-
+ ",".join(self.variables)
|
611
|
-
)
|
612
|
-
for var in self.variables:
|
613
|
-
if var in cond_dict:
|
614
|
-
self._check_category(var, cond_dict[var])
|
615
|
-
else:
|
616
|
-
cond_dict[var] = self._get_default_value(var)
|
617
|
-
df = pd.DataFrame([kwargs])
|
618
|
-
return self.design.model_spec.get_model_matrix(df).iloc[0]
|
619
|
-
|
620
|
-
def _get_factor_metadata_for_variable(self, var):
|
621
|
-
factors = self.variable_to_factors[var]
|
622
|
-
return list(chain.from_iterable(self.factor_storage[f] for f in factors))
|
623
|
-
|
624
|
-
def _get_default_value(self, var):
|
625
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
626
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
|
627
|
-
try:
|
628
|
-
tmp_base = resolve_ambiguous(factor_metadata, "base")
|
629
|
-
except AmbiguousAttributeError as e:
|
630
|
-
raise ValueError(
|
631
|
-
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
|
632
|
-
) from e
|
633
|
-
return tmp_base if tmp_base is not None else "\0"
|
634
|
-
else:
|
635
|
-
return 0
|
636
|
-
|
637
|
-
def _check_category(self, var, value):
|
638
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
639
|
-
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
|
640
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
|
641
|
-
raise ValueError(
|
642
|
-
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
|
643
|
-
)
|
1022
|
+
return self.formulaic_contrasts.cond(**kwargs)
|
644
1023
|
|
645
|
-
def contrast(self,
|
646
|
-
"""
|
647
|
-
Build a simple contrast for pairwise comparisons.
|
1024
|
+
def contrast(self, *args, **kwargs):
|
1025
|
+
"""Build a simple contrast for pairwise comparisons.
|
648
1026
|
|
649
1027
|
Args:
|
650
1028
|
column: column in adata.obs to test on.
|
@@ -654,4 +1032,8 @@ class LinearModelBase(MethodBase):
|
|
654
1032
|
Returns:
|
655
1033
|
Numeric contrast vector.
|
656
1034
|
"""
|
657
|
-
|
1035
|
+
if self.formulaic_contrasts is None:
|
1036
|
+
raise RuntimeError(
|
1037
|
+
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
1038
|
+
)
|
1039
|
+
return self.formulaic_contrasts.contrast(*args, **kwargs)
|