pertpy 0.9.3__py3-none-any.whl → 0.9.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +20 -0
- pertpy/data/_dataloader.py +4 -4
- pertpy/data/_datasets.py +3 -3
- pertpy/metadata/_cell_line.py +19 -7
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +19 -6
- pertpy/tools/__init__.py +12 -15
- pertpy/tools/_augur.py +36 -46
- pertpy/tools/_cinemaot.py +24 -18
- pertpy/tools/_coda/_base_coda.py +87 -106
- pertpy/tools/_dialogue.py +17 -21
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +495 -113
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +15 -8
- pertpy/tools/_enrichment.py +18 -8
- pertpy/tools/_milo.py +58 -46
- pertpy/tools/_mixscape.py +111 -100
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +50 -0
- pertpy/tools/_scgen/_scgen.py +35 -25
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/METADATA +5 -4
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/RECORD +29 -29
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.3.dist-info → pertpy-0.9.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
-
import
|
1
|
+
import math
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from
|
4
|
-
from itertools import
|
3
|
+
from collections.abc import Iterable, Mapping, Sequence
|
4
|
+
from itertools import zip_longest
|
5
5
|
from types import MappingProxyType
|
6
6
|
|
7
7
|
import adjustText
|
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
import seaborn as sns
|
14
|
+
from formulaic_contrasts import FormulaicContrasts
|
15
|
+
from lamin_utils import logger
|
16
|
+
from matplotlib.pyplot import Figure
|
14
17
|
from matplotlib.ticker import MaxNLocator
|
15
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
20
|
+
from pertpy.tools import PseudobulkSpace
|
16
21
|
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
|
17
|
-
from pertpy.tools._differential_gene_expression._formulaic import (
|
18
|
-
AmbiguousAttributeError,
|
19
|
-
Factor,
|
20
|
-
get_factor_storage_and_materializer,
|
21
|
-
resolve_ambiguous,
|
22
|
-
)
|
23
|
-
|
24
|
-
|
25
|
-
@dataclass
|
26
|
-
class Contrast:
|
27
|
-
"""Simple contrast for comparison between groups"""
|
28
|
-
|
29
|
-
column: str
|
30
|
-
baseline: str
|
31
|
-
group_to_compare: str
|
32
|
-
|
33
|
-
|
34
|
-
ContrastType = Contrast | tuple[str, str, str]
|
35
22
|
|
36
23
|
|
37
24
|
class MethodBase(ABC):
|
@@ -58,7 +45,7 @@ class MethodBase(ABC):
|
|
58
45
|
if self.layer is None:
|
59
46
|
return self.adata.X
|
60
47
|
else:
|
61
|
-
return self.adata.
|
48
|
+
return self.adata.layers[self.layer]
|
62
49
|
|
63
50
|
@classmethod
|
64
51
|
@abstractmethod
|
@@ -91,9 +78,28 @@ class MethodBase(ABC):
|
|
91
78
|
|
92
79
|
Returns:
|
93
80
|
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
|
81
|
+
|
82
|
+
Examples:
|
83
|
+
>>> # Example with EdgeR
|
84
|
+
>>> import pertpy as pt
|
85
|
+
>>> adata = pt.dt.zhang_2021()
|
86
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
87
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
88
|
+
>>> pdata = ps.compute(
|
89
|
+
... adata,
|
90
|
+
... target_col="Patient",
|
91
|
+
... groups_col="Cluster",
|
92
|
+
... layer_key="counts",
|
93
|
+
... mode="sum",
|
94
|
+
... min_cells=10,
|
95
|
+
... min_counts=1000,
|
96
|
+
... )
|
97
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
98
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
94
99
|
"""
|
95
100
|
...
|
96
101
|
|
102
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
97
103
|
def plot_volcano(
|
98
104
|
self,
|
99
105
|
data: pd.DataFrame | ad.AnnData,
|
@@ -115,13 +121,14 @@ class MethodBase(ABC):
|
|
115
121
|
figsize: tuple[int, int] = (5, 5),
|
116
122
|
legend_pos: tuple[float, float] = (1.6, 1),
|
117
123
|
point_sizes: tuple[int, int] = (15, 150),
|
118
|
-
save: bool | str | None = None,
|
119
124
|
shapes: list[str] | None = None,
|
120
125
|
shape_order: list[str] | None = None,
|
121
126
|
x_label: str | None = None,
|
122
127
|
y_label: str | None = None,
|
128
|
+
show: bool = True,
|
129
|
+
return_fig: bool = False,
|
123
130
|
**kwargs: int,
|
124
|
-
) -> None:
|
131
|
+
) -> Figure | None:
|
125
132
|
"""Creates a volcano plot from a pandas DataFrame or Anndata.
|
126
133
|
|
127
134
|
Args:
|
@@ -143,12 +150,40 @@ class MethodBase(ABC):
|
|
143
150
|
top_right_frame: Whether to show the top and right frame of the plot.
|
144
151
|
figsize: Size of the figure.
|
145
152
|
legend_pos: Position of the legend as determined by matplotlib.
|
146
|
-
save: Saves the plot if True or to the path provided.
|
147
153
|
shapes: List of matplotlib marker ids.
|
148
154
|
shape_order: Order of categories for shapes.
|
149
155
|
x_label: Label for the x-axis.
|
150
156
|
y_label: Label for the y-axis.
|
157
|
+
{common_plot_args}
|
151
158
|
**kwargs: Additional arguments for seaborn.scatterplot.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
162
|
+
|
163
|
+
Examples:
|
164
|
+
>>> # Example with EdgeR
|
165
|
+
>>> import pertpy as pt
|
166
|
+
>>> adata = pt.dt.zhang_2021()
|
167
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
168
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
169
|
+
>>> pdata = ps.compute(
|
170
|
+
... adata,
|
171
|
+
... target_col="Patient",
|
172
|
+
... groups_col="Cluster",
|
173
|
+
... layer_key="counts",
|
174
|
+
... mode="sum",
|
175
|
+
... min_cells=10,
|
176
|
+
... min_counts=1000,
|
177
|
+
... )
|
178
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
179
|
+
>>> edgr.fit()
|
180
|
+
>>> res_df = edgr.test_contrasts(
|
181
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
182
|
+
... )
|
183
|
+
>>> edgr.plot_volcano(res_df, log2fc_thresh=0)
|
184
|
+
|
185
|
+
Preview:
|
186
|
+
.. image:: /_static/docstring_previews/de_volcano.png
|
152
187
|
"""
|
153
188
|
if colors is None:
|
154
189
|
colors = ["gray", "#D62728", "#1F77B4"]
|
@@ -243,7 +278,7 @@ class MethodBase(ABC):
|
|
243
278
|
if varm_key is None:
|
244
279
|
raise ValueError("Please pass a .varm key to use for plotting")
|
245
280
|
|
246
|
-
raise NotImplementedError("Anndata not implemented yet")
|
281
|
+
raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
|
247
282
|
df = data.varm[varm_key].copy()
|
248
283
|
|
249
284
|
df = data.copy(deep=True)
|
@@ -449,26 +484,412 @@ class MethodBase(ABC):
|
|
449
484
|
|
450
485
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
451
486
|
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
487
|
+
if show:
|
488
|
+
plt.show()
|
489
|
+
if return_fig:
|
490
|
+
return plt.gcf()
|
491
|
+
return None
|
492
|
+
|
493
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
494
|
+
def plot_paired(
|
495
|
+
self,
|
496
|
+
adata: ad.AnnData,
|
497
|
+
results_df: pd.DataFrame,
|
498
|
+
groupby: str,
|
499
|
+
pairedby: str,
|
500
|
+
*,
|
501
|
+
var_names: Sequence[str] = None,
|
502
|
+
n_top_vars: int = 15,
|
503
|
+
layer: str = None,
|
504
|
+
pvalue_col: str = "adj_p_value",
|
505
|
+
symbol_col: str = "variable",
|
506
|
+
n_cols: int = 4,
|
507
|
+
panel_size: tuple[int, int] = (5, 5),
|
508
|
+
show_legend: bool = True,
|
509
|
+
size: int = 10,
|
510
|
+
y_label: str = "expression",
|
511
|
+
pvalue_template=lambda x: f"p={x:.2e}",
|
512
|
+
boxplot_properties=None,
|
513
|
+
palette=None,
|
514
|
+
show: bool = True,
|
515
|
+
return_fig: bool = False,
|
516
|
+
) -> Figure | None:
|
517
|
+
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
518
|
+
|
519
|
+
Visualizes a panel of paired scatterplots per variable.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
adata: AnnData object, can be pseudobulked.
|
523
|
+
results_df: DataFrame with results from a differential expression test.
|
524
|
+
groupby: .obs column containing the grouping. Must contain exactly two different values.
|
525
|
+
pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
|
526
|
+
var_names: Variables to plot.
|
527
|
+
n_top_vars: Number of top variables to plot.
|
528
|
+
layer: Layer to use for plotting.
|
529
|
+
pvalue_col: Column name of the p values.
|
530
|
+
symbol_col: Column name of gene IDs.
|
531
|
+
n_cols: Number of columns in the plot.
|
532
|
+
panel_size: Size of each panel.
|
533
|
+
show_legend: Whether to show the legend.
|
534
|
+
size: Size of the points.
|
535
|
+
y_label: Label for the y-axis.
|
536
|
+
pvalue_template: Template for the p-value string displayed in the title of each panel.
|
537
|
+
boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
|
538
|
+
palette: Color palette for the line- and stripplot.
|
539
|
+
{common_plot_args}
|
540
|
+
|
541
|
+
Returns:
|
542
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
543
|
+
|
544
|
+
Examples:
|
545
|
+
>>> # Example with EdgeR
|
546
|
+
>>> import pertpy as pt
|
547
|
+
>>> adata = pt.dt.zhang_2021()
|
548
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
549
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
550
|
+
>>> pdata = ps.compute(
|
551
|
+
... adata,
|
552
|
+
... target_col="Patient",
|
553
|
+
... groups_col="Cluster",
|
554
|
+
... layer_key="counts",
|
555
|
+
... mode="sum",
|
556
|
+
... min_cells=10,
|
557
|
+
... min_counts=1000,
|
558
|
+
... )
|
559
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
560
|
+
>>> edgr.fit()
|
561
|
+
>>> res_df = edgr.test_contrasts(
|
562
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
563
|
+
... )
|
564
|
+
>>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
|
565
|
+
|
566
|
+
Preview:
|
567
|
+
.. image:: /_static/docstring_previews/de_paired_expression.png
|
568
|
+
"""
|
569
|
+
if boxplot_properties is None:
|
570
|
+
boxplot_properties = {}
|
571
|
+
groups = adata.obs[groupby].unique()
|
572
|
+
if len(groups) != 2:
|
573
|
+
raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
|
574
|
+
|
575
|
+
if var_names is None:
|
576
|
+
var_names = results_df.head(n_top_vars)[symbol_col].tolist()
|
577
|
+
|
578
|
+
adata = adata[:, var_names]
|
579
|
+
|
580
|
+
if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
|
581
|
+
logger.info("Performing pseudobulk for paired samples")
|
582
|
+
ps = PseudobulkSpace()
|
583
|
+
adata = ps.compute(
|
584
|
+
adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
|
585
|
+
)
|
586
|
+
|
587
|
+
if layer is not None:
|
588
|
+
X = adata.layers[layer]
|
589
|
+
else:
|
590
|
+
X = adata.X
|
591
|
+
try:
|
592
|
+
X = X.toarray()
|
593
|
+
except AttributeError:
|
594
|
+
pass
|
595
|
+
|
596
|
+
groupby_cols = [pairedby, groupby]
|
597
|
+
df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
|
598
|
+
|
599
|
+
# remove unpaired samples
|
600
|
+
paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
|
601
|
+
df = df[df[pairedby].isin(paired_samples)]
|
602
|
+
removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
|
603
|
+
if removed_samples > 0:
|
604
|
+
logger.warning(f"{removed_samples} unpaired samples removed")
|
605
|
+
|
606
|
+
pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
|
607
|
+
df.reset_index(drop=False, inplace=True)
|
608
|
+
|
609
|
+
# transform data for seaborn
|
610
|
+
df_melt = df.melt(
|
611
|
+
id_vars=groupby_cols,
|
612
|
+
var_name="var",
|
613
|
+
value_name="val",
|
614
|
+
)
|
615
|
+
|
616
|
+
n_panels = len(var_names)
|
617
|
+
nrows = math.ceil(n_panels / n_cols)
|
618
|
+
ncols = min(n_cols, n_panels)
|
619
|
+
|
620
|
+
fig, axes = plt.subplots(
|
621
|
+
nrows,
|
622
|
+
ncols,
|
623
|
+
figsize=(ncols * panel_size[0], nrows * panel_size[1]),
|
624
|
+
tight_layout=True,
|
625
|
+
squeeze=False,
|
626
|
+
)
|
627
|
+
axes = axes.flatten()
|
628
|
+
for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
|
629
|
+
if var is not None:
|
630
|
+
sns.boxplot(
|
631
|
+
x=groupby,
|
632
|
+
data=df_melt.loc[df_melt["var"] == var],
|
633
|
+
y="val",
|
634
|
+
ax=ax,
|
635
|
+
color="white",
|
636
|
+
fliersize=0,
|
637
|
+
**boxplot_properties,
|
638
|
+
)
|
639
|
+
if pairedby is not None:
|
640
|
+
sns.lineplot(
|
641
|
+
x=groupby,
|
642
|
+
data=df_melt.loc[df_melt["var"] == var],
|
643
|
+
y="val",
|
644
|
+
ax=ax,
|
645
|
+
hue=pairedby,
|
646
|
+
legend=False,
|
647
|
+
errorbar=None,
|
648
|
+
palette=palette,
|
649
|
+
)
|
650
|
+
jitter = 0 if pairedby else True
|
651
|
+
sns.stripplot(
|
652
|
+
x=groupby,
|
653
|
+
data=df_melt.loc[df_melt["var"] == var],
|
654
|
+
y="val",
|
655
|
+
ax=ax,
|
656
|
+
hue=pairedby,
|
657
|
+
jitter=jitter,
|
658
|
+
size=size,
|
659
|
+
linewidth=1,
|
660
|
+
palette=palette,
|
661
|
+
)
|
662
|
+
|
663
|
+
ax.set_xlabel("")
|
664
|
+
ax.tick_params(
|
665
|
+
axis="x",
|
666
|
+
labelsize=15,
|
667
|
+
)
|
668
|
+
ax.legend().set_visible(False)
|
669
|
+
ax.set_ylabel(y_label)
|
670
|
+
ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
|
671
|
+
else:
|
672
|
+
ax.set_visible(False)
|
673
|
+
fig.tight_layout()
|
674
|
+
|
675
|
+
if show_legend is True:
|
676
|
+
axes[n_panels - 1].legend().set_visible(True)
|
677
|
+
axes[n_panels - 1].legend(
|
678
|
+
bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
|
679
|
+
)
|
680
|
+
|
681
|
+
plt.tight_layout()
|
682
|
+
if show:
|
683
|
+
plt.show()
|
684
|
+
if return_fig:
|
685
|
+
return plt.gcf()
|
686
|
+
return None
|
687
|
+
|
688
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
689
|
+
def plot_fold_change(
|
690
|
+
self,
|
691
|
+
results_df: pd.DataFrame,
|
692
|
+
*,
|
693
|
+
var_names: Sequence[str] = None,
|
694
|
+
n_top_vars: int = 15,
|
695
|
+
log2fc_col: str = "log_fc",
|
696
|
+
symbol_col: str = "variable",
|
697
|
+
y_label: str = "Log2 fold change",
|
698
|
+
figsize: tuple[int, int] = (10, 5),
|
699
|
+
show: bool = True,
|
700
|
+
return_fig: bool = False,
|
701
|
+
**barplot_kwargs,
|
702
|
+
) -> Figure | None:
|
703
|
+
"""Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
|
704
|
+
|
705
|
+
Args:
|
706
|
+
results_df: DataFrame with results from DE analysis.
|
707
|
+
var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
|
708
|
+
n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
|
709
|
+
log2fc_col: Column name of log2 Fold-Change values.
|
710
|
+
symbol_col: Column name of gene IDs.
|
711
|
+
y_label: Label for the y-axis.
|
712
|
+
figsize: Size of the figure.
|
713
|
+
{common_plot_args}
|
714
|
+
**barplot_kwargs: Additional arguments for seaborn.barplot.
|
715
|
+
|
716
|
+
Returns:
|
717
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
718
|
+
|
719
|
+
Examples:
|
720
|
+
>>> # Example with EdgeR
|
721
|
+
>>> import pertpy as pt
|
722
|
+
>>> adata = pt.dt.zhang_2021()
|
723
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
724
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
725
|
+
>>> pdata = ps.compute(
|
726
|
+
... adata,
|
727
|
+
... target_col="Patient",
|
728
|
+
... groups_col="Cluster",
|
729
|
+
... layer_key="counts",
|
730
|
+
... mode="sum",
|
731
|
+
... min_cells=10,
|
732
|
+
... min_counts=1000,
|
733
|
+
... )
|
734
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
735
|
+
>>> edgr.fit()
|
736
|
+
>>> res_df = edgr.test_contrasts(
|
737
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
738
|
+
... )
|
739
|
+
>>> edgr.plot_fold_change(res_df)
|
740
|
+
|
741
|
+
Preview:
|
742
|
+
.. image:: /_static/docstring_previews/de_fold_change.png
|
743
|
+
"""
|
744
|
+
if var_names is None:
|
745
|
+
var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
|
746
|
+
var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
|
747
|
+
assert len(var_names) == 2 * n_top_vars
|
748
|
+
|
749
|
+
df = results_df[results_df[symbol_col].isin(var_names)]
|
750
|
+
df.sort_values(log2fc_col, ascending=False, inplace=True)
|
751
|
+
|
752
|
+
plt.figure(figsize=figsize)
|
753
|
+
sns.barplot(
|
754
|
+
x=symbol_col,
|
755
|
+
y=log2fc_col,
|
756
|
+
data=df,
|
757
|
+
palette="RdBu",
|
758
|
+
legend=False,
|
759
|
+
**barplot_kwargs,
|
760
|
+
)
|
761
|
+
plt.xticks(rotation=90)
|
762
|
+
plt.xlabel("")
|
763
|
+
plt.ylabel(y_label)
|
764
|
+
|
765
|
+
if show:
|
766
|
+
plt.show()
|
767
|
+
if return_fig:
|
768
|
+
return plt.gcf()
|
769
|
+
return None
|
770
|
+
|
771
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
772
|
+
def plot_multicomparison_fc(
|
773
|
+
self,
|
774
|
+
results_df: pd.DataFrame,
|
775
|
+
*,
|
776
|
+
n_top_vars=15,
|
777
|
+
contrast_col: str = "contrast",
|
778
|
+
log2fc_col: str = "log_fc",
|
779
|
+
pvalue_col: str = "adj_p_value",
|
780
|
+
symbol_col: str = "variable",
|
781
|
+
marker_size: int = 100,
|
782
|
+
figsize: tuple[int, int] = (10, 2),
|
783
|
+
x_label: str = "Contrast",
|
784
|
+
y_label: str = "Gene",
|
785
|
+
show: bool = True,
|
786
|
+
return_fig: bool = False,
|
787
|
+
**heatmap_kwargs,
|
788
|
+
) -> Figure | None:
|
789
|
+
"""Plot a matrix of log2 fold changes from the results.
|
790
|
+
|
791
|
+
Args:
|
792
|
+
results_df: DataFrame with results from DE analysis.
|
793
|
+
n_top_vars: Number of top variables to plot per group.
|
794
|
+
contrast_col: Column in results_df containing information about the contrast.
|
795
|
+
log2fc_col: Column in results_df containing the log2 fold change.
|
796
|
+
pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
|
797
|
+
symbol_col: Column in results_df containing the gene symbol.
|
798
|
+
marker_size: Size of the biggest marker for significant variables.
|
799
|
+
figsize: Size of the figure.
|
800
|
+
x_label: Label for the x-axis.
|
801
|
+
y_label: Label for the y-axis.
|
802
|
+
{common_plot_args}
|
803
|
+
**heatmap_kwargs: Additional arguments for seaborn.heatmap.
|
804
|
+
|
805
|
+
Returns:
|
806
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
807
|
+
|
808
|
+
Examples:
|
809
|
+
>>> # Example with EdgeR
|
810
|
+
>>> import pertpy as pt
|
811
|
+
>>> adata = pt.dt.zhang_2021()
|
812
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
813
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
814
|
+
>>> pdata = ps.compute(
|
815
|
+
... adata,
|
816
|
+
... target_col="Patient",
|
817
|
+
... groups_col="Cluster",
|
818
|
+
... layer_key="counts",
|
819
|
+
... mode="sum",
|
820
|
+
... min_cells=10,
|
821
|
+
... min_counts=1000,
|
822
|
+
... )
|
823
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
824
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
825
|
+
>>> edgr.plot_multicomparison_fc(res_df)
|
826
|
+
|
827
|
+
Preview:
|
828
|
+
.. image:: /_static/docstring_previews/de_multicomparison_fc.png
|
829
|
+
"""
|
830
|
+
groups = results_df[contrast_col].unique().tolist()
|
831
|
+
|
832
|
+
results_df["abs_log_fc"] = results_df[log2fc_col].abs()
|
833
|
+
|
834
|
+
def _get_significance(p_val):
|
835
|
+
if p_val < 0.001:
|
836
|
+
return "< 0.001"
|
837
|
+
elif p_val < 0.01:
|
838
|
+
return "< 0.01"
|
839
|
+
elif p_val < 0.1:
|
840
|
+
return "< 0.1"
|
841
|
+
else:
|
842
|
+
return "n.s."
|
843
|
+
|
844
|
+
results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
|
845
|
+
|
846
|
+
var_names = []
|
847
|
+
for group in groups:
|
848
|
+
var_names += (
|
849
|
+
results_df[results_df[contrast_col] == group]
|
850
|
+
.sort_values("abs_log_fc", ascending=False)
|
851
|
+
.head(n_top_vars)[symbol_col]
|
852
|
+
.tolist()
|
853
|
+
)
|
854
|
+
|
855
|
+
results_df = results_df[results_df[symbol_col].isin(var_names)]
|
856
|
+
df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
|
464
857
|
|
465
|
-
plt.
|
858
|
+
plt.figure(figsize=figsize)
|
859
|
+
sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
|
860
|
+
|
861
|
+
_size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
|
862
|
+
x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
|
863
|
+
y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
|
864
|
+
|
865
|
+
for _i, row in results_df.iterrows():
|
866
|
+
if row["significance"] != "n.s.":
|
867
|
+
plt.scatter(
|
868
|
+
x=x_locs[x_labels.index(row[symbol_col])],
|
869
|
+
y=y_locs[y_labels.index(row[contrast_col])],
|
870
|
+
s=_size[row["significance"]],
|
871
|
+
marker="*",
|
872
|
+
c="white",
|
873
|
+
)
|
874
|
+
|
875
|
+
plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
|
876
|
+
plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
|
877
|
+
plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
|
878
|
+
plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
|
879
|
+
|
880
|
+
plt.xlabel(x_label)
|
881
|
+
plt.ylabel(y_label)
|
882
|
+
|
883
|
+
if show:
|
884
|
+
plt.show()
|
885
|
+
if return_fig:
|
886
|
+
return plt.gcf()
|
887
|
+
return None
|
466
888
|
|
467
889
|
|
468
890
|
class LinearModelBase(MethodBase):
|
469
891
|
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
|
470
|
-
"""
|
471
|
-
Initialize the method.
|
892
|
+
"""Initialize the method.
|
472
893
|
|
473
894
|
Args:
|
474
895
|
adata: AnnData object, usually pseudobulked.
|
@@ -480,26 +901,24 @@ class LinearModelBase(MethodBase):
|
|
480
901
|
super().__init__(adata, mask=mask, layer=layer)
|
481
902
|
self._check_counts()
|
482
903
|
|
483
|
-
self.
|
484
|
-
self.variable_to_factors = None
|
485
|
-
|
904
|
+
self.formulaic_contrasts = None
|
486
905
|
if isinstance(design, str):
|
487
|
-
self.
|
488
|
-
self.design =
|
906
|
+
self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
|
907
|
+
self.design = self.formulaic_contrasts.design_matrix
|
489
908
|
else:
|
490
909
|
self.design = design
|
491
910
|
|
492
911
|
@classmethod
|
493
912
|
def compare_groups(
|
494
913
|
cls,
|
495
|
-
adata,
|
496
|
-
column,
|
497
|
-
baseline,
|
498
|
-
groups_to_compare,
|
914
|
+
adata: ad.AnnData,
|
915
|
+
column: str,
|
916
|
+
baseline: str,
|
917
|
+
groups_to_compare: str | Iterable[str],
|
499
918
|
*,
|
500
|
-
paired_by=None,
|
501
|
-
mask=None,
|
502
|
-
layer=None,
|
919
|
+
paired_by: str | None = None,
|
920
|
+
mask: pd.Series | None = None,
|
921
|
+
layer: str | None = None,
|
503
922
|
fit_kwargs=MappingProxyType({}),
|
504
923
|
test_kwargs=MappingProxyType({}),
|
505
924
|
):
|
@@ -525,17 +944,16 @@ class LinearModelBase(MethodBase):
|
|
525
944
|
@property
|
526
945
|
def variables(self):
|
527
946
|
"""Get the names of the variables used in the model definition."""
|
528
|
-
|
529
|
-
return self.design.model_spec.variables_by_source["data"]
|
530
|
-
except AttributeError:
|
947
|
+
if self.formulaic_contrasts is None:
|
531
948
|
raise ValueError(
|
532
949
|
"Retrieving variables is only possible if the model was initialized using a formula."
|
533
950
|
) from None
|
951
|
+
else:
|
952
|
+
return self.formulaic_contrasts.variables
|
534
953
|
|
535
954
|
@abstractmethod
|
536
955
|
def _check_counts(self):
|
537
|
-
"""
|
538
|
-
Check that counts are valid for the specific method.
|
956
|
+
"""Check that counts are valid for the specific method.
|
539
957
|
|
540
958
|
Raises:
|
541
959
|
ValueError: if the data matrix does not comply with the expectations.
|
@@ -544,8 +962,7 @@ class LinearModelBase(MethodBase):
|
|
544
962
|
|
545
963
|
@abstractmethod
|
546
964
|
def fit(self, **kwargs):
|
547
|
-
"""
|
548
|
-
Fit the model.
|
965
|
+
"""Fit the model.
|
549
966
|
|
550
967
|
Args:
|
551
968
|
**kwargs: Additional arguments for fitting the specific method.
|
@@ -555,9 +972,8 @@ class LinearModelBase(MethodBase):
|
|
555
972
|
@abstractmethod
|
556
973
|
def _test_single_contrast(self, contrast, **kwargs): ...
|
557
974
|
|
558
|
-
def test_contrasts(self, contrasts, **kwargs):
|
559
|
-
"""
|
560
|
-
Perform a comparison as specified in a contrast vector.
|
975
|
+
def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
|
976
|
+
"""Perform a comparison as specified in a contrast vector.
|
561
977
|
|
562
978
|
Args:
|
563
979
|
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
|
@@ -573,25 +989,25 @@ class LinearModelBase(MethodBase):
|
|
573
989
|
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
|
574
990
|
|
575
991
|
results_df = pd.concat(results)
|
992
|
+
|
576
993
|
return results_df
|
577
994
|
|
578
995
|
def test_reduced(self, modelB):
|
579
|
-
"""
|
580
|
-
Test against a reduced model.
|
996
|
+
"""Test against a reduced model.
|
581
997
|
|
582
998
|
Args:
|
583
999
|
modelB: the reduced model against which to test.
|
584
1000
|
|
585
1001
|
Example:
|
586
|
-
|
587
|
-
|
588
|
-
|
1002
|
+
>>> import pertpy as pt
|
1003
|
+
>>> modelA = Model().fit()
|
1004
|
+
>>> modelB = Model().fit()
|
1005
|
+
>>> modelA.test_reduced(modelB)
|
589
1006
|
"""
|
590
1007
|
raise NotImplementedError
|
591
1008
|
|
592
1009
|
def cond(self, **kwargs):
|
593
|
-
"""
|
594
|
-
Get a contrast vector representing a specific condition.
|
1010
|
+
"""Get a contrast vector representing a specific condition.
|
595
1011
|
|
596
1012
|
Args:
|
597
1013
|
**kwargs: column/value pairs.
|
@@ -599,52 +1015,14 @@ class LinearModelBase(MethodBase):
|
|
599
1015
|
Returns:
|
600
1016
|
A contrast vector that aligns to the columns of the design matrix.
|
601
1017
|
"""
|
602
|
-
if self.
|
1018
|
+
if self.formulaic_contrasts is None:
|
603
1019
|
raise RuntimeError(
|
604
1020
|
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
605
1021
|
)
|
606
|
-
|
607
|
-
if not set(cond_dict.keys()).issubset(self.variables):
|
608
|
-
raise ValueError(
|
609
|
-
"You specified a variable that is not part of the model. Available variables: "
|
610
|
-
+ ",".join(self.variables)
|
611
|
-
)
|
612
|
-
for var in self.variables:
|
613
|
-
if var in cond_dict:
|
614
|
-
self._check_category(var, cond_dict[var])
|
615
|
-
else:
|
616
|
-
cond_dict[var] = self._get_default_value(var)
|
617
|
-
df = pd.DataFrame([kwargs])
|
618
|
-
return self.design.model_spec.get_model_matrix(df).iloc[0]
|
619
|
-
|
620
|
-
def _get_factor_metadata_for_variable(self, var):
|
621
|
-
factors = self.variable_to_factors[var]
|
622
|
-
return list(chain.from_iterable(self.factor_storage[f] for f in factors))
|
623
|
-
|
624
|
-
def _get_default_value(self, var):
|
625
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
626
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
|
627
|
-
try:
|
628
|
-
tmp_base = resolve_ambiguous(factor_metadata, "base")
|
629
|
-
except AmbiguousAttributeError as e:
|
630
|
-
raise ValueError(
|
631
|
-
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
|
632
|
-
) from e
|
633
|
-
return tmp_base if tmp_base is not None else "\0"
|
634
|
-
else:
|
635
|
-
return 0
|
636
|
-
|
637
|
-
def _check_category(self, var, value):
|
638
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
639
|
-
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
|
640
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
|
641
|
-
raise ValueError(
|
642
|
-
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
|
643
|
-
)
|
1022
|
+
return self.formulaic_contrasts.cond(**kwargs)
|
644
1023
|
|
645
|
-
def contrast(self,
|
646
|
-
"""
|
647
|
-
Build a simple contrast for pairwise comparisons.
|
1024
|
+
def contrast(self, *args, **kwargs):
|
1025
|
+
"""Build a simple contrast for pairwise comparisons.
|
648
1026
|
|
649
1027
|
Args:
|
650
1028
|
column: column in adata.obs to test on.
|
@@ -654,4 +1032,8 @@ class LinearModelBase(MethodBase):
|
|
654
1032
|
Returns:
|
655
1033
|
Numeric contrast vector.
|
656
1034
|
"""
|
657
|
-
|
1035
|
+
if self.formulaic_contrasts is None:
|
1036
|
+
raise RuntimeError(
|
1037
|
+
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
1038
|
+
)
|
1039
|
+
return self.formulaic_contrasts.contrast(*args, **kwargs)
|