pertpy 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +2 -3
- pertpy/tools/__init__.py +42 -4
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +76 -53
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +20 -22
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +22 -13
- pertpy-0.8.0.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,657 @@
|
|
1
|
+
import os
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from itertools import chain
|
5
|
+
from types import MappingProxyType
|
6
|
+
|
7
|
+
import adjustText
|
8
|
+
import anndata as ad
|
9
|
+
import matplotlib.patheffects as PathEffects
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import seaborn as sns
|
14
|
+
from matplotlib.ticker import MaxNLocator
|
15
|
+
|
16
|
+
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
|
17
|
+
from pertpy.tools._differential_gene_expression._formulaic import (
|
18
|
+
AmbiguousAttributeError,
|
19
|
+
Factor,
|
20
|
+
get_factor_storage_and_materializer,
|
21
|
+
resolve_ambiguous,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class Contrast:
|
27
|
+
"""Simple contrast for comparison between groups"""
|
28
|
+
|
29
|
+
column: str
|
30
|
+
baseline: str
|
31
|
+
group_to_compare: str
|
32
|
+
|
33
|
+
|
34
|
+
ContrastType = Contrast | tuple[str, str, str]
|
35
|
+
|
36
|
+
|
37
|
+
class MethodBase(ABC):
|
38
|
+
def __init__(self, adata, *, mask=None, layer=None, **kwargs):
|
39
|
+
"""
|
40
|
+
Initialize the method.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
adata: AnnData object, usually pseudobulked.
|
44
|
+
mask: A column in `adata.var` that contains a boolean mask with selected features.
|
45
|
+
layer: Layer to use in fit(). If None, use the X array.
|
46
|
+
**kwargs: Keyword arguments specific to the method implementation.
|
47
|
+
"""
|
48
|
+
self.adata = adata
|
49
|
+
if mask is not None:
|
50
|
+
self.adata = self.adata[:, self.adata.var[mask]]
|
51
|
+
|
52
|
+
self.layer = layer
|
53
|
+
check_is_numeric_matrix(self.data)
|
54
|
+
|
55
|
+
@property
|
56
|
+
def data(self):
|
57
|
+
"""Get the data matrix from anndata this object was initalized with (X or layer)."""
|
58
|
+
if self.layer is None:
|
59
|
+
return self.adata.X
|
60
|
+
else:
|
61
|
+
return self.adata.layer[self.layer]
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
@abstractmethod
|
65
|
+
def compare_groups(
|
66
|
+
cls,
|
67
|
+
adata,
|
68
|
+
column,
|
69
|
+
baseline,
|
70
|
+
groups_to_compare,
|
71
|
+
*,
|
72
|
+
paired_by=None,
|
73
|
+
mask=None,
|
74
|
+
layer=None,
|
75
|
+
fit_kwargs=MappingProxyType({}),
|
76
|
+
test_kwargs=MappingProxyType({}),
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Compare between groups in a specified column.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
adata: AnnData object.
|
83
|
+
column: column in obs that contains the grouping information.
|
84
|
+
baseline: baseline value (one category from variable).
|
85
|
+
groups_to_compare: One or multiple categories from variable to compare against baseline.
|
86
|
+
paired_by: Column from `obs` that contains information about paired sample (e.g. subject_id).
|
87
|
+
mask: Subset anndata by a boolean mask stored in this column in `.obs` before making any tests.
|
88
|
+
layer: Use this layer instead of `.X`.
|
89
|
+
fit_kwargs: Additional fit options.
|
90
|
+
test_kwargs: Additional test options.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
|
94
|
+
"""
|
95
|
+
...
|
96
|
+
|
97
|
+
def plot_volcano(
|
98
|
+
self,
|
99
|
+
data: pd.DataFrame | ad.AnnData,
|
100
|
+
*,
|
101
|
+
log2fc_col: str = "log_fc",
|
102
|
+
pvalue_col: str = "adj_p_value",
|
103
|
+
symbol_col: str = "variable",
|
104
|
+
pval_thresh: float = 0.05,
|
105
|
+
log2fc_thresh: float = 0.75,
|
106
|
+
to_label: int | list[str] = 5,
|
107
|
+
s_curve: bool | None = False,
|
108
|
+
colors: list[str] = None,
|
109
|
+
varm_key: str | None = None,
|
110
|
+
color_dict: dict[str, list[str]] | None = None,
|
111
|
+
shape_dict: dict[str, list[str]] | None = None,
|
112
|
+
size_col: str | None = None,
|
113
|
+
fontsize: int = 10,
|
114
|
+
top_right_frame: bool = False,
|
115
|
+
figsize: tuple[int, int] = (5, 5),
|
116
|
+
legend_pos: tuple[float, float] = (1.6, 1),
|
117
|
+
point_sizes: tuple[int, int] = (15, 150),
|
118
|
+
save: bool | str | None = None,
|
119
|
+
shapes: list[str] | None = None,
|
120
|
+
shape_order: list[str] | None = None,
|
121
|
+
x_label: str | None = None,
|
122
|
+
y_label: str | None = None,
|
123
|
+
**kwargs: int,
|
124
|
+
) -> None:
|
125
|
+
"""Creates a volcano plot from a pandas DataFrame or Anndata.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
data: DataFrame or Anndata to plot.
|
129
|
+
log2fc_col: Column name of log2 Fold-Change values.
|
130
|
+
pvalue_col: Column name of the p values.
|
131
|
+
symbol_col: Column name of gene IDs.
|
132
|
+
varm_key: Key in Anndata.varm slot to use for plotting if an Anndata object was passed.
|
133
|
+
size_col: Column name to size points by.
|
134
|
+
point_sizes: Lower and upper bounds of point sizes.
|
135
|
+
pval_thresh: Threshold p value for significance.
|
136
|
+
log2fc_thresh: Threshold for log2 fold change significance.
|
137
|
+
to_label: Number of top genes or list of genes to label.
|
138
|
+
s_curve: Whether to use a reciprocal threshold for up and down gene determination.
|
139
|
+
color_dict: Dictionary for coloring dots by categories.
|
140
|
+
shape_dict: Dictionary for shaping dots by categories.
|
141
|
+
fontsize: Size of gene labels.
|
142
|
+
colors: Colors for [non-DE, up, down] genes. Defaults to ['gray', '#D62728', '#1F77B4'].
|
143
|
+
top_right_frame: Whether to show the top and right frame of the plot.
|
144
|
+
figsize: Size of the figure.
|
145
|
+
legend_pos: Position of the legend as determined by matplotlib.
|
146
|
+
save: Saves the plot if True or to the path provided.
|
147
|
+
shapes: List of matplotlib marker ids.
|
148
|
+
shape_order: Order of categories for shapes.
|
149
|
+
x_label: Label for the x-axis.
|
150
|
+
y_label: Label for the y-axis.
|
151
|
+
**kwargs: Additional arguments for seaborn.scatterplot.
|
152
|
+
"""
|
153
|
+
if colors is None:
|
154
|
+
colors = ["gray", "#D62728", "#1F77B4"]
|
155
|
+
|
156
|
+
def _pval_reciprocal(lfc: float) -> float:
|
157
|
+
"""
|
158
|
+
Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
|
159
|
+
|
160
|
+
Used for plotting the S-curve
|
161
|
+
"""
|
162
|
+
return pval_thresh / (lfc - log2fc_thresh)
|
163
|
+
|
164
|
+
def _map_shape(symbol: str) -> str:
|
165
|
+
if shape_dict is not None:
|
166
|
+
for k in shape_dict.keys():
|
167
|
+
if shape_dict[k] is not None and symbol in shape_dict[k]:
|
168
|
+
return k
|
169
|
+
return "other"
|
170
|
+
|
171
|
+
# TODO join the two mapping functions
|
172
|
+
def _map_genes_categories(
|
173
|
+
row: pd.Series,
|
174
|
+
log2fc_col: str,
|
175
|
+
nlog10_col: str,
|
176
|
+
log2fc_thresh: float,
|
177
|
+
pval_thresh: float = None,
|
178
|
+
s_curve: bool = False,
|
179
|
+
) -> str:
|
180
|
+
"""
|
181
|
+
Map genes to categorize based on log2fc and pvalue.
|
182
|
+
|
183
|
+
These categories are used for coloring the dots.
|
184
|
+
Used when no color_dict is passed, sets up/down/nonsignificant.
|
185
|
+
"""
|
186
|
+
log2fc = row[log2fc_col]
|
187
|
+
nlog10 = row[nlog10_col]
|
188
|
+
|
189
|
+
if s_curve:
|
190
|
+
# S-curve condition for Up or Down categorization
|
191
|
+
reciprocal_thresh = _pval_reciprocal(abs(log2fc))
|
192
|
+
if log2fc > log2fc_thresh and nlog10 > reciprocal_thresh:
|
193
|
+
return "Up"
|
194
|
+
elif log2fc < -log2fc_thresh and nlog10 > reciprocal_thresh:
|
195
|
+
return "Down"
|
196
|
+
else:
|
197
|
+
return "not DE"
|
198
|
+
else:
|
199
|
+
# Standard condition for Up or Down categorization
|
200
|
+
if log2fc > log2fc_thresh and nlog10 > pval_thresh:
|
201
|
+
return "Up"
|
202
|
+
elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
|
203
|
+
return "Down"
|
204
|
+
else:
|
205
|
+
return "not DE"
|
206
|
+
|
207
|
+
def _map_genes_categories_highlight(
|
208
|
+
row: pd.Series,
|
209
|
+
log2fc_col: str,
|
210
|
+
nlog10_col: str,
|
211
|
+
log2fc_thresh: float,
|
212
|
+
pval_thresh: float = None,
|
213
|
+
s_curve: bool = False,
|
214
|
+
symbol_col: str = None,
|
215
|
+
) -> str:
|
216
|
+
"""
|
217
|
+
Map genes to categorize based on log2fc and pvalue.
|
218
|
+
|
219
|
+
These categories are used for coloring the dots.
|
220
|
+
Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
|
221
|
+
"""
|
222
|
+
log2fc = row[log2fc_col]
|
223
|
+
nlog10 = row[nlog10_col]
|
224
|
+
symbol = row[symbol_col]
|
225
|
+
|
226
|
+
if color_dict is not None:
|
227
|
+
for k in color_dict.keys():
|
228
|
+
if symbol in color_dict[k]:
|
229
|
+
return k
|
230
|
+
|
231
|
+
if s_curve:
|
232
|
+
# Use S-curve condition for filtering DE
|
233
|
+
if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_thresh:
|
234
|
+
return "DE"
|
235
|
+
return "not DE"
|
236
|
+
else:
|
237
|
+
# Use standard condition for filtering DE
|
238
|
+
if abs(log2fc) < log2fc_thresh or nlog10 < pval_thresh:
|
239
|
+
return "not DE"
|
240
|
+
return "DE"
|
241
|
+
|
242
|
+
if isinstance(data, ad.AnnData):
|
243
|
+
if varm_key is None:
|
244
|
+
raise ValueError("Please pass a .varm key to use for plotting")
|
245
|
+
|
246
|
+
raise NotImplementedError("Anndata not implemented yet")
|
247
|
+
df = data.varm[varm_key].copy()
|
248
|
+
|
249
|
+
df = data.copy(deep=True)
|
250
|
+
|
251
|
+
# clean and replace 0s as they would lead to -inf
|
252
|
+
if df[[log2fc_col, pvalue_col]].isnull().values.any():
|
253
|
+
print("NaNs encountered, dropping rows with NaNs")
|
254
|
+
df = df.dropna(subset=[log2fc_col, pvalue_col])
|
255
|
+
|
256
|
+
if df[pvalue_col].min() == 0:
|
257
|
+
print("0s encountered for p value, replacing with 1e-323")
|
258
|
+
df.loc[df[pvalue_col] == 0, pvalue_col] = 1e-323
|
259
|
+
|
260
|
+
# convert p value threshold to nlog10
|
261
|
+
pval_thresh = -np.log10(pval_thresh)
|
262
|
+
# make nlog10 column
|
263
|
+
df["nlog10"] = -np.log10(df[pvalue_col])
|
264
|
+
y_max = df["nlog10"].max() + 1
|
265
|
+
# make a column to pick top genes
|
266
|
+
df["top_genes"] = df["nlog10"] * df[log2fc_col]
|
267
|
+
|
268
|
+
# Label everything with assigned color / shape
|
269
|
+
if shape_dict or color_dict:
|
270
|
+
combined_labels = []
|
271
|
+
if isinstance(shape_dict, dict):
|
272
|
+
combined_labels.extend([item for sublist in shape_dict.values() for item in sublist])
|
273
|
+
if isinstance(color_dict, dict):
|
274
|
+
combined_labels.extend([item for sublist in color_dict.values() for item in sublist])
|
275
|
+
label_df = df[df[symbol_col].isin(combined_labels)]
|
276
|
+
|
277
|
+
# Label top n_gens
|
278
|
+
elif isinstance(to_label, int):
|
279
|
+
label_df = pd.concat(
|
280
|
+
(
|
281
|
+
df.sort_values("top_genes")[-to_label:],
|
282
|
+
df.sort_values("top_genes")[0:to_label],
|
283
|
+
)
|
284
|
+
)
|
285
|
+
|
286
|
+
# assume that a list of genes was passed to label
|
287
|
+
else:
|
288
|
+
label_df = df[df[symbol_col].isin(to_label)]
|
289
|
+
|
290
|
+
# By default mode colors by up/down if no dict is passed
|
291
|
+
|
292
|
+
if color_dict is None:
|
293
|
+
df["color"] = df.apply(
|
294
|
+
lambda row: _map_genes_categories(
|
295
|
+
row,
|
296
|
+
log2fc_col=log2fc_col,
|
297
|
+
nlog10_col="nlog10",
|
298
|
+
log2fc_thresh=log2fc_thresh,
|
299
|
+
pval_thresh=pval_thresh,
|
300
|
+
s_curve=s_curve,
|
301
|
+
),
|
302
|
+
axis=1,
|
303
|
+
)
|
304
|
+
|
305
|
+
# order of colors
|
306
|
+
hues = ["not DE", "Up", "Down"][: len(df.color.unique())]
|
307
|
+
|
308
|
+
else:
|
309
|
+
df["color"] = df.apply(
|
310
|
+
lambda row: _map_genes_categories_highlight(
|
311
|
+
row,
|
312
|
+
log2fc_col=log2fc_col,
|
313
|
+
nlog10_col="nlog10",
|
314
|
+
log2fc_thresh=log2fc_thresh,
|
315
|
+
pval_thresh=pval_thresh,
|
316
|
+
symbol_col=symbol_col,
|
317
|
+
s_curve=s_curve,
|
318
|
+
),
|
319
|
+
axis=1,
|
320
|
+
)
|
321
|
+
|
322
|
+
user_added_cats = [x for x in df.color.unique() if x not in ["DE", "not DE"]]
|
323
|
+
hues = ["DE", "not DE"] + user_added_cats
|
324
|
+
|
325
|
+
# order of colors
|
326
|
+
hues = hues[: len(df.color.unique())]
|
327
|
+
colors = [
|
328
|
+
"dimgrey",
|
329
|
+
"lightgrey",
|
330
|
+
"tab:blue",
|
331
|
+
"tab:orange",
|
332
|
+
"tab:green",
|
333
|
+
"tab:red",
|
334
|
+
"tab:purple",
|
335
|
+
"tab:brown",
|
336
|
+
"tab:pink",
|
337
|
+
"tab:olive",
|
338
|
+
"tab:cyan",
|
339
|
+
]
|
340
|
+
|
341
|
+
# coloring if dictionary passed, subtle background + highlight
|
342
|
+
# map shapes if dictionary exists
|
343
|
+
if shape_dict is not None:
|
344
|
+
df["shape"] = df[symbol_col].map(_map_shape)
|
345
|
+
user_added_cats = [x for x in df["shape"].unique() if x != "other"]
|
346
|
+
shape_order = ["other"] + user_added_cats
|
347
|
+
if shapes is None:
|
348
|
+
shapes = ["o", "^", "s", "X", "*", "d"]
|
349
|
+
shapes = shapes[: len(df["shape"].unique())]
|
350
|
+
shape_col = "shape"
|
351
|
+
else:
|
352
|
+
shape_col = None
|
353
|
+
|
354
|
+
# build palette
|
355
|
+
colors = colors[: len(df.color.unique())]
|
356
|
+
|
357
|
+
# We want plot highlighted genes on top + at bigger size, split dataframe
|
358
|
+
df_highlight = None
|
359
|
+
if shape_dict or color_dict:
|
360
|
+
label_genes = label_df[symbol_col].unique()
|
361
|
+
df_highlight = df[df[symbol_col].isin(label_genes)]
|
362
|
+
df = df[~df[symbol_col].isin(label_genes)]
|
363
|
+
|
364
|
+
plt.figure(figsize=figsize)
|
365
|
+
# Plot non-highlighted genes
|
366
|
+
ax = sns.scatterplot(
|
367
|
+
data=df,
|
368
|
+
x=log2fc_col,
|
369
|
+
y="nlog10",
|
370
|
+
hue="color",
|
371
|
+
hue_order=hues,
|
372
|
+
palette=colors,
|
373
|
+
size=size_col,
|
374
|
+
sizes=point_sizes,
|
375
|
+
style=shape_col,
|
376
|
+
style_order=shape_order,
|
377
|
+
markers=shapes,
|
378
|
+
**kwargs,
|
379
|
+
)
|
380
|
+
# Plot highlighted genes
|
381
|
+
if df_highlight is not None:
|
382
|
+
ax = sns.scatterplot(
|
383
|
+
data=df_highlight,
|
384
|
+
x=log2fc_col,
|
385
|
+
y="nlog10",
|
386
|
+
hue="color",
|
387
|
+
hue_order=hues,
|
388
|
+
palette=colors,
|
389
|
+
size=size_col,
|
390
|
+
sizes=point_sizes,
|
391
|
+
style=shape_col,
|
392
|
+
style_order=shape_order,
|
393
|
+
markers=shapes,
|
394
|
+
legend=False,
|
395
|
+
edgecolor="black",
|
396
|
+
linewidth=1,
|
397
|
+
**kwargs,
|
398
|
+
)
|
399
|
+
|
400
|
+
# plot vertical and horizontal lines
|
401
|
+
if s_curve:
|
402
|
+
x = np.arange((log2fc_thresh + 0.000001), y_max, 0.01)
|
403
|
+
y = _pval_reciprocal(x)
|
404
|
+
ax.plot(x, y, zorder=1, c="k", lw=2, ls="--")
|
405
|
+
ax.plot(-x, y, zorder=1, c="k", lw=2, ls="--")
|
406
|
+
|
407
|
+
else:
|
408
|
+
ax.axhline(pval_thresh, zorder=1, c="k", lw=2, ls="--")
|
409
|
+
ax.axvline(log2fc_thresh, zorder=1, c="k", lw=2, ls="--")
|
410
|
+
ax.axvline(log2fc_thresh * -1, zorder=1, c="k", lw=2, ls="--")
|
411
|
+
plt.ylim(0, y_max)
|
412
|
+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
413
|
+
|
414
|
+
# make labels
|
415
|
+
texts = []
|
416
|
+
for i in range(len(label_df)):
|
417
|
+
txt = plt.text(
|
418
|
+
x=label_df.iloc[i][log2fc_col],
|
419
|
+
y=label_df.iloc[i].nlog10,
|
420
|
+
s=label_df.iloc[i][symbol_col],
|
421
|
+
fontsize=fontsize,
|
422
|
+
)
|
423
|
+
|
424
|
+
txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground="w")])
|
425
|
+
texts.append(txt)
|
426
|
+
|
427
|
+
adjustText.adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k", "zorder": 5})
|
428
|
+
|
429
|
+
# make things pretty
|
430
|
+
for axis in ["bottom", "left", "top", "right"]:
|
431
|
+
ax.spines[axis].set_linewidth(2)
|
432
|
+
|
433
|
+
if not top_right_frame:
|
434
|
+
ax.spines["top"].set_visible(False)
|
435
|
+
ax.spines["right"].set_visible(False)
|
436
|
+
|
437
|
+
ax.tick_params(width=2)
|
438
|
+
plt.xticks(size=11, fontsize=10)
|
439
|
+
plt.yticks(size=11)
|
440
|
+
|
441
|
+
# Set default axis titles
|
442
|
+
if x_label is None:
|
443
|
+
x_label = log2fc_col
|
444
|
+
if y_label is None:
|
445
|
+
y_label = f"-$log_{{10}}$ {pvalue_col}"
|
446
|
+
|
447
|
+
plt.xlabel(x_label, size=15)
|
448
|
+
plt.ylabel(y_label, size=15)
|
449
|
+
|
450
|
+
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
451
|
+
|
452
|
+
# TODO replace with scanpy save style
|
453
|
+
if save:
|
454
|
+
files = os.listdir()
|
455
|
+
for x in range(100):
|
456
|
+
file_pref = "volcano_" + "%02d" % (x,)
|
457
|
+
if len([x for x in files if x.startswith(file_pref)]) == 0:
|
458
|
+
plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
|
459
|
+
plt.savefig(file_pref + ".svg", bbox_inches="tight")
|
460
|
+
break
|
461
|
+
elif isinstance(save, str):
|
462
|
+
plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
|
463
|
+
plt.savefig(save + ".svg", bbox_inches="tight")
|
464
|
+
|
465
|
+
plt.show()
|
466
|
+
|
467
|
+
|
468
|
+
class LinearModelBase(MethodBase):
|
469
|
+
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
|
470
|
+
"""
|
471
|
+
Initialize the method.
|
472
|
+
|
473
|
+
Args:
|
474
|
+
adata: AnnData object, usually pseudobulked.
|
475
|
+
design: Model design. Can be either a design matrix, a formulaic formula.Formulaic formula in the format 'x + z' or '~x+z'.
|
476
|
+
mask: A column in adata.var that contains a boolean mask with selected features.
|
477
|
+
layer: Layer to use in fit(). If None, use the X array.
|
478
|
+
**kwargs: Keyword arguments specific to the method implementation.
|
479
|
+
"""
|
480
|
+
super().__init__(adata, mask=mask, layer=layer)
|
481
|
+
self._check_counts()
|
482
|
+
|
483
|
+
self.factor_storage = None
|
484
|
+
self.variable_to_factors = None
|
485
|
+
|
486
|
+
if isinstance(design, str):
|
487
|
+
self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
|
488
|
+
self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
|
489
|
+
else:
|
490
|
+
self.design = design
|
491
|
+
|
492
|
+
@classmethod
|
493
|
+
def compare_groups(
|
494
|
+
cls,
|
495
|
+
adata,
|
496
|
+
column,
|
497
|
+
baseline,
|
498
|
+
groups_to_compare,
|
499
|
+
*,
|
500
|
+
paired_by=None,
|
501
|
+
mask=None,
|
502
|
+
layer=None,
|
503
|
+
fit_kwargs=MappingProxyType({}),
|
504
|
+
test_kwargs=MappingProxyType({}),
|
505
|
+
):
|
506
|
+
design = f"~{column}"
|
507
|
+
if paired_by is not None:
|
508
|
+
design += f"+{paired_by}"
|
509
|
+
if isinstance(groups_to_compare, str):
|
510
|
+
groups_to_compare = [groups_to_compare]
|
511
|
+
model = cls(adata, design=design, mask=mask, layer=layer)
|
512
|
+
|
513
|
+
model.fit(**fit_kwargs)
|
514
|
+
|
515
|
+
de_res = model.test_contrasts(
|
516
|
+
{
|
517
|
+
group_to_compare: model.contrast(column=column, baseline=baseline, group_to_compare=group_to_compare)
|
518
|
+
for group_to_compare in groups_to_compare
|
519
|
+
},
|
520
|
+
**test_kwargs,
|
521
|
+
)
|
522
|
+
|
523
|
+
return de_res
|
524
|
+
|
525
|
+
@property
|
526
|
+
def variables(self):
|
527
|
+
"""Get the names of the variables used in the model definition."""
|
528
|
+
try:
|
529
|
+
return self.design.model_spec.variables_by_source["data"]
|
530
|
+
except AttributeError:
|
531
|
+
raise ValueError(
|
532
|
+
"Retrieving variables is only possible if the model was initialized using a formula."
|
533
|
+
) from None
|
534
|
+
|
535
|
+
@abstractmethod
|
536
|
+
def _check_counts(self):
|
537
|
+
"""
|
538
|
+
Check that counts are valid for the specific method.
|
539
|
+
|
540
|
+
Raises:
|
541
|
+
ValueError: if the data matrix does not comply with the expectations.
|
542
|
+
"""
|
543
|
+
...
|
544
|
+
|
545
|
+
@abstractmethod
|
546
|
+
def fit(self, **kwargs):
|
547
|
+
"""
|
548
|
+
Fit the model.
|
549
|
+
|
550
|
+
Args:
|
551
|
+
**kwargs: Additional arguments for fitting the specific method.
|
552
|
+
"""
|
553
|
+
...
|
554
|
+
|
555
|
+
@abstractmethod
|
556
|
+
def _test_single_contrast(self, contrast, **kwargs): ...
|
557
|
+
|
558
|
+
def test_contrasts(self, contrasts, **kwargs):
|
559
|
+
"""
|
560
|
+
Perform a comparison as specified in a contrast vector.
|
561
|
+
|
562
|
+
Args:
|
563
|
+
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
|
564
|
+
**kwargs: passed to the respective implementation.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
A dataframe with the results.
|
568
|
+
"""
|
569
|
+
if not isinstance(contrasts, dict):
|
570
|
+
contrasts = {None: contrasts}
|
571
|
+
results = []
|
572
|
+
for name, contrast in contrasts.items():
|
573
|
+
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
|
574
|
+
|
575
|
+
results_df = pd.concat(results)
|
576
|
+
return results_df
|
577
|
+
|
578
|
+
def test_reduced(self, modelB):
|
579
|
+
"""
|
580
|
+
Test against a reduced model.
|
581
|
+
|
582
|
+
Args:
|
583
|
+
modelB: the reduced model against which to test.
|
584
|
+
|
585
|
+
Example:
|
586
|
+
modelA = Model().fit()
|
587
|
+
modelB = Model().fit()
|
588
|
+
modelA.test_reduced(modelB)
|
589
|
+
"""
|
590
|
+
raise NotImplementedError
|
591
|
+
|
592
|
+
def cond(self, **kwargs):
|
593
|
+
"""
|
594
|
+
Get a contrast vector representing a specific condition.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
**kwargs: column/value pairs.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
A contrast vector that aligns to the columns of the design matrix.
|
601
|
+
"""
|
602
|
+
if self.factor_storage is None:
|
603
|
+
raise RuntimeError(
|
604
|
+
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
605
|
+
)
|
606
|
+
cond_dict = kwargs
|
607
|
+
if not set(cond_dict.keys()).issubset(self.variables):
|
608
|
+
raise ValueError(
|
609
|
+
"You specified a variable that is not part of the model. Available variables: "
|
610
|
+
+ ",".join(self.variables)
|
611
|
+
)
|
612
|
+
for var in self.variables:
|
613
|
+
if var in cond_dict:
|
614
|
+
self._check_category(var, cond_dict[var])
|
615
|
+
else:
|
616
|
+
cond_dict[var] = self._get_default_value(var)
|
617
|
+
df = pd.DataFrame([kwargs])
|
618
|
+
return self.design.model_spec.get_model_matrix(df).iloc[0]
|
619
|
+
|
620
|
+
def _get_factor_metadata_for_variable(self, var):
|
621
|
+
factors = self.variable_to_factors[var]
|
622
|
+
return list(chain.from_iterable(self.factor_storage[f] for f in factors))
|
623
|
+
|
624
|
+
def _get_default_value(self, var):
|
625
|
+
factor_metadata = self._get_factor_metadata_for_variable(var)
|
626
|
+
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
|
627
|
+
try:
|
628
|
+
tmp_base = resolve_ambiguous(factor_metadata, "base")
|
629
|
+
except AmbiguousAttributeError as e:
|
630
|
+
raise ValueError(
|
631
|
+
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
|
632
|
+
) from e
|
633
|
+
return tmp_base if tmp_base is not None else "\0"
|
634
|
+
else:
|
635
|
+
return 0
|
636
|
+
|
637
|
+
def _check_category(self, var, value):
|
638
|
+
factor_metadata = self._get_factor_metadata_for_variable(var)
|
639
|
+
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
|
640
|
+
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
|
641
|
+
raise ValueError(
|
642
|
+
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
|
643
|
+
)
|
644
|
+
|
645
|
+
def contrast(self, column, baseline, group_to_compare):
|
646
|
+
"""
|
647
|
+
Build a simple contrast for pairwise comparisons.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
column: column in adata.obs to test on.
|
651
|
+
baseline: baseline category (denominator).
|
652
|
+
group_to_compare: category to compare against baseline (nominator).
|
653
|
+
|
654
|
+
Returns:
|
655
|
+
Numeric contrast vector.
|
656
|
+
"""
|
657
|
+
return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from scipy.sparse import issparse, spmatrix
|
3
|
+
|
4
|
+
|
5
|
+
def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
|
6
|
+
"""Check if a matrix is numeric and only contains finite/non-NA values.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
array: Dense or sparse matrix to check.
|
10
|
+
|
11
|
+
Raises:
|
12
|
+
ValueError: If the matrix is not numeric or contains NaNs or infinite values.
|
13
|
+
"""
|
14
|
+
if not np.issubdtype(array.dtype, np.number):
|
15
|
+
raise ValueError("Counts must be numeric.")
|
16
|
+
if issparse(array):
|
17
|
+
if np.any(~np.isfinite(array.data)):
|
18
|
+
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
19
|
+
else:
|
20
|
+
if np.any(~np.isfinite(array)):
|
21
|
+
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
22
|
+
|
23
|
+
|
24
|
+
def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
|
25
|
+
"""Check if a matrix container integers, or floats that are close to integers.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
array: Dense or sparse matrix to check.
|
29
|
+
tolerance: Values must be this close to integers.
|
30
|
+
|
31
|
+
Raises:
|
32
|
+
ValueError: If the matrix contains values that are not close to integers.
|
33
|
+
"""
|
34
|
+
if issparse(array):
|
35
|
+
if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
|
36
|
+
raise ValueError("Non-zero elements of the matrix must be close to integer values.")
|
37
|
+
else:
|
38
|
+
if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
|
39
|
+
raise ValueError("Matrix must be a count matrix.")
|
40
|
+
if (array < 0).sum() > 0:
|
41
|
+
raise ValueError("Non-zero elements of the matrix must be positive.")
|