edgepython 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,409 @@
1
+ # This code was written by Claude (Anthropic). The project was directed by Lior Pachter.
2
+ """
3
+ Visualization functions for edgePython.
4
+
5
+ Port of edgeR's plotMD, plotBCV, plotMDS, plotSmear, plotQLDisp, maPlot, gof.
6
+ """
7
+
8
+ import numpy as np
9
+ import warnings
10
+
11
+
12
+ def plot_md(obj, column=None, coef=None, xlab='Average log CPM',
13
+ ylab='log-fold-change', main=None, status=None,
14
+ values=None, col=None, hl_col=None, **kwargs):
15
+ """Mean-difference plot (MD plot / MA plot).
16
+
17
+ Port of edgeR's plotMD.
18
+
19
+ Parameters
20
+ ----------
21
+ obj : dict (DGEGLM, DGELRT, DGEExact, or DGEList-like)
22
+ Object containing logCPM and logFC.
23
+ column : int, optional
24
+ Column of coefficients to plot.
25
+ coef : int, optional
26
+ Alias for column.
27
+ xlab, ylab, main : str
28
+ Plot labels.
29
+ status : ndarray, optional
30
+ Status indicators for coloring.
31
+ """
32
+ import matplotlib.pyplot as plt
33
+
34
+ if coef is not None:
35
+ column = coef
36
+
37
+ # Extract x (logCPM) and y (logFC) values
38
+ if isinstance(obj, dict) and 'table' in obj:
39
+ tab = obj['table']
40
+ x = tab['logCPM'].values if 'logCPM' in tab.columns else tab.get('logCPM', np.zeros(len(tab)))
41
+ y = tab['logFC'].values if 'logFC' in tab.columns else tab.iloc[:, 0].values
42
+ elif isinstance(obj, dict) and 'coefficients' in obj:
43
+ from .expression import ave_log_cpm
44
+ x = obj.get('AveLogCPM')
45
+ if x is None:
46
+ x = ave_log_cpm(obj)
47
+ if column is None:
48
+ column = 0
49
+ y = obj['coefficients'][:, column] / np.log(2)
50
+ elif isinstance(obj, dict) and 'counts' in obj:
51
+ from .expression import ave_log_cpm, cpm
52
+ x = ave_log_cpm(obj)
53
+ cpm_vals = cpm(obj, log=True)
54
+ if column is None:
55
+ column = 0
56
+ y = cpm_vals[:, column] - np.mean(cpm_vals, axis=1)
57
+ else:
58
+ raise ValueError("Unsupported object type for plotMD")
59
+
60
+ fig, ax = plt.subplots(figsize=(8, 6))
61
+
62
+ if status is not None:
63
+ status = np.asarray(status)
64
+ unique_status = np.unique(status)
65
+ colors = ['grey', 'red', 'blue']
66
+ for i, s in enumerate(unique_status):
67
+ mask = status == s
68
+ c = colors[i % len(colors)] if col is None else (col[i] if isinstance(col, list) else col)
69
+ ax.scatter(x[mask], y[mask], s=2, alpha=0.5, c=c, label=str(s))
70
+ ax.legend()
71
+ else:
72
+ ax.scatter(x, y, s=2, alpha=0.5, c='black')
73
+
74
+ ax.axhline(y=0, color='red', linestyle='--', linewidth=0.5)
75
+ ax.set_xlabel(xlab)
76
+ ax.set_ylabel(ylab)
77
+ if main:
78
+ ax.set_title(main)
79
+
80
+ plt.tight_layout()
81
+ return fig, ax
82
+
83
+
84
+ def plot_bcv(y, xlab='Average log CPM', ylab='Biological coefficient of variation',
85
+ pch=16, cex=0.2, col_common='red', col_trend='blue',
86
+ col_tagwise='black', **kwargs):
87
+ """Plot biological coefficient of variation.
88
+
89
+ Port of edgeR's plotBCV.
90
+
91
+ Parameters
92
+ ----------
93
+ y : DGEList-like dict
94
+ Must have dispersion estimates.
95
+ """
96
+ import matplotlib.pyplot as plt
97
+ from .expression import ave_log_cpm
98
+
99
+ alc = y.get('AveLogCPM')
100
+ if alc is None:
101
+ alc = ave_log_cpm(y)
102
+
103
+ fig, ax = plt.subplots(figsize=(8, 6))
104
+
105
+ # Tagwise dispersions
106
+ if y.get('tagwise.dispersion') is not None:
107
+ bcv_tagwise = np.sqrt(y['tagwise.dispersion'])
108
+ ax.scatter(alc, bcv_tagwise, s=cex * 10, alpha=0.3, c=col_tagwise, label='Tagwise')
109
+
110
+ # Trended dispersion
111
+ if y.get('trended.dispersion') is not None:
112
+ bcv_trend = np.sqrt(y['trended.dispersion'])
113
+ o = np.argsort(alc)
114
+ ax.plot(alc[o], bcv_trend[o], c=col_trend, linewidth=2, label='Trend')
115
+
116
+ # Common dispersion
117
+ if y.get('common.dispersion') is not None:
118
+ bcv_common = np.sqrt(y['common.dispersion'])
119
+ ax.axhline(y=bcv_common, color=col_common, linewidth=2, label='Common')
120
+
121
+ ax.set_xlabel(xlab)
122
+ ax.set_ylabel(ylab)
123
+ ax.legend()
124
+
125
+ plt.tight_layout()
126
+ return fig, ax
127
+
128
+
129
+ def plot_mds(y, top=500, labels=None, pch=None, cex=1, dim_plot=(1, 2),
130
+ gene_selection='pairwise', xlab=None, ylab=None, main=None,
131
+ **kwargs):
132
+ """Multi-dimensional scaling plot.
133
+
134
+ Port of edgeR's plotMDS.
135
+
136
+ Parameters
137
+ ----------
138
+ y : DGEList-like dict or ndarray
139
+ Count data.
140
+ top : int
141
+ Number of top genes to use.
142
+ labels : list, optional
143
+ Sample labels.
144
+ dim_plot : tuple
145
+ Dimensions to plot.
146
+ gene_selection : str
147
+ 'pairwise' or 'common'.
148
+ """
149
+ import matplotlib.pyplot as plt
150
+ from .expression import cpm
151
+
152
+ if isinstance(y, dict) and 'counts' in y:
153
+ counts = y['counts']
154
+ if labels is None:
155
+ labels = y['samples'].index.tolist() if hasattr(y['samples'], 'index') else None
156
+ else:
157
+ counts = np.asarray(y, dtype=np.float64)
158
+
159
+ # Log-CPM values
160
+ lib_size = counts.sum(axis=0)
161
+ log_cpm = np.log2(counts / lib_size[None, :] * 1e6 + 0.5)
162
+
163
+ nsamples = counts.shape[1]
164
+ if labels is None:
165
+ labels = [f'S{i+1}' for i in range(nsamples)]
166
+
167
+ # Select top variable genes
168
+ var = np.var(log_cpm, axis=1)
169
+ top_idx = np.argsort(var)[::-1][:min(top, len(var))]
170
+ log_cpm_top = log_cpm[top_idx]
171
+
172
+ # Pairwise distances
173
+ dist_mat = np.zeros((nsamples, nsamples))
174
+ for i in range(nsamples):
175
+ for j in range(i + 1, nsamples):
176
+ d = np.sqrt(np.mean((log_cpm_top[:, i] - log_cpm_top[:, j]) ** 2))
177
+ dist_mat[i, j] = d
178
+ dist_mat[j, i] = d
179
+
180
+ # Classical MDS
181
+ H = np.eye(nsamples) - np.ones((nsamples, nsamples)) / nsamples
182
+ B = -0.5 * H @ (dist_mat ** 2) @ H
183
+ eigvals, eigvecs = np.linalg.eigh(B)
184
+ idx = np.argsort(eigvals)[::-1]
185
+ eigvals = eigvals[idx]
186
+ eigvecs = eigvecs[:, idx]
187
+
188
+ # Get coordinates for desired dimensions
189
+ d1 = dim_plot[0] - 1
190
+ d2 = dim_plot[1] - 1
191
+ x_coord = eigvecs[:, d1] * np.sqrt(max(eigvals[d1], 0))
192
+ y_coord = eigvecs[:, d2] * np.sqrt(max(eigvals[d2], 0))
193
+
194
+ # Variance explained
195
+ total_var = np.sum(np.maximum(eigvals, 0))
196
+ var_exp1 = max(eigvals[d1], 0) / total_var * 100 if total_var > 0 else 0
197
+ var_exp2 = max(eigvals[d2], 0) / total_var * 100 if total_var > 0 else 0
198
+
199
+ fig, ax = plt.subplots(figsize=(8, 6))
200
+ ax.scatter(x_coord, y_coord, s=50)
201
+
202
+ for i, label in enumerate(labels):
203
+ ax.annotate(label, (x_coord[i], y_coord[i]),
204
+ textcoords="offset points", xytext=(5, 5), fontsize=8)
205
+
206
+ if xlab is None:
207
+ xlab = f'Dimension {dim_plot[0]} ({var_exp1:.1f}%)'
208
+ if ylab is None:
209
+ ylab = f'Dimension {dim_plot[1]} ({var_exp2:.1f}%)'
210
+
211
+ ax.set_xlabel(xlab)
212
+ ax.set_ylabel(ylab)
213
+ if main:
214
+ ax.set_title(main)
215
+
216
+ plt.tight_layout()
217
+ return fig, ax
218
+
219
+
220
+ def plot_smear(obj, pair=None, de_tags=None, xlab='Average logCPM',
221
+ ylab='logFC', main='MA Plot', smooth_scatter=False,
222
+ lowess=False, **kwargs):
223
+ """Smear plot (MA plot for DGE).
224
+
225
+ Port of edgeR's plotSmear.
226
+
227
+ Parameters
228
+ ----------
229
+ obj : DGEList or DGEExact-like dict
230
+ Object to plot.
231
+ pair : list, optional
232
+ Groups to compare.
233
+ de_tags : list or ndarray, optional
234
+ Indices of DE genes to highlight.
235
+ """
236
+ import matplotlib.pyplot as plt
237
+
238
+ if isinstance(obj, dict) and 'table' in obj:
239
+ tab = obj['table']
240
+ x = tab['logCPM'].values
241
+ y_vals = tab['logFC'].values
242
+ else:
243
+ raise ValueError("Object must have a 'table' attribute")
244
+
245
+ fig, ax = plt.subplots(figsize=(8, 6))
246
+ ax.scatter(x, y_vals, s=2, alpha=0.3, c='black')
247
+
248
+ if de_tags is not None:
249
+ de_tags = np.asarray(de_tags)
250
+ ax.scatter(x[de_tags], y_vals[de_tags], s=4, c='red', alpha=0.5)
251
+
252
+ ax.axhline(y=0, color='blue', linestyle='--', linewidth=0.5)
253
+ ax.set_xlabel(xlab)
254
+ ax.set_ylabel(ylab)
255
+ ax.set_title(main)
256
+
257
+ plt.tight_layout()
258
+ return fig, ax
259
+
260
+
261
+ def plot_ql_disp(glmfit, xlab='Average Log2 CPM',
262
+ ylab='Quarter-Root Mean Deviance',
263
+ pch=16, cex=0.2, col_shrunk='red', col_trend='blue',
264
+ col_raw='black', **kwargs):
265
+ """Plot quasi-likelihood dispersions.
266
+
267
+ Port of edgeR's plotQLDisp.
268
+
269
+ Parameters
270
+ ----------
271
+ glmfit : dict (DGEGLM-like)
272
+ Fitted QL GLM from glm_ql_fit().
273
+ """
274
+ import matplotlib.pyplot as plt
275
+ from .expression import ave_log_cpm
276
+
277
+ if glmfit.get('s2.post') is None:
278
+ raise ValueError("need to run glm_ql_fit before plot_ql_disp")
279
+
280
+ A = glmfit.get('AveLogCPM')
281
+ if A is None:
282
+ A = ave_log_cpm(glmfit)
283
+
284
+ if glmfit.get('df.residual.zeros') is None:
285
+ df_residual = glmfit.get('df.residual.adj', glmfit['df.residual'])
286
+ deviance = glmfit.get('deviance.adj', glmfit['deviance'])
287
+ else:
288
+ df_residual = glmfit['df.residual.zeros']
289
+ deviance = glmfit['deviance']
290
+
291
+ df_residual = np.asarray(df_residual, dtype=np.float64)
292
+ s2 = deviance / np.maximum(df_residual, 1e-8)
293
+ s2[df_residual < 1e-8] = 0
294
+
295
+ fig, ax = plt.subplots(figsize=(8, 6))
296
+
297
+ # Raw
298
+ ax.scatter(A, s2 ** 0.25, s=cex * 10, alpha=0.3, c=col_raw, label='Raw')
299
+
300
+ # Squeezed
301
+ ax.scatter(A, np.asarray(glmfit['s2.post']) ** 0.25, s=cex * 10,
302
+ alpha=0.3, c=col_shrunk, label='Squeezed')
303
+
304
+ # Trend
305
+ s2_prior = np.atleast_1d(glmfit['s2.prior'])
306
+ if len(s2_prior) == 1:
307
+ ax.axhline(y=s2_prior[0] ** 0.25, color=col_trend, linewidth=2, label='Trend')
308
+ else:
309
+ o = np.argsort(A)
310
+ ax.plot(A[o], s2_prior[o] ** 0.25, c=col_trend, linewidth=2, label='Trend')
311
+
312
+ ax.set_xlabel(xlab)
313
+ ax.set_ylabel(ylab)
314
+ ax.legend()
315
+
316
+ plt.tight_layout()
317
+ return fig, ax
318
+
319
+
320
+ def ma_plot(x, y, logFC=None, de_tags=None, smooth_scatter=False,
321
+ xlab='A', ylab='M', main='MA Plot', **kwargs):
322
+ """Simple MA plot.
323
+
324
+ Parameters
325
+ ----------
326
+ x : ndarray
327
+ Average expression (A values).
328
+ y : ndarray
329
+ Log fold change (M values).
330
+ """
331
+ import matplotlib.pyplot as plt
332
+
333
+ fig, ax = plt.subplots(figsize=(8, 6))
334
+ ax.scatter(x, y, s=2, alpha=0.3, c='black')
335
+
336
+ if de_tags is not None:
337
+ de_tags = np.asarray(de_tags)
338
+ ax.scatter(x[de_tags], y[de_tags], s=4, c='red', alpha=0.5)
339
+
340
+ ax.axhline(y=0, color='blue', linestyle='--', linewidth=0.5)
341
+ ax.set_xlabel(xlab)
342
+ ax.set_ylabel(ylab)
343
+ ax.set_title(main)
344
+
345
+ plt.tight_layout()
346
+ return fig, ax
347
+
348
+
349
+ def gof(glmfit, pcutoff=0.1, adjust='holm', plot=True, main='Goodness of Fit',
350
+ **kwargs):
351
+ """Goodness of fit test for each gene.
352
+
353
+ Port of edgeR's gof.
354
+
355
+ Parameters
356
+ ----------
357
+ glmfit : dict (DGEGLM-like)
358
+ Fitted GLM.
359
+ pcutoff : float
360
+ P-value cutoff.
361
+ adjust : str
362
+ P-value adjustment method.
363
+ plot : bool
364
+ Whether to plot.
365
+
366
+ Returns
367
+ -------
368
+ dict with 'gof.statistics', 'gof.pvalues', 'outlier', 'df'.
369
+ """
370
+ from scipy.stats import chi2
371
+ from statsmodels.stats.multitest import multipletests
372
+
373
+ deviance = glmfit['deviance']
374
+ df = glmfit['df.residual']
375
+
376
+ df = np.asarray(df, dtype=np.float64)
377
+ gof_pvalues = chi2.sf(deviance, df)
378
+
379
+ # Adjust p-values
380
+ method_map = {'holm': 'holm', 'BH': 'fdr_bh', 'bonferroni': 'bonferroni'}
381
+ sm_method = method_map.get(adjust, adjust)
382
+ _, adj_p, _, _ = multipletests(gof_pvalues, method=sm_method)
383
+
384
+ outlier = adj_p < pcutoff
385
+
386
+ if plot:
387
+ import matplotlib.pyplot as plt
388
+ from scipy.stats import chi2 as chi2_dist
389
+
390
+ fig, ax = plt.subplots(figsize=(8, 6))
391
+ # QQ plot
392
+ n = len(deviance)
393
+ theoretical = chi2_dist.ppf(np.arange(1, n + 1) / (n + 1), df[0])
394
+ observed = np.sort(deviance)
395
+
396
+ ax.scatter(theoretical, observed, s=2, alpha=0.5)
397
+ max_val = max(np.max(theoretical), np.max(observed))
398
+ ax.plot([0, max_val], [0, max_val], 'r--', linewidth=0.5)
399
+ ax.set_xlabel('Theoretical quantiles')
400
+ ax.set_ylabel('Observed deviance')
401
+ ax.set_title(main)
402
+ plt.tight_layout()
403
+
404
+ return {
405
+ 'gof.statistics': deviance,
406
+ 'gof.pvalues': gof_pvalues,
407
+ 'outlier': outlier,
408
+ 'df': df
409
+ }