pertpy 0.7.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pertpy/__init__.py +2 -1
- pertpy/data/__init__.py +61 -0
- pertpy/data/_dataloader.py +27 -23
- pertpy/data/_datasets.py +58 -0
- pertpy/metadata/__init__.py +2 -0
- pertpy/metadata/_cell_line.py +39 -70
- pertpy/metadata/_compound.py +3 -4
- pertpy/metadata/_drug.py +2 -6
- pertpy/metadata/_look_up.py +38 -51
- pertpy/metadata/_metadata.py +7 -10
- pertpy/metadata/_moa.py +2 -6
- pertpy/plot/__init__.py +0 -5
- pertpy/preprocessing/__init__.py +2 -0
- pertpy/preprocessing/_guide_rna.py +6 -7
- pertpy/tools/__init__.py +67 -6
- pertpy/tools/_augur.py +14 -15
- pertpy/tools/_cinemaot.py +2 -2
- pertpy/tools/_coda/_base_coda.py +118 -142
- pertpy/tools/_coda/_sccoda.py +16 -15
- pertpy/tools/_coda/_tasccoda.py +21 -22
- pertpy/tools/_dialogue.py +18 -23
- pertpy/tools/_differential_gene_expression/__init__.py +20 -0
- pertpy/tools/_differential_gene_expression/_base.py +657 -0
- pertpy/tools/_differential_gene_expression/_checks.py +41 -0
- pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
- pertpy/tools/_differential_gene_expression/_edger.py +125 -0
- pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
- pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
- pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
- pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
- pertpy/tools/_distances/_distance_tests.py +21 -16
- pertpy/tools/_distances/_distances.py +406 -70
- pertpy/tools/_enrichment.py +10 -15
- pertpy/tools/_kernel_pca.py +1 -1
- pertpy/tools/_milo.py +77 -54
- pertpy/tools/_mixscape.py +15 -11
- pertpy/tools/_perturbation_space/_clustering.py +5 -2
- pertpy/tools/_perturbation_space/_comparison.py +112 -0
- pertpy/tools/_perturbation_space/_discriminator_classifiers.py +21 -23
- pertpy/tools/_perturbation_space/_perturbation_space.py +23 -21
- pertpy/tools/_perturbation_space/_simple.py +3 -3
- pertpy/tools/_scgen/__init__.py +1 -1
- pertpy/tools/_scgen/_base_components.py +2 -3
- pertpy/tools/_scgen/_scgen.py +33 -28
- pertpy/tools/_scgen/_utils.py +2 -2
- {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/METADATA +32 -14
- pertpy-0.9.1.dist-info/RECORD +57 -0
- {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/WHEEL +1 -1
- pertpy/plot/_augur.py +0 -171
- pertpy/plot/_coda.py +0 -601
- pertpy/plot/_guide_rna.py +0 -64
- pertpy/plot/_milopy.py +0 -209
- pertpy/plot/_mixscape.py +0 -355
- pertpy/tools/_differential_gene_expression.py +0 -325
- pertpy-0.7.0.dist-info/RECORD +0 -53
- {pertpy-0.7.0.dist-info → pertpy-0.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,657 @@
|
|
1
|
+
import os
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from itertools import chain
|
5
|
+
from types import MappingProxyType
|
6
|
+
|
7
|
+
import adjustText
|
8
|
+
import anndata as ad
|
9
|
+
import matplotlib.patheffects as PathEffects
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import seaborn as sns
|
14
|
+
from matplotlib.ticker import MaxNLocator
|
15
|
+
|
16
|
+
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
|
17
|
+
from pertpy.tools._differential_gene_expression._formulaic import (
|
18
|
+
AmbiguousAttributeError,
|
19
|
+
Factor,
|
20
|
+
get_factor_storage_and_materializer,
|
21
|
+
resolve_ambiguous,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class Contrast:
|
27
|
+
"""Simple contrast for comparison between groups"""
|
28
|
+
|
29
|
+
column: str
|
30
|
+
baseline: str
|
31
|
+
group_to_compare: str
|
32
|
+
|
33
|
+
|
34
|
+
ContrastType = Contrast | tuple[str, str, str]
|
35
|
+
|
36
|
+
|
37
|
+
class MethodBase(ABC):
|
38
|
+
def __init__(self, adata, *, mask=None, layer=None, **kwargs):
|
39
|
+
"""
|
40
|
+
Initialize the method.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
adata: AnnData object, usually pseudobulked.
|
44
|
+
mask: A column in `adata.var` that contains a boolean mask with selected features.
|
45
|
+
layer: Layer to use in fit(). If None, use the X array.
|
46
|
+
**kwargs: Keyword arguments specific to the method implementation.
|
47
|
+
"""
|
48
|
+
self.adata = adata
|
49
|
+
if mask is not None:
|
50
|
+
self.adata = self.adata[:, self.adata.var[mask]]
|
51
|
+
|
52
|
+
self.layer = layer
|
53
|
+
check_is_numeric_matrix(self.data)
|
54
|
+
|
55
|
+
@property
|
56
|
+
def data(self):
|
57
|
+
"""Get the data matrix from anndata this object was initalized with (X or layer)."""
|
58
|
+
if self.layer is None:
|
59
|
+
return self.adata.X
|
60
|
+
else:
|
61
|
+
return self.adata.layer[self.layer]
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
@abstractmethod
|
65
|
+
def compare_groups(
|
66
|
+
cls,
|
67
|
+
adata,
|
68
|
+
column,
|
69
|
+
baseline,
|
70
|
+
groups_to_compare,
|
71
|
+
*,
|
72
|
+
paired_by=None,
|
73
|
+
mask=None,
|
74
|
+
layer=None,
|
75
|
+
fit_kwargs=MappingProxyType({}),
|
76
|
+
test_kwargs=MappingProxyType({}),
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Compare between groups in a specified column.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
adata: AnnData object.
|
83
|
+
column: column in obs that contains the grouping information.
|
84
|
+
baseline: baseline value (one category from variable).
|
85
|
+
groups_to_compare: One or multiple categories from variable to compare against baseline.
|
86
|
+
paired_by: Column from `obs` that contains information about paired sample (e.g. subject_id).
|
87
|
+
mask: Subset anndata by a boolean mask stored in this column in `.obs` before making any tests.
|
88
|
+
layer: Use this layer instead of `.X`.
|
89
|
+
fit_kwargs: Additional fit options.
|
90
|
+
test_kwargs: Additional test options.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Pandas dataframe with results ordered by significance. If multiple comparisons were performed this is indicated in an additional column.
|
94
|
+
"""
|
95
|
+
...
|
96
|
+
|
97
|
+
def plot_volcano(
|
98
|
+
self,
|
99
|
+
data: pd.DataFrame | ad.AnnData,
|
100
|
+
*,
|
101
|
+
log2fc_col: str = "log_fc",
|
102
|
+
pvalue_col: str = "adj_p_value",
|
103
|
+
symbol_col: str = "variable",
|
104
|
+
pval_thresh: float = 0.05,
|
105
|
+
log2fc_thresh: float = 0.75,
|
106
|
+
to_label: int | list[str] = 5,
|
107
|
+
s_curve: bool | None = False,
|
108
|
+
colors: list[str] = None,
|
109
|
+
varm_key: str | None = None,
|
110
|
+
color_dict: dict[str, list[str]] | None = None,
|
111
|
+
shape_dict: dict[str, list[str]] | None = None,
|
112
|
+
size_col: str | None = None,
|
113
|
+
fontsize: int = 10,
|
114
|
+
top_right_frame: bool = False,
|
115
|
+
figsize: tuple[int, int] = (5, 5),
|
116
|
+
legend_pos: tuple[float, float] = (1.6, 1),
|
117
|
+
point_sizes: tuple[int, int] = (15, 150),
|
118
|
+
save: bool | str | None = None,
|
119
|
+
shapes: list[str] | None = None,
|
120
|
+
shape_order: list[str] | None = None,
|
121
|
+
x_label: str | None = None,
|
122
|
+
y_label: str | None = None,
|
123
|
+
**kwargs: int,
|
124
|
+
) -> None:
|
125
|
+
"""Creates a volcano plot from a pandas DataFrame or Anndata.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
data: DataFrame or Anndata to plot.
|
129
|
+
log2fc_col: Column name of log2 Fold-Change values.
|
130
|
+
pvalue_col: Column name of the p values.
|
131
|
+
symbol_col: Column name of gene IDs.
|
132
|
+
varm_key: Key in Anndata.varm slot to use for plotting if an Anndata object was passed.
|
133
|
+
size_col: Column name to size points by.
|
134
|
+
point_sizes: Lower and upper bounds of point sizes.
|
135
|
+
pval_thresh: Threshold p value for significance.
|
136
|
+
log2fc_thresh: Threshold for log2 fold change significance.
|
137
|
+
to_label: Number of top genes or list of genes to label.
|
138
|
+
s_curve: Whether to use a reciprocal threshold for up and down gene determination.
|
139
|
+
color_dict: Dictionary for coloring dots by categories.
|
140
|
+
shape_dict: Dictionary for shaping dots by categories.
|
141
|
+
fontsize: Size of gene labels.
|
142
|
+
colors: Colors for [non-DE, up, down] genes. Defaults to ['gray', '#D62728', '#1F77B4'].
|
143
|
+
top_right_frame: Whether to show the top and right frame of the plot.
|
144
|
+
figsize: Size of the figure.
|
145
|
+
legend_pos: Position of the legend as determined by matplotlib.
|
146
|
+
save: Saves the plot if True or to the path provided.
|
147
|
+
shapes: List of matplotlib marker ids.
|
148
|
+
shape_order: Order of categories for shapes.
|
149
|
+
x_label: Label for the x-axis.
|
150
|
+
y_label: Label for the y-axis.
|
151
|
+
**kwargs: Additional arguments for seaborn.scatterplot.
|
152
|
+
"""
|
153
|
+
if colors is None:
|
154
|
+
colors = ["gray", "#D62728", "#1F77B4"]
|
155
|
+
|
156
|
+
def _pval_reciprocal(lfc: float) -> float:
|
157
|
+
"""
|
158
|
+
Function for relating -log10(pvalue) and logfoldchange in a reciprocal.
|
159
|
+
|
160
|
+
Used for plotting the S-curve
|
161
|
+
"""
|
162
|
+
return pval_thresh / (lfc - log2fc_thresh)
|
163
|
+
|
164
|
+
def _map_shape(symbol: str) -> str:
|
165
|
+
if shape_dict is not None:
|
166
|
+
for k in shape_dict.keys():
|
167
|
+
if shape_dict[k] is not None and symbol in shape_dict[k]:
|
168
|
+
return k
|
169
|
+
return "other"
|
170
|
+
|
171
|
+
# TODO join the two mapping functions
|
172
|
+
def _map_genes_categories(
|
173
|
+
row: pd.Series,
|
174
|
+
log2fc_col: str,
|
175
|
+
nlog10_col: str,
|
176
|
+
log2fc_thresh: float,
|
177
|
+
pval_thresh: float = None,
|
178
|
+
s_curve: bool = False,
|
179
|
+
) -> str:
|
180
|
+
"""
|
181
|
+
Map genes to categorize based on log2fc and pvalue.
|
182
|
+
|
183
|
+
These categories are used for coloring the dots.
|
184
|
+
Used when no color_dict is passed, sets up/down/nonsignificant.
|
185
|
+
"""
|
186
|
+
log2fc = row[log2fc_col]
|
187
|
+
nlog10 = row[nlog10_col]
|
188
|
+
|
189
|
+
if s_curve:
|
190
|
+
# S-curve condition for Up or Down categorization
|
191
|
+
reciprocal_thresh = _pval_reciprocal(abs(log2fc))
|
192
|
+
if log2fc > log2fc_thresh and nlog10 > reciprocal_thresh:
|
193
|
+
return "Up"
|
194
|
+
elif log2fc < -log2fc_thresh and nlog10 > reciprocal_thresh:
|
195
|
+
return "Down"
|
196
|
+
else:
|
197
|
+
return "not DE"
|
198
|
+
else:
|
199
|
+
# Standard condition for Up or Down categorization
|
200
|
+
if log2fc > log2fc_thresh and nlog10 > pval_thresh:
|
201
|
+
return "Up"
|
202
|
+
elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
|
203
|
+
return "Down"
|
204
|
+
else:
|
205
|
+
return "not DE"
|
206
|
+
|
207
|
+
def _map_genes_categories_highlight(
|
208
|
+
row: pd.Series,
|
209
|
+
log2fc_col: str,
|
210
|
+
nlog10_col: str,
|
211
|
+
log2fc_thresh: float,
|
212
|
+
pval_thresh: float = None,
|
213
|
+
s_curve: bool = False,
|
214
|
+
symbol_col: str = None,
|
215
|
+
) -> str:
|
216
|
+
"""
|
217
|
+
Map genes to categorize based on log2fc and pvalue.
|
218
|
+
|
219
|
+
These categories are used for coloring the dots.
|
220
|
+
Used when color_dict is passed, sets DE / not DE for background and user supplied highlight genes.
|
221
|
+
"""
|
222
|
+
log2fc = row[log2fc_col]
|
223
|
+
nlog10 = row[nlog10_col]
|
224
|
+
symbol = row[symbol_col]
|
225
|
+
|
226
|
+
if color_dict is not None:
|
227
|
+
for k in color_dict.keys():
|
228
|
+
if symbol in color_dict[k]:
|
229
|
+
return k
|
230
|
+
|
231
|
+
if s_curve:
|
232
|
+
# Use S-curve condition for filtering DE
|
233
|
+
if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_thresh:
|
234
|
+
return "DE"
|
235
|
+
return "not DE"
|
236
|
+
else:
|
237
|
+
# Use standard condition for filtering DE
|
238
|
+
if abs(log2fc) < log2fc_thresh or nlog10 < pval_thresh:
|
239
|
+
return "not DE"
|
240
|
+
return "DE"
|
241
|
+
|
242
|
+
if isinstance(data, ad.AnnData):
|
243
|
+
if varm_key is None:
|
244
|
+
raise ValueError("Please pass a .varm key to use for plotting")
|
245
|
+
|
246
|
+
raise NotImplementedError("Anndata not implemented yet")
|
247
|
+
df = data.varm[varm_key].copy()
|
248
|
+
|
249
|
+
df = data.copy(deep=True)
|
250
|
+
|
251
|
+
# clean and replace 0s as they would lead to -inf
|
252
|
+
if df[[log2fc_col, pvalue_col]].isnull().values.any():
|
253
|
+
print("NaNs encountered, dropping rows with NaNs")
|
254
|
+
df = df.dropna(subset=[log2fc_col, pvalue_col])
|
255
|
+
|
256
|
+
if df[pvalue_col].min() == 0:
|
257
|
+
print("0s encountered for p value, replacing with 1e-323")
|
258
|
+
df.loc[df[pvalue_col] == 0, pvalue_col] = 1e-323
|
259
|
+
|
260
|
+
# convert p value threshold to nlog10
|
261
|
+
pval_thresh = -np.log10(pval_thresh)
|
262
|
+
# make nlog10 column
|
263
|
+
df["nlog10"] = -np.log10(df[pvalue_col])
|
264
|
+
y_max = df["nlog10"].max() + 1
|
265
|
+
# make a column to pick top genes
|
266
|
+
df["top_genes"] = df["nlog10"] * df[log2fc_col]
|
267
|
+
|
268
|
+
# Label everything with assigned color / shape
|
269
|
+
if shape_dict or color_dict:
|
270
|
+
combined_labels = []
|
271
|
+
if isinstance(shape_dict, dict):
|
272
|
+
combined_labels.extend([item for sublist in shape_dict.values() for item in sublist])
|
273
|
+
if isinstance(color_dict, dict):
|
274
|
+
combined_labels.extend([item for sublist in color_dict.values() for item in sublist])
|
275
|
+
label_df = df[df[symbol_col].isin(combined_labels)]
|
276
|
+
|
277
|
+
# Label top n_gens
|
278
|
+
elif isinstance(to_label, int):
|
279
|
+
label_df = pd.concat(
|
280
|
+
(
|
281
|
+
df.sort_values("top_genes")[-to_label:],
|
282
|
+
df.sort_values("top_genes")[0:to_label],
|
283
|
+
)
|
284
|
+
)
|
285
|
+
|
286
|
+
# assume that a list of genes was passed to label
|
287
|
+
else:
|
288
|
+
label_df = df[df[symbol_col].isin(to_label)]
|
289
|
+
|
290
|
+
# By default mode colors by up/down if no dict is passed
|
291
|
+
|
292
|
+
if color_dict is None:
|
293
|
+
df["color"] = df.apply(
|
294
|
+
lambda row: _map_genes_categories(
|
295
|
+
row,
|
296
|
+
log2fc_col=log2fc_col,
|
297
|
+
nlog10_col="nlog10",
|
298
|
+
log2fc_thresh=log2fc_thresh,
|
299
|
+
pval_thresh=pval_thresh,
|
300
|
+
s_curve=s_curve,
|
301
|
+
),
|
302
|
+
axis=1,
|
303
|
+
)
|
304
|
+
|
305
|
+
# order of colors
|
306
|
+
hues = ["not DE", "Up", "Down"][: len(df.color.unique())]
|
307
|
+
|
308
|
+
else:
|
309
|
+
df["color"] = df.apply(
|
310
|
+
lambda row: _map_genes_categories_highlight(
|
311
|
+
row,
|
312
|
+
log2fc_col=log2fc_col,
|
313
|
+
nlog10_col="nlog10",
|
314
|
+
log2fc_thresh=log2fc_thresh,
|
315
|
+
pval_thresh=pval_thresh,
|
316
|
+
symbol_col=symbol_col,
|
317
|
+
s_curve=s_curve,
|
318
|
+
),
|
319
|
+
axis=1,
|
320
|
+
)
|
321
|
+
|
322
|
+
user_added_cats = [x for x in df.color.unique() if x not in ["DE", "not DE"]]
|
323
|
+
hues = ["DE", "not DE"] + user_added_cats
|
324
|
+
|
325
|
+
# order of colors
|
326
|
+
hues = hues[: len(df.color.unique())]
|
327
|
+
colors = [
|
328
|
+
"dimgrey",
|
329
|
+
"lightgrey",
|
330
|
+
"tab:blue",
|
331
|
+
"tab:orange",
|
332
|
+
"tab:green",
|
333
|
+
"tab:red",
|
334
|
+
"tab:purple",
|
335
|
+
"tab:brown",
|
336
|
+
"tab:pink",
|
337
|
+
"tab:olive",
|
338
|
+
"tab:cyan",
|
339
|
+
]
|
340
|
+
|
341
|
+
# coloring if dictionary passed, subtle background + highlight
|
342
|
+
# map shapes if dictionary exists
|
343
|
+
if shape_dict is not None:
|
344
|
+
df["shape"] = df[symbol_col].map(_map_shape)
|
345
|
+
user_added_cats = [x for x in df["shape"].unique() if x != "other"]
|
346
|
+
shape_order = ["other"] + user_added_cats
|
347
|
+
if shapes is None:
|
348
|
+
shapes = ["o", "^", "s", "X", "*", "d"]
|
349
|
+
shapes = shapes[: len(df["shape"].unique())]
|
350
|
+
shape_col = "shape"
|
351
|
+
else:
|
352
|
+
shape_col = None
|
353
|
+
|
354
|
+
# build palette
|
355
|
+
colors = colors[: len(df.color.unique())]
|
356
|
+
|
357
|
+
# We want plot highlighted genes on top + at bigger size, split dataframe
|
358
|
+
df_highlight = None
|
359
|
+
if shape_dict or color_dict:
|
360
|
+
label_genes = label_df[symbol_col].unique()
|
361
|
+
df_highlight = df[df[symbol_col].isin(label_genes)]
|
362
|
+
df = df[~df[symbol_col].isin(label_genes)]
|
363
|
+
|
364
|
+
plt.figure(figsize=figsize)
|
365
|
+
# Plot non-highlighted genes
|
366
|
+
ax = sns.scatterplot(
|
367
|
+
data=df,
|
368
|
+
x=log2fc_col,
|
369
|
+
y="nlog10",
|
370
|
+
hue="color",
|
371
|
+
hue_order=hues,
|
372
|
+
palette=colors,
|
373
|
+
size=size_col,
|
374
|
+
sizes=point_sizes,
|
375
|
+
style=shape_col,
|
376
|
+
style_order=shape_order,
|
377
|
+
markers=shapes,
|
378
|
+
**kwargs,
|
379
|
+
)
|
380
|
+
# Plot highlighted genes
|
381
|
+
if df_highlight is not None:
|
382
|
+
ax = sns.scatterplot(
|
383
|
+
data=df_highlight,
|
384
|
+
x=log2fc_col,
|
385
|
+
y="nlog10",
|
386
|
+
hue="color",
|
387
|
+
hue_order=hues,
|
388
|
+
palette=colors,
|
389
|
+
size=size_col,
|
390
|
+
sizes=point_sizes,
|
391
|
+
style=shape_col,
|
392
|
+
style_order=shape_order,
|
393
|
+
markers=shapes,
|
394
|
+
legend=False,
|
395
|
+
edgecolor="black",
|
396
|
+
linewidth=1,
|
397
|
+
**kwargs,
|
398
|
+
)
|
399
|
+
|
400
|
+
# plot vertical and horizontal lines
|
401
|
+
if s_curve:
|
402
|
+
x = np.arange((log2fc_thresh + 0.000001), y_max, 0.01)
|
403
|
+
y = _pval_reciprocal(x)
|
404
|
+
ax.plot(x, y, zorder=1, c="k", lw=2, ls="--")
|
405
|
+
ax.plot(-x, y, zorder=1, c="k", lw=2, ls="--")
|
406
|
+
|
407
|
+
else:
|
408
|
+
ax.axhline(pval_thresh, zorder=1, c="k", lw=2, ls="--")
|
409
|
+
ax.axvline(log2fc_thresh, zorder=1, c="k", lw=2, ls="--")
|
410
|
+
ax.axvline(log2fc_thresh * -1, zorder=1, c="k", lw=2, ls="--")
|
411
|
+
plt.ylim(0, y_max)
|
412
|
+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
413
|
+
|
414
|
+
# make labels
|
415
|
+
texts = []
|
416
|
+
for i in range(len(label_df)):
|
417
|
+
txt = plt.text(
|
418
|
+
x=label_df.iloc[i][log2fc_col],
|
419
|
+
y=label_df.iloc[i].nlog10,
|
420
|
+
s=label_df.iloc[i][symbol_col],
|
421
|
+
fontsize=fontsize,
|
422
|
+
)
|
423
|
+
|
424
|
+
txt.set_path_effects([PathEffects.withStroke(linewidth=3, foreground="w")])
|
425
|
+
texts.append(txt)
|
426
|
+
|
427
|
+
adjustText.adjust_text(texts, arrowprops={"arrowstyle": "-", "color": "k", "zorder": 5})
|
428
|
+
|
429
|
+
# make things pretty
|
430
|
+
for axis in ["bottom", "left", "top", "right"]:
|
431
|
+
ax.spines[axis].set_linewidth(2)
|
432
|
+
|
433
|
+
if not top_right_frame:
|
434
|
+
ax.spines["top"].set_visible(False)
|
435
|
+
ax.spines["right"].set_visible(False)
|
436
|
+
|
437
|
+
ax.tick_params(width=2)
|
438
|
+
plt.xticks(size=11, fontsize=10)
|
439
|
+
plt.yticks(size=11)
|
440
|
+
|
441
|
+
# Set default axis titles
|
442
|
+
if x_label is None:
|
443
|
+
x_label = log2fc_col
|
444
|
+
if y_label is None:
|
445
|
+
y_label = f"-$log_{{10}}$ {pvalue_col}"
|
446
|
+
|
447
|
+
plt.xlabel(x_label, size=15)
|
448
|
+
plt.ylabel(y_label, size=15)
|
449
|
+
|
450
|
+
plt.legend(loc=1, bbox_to_anchor=legend_pos, frameon=False)
|
451
|
+
|
452
|
+
# TODO replace with scanpy save style
|
453
|
+
if save:
|
454
|
+
files = os.listdir()
|
455
|
+
for x in range(100):
|
456
|
+
file_pref = "volcano_" + "%02d" % (x,)
|
457
|
+
if len([x for x in files if x.startswith(file_pref)]) == 0:
|
458
|
+
plt.savefig(file_pref + ".png", dpi=300, bbox_inches="tight")
|
459
|
+
plt.savefig(file_pref + ".svg", bbox_inches="tight")
|
460
|
+
break
|
461
|
+
elif isinstance(save, str):
|
462
|
+
plt.savefig(save + ".png", dpi=300, bbox_inches="tight")
|
463
|
+
plt.savefig(save + ".svg", bbox_inches="tight")
|
464
|
+
|
465
|
+
plt.show()
|
466
|
+
|
467
|
+
|
468
|
+
class LinearModelBase(MethodBase):
|
469
|
+
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
|
470
|
+
"""
|
471
|
+
Initialize the method.
|
472
|
+
|
473
|
+
Args:
|
474
|
+
adata: AnnData object, usually pseudobulked.
|
475
|
+
design: Model design. Can be either a design matrix, a formulaic formula.Formulaic formula in the format 'x + z' or '~x+z'.
|
476
|
+
mask: A column in adata.var that contains a boolean mask with selected features.
|
477
|
+
layer: Layer to use in fit(). If None, use the X array.
|
478
|
+
**kwargs: Keyword arguments specific to the method implementation.
|
479
|
+
"""
|
480
|
+
super().__init__(adata, mask=mask, layer=layer)
|
481
|
+
self._check_counts()
|
482
|
+
|
483
|
+
self.factor_storage = None
|
484
|
+
self.variable_to_factors = None
|
485
|
+
|
486
|
+
if isinstance(design, str):
|
487
|
+
self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
|
488
|
+
self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
|
489
|
+
else:
|
490
|
+
self.design = design
|
491
|
+
|
492
|
+
@classmethod
|
493
|
+
def compare_groups(
|
494
|
+
cls,
|
495
|
+
adata,
|
496
|
+
column,
|
497
|
+
baseline,
|
498
|
+
groups_to_compare,
|
499
|
+
*,
|
500
|
+
paired_by=None,
|
501
|
+
mask=None,
|
502
|
+
layer=None,
|
503
|
+
fit_kwargs=MappingProxyType({}),
|
504
|
+
test_kwargs=MappingProxyType({}),
|
505
|
+
):
|
506
|
+
design = f"~{column}"
|
507
|
+
if paired_by is not None:
|
508
|
+
design += f"+{paired_by}"
|
509
|
+
if isinstance(groups_to_compare, str):
|
510
|
+
groups_to_compare = [groups_to_compare]
|
511
|
+
model = cls(adata, design=design, mask=mask, layer=layer)
|
512
|
+
|
513
|
+
model.fit(**fit_kwargs)
|
514
|
+
|
515
|
+
de_res = model.test_contrasts(
|
516
|
+
{
|
517
|
+
group_to_compare: model.contrast(column=column, baseline=baseline, group_to_compare=group_to_compare)
|
518
|
+
for group_to_compare in groups_to_compare
|
519
|
+
},
|
520
|
+
**test_kwargs,
|
521
|
+
)
|
522
|
+
|
523
|
+
return de_res
|
524
|
+
|
525
|
+
@property
|
526
|
+
def variables(self):
|
527
|
+
"""Get the names of the variables used in the model definition."""
|
528
|
+
try:
|
529
|
+
return self.design.model_spec.variables_by_source["data"]
|
530
|
+
except AttributeError:
|
531
|
+
raise ValueError(
|
532
|
+
"Retrieving variables is only possible if the model was initialized using a formula."
|
533
|
+
) from None
|
534
|
+
|
535
|
+
@abstractmethod
|
536
|
+
def _check_counts(self):
|
537
|
+
"""
|
538
|
+
Check that counts are valid for the specific method.
|
539
|
+
|
540
|
+
Raises:
|
541
|
+
ValueError: if the data matrix does not comply with the expectations.
|
542
|
+
"""
|
543
|
+
...
|
544
|
+
|
545
|
+
@abstractmethod
|
546
|
+
def fit(self, **kwargs):
|
547
|
+
"""
|
548
|
+
Fit the model.
|
549
|
+
|
550
|
+
Args:
|
551
|
+
**kwargs: Additional arguments for fitting the specific method.
|
552
|
+
"""
|
553
|
+
...
|
554
|
+
|
555
|
+
@abstractmethod
|
556
|
+
def _test_single_contrast(self, contrast, **kwargs): ...
|
557
|
+
|
558
|
+
def test_contrasts(self, contrasts, **kwargs):
|
559
|
+
"""
|
560
|
+
Perform a comparison as specified in a contrast vector.
|
561
|
+
|
562
|
+
Args:
|
563
|
+
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
|
564
|
+
**kwargs: passed to the respective implementation.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
A dataframe with the results.
|
568
|
+
"""
|
569
|
+
if not isinstance(contrasts, dict):
|
570
|
+
contrasts = {None: contrasts}
|
571
|
+
results = []
|
572
|
+
for name, contrast in contrasts.items():
|
573
|
+
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))
|
574
|
+
|
575
|
+
results_df = pd.concat(results)
|
576
|
+
return results_df
|
577
|
+
|
578
|
+
def test_reduced(self, modelB):
|
579
|
+
"""
|
580
|
+
Test against a reduced model.
|
581
|
+
|
582
|
+
Args:
|
583
|
+
modelB: the reduced model against which to test.
|
584
|
+
|
585
|
+
Example:
|
586
|
+
modelA = Model().fit()
|
587
|
+
modelB = Model().fit()
|
588
|
+
modelA.test_reduced(modelB)
|
589
|
+
"""
|
590
|
+
raise NotImplementedError
|
591
|
+
|
592
|
+
def cond(self, **kwargs):
|
593
|
+
"""
|
594
|
+
Get a contrast vector representing a specific condition.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
**kwargs: column/value pairs.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
A contrast vector that aligns to the columns of the design matrix.
|
601
|
+
"""
|
602
|
+
if self.factor_storage is None:
|
603
|
+
raise RuntimeError(
|
604
|
+
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
|
605
|
+
)
|
606
|
+
cond_dict = kwargs
|
607
|
+
if not set(cond_dict.keys()).issubset(self.variables):
|
608
|
+
raise ValueError(
|
609
|
+
"You specified a variable that is not part of the model. Available variables: "
|
610
|
+
+ ",".join(self.variables)
|
611
|
+
)
|
612
|
+
for var in self.variables:
|
613
|
+
if var in cond_dict:
|
614
|
+
self._check_category(var, cond_dict[var])
|
615
|
+
else:
|
616
|
+
cond_dict[var] = self._get_default_value(var)
|
617
|
+
df = pd.DataFrame([kwargs])
|
618
|
+
return self.design.model_spec.get_model_matrix(df).iloc[0]
|
619
|
+
|
620
|
+
def _get_factor_metadata_for_variable(self, var):
|
621
|
+
factors = self.variable_to_factors[var]
|
622
|
+
return list(chain.from_iterable(self.factor_storage[f] for f in factors))
|
623
|
+
|
624
|
+
def _get_default_value(self, var):
|
625
|
+
factor_metadata = self._get_factor_metadata_for_variable(var)
|
626
|
+
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
|
627
|
+
try:
|
628
|
+
tmp_base = resolve_ambiguous(factor_metadata, "base")
|
629
|
+
except AmbiguousAttributeError as e:
|
630
|
+
raise ValueError(
|
631
|
+
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
|
632
|
+
) from e
|
633
|
+
return tmp_base if tmp_base is not None else "\0"
|
634
|
+
else:
|
635
|
+
return 0
|
636
|
+
|
637
|
+
def _check_category(self, var, value):
|
638
|
+
factor_metadata = self._get_factor_metadata_for_variable(var)
|
639
|
+
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
|
640
|
+
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
|
641
|
+
raise ValueError(
|
642
|
+
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
|
643
|
+
)
|
644
|
+
|
645
|
+
def contrast(self, column, baseline, group_to_compare):
|
646
|
+
"""
|
647
|
+
Build a simple contrast for pairwise comparisons.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
column: column in adata.obs to test on.
|
651
|
+
baseline: baseline category (denominator).
|
652
|
+
group_to_compare: category to compare against baseline (nominator).
|
653
|
+
|
654
|
+
Returns:
|
655
|
+
Numeric contrast vector.
|
656
|
+
"""
|
657
|
+
return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from scipy.sparse import issparse, spmatrix
|
3
|
+
|
4
|
+
|
5
|
+
def check_is_numeric_matrix(array: np.ndarray | spmatrix) -> None:
|
6
|
+
"""Check if a matrix is numeric and only contains finite/non-NA values.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
array: Dense or sparse matrix to check.
|
10
|
+
|
11
|
+
Raises:
|
12
|
+
ValueError: If the matrix is not numeric or contains NaNs or infinite values.
|
13
|
+
"""
|
14
|
+
if not np.issubdtype(array.dtype, np.number):
|
15
|
+
raise ValueError("Counts must be numeric.")
|
16
|
+
if issparse(array):
|
17
|
+
if np.any(~np.isfinite(array.data)):
|
18
|
+
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
19
|
+
else:
|
20
|
+
if np.any(~np.isfinite(array)):
|
21
|
+
raise ValueError("Counts cannot contain negative, NaN or Inf values.")
|
22
|
+
|
23
|
+
|
24
|
+
def check_is_integer_matrix(array: np.ndarray | spmatrix, tolerance: float = 1e-6) -> None:
|
25
|
+
"""Check if a matrix container integers, or floats that are close to integers.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
array: Dense or sparse matrix to check.
|
29
|
+
tolerance: Values must be this close to integers.
|
30
|
+
|
31
|
+
Raises:
|
32
|
+
ValueError: If the matrix contains values that are not close to integers.
|
33
|
+
"""
|
34
|
+
if issparse(array):
|
35
|
+
if not array.data.dtype.kind == "i" and not np.all(np.abs(array.data - np.round(array.data)) < tolerance):
|
36
|
+
raise ValueError("Non-zero elements of the matrix must be close to integer values.")
|
37
|
+
else:
|
38
|
+
if not array.dtype.kind == "i" and not np.all(np.abs(array - np.round(array)) < tolerance):
|
39
|
+
raise ValueError("Matrix must be a count matrix.")
|
40
|
+
if (array < 0).sum() > 0:
|
41
|
+
raise ValueError("Non-zero elements of the matrix must be positive.")
|