pertpy 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +1 -1
- pertpy/_doc.py +19 -0
- pertpy/data/_datasets.py +1 -1
- pertpy/metadata/_cell_line.py +18 -8
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_metadata.py +1 -1
- pertpy/preprocessing/_guide_rna.py +114 -13
- pertpy/preprocessing/_guide_rna_mixture.py +179 -0
- pertpy/tools/__init__.py +1 -1
- pertpy/tools/_augur.py +64 -86
- pertpy/tools/_cinemaot.py +21 -17
- pertpy/tools/_coda/_base_coda.py +90 -117
- pertpy/tools/_dialogue.py +32 -40
- pertpy/tools/_differential_gene_expression/__init__.py +1 -2
- pertpy/tools/_differential_gene_expression/_base.py +486 -112
- pertpy/tools/_differential_gene_expression/_edger.py +30 -21
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +15 -29
- pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -11
- pertpy/tools/_distances/_distances.py +71 -56
- pertpy/tools/_enrichment.py +16 -8
- pertpy/tools/_milo.py +54 -50
- pertpy/tools/_mixscape.py +307 -208
- pertpy/tools/_perturbation_space/_perturbation_space.py +40 -31
- pertpy/tools/_perturbation_space/_simple.py +48 -0
- pertpy/tools/_scgen/_scgen.py +35 -27
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/METADATA +6 -6
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/RECORD +29 -28
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/WHEEL +1 -1
- pertpy/tools/_differential_gene_expression/_formulaic.py +0 -189
- {pertpy-0.9.4.dist-info → pertpy-0.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
-
import
|
1
|
+
import math
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from
|
4
|
-
from itertools import
|
3
|
+
from collections.abc import Iterable, Mapping, Sequence
|
4
|
+
from itertools import zip_longest
|
5
5
|
from types import MappingProxyType
|
6
6
|
|
7
7
|
import adjustText
|
@@ -11,27 +11,14 @@ import matplotlib.pyplot as plt
|
|
11
11
|
import numpy as np
|
12
12
|
import pandas as pd
|
13
13
|
import seaborn as sns
|
14
|
+
from formulaic_contrasts import FormulaicContrasts
|
15
|
+
from lamin_utils import logger
|
16
|
+
from matplotlib.pyplot import Figure
|
14
17
|
from matplotlib.ticker import MaxNLocator
|
15
18
|
|
19
|
+
from pertpy._doc import _doc_params, doc_common_plot_args
|
20
|
+
from pertpy.tools import PseudobulkSpace
|
16
21
|
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
|
17
|
-
from pertpy.tools._differential_gene_expression._formulaic import (
|
18
|
-
AmbiguousAttributeError,
|
19
|
-
Factor,
|
20
|
-
get_factor_storage_and_materializer,
|
21
|
-
resolve_ambiguous,
|
22
|
-
)
|
23
|
-
|
24
|
-
|
25
|
-
@dataclass
|
26
|
-
class Contrast:
|
27
|
-
"""Simple contrast for comparison between groups"""
|
28
|
-
|
29
|
-
column: str
|
30
|
-
baseline: str
|
31
|
-
group_to_compare: str
|
32
|
-
|
33
|
-
|
34
|
-
ContrastType = Contrast | tuple[str, str, str]
|
35
22
|
|
36
23
|
|
37
24
|
class MethodBase(ABC):
|
@@ -58,7 +45,7 @@ class MethodBase(ABC):
|
|
58
45
|
if self.layer is None:
|
59
46
|
return self.adata.X
|
60
47
|
else:
|
61
|
-
return self.adata.
|
48
|
+
return self.adata.layers[self.layer]
|
62
49
|
|
63
50
|
@classmethod
|
64
51
|
@abstractmethod
|
@@ -91,9 +78,28 @@ class MethodBase(ABC):
|
|
91
78
|
|
92
79
|
Returns:
|
93
80
|
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
|
81
|
+
|
82
|
+
Examples:
|
83
|
+
>>> # Example with EdgeR
|
84
|
+
>>> import pertpy as pt
|
85
|
+
>>> adata = pt.dt.zhang_2021()
|
86
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
87
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
88
|
+
>>> pdata = ps.compute(
|
89
|
+
... adata,
|
90
|
+
... target_col="Patient",
|
91
|
+
... groups_col="Cluster",
|
92
|
+
... layer_key="counts",
|
93
|
+
... mode="sum",
|
94
|
+
... min_cells=10,
|
95
|
+
... min_counts=1000,
|
96
|
+
... )
|
97
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
98
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
94
99
|
"""
|
95
100
|
...
|
96
101
|
|
102
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
97
103
|
def plot_volcano(
|
98
104
|
self,
|
99
105
|
data: pd.DataFrame | ad.AnnData,
|
@@ -115,13 +121,13 @@ class MethodBase(ABC):
|
|
115
121
|
figsize: tuple[int, int] = (5, 5),
|
116
122
|
legend_pos: tuple[float, float] = (1.6, 1),
|
117
123
|
point_sizes: tuple[int, int] = (15, 150),
|
118
|
-
save: bool | str | None = None,
|
119
124
|
shapes: list[str] | None = None,
|
120
125
|
shape_order: list[str] | None = None,
|
121
126
|
x_label: str | None = None,
|
122
127
|
y_label: str | None = None,
|
128
|
+
return_fig: bool = False,
|
123
129
|
**kwargs: int,
|
124
|
-
) -> None:
|
130
|
+
) -> Figure | None:
|
125
131
|
"""Creates a volcano plot from a pandas DataFrame or Anndata.
|
126
132
|
|
127
133
|
Args:
|
@@ -143,12 +149,40 @@ class MethodBase(ABC):
|
|
143
149
|
top_right_frame: Whether to show the top and right frame of the plot.
|
144
150
|
figsize: Size of the figure.
|
145
151
|
legend_pos: Position of the legend as determined by matplotlib.
|
146
|
-
save: Saves the plot if True or to the path provided.
|
147
152
|
shapes: List of matplotlib marker ids.
|
148
153
|
shape_order: Order of categories for shapes.
|
149
154
|
x_label: Label for the x-axis.
|
150
155
|
y_label: Label for the y-axis.
|
156
|
+
{common_plot_args}
|
151
157
|
**kwargs: Additional arguments for seaborn.scatterplot.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
161
|
+
|
162
|
+
Examples:
|
163
|
+
>>> # Example with EdgeR
|
164
|
+
>>> import pertpy as pt
|
165
|
+
>>> adata = pt.dt.zhang_2021()
|
166
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
167
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
168
|
+
>>> pdata = ps.compute(
|
169
|
+
... adata,
|
170
|
+
... target_col="Patient",
|
171
|
+
... groups_col="Cluster",
|
172
|
+
... layer_key="counts",
|
173
|
+
... mode="sum",
|
174
|
+
... min_cells=10,
|
175
|
+
... min_counts=1000,
|
176
|
+
... )
|
177
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
178
|
+
>>> edgr.fit()
|
179
|
+
>>> res_df = edgr.test_contrasts(
|
180
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
181
|
+
... )
|
182
|
+
>>> edgr.plot_volcano(res_df, log2fc_thresh=0)
|
183
|
+
|
184
|
+
Preview:
|
185
|
+
.. image:: /_static/docstring_previews/de_volcano.png
|
152
186
|
"""
|
153
187
|
if colors is None:
|
154
188
|
colors = ["gray", "#D62728", "#1F77B4"]
|
@@ -243,7 +277,7 @@ class MethodBase(ABC):
|
|
243
277
|
if varm_key is None:
|
244
278
|
raise ValueError("Please pass a .varm key to use for plotting")
|
245
279
|
|
246
|
-
raise NotImplementedError("Anndata not implemented yet")
|
280
|
+
raise NotImplementedError("Anndata not implemented yet") # TODO: Implement this
|
247
281
|
df = data.varm[varm_key].copy()
|
248
282
|
|
249
283
|
df = data.copy(deep=True)
|
@@ -449,26 +483,405 @@ class MethodBase(ABC):
|
|
449
483
|
|
450
484
|
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
451
485
|
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
486
|
+
if return_fig:
|
487
|
+
return plt.gcf()
|
488
|
+
plt.show()
|
489
|
+
return None
|
490
|
+
|
491
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
492
|
+
def plot_paired(
|
493
|
+
self,
|
494
|
+
adata: ad.AnnData,
|
495
|
+
results_df: pd.DataFrame,
|
496
|
+
groupby: str,
|
497
|
+
pairedby: str,
|
498
|
+
*,
|
499
|
+
var_names: Sequence[str] = None,
|
500
|
+
n_top_vars: int = 15,
|
501
|
+
layer: str = None,
|
502
|
+
pvalue_col: str = "adj_p_value",
|
503
|
+
symbol_col: str = "variable",
|
504
|
+
n_cols: int = 4,
|
505
|
+
panel_size: tuple[int, int] = (5, 5),
|
506
|
+
show_legend: bool = True,
|
507
|
+
size: int = 10,
|
508
|
+
y_label: str = "expression",
|
509
|
+
pvalue_template=lambda x: f"p={x:.2e}",
|
510
|
+
boxplot_properties=None,
|
511
|
+
palette=None,
|
512
|
+
return_fig: bool = False,
|
513
|
+
) -> Figure | None:
|
514
|
+
"""Creates a pairwise expression plot from a Pandas DataFrame or Anndata.
|
515
|
+
|
516
|
+
Visualizes a panel of paired scatterplots per variable.
|
517
|
+
|
518
|
+
Args:
|
519
|
+
adata: AnnData object, can be pseudobulked.
|
520
|
+
results_df: DataFrame with results from a differential expression test.
|
521
|
+
groupby: .obs column containing the grouping. Must contain exactly two different values.
|
522
|
+
pairedby: .obs column containing the pairing (e.g. "patient_id"). If None, an independent t-test is performed.
|
523
|
+
var_names: Variables to plot.
|
524
|
+
n_top_vars: Number of top variables to plot.
|
525
|
+
layer: Layer to use for plotting.
|
526
|
+
pvalue_col: Column name of the p values.
|
527
|
+
symbol_col: Column name of gene IDs.
|
528
|
+
n_cols: Number of columns in the plot.
|
529
|
+
panel_size: Size of each panel.
|
530
|
+
show_legend: Whether to show the legend.
|
531
|
+
size: Size of the points.
|
532
|
+
y_label: Label for the y-axis.
|
533
|
+
pvalue_template: Template for the p-value string displayed in the title of each panel.
|
534
|
+
boxplot_properties: Additional properties for the boxplot, passed to seaborn.boxplot.
|
535
|
+
palette: Color palette for the line- and stripplot.
|
536
|
+
{common_plot_args}
|
537
|
+
|
538
|
+
Returns:
|
539
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
540
|
+
|
541
|
+
Examples:
|
542
|
+
>>> # Example with EdgeR
|
543
|
+
>>> import pertpy as pt
|
544
|
+
>>> adata = pt.dt.zhang_2021()
|
545
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
546
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
547
|
+
>>> pdata = ps.compute(
|
548
|
+
... adata,
|
549
|
+
... target_col="Patient",
|
550
|
+
... groups_col="Cluster",
|
551
|
+
... layer_key="counts",
|
552
|
+
... mode="sum",
|
553
|
+
... min_cells=10,
|
554
|
+
... min_counts=1000,
|
555
|
+
... )
|
556
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
557
|
+
>>> edgr.fit()
|
558
|
+
>>> res_df = edgr.test_contrasts(
|
559
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
560
|
+
... )
|
561
|
+
>>> edgr.plot_paired(pdata, results_df=res_df, n_top_vars=8, groupby="Treatment", pairedby="Efficacy")
|
562
|
+
|
563
|
+
Preview:
|
564
|
+
.. image:: /_static/docstring_previews/de_paired_expression.png
|
565
|
+
"""
|
566
|
+
if boxplot_properties is None:
|
567
|
+
boxplot_properties = {}
|
568
|
+
groups = adata.obs[groupby].unique()
|
569
|
+
if len(groups) != 2:
|
570
|
+
raise ValueError("The number of groups in the group_by column must be exactly 2 to enable paired testing")
|
571
|
+
|
572
|
+
if var_names is None:
|
573
|
+
var_names = results_df.head(n_top_vars)[symbol_col].tolist()
|
574
|
+
|
575
|
+
adata = adata[:, var_names]
|
576
|
+
|
577
|
+
if any(adata.obs[[groupby, pairedby]].value_counts() > 1):
|
578
|
+
logger.info("Performing pseudobulk for paired samples")
|
579
|
+
ps = PseudobulkSpace()
|
580
|
+
adata = ps.compute(
|
581
|
+
adata, target_col=groupby, groups_col=pairedby, layer_key=layer, mode="sum", min_cells=1, min_counts=1
|
582
|
+
)
|
464
583
|
|
584
|
+
if layer is not None:
|
585
|
+
X = adata.layers[layer]
|
586
|
+
else:
|
587
|
+
X = adata.X
|
588
|
+
try:
|
589
|
+
X = X.toarray()
|
590
|
+
except AttributeError:
|
591
|
+
pass
|
592
|
+
|
593
|
+
groupby_cols = [pairedby, groupby]
|
594
|
+
df = adata.obs.loc[:, groupby_cols].join(pd.DataFrame(X, index=adata.obs_names, columns=var_names))
|
595
|
+
|
596
|
+
# remove unpaired samples
|
597
|
+
paired_samples = set(df[df[groupby] == groups[0]][pairedby]) & set(df[df[groupby] == groups[1]][pairedby])
|
598
|
+
df = df[df[pairedby].isin(paired_samples)]
|
599
|
+
removed_samples = adata.obs[pairedby].nunique() - len(df[pairedby].unique())
|
600
|
+
if removed_samples > 0:
|
601
|
+
logger.warning(f"{removed_samples} unpaired samples removed")
|
602
|
+
|
603
|
+
pvalues = results_df.set_index(symbol_col).loc[var_names, pvalue_col].values
|
604
|
+
df.reset_index(drop=False, inplace=True)
|
605
|
+
|
606
|
+
# transform data for seaborn
|
607
|
+
df_melt = df.melt(
|
608
|
+
id_vars=groupby_cols,
|
609
|
+
var_name="var",
|
610
|
+
value_name="val",
|
611
|
+
)
|
612
|
+
|
613
|
+
n_panels = len(var_names)
|
614
|
+
nrows = math.ceil(n_panels / n_cols)
|
615
|
+
ncols = min(n_cols, n_panels)
|
616
|
+
|
617
|
+
fig, axes = plt.subplots(
|
618
|
+
nrows,
|
619
|
+
ncols,
|
620
|
+
figsize=(ncols * panel_size[0], nrows * panel_size[1]),
|
621
|
+
tight_layout=True,
|
622
|
+
squeeze=False,
|
623
|
+
)
|
624
|
+
axes = axes.flatten()
|
625
|
+
for i, (var, ax) in enumerate(zip_longest(var_names, axes)):
|
626
|
+
if var is not None:
|
627
|
+
sns.boxplot(
|
628
|
+
x=groupby,
|
629
|
+
data=df_melt.loc[df_melt["var"] == var],
|
630
|
+
y="val",
|
631
|
+
ax=ax,
|
632
|
+
color="white",
|
633
|
+
fliersize=0,
|
634
|
+
**boxplot_properties,
|
635
|
+
)
|
636
|
+
if pairedby is not None:
|
637
|
+
sns.lineplot(
|
638
|
+
x=groupby,
|
639
|
+
data=df_melt.loc[df_melt["var"] == var],
|
640
|
+
y="val",
|
641
|
+
ax=ax,
|
642
|
+
hue=pairedby,
|
643
|
+
legend=False,
|
644
|
+
errorbar=None,
|
645
|
+
palette=palette,
|
646
|
+
)
|
647
|
+
jitter = 0 if pairedby else True
|
648
|
+
sns.stripplot(
|
649
|
+
x=groupby,
|
650
|
+
data=df_melt.loc[df_melt["var"] == var],
|
651
|
+
y="val",
|
652
|
+
ax=ax,
|
653
|
+
hue=pairedby,
|
654
|
+
jitter=jitter,
|
655
|
+
size=size,
|
656
|
+
linewidth=1,
|
657
|
+
palette=palette,
|
658
|
+
)
|
659
|
+
|
660
|
+
ax.set_xlabel("")
|
661
|
+
ax.tick_params(
|
662
|
+
axis="x",
|
663
|
+
labelsize=15,
|
664
|
+
)
|
665
|
+
ax.legend().set_visible(False)
|
666
|
+
ax.set_ylabel(y_label)
|
667
|
+
ax.set_title(f"{var}\n{pvalue_template(pvalues[i])}")
|
668
|
+
else:
|
669
|
+
ax.set_visible(False)
|
670
|
+
fig.tight_layout()
|
671
|
+
|
672
|
+
if show_legend is True:
|
673
|
+
axes[n_panels - 1].legend().set_visible(True)
|
674
|
+
axes[n_panels - 1].legend(
|
675
|
+
bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=adata.obs[pairedby].nunique()
|
676
|
+
)
|
677
|
+
|
678
|
+
plt.tight_layout()
|
679
|
+
if return_fig:
|
680
|
+
return plt.gcf()
|
465
681
|
plt.show()
|
682
|
+
return None
|
683
|
+
|
684
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
685
|
+
def plot_fold_change(
|
686
|
+
self,
|
687
|
+
results_df: pd.DataFrame,
|
688
|
+
*,
|
689
|
+
var_names: Sequence[str] = None,
|
690
|
+
n_top_vars: int = 15,
|
691
|
+
log2fc_col: str = "log_fc",
|
692
|
+
symbol_col: str = "variable",
|
693
|
+
y_label: str = "Log2 fold change",
|
694
|
+
figsize: tuple[int, int] = (10, 5),
|
695
|
+
return_fig: bool = False,
|
696
|
+
**barplot_kwargs,
|
697
|
+
) -> Figure | None:
|
698
|
+
"""Plot a metric from the results as a bar chart, optionally with additional information about paired samples in a scatter plot.
|
699
|
+
|
700
|
+
Args:
|
701
|
+
results_df: DataFrame with results from DE analysis.
|
702
|
+
var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
|
703
|
+
n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
|
704
|
+
log2fc_col: Column name of log2 Fold-Change values.
|
705
|
+
symbol_col: Column name of gene IDs.
|
706
|
+
y_label: Label for the y-axis.
|
707
|
+
figsize: Size of the figure.
|
708
|
+
{common_plot_args}
|
709
|
+
**barplot_kwargs: Additional arguments for seaborn.barplot.
|
710
|
+
|
711
|
+
Returns:
|
712
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
713
|
+
|
714
|
+
Examples:
|
715
|
+
>>> # Example with EdgeR
|
716
|
+
>>> import pertpy as pt
|
717
|
+
>>> adata = pt.dt.zhang_2021()
|
718
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
719
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
720
|
+
>>> pdata = ps.compute(
|
721
|
+
... adata,
|
722
|
+
... target_col="Patient",
|
723
|
+
... groups_col="Cluster",
|
724
|
+
... layer_key="counts",
|
725
|
+
... mode="sum",
|
726
|
+
... min_cells=10,
|
727
|
+
... min_counts=1000,
|
728
|
+
... )
|
729
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
730
|
+
>>> edgr.fit()
|
731
|
+
>>> res_df = edgr.test_contrasts(
|
732
|
+
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
|
733
|
+
... )
|
734
|
+
>>> edgr.plot_fold_change(res_df)
|
735
|
+
|
736
|
+
Preview:
|
737
|
+
.. image:: /_static/docstring_previews/de_fold_change.png
|
738
|
+
"""
|
739
|
+
if var_names is None:
|
740
|
+
var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
|
741
|
+
var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
|
742
|
+
assert len(var_names) == 2 * n_top_vars
|
743
|
+
|
744
|
+
df = results_df[results_df[symbol_col].isin(var_names)]
|
745
|
+
df.sort_values(log2fc_col, ascending=False, inplace=True)
|
746
|
+
|
747
|
+
plt.figure(figsize=figsize)
|
748
|
+
sns.barplot(
|
749
|
+
x=symbol_col,
|
750
|
+
y=log2fc_col,
|
751
|
+
data=df,
|
752
|
+
palette="RdBu",
|
753
|
+
legend=False,
|
754
|
+
**barplot_kwargs,
|
755
|
+
)
|
756
|
+
plt.xticks(rotation=90)
|
757
|
+
plt.xlabel("")
|
758
|
+
plt.ylabel(y_label)
|
759
|
+
|
760
|
+
if return_fig:
|
761
|
+
return plt.gcf()
|
762
|
+
plt.show()
|
763
|
+
return None
|
764
|
+
|
765
|
+
@_doc_params(common_plot_args=doc_common_plot_args)
|
766
|
+
def plot_multicomparison_fc(
|
767
|
+
self,
|
768
|
+
results_df: pd.DataFrame,
|
769
|
+
*,
|
770
|
+
n_top_vars=15,
|
771
|
+
contrast_col: str = "contrast",
|
772
|
+
log2fc_col: str = "log_fc",
|
773
|
+
pvalue_col: str = "adj_p_value",
|
774
|
+
symbol_col: str = "variable",
|
775
|
+
marker_size: int = 100,
|
776
|
+
figsize: tuple[int, int] = (10, 2),
|
777
|
+
x_label: str = "Contrast",
|
778
|
+
y_label: str = "Gene",
|
779
|
+
return_fig: bool = False,
|
780
|
+
**heatmap_kwargs,
|
781
|
+
) -> Figure | None:
|
782
|
+
"""Plot a matrix of log2 fold changes from the results.
|
783
|
+
|
784
|
+
Args:
|
785
|
+
results_df: DataFrame with results from DE analysis.
|
786
|
+
n_top_vars: Number of top variables to plot per group.
|
787
|
+
contrast_col: Column in results_df containing information about the contrast.
|
788
|
+
log2fc_col: Column in results_df containing the log2 fold change.
|
789
|
+
pvalue_col: Column in results_df containing the p-value. Can be used to switch between adjusted and unadjusted p-values.
|
790
|
+
symbol_col: Column in results_df containing the gene symbol.
|
791
|
+
marker_size: Size of the biggest marker for significant variables.
|
792
|
+
figsize: Size of the figure.
|
793
|
+
x_label: Label for the x-axis.
|
794
|
+
y_label: Label for the y-axis.
|
795
|
+
{common_plot_args}
|
796
|
+
**heatmap_kwargs: Additional arguments for seaborn.heatmap.
|
797
|
+
|
798
|
+
Returns:
|
799
|
+
If `return_fig` is `True`, returns the figure, otherwise `None`.
|
800
|
+
|
801
|
+
Examples:
|
802
|
+
>>> # Example with EdgeR
|
803
|
+
>>> import pertpy as pt
|
804
|
+
>>> adata = pt.dt.zhang_2021()
|
805
|
+
>>> adata.layers["counts"] = adata.X.copy()
|
806
|
+
>>> ps = pt.tl.PseudobulkSpace()
|
807
|
+
>>> pdata = ps.compute(
|
808
|
+
... adata,
|
809
|
+
... target_col="Patient",
|
810
|
+
... groups_col="Cluster",
|
811
|
+
... layer_key="counts",
|
812
|
+
... mode="sum",
|
813
|
+
... min_cells=10,
|
814
|
+
... min_counts=1000,
|
815
|
+
... )
|
816
|
+
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
|
817
|
+
>>> res_df = edgr.compare_groups(pdata, column="Efficacy", baseline="SD", groups_to_compare=["PR", "PD"])
|
818
|
+
>>> edgr.plot_multicomparison_fc(res_df)
|
819
|
+
|
820
|
+
Preview:
|
821
|
+
.. image:: /_static/docstring_previews/de_multicomparison_fc.png
|
822
|
+
"""
|
823
|
+
groups = results_df[contrast_col].unique().tolist()
|
824
|
+
|
825
|
+
results_df["abs_log_fc"] = results_df[log2fc_col].abs()
|
826
|
+
|
827
|
+
def _get_significance(p_val):
|
828
|
+
if p_val < 0.001:
|
829
|
+
return "< 0.001"
|
830
|
+
elif p_val < 0.01:
|
831
|
+
return "< 0.01"
|
832
|
+
elif p_val < 0.1:
|
833
|
+
return "< 0.1"
|
834
|
+
else:
|
835
|
+
return "n.s."
|
836
|
+
|
837
|
+
results_df["significance"] = results_df[pvalue_col].apply(_get_significance)
|
838
|
+
|
839
|
+
var_names = []
|
840
|
+
for group in groups:
|
841
|
+
var_names += (
|
842
|
+
results_df[results_df[contrast_col] == group]
|
843
|
+
.sort_values("abs_log_fc", ascending=False)
|
844
|
+
.head(n_top_vars)[symbol_col]
|
845
|
+
.tolist()
|
846
|
+
)
|
847
|
+
|
848
|
+
results_df = results_df[results_df[symbol_col].isin(var_names)]
|
849
|
+
df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]
|
850
|
+
|
851
|
+
plt.figure(figsize=figsize)
|
852
|
+
sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
|
853
|
+
|
854
|
+
_size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
|
855
|
+
x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
|
856
|
+
y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
|
857
|
+
|
858
|
+
for _i, row in results_df.iterrows():
|
859
|
+
if row["significance"] != "n.s.":
|
860
|
+
plt.scatter(
|
861
|
+
x=x_locs[x_labels.index(row[symbol_col])],
|
862
|
+
y=y_locs[y_labels.index(row[contrast_col])],
|
863
|
+
s=_size[row["significance"]],
|
864
|
+
marker="*",
|
865
|
+
c="white",
|
866
|
+
)
|
867
|
+
|
868
|
+
plt.scatter([], [], s=marker_size, marker="*", c="black", label="< 0.001")
|
869
|
+
plt.scatter([], [], s=math.floor(marker_size / 2), marker="*", c="black", label="< 0.01")
|
870
|
+
plt.scatter([], [], s=math.floor(marker_size / 4), marker="*", c="black", label="< 0.1")
|
871
|
+
plt.legend(title="Significance", bbox_to_anchor=(1.2, -0.05))
|
872
|
+
|
873
|
+
plt.xlabel(x_label)
|
874
|
+
plt.ylabel(y_label)
|
875
|
+
|
876
|
+
if return_fig:
|
877
|
+
return plt.gcf()
|
878
|
+
plt.show()
|
879
|
+
return None
|
466
880
|
|
467
881
|
|
468
882
|
class LinearModelBase(MethodBase):
|
469
883
|
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
|
470
|
-
"""
|
471
|
-
Initialize the method.
|
884
|
+
"""Initialize the method.
|
472
885
|
|
473
886
|
Args:
|
474
887
|
adata: AnnData object, usually pseudobulked.
|
@@ -480,26 +893,24 @@ class LinearModelBase(MethodBase):
|
|
480
893
|
super().__init__(adata, mask=mask, layer=layer)
|
481
894
|
self._check_counts()
|
482
895
|
|
483
|
-
self.
|
484
|
-
self.variable_to_factors = None
|
485
|
-
|
896
|
+
self.formulaic_contrasts = None
|
486
897
|
if isinstance(design, str):
|
487
|
-
self.
|
488
|
-
self.design =
|
898
|
+
self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
|
899
|
+
self.design = self.formulaic_contrasts.design_matrix
|
489
900
|
else:
|
490
901
|
self.design = design
|
491
902
|
|
492
903
|
@classmethod
|
493
904
|
def compare_groups(
|
494
905
|
cls,
|
495
|
-
adata,
|
496
|
-
column,
|
497
|
-
baseline,
|
498
|
-
groups_to_compare,
|
906
|
+
adata: ad.AnnData,
|
907
|
+
column: str,
|
908
|
+
baseline: str,
|
909
|
+
groups_to_compare: str | Iterable[str],
|
499
910
|
*,
|
500
|
-
paired_by=None,
|
501
|
-
mask=None,
|
502
|
-
layer=None,
|
911
|
+
paired_by: str | None = None,
|
912
|
+
mask: pd.Series | None = None,
|
913
|
+
layer: str | None = None,
|
503
914
|
fit_kwargs=MappingProxyType({}),
|
504
915
|
test_kwargs=MappingProxyType({}),
|
505
916
|
):
|
@@ -525,17 +936,16 @@ class LinearModelBase(MethodBase):
|
|
525
936
|
@property
|
526
937
|
def variables(self):
|
527
938
|
"""Get the names of the variables used in the model definition."""
|
528
|
-
|
529
|
-
return self.design.model_spec.variables_by_source["data"]
|
530
|
-
except AttributeError:
|
939
|
+
if self.formulaic_contrasts is None:
|
531
940
|
raise ValueError(
|
532
941
|
"Retrieving variables is only possible if the model was initialized using a formula."
|
533
942
|
) from None
|
943
|
+
else:
|
944
|
+
return self.formulaic_contrasts.variables
|
534
945
|
|
535
946
|
@abstractmethod
|
536
947
|
def _check_counts(self):
|
537
|
-
"""
|
538
|
-
Check that counts are valid for the specific method.
|
948
|
+
"""Check that counts are valid for the specific method.
|
539
949
|
|
540
950
|
Raises:
|
541
951
|
ValueError: if the data matrix does not comply with the expectations.
|
@@ -544,8 +954,7 @@ class LinearModelBase(MethodBase):
|
|
544
954
|
|
545
955
|
@abstractmethod
|
546
956
|
def fit(self, **kwargs):
|
547
|
-
"""
|
548
|
-
Fit the model.
|
957
|
+
"""Fit the model.
|
549
958
|
|
550
959
|
Args:
|
551
960
|
**kwargs: Additional arguments for fitting the specific method.
|
@@ -555,9 +964,8 @@ class LinearModelBase(MethodBase):
|
|
555
964
|
@abstractmethod
|
556
965
|
def _test_single_contrast(self, contrast, **kwargs): ...
|
557
966
|
|
558
|
-
def test_contrasts(self, contrasts, **kwargs):
|
559
|
-
"""
|
560
|
-
Perform a comparison as specified in a contrast vector.
|
967
|
+
def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
|
968
|
+
"""Perform a comparison as specified in a contrast vector.
|
561
969
|
|
562
970
|
Args:
|
563
971
|
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
|
@@ -573,25 +981,25 @@ class LinearModelBase(MethodBase):
|
|
573
981
|
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
|
574
982
|
|
575
983
|
results_df = pd.concat(results)
|
984
|
+
|
576
985
|
return results_df
|
577
986
|
|
578
987
|
def test_reduced(self, modelB):
|
579
|
-
"""
|
580
|
-
Test against a reduced model.
|
988
|
+
"""Test against a reduced model.
|
581
989
|
|
582
990
|
Args:
|
583
991
|
modelB: the reduced model against which to test.
|
584
992
|
|
585
993
|
Example:
|
586
|
-
|
587
|
-
|
588
|
-
|
994
|
+
>>> import pertpy as pt
|
995
|
+
>>> modelA = Model().fit()
|
996
|
+
>>> modelB = Model().fit()
|
997
|
+
>>> modelA.test_reduced(modelB)
|
589
998
|
"""
|
590
999
|
raise NotImplementedError
|
591
1000
|
|
592
1001
|
def cond(self, **kwargs):
|
593
|
-
"""
|
594
|
-
Get a contrast vector representing a specific condition.
|
1002
|
+
"""Get a contrast vector representing a specific condition.
|
595
1003
|
|
596
1004
|
Args:
|
597
1005
|
**kwargs: column/value pairs.
|
@@ -599,52 +1007,14 @@ class LinearModelBase(MethodBase):
|
|
599
1007
|
Returns:
|
600
1008
|
A contrast vector that aligns to the columns of the design matrix.
|
601
1009
|
"""
|
602
|
-
if self.
|
1010
|
+
if self.formulaic_contrasts is None:
|
603
1011
|
raise RuntimeError(
|
604
1012
|
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
605
1013
|
)
|
606
|
-
|
607
|
-
if not set(cond_dict.keys()).issubset(self.variables):
|
608
|
-
raise ValueError(
|
609
|
-
"You specified a variable that is not part of the model. Available variables: "
|
610
|
-
+ ",".join(self.variables)
|
611
|
-
)
|
612
|
-
for var in self.variables:
|
613
|
-
if var in cond_dict:
|
614
|
-
self._check_category(var, cond_dict[var])
|
615
|
-
else:
|
616
|
-
cond_dict[var] = self._get_default_value(var)
|
617
|
-
df = pd.DataFrame([kwargs])
|
618
|
-
return self.design.model_spec.get_model_matrix(df).iloc[0]
|
619
|
-
|
620
|
-
def _get_factor_metadata_for_variable(self, var):
|
621
|
-
factors = self.variable_to_factors[var]
|
622
|
-
return list(chain.from_iterable(self.factor_storage[f] for f in factors))
|
623
|
-
|
624
|
-
def _get_default_value(self, var):
|
625
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
626
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
|
627
|
-
try:
|
628
|
-
tmp_base = resolve_ambiguous(factor_metadata, "base")
|
629
|
-
except AmbiguousAttributeError as e:
|
630
|
-
raise ValueError(
|
631
|
-
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
|
632
|
-
) from e
|
633
|
-
return tmp_base if tmp_base is not None else "\0"
|
634
|
-
else:
|
635
|
-
return 0
|
636
|
-
|
637
|
-
def _check_category(self, var, value):
|
638
|
-
factor_metadata = self._get_factor_metadata_for_variable(var)
|
639
|
-
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
|
640
|
-
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
|
641
|
-
raise ValueError(
|
642
|
-
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
|
643
|
-
)
|
1014
|
+
return self.formulaic_contrasts.cond(**kwargs)
|
644
1015
|
|
645
|
-
def contrast(self,
|
646
|
-
"""
|
647
|
-
Build a simple contrast for pairwise comparisons.
|
1016
|
+
def contrast(self, *args, **kwargs):
|
1017
|
+
"""Build a simple contrast for pairwise comparisons.
|
648
1018
|
|
649
1019
|
Args:
|
650
1020
|
column: column in adata.obs to test on.
|
@@ -654,4 +1024,8 @@ class LinearModelBase(MethodBase):
|
|
654
1024
|
Returns:
|
655
1025
|
Numeric contrast vector.
|
656
1026
|
"""
|
657
|
-
|
1027
|
+
if self.formulaic_contrasts is None:
|
1028
|
+
raise RuntimeError(
|
1029
|
+
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
1030
|
+
)
|
1031
|
+
return self.formulaic_contrasts.contrast(*args, **kwargs)
|