DeConveil 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- DeConveil/__init__.py +7 -0
- DeConveil/dds.py +1279 -0
- DeConveil/default_inference.py +284 -0
- DeConveil/ds.py +758 -0
- DeConveil/grid_search.py +195 -0
- DeConveil/inference.py +373 -0
- DeConveil/utils_CNaware.py +809 -0
- DeConveil-0.1.0.dist-info/LICENSE +21 -0
- DeConveil-0.1.0.dist-info/METADATA +35 -0
- DeConveil-0.1.0.dist-info/RECORD +12 -0
- DeConveil-0.1.0.dist-info/WHEEL +5 -0
- DeConveil-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,809 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import warnings
|
|
4
|
+
from math import ceil
|
|
5
|
+
from math import floor
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List
|
|
8
|
+
from typing import Literal
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
from typing import Union
|
|
12
|
+
from typing import cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from matplotlib import pyplot as plt
|
|
17
|
+
from scipy.linalg import solve # type: ignore
|
|
18
|
+
from scipy.optimize import minimize # type: ignore
|
|
19
|
+
from scipy.special import gammaln # type: ignore
|
|
20
|
+
from scipy.special import polygamma # type: ignore
|
|
21
|
+
from scipy.stats import norm # type: ignore
|
|
22
|
+
from sklearn.linear_model import LinearRegression # type: ignore
|
|
23
|
+
import matplotlib.pyplot as plt
|
|
24
|
+
import seaborn as sns
|
|
25
|
+
|
|
26
|
+
from deconveil.grid_search import grid_fit_beta
|
|
27
|
+
|
|
28
|
+
from pydeseq2.utils import fit_alpha_mle
|
|
29
|
+
from pydeseq2.utils import get_num_processes
|
|
30
|
+
from pydeseq2.grid_search import grid_fit_alpha
|
|
31
|
+
from pydeseq2.grid_search import grid_fit_shrink_beta
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def irls_glm(
|
|
35
|
+
counts: np.ndarray,
|
|
36
|
+
cnv: np.ndarray,
|
|
37
|
+
size_factors: np.ndarray,
|
|
38
|
+
design_matrix: np.ndarray,
|
|
39
|
+
disp: float,
|
|
40
|
+
min_mu: float = 0.5,
|
|
41
|
+
beta_tol: float = 1e-8,
|
|
42
|
+
min_beta: float = -30,
|
|
43
|
+
max_beta: float = 30,
|
|
44
|
+
optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B",
|
|
45
|
+
maxiter: int = 250,
|
|
46
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:
|
|
47
|
+
|
|
48
|
+
assert optimizer in ["BFGS", "L-BFGS-B"]
|
|
49
|
+
|
|
50
|
+
num_vars = design_matrix.shape[1]
|
|
51
|
+
X = design_matrix
|
|
52
|
+
|
|
53
|
+
# if full rank, estimate initial betas for IRLS below
|
|
54
|
+
if np.linalg.matrix_rank(X) == num_vars:
|
|
55
|
+
Q, R = np.linalg.qr(X)
|
|
56
|
+
y = np.log((counts / cnv) / size_factors + 0.1)
|
|
57
|
+
beta_init = solve(R, Q.T @ y)
|
|
58
|
+
beta = beta_init
|
|
59
|
+
|
|
60
|
+
else: # Initialise intercept with log base mean
|
|
61
|
+
beta_init = np.zeros(num_vars)
|
|
62
|
+
beta_init[0] = np.log((counts / cnv) / size_factors).mean()
|
|
63
|
+
beta = beta_init
|
|
64
|
+
|
|
65
|
+
dev = 1000.0
|
|
66
|
+
dev_ratio = 1.0
|
|
67
|
+
|
|
68
|
+
ridge_factor = np.diag(np.repeat(1e-6, num_vars))
|
|
69
|
+
mu = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
70
|
+
|
|
71
|
+
converged = True
|
|
72
|
+
i = 0
|
|
73
|
+
while dev_ratio > beta_tol:
|
|
74
|
+
W = mu / (1.0 + mu * disp)
|
|
75
|
+
z = np.log((mu / cnv) / size_factors) + (counts - mu) / mu
|
|
76
|
+
H = (X.T * W) @ X + ridge_factor
|
|
77
|
+
beta_hat = solve(H, X.T @ (W * z), assume_a="pos")
|
|
78
|
+
i += 1
|
|
79
|
+
|
|
80
|
+
if sum(np.abs(beta_hat) > max_beta) > 0 or i >= maxiter:
|
|
81
|
+
# If IRLS starts diverging, use L-BFGS-B
|
|
82
|
+
def f(beta: np.ndarray) -> float:
|
|
83
|
+
# closure to minimize
|
|
84
|
+
mu_ = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
85
|
+
|
|
86
|
+
return nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum()
|
|
87
|
+
|
|
88
|
+
def df(beta: np.ndarray) -> np.ndarray:
|
|
89
|
+
mu_ = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
90
|
+
return (
|
|
91
|
+
-X.T @ counts
|
|
92
|
+
+ ((1 / disp + counts) * mu_ / (1 / disp + mu_)) @ X
|
|
93
|
+
+ ridge_factor @ beta
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
res = minimize(
|
|
97
|
+
f,
|
|
98
|
+
beta_init,
|
|
99
|
+
jac=df,
|
|
100
|
+
method=optimizer,
|
|
101
|
+
bounds=(
|
|
102
|
+
[(min_beta, max_beta)] * num_vars
|
|
103
|
+
if optimizer == "L-BFGS-B"
|
|
104
|
+
else None
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
beta = res.x
|
|
109
|
+
mu = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
110
|
+
converged = res.success
|
|
111
|
+
|
|
112
|
+
beta = beta_hat
|
|
113
|
+
mu = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
114
|
+
|
|
115
|
+
# Compute deviation
|
|
116
|
+
old_dev = dev
|
|
117
|
+
# Replaced deviation with -2 * nll, as in the R code
|
|
118
|
+
dev = -2 * nb_nll(counts, mu, disp)
|
|
119
|
+
dev_ratio = np.abs(dev - old_dev) / (np.abs(dev) + 0.1)
|
|
120
|
+
|
|
121
|
+
# Compute H diagonal (useful for Cook distance outlier filtering)
|
|
122
|
+
W = mu / (1.0 + mu * disp)
|
|
123
|
+
W_sq = np.sqrt(W)
|
|
124
|
+
XtWX = (X.T * W) @ X + ridge_factor
|
|
125
|
+
H = W_sq * np.diag(X @ np.linalg.inv(XtWX) @ X.T) * W_sq
|
|
126
|
+
|
|
127
|
+
# Return an UNthresholded mu
|
|
128
|
+
# Previous quantities are estimated with a threshold though
|
|
129
|
+
mu = np.maximum(cnv * size_factors * np.exp(np.clip(X @ beta, -30, 30)), min_mu)
|
|
130
|
+
|
|
131
|
+
return beta, mu, H, converged
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def fit_lin_mu(
|
|
135
|
+
counts: np.ndarray,
|
|
136
|
+
size_factors: np.ndarray,
|
|
137
|
+
design_matrix: np.ndarray,
|
|
138
|
+
min_mu: float = 0.5,
|
|
139
|
+
) -> np.ndarray:
|
|
140
|
+
"""Estimate mean of negative binomial model using a linear regression.
|
|
141
|
+
|
|
142
|
+
Used to initialize genewise dispersion models.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
counts : ndarray
|
|
147
|
+
Raw counts for a given gene.
|
|
148
|
+
|
|
149
|
+
size_factors : ndarray
|
|
150
|
+
Sample-wise scaling factors (obtained from median-of-ratios).
|
|
151
|
+
|
|
152
|
+
design_matrix : ndarray
|
|
153
|
+
Design matrix.
|
|
154
|
+
|
|
155
|
+
min_mu : float
|
|
156
|
+
Lower threshold for fitted means, for numerical stability. (default: ``0.5``).
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
ndarray
|
|
161
|
+
Estimated mean.
|
|
162
|
+
"""
|
|
163
|
+
reg = LinearRegression(fit_intercept=False)
|
|
164
|
+
reg.fit(design_matrix, counts / size_factors)
|
|
165
|
+
mu_hat = size_factors * reg.predict(design_matrix)
|
|
166
|
+
# Threshold mu_hat as 1/mu_hat will be used later on.
|
|
167
|
+
return np.maximum(mu_hat, min_mu)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def fit_rough_dispersions(
|
|
171
|
+
normed_counts: np.ndarray, design_matrix: pd.DataFrame
|
|
172
|
+
) -> np.ndarray:
|
|
173
|
+
"""Rough dispersion estimates from linear model, as per the R code.
|
|
174
|
+
|
|
175
|
+
Used as initial estimates in :meth:`DeseqDataSet.fit_genewise_dispersions()
|
|
176
|
+
<pydeseq2.dds.DeseqDataSet.fit_genewise_dispersions>`.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
normed_counts : ndarray
|
|
181
|
+
Array of deseq2-normalized read counts. Rows: samples, columns: genes.
|
|
182
|
+
|
|
183
|
+
design_matrix : pandas.DataFrame
|
|
184
|
+
A DataFrame with experiment design information (to split cohorts).
|
|
185
|
+
Indexed by sample barcodes. Unexpanded, *with* intercept.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
ndarray
|
|
190
|
+
Estimated dispersion parameter for each gene.
|
|
191
|
+
"""
|
|
192
|
+
num_samples, num_vars = design_matrix.shape
|
|
193
|
+
# This method is only possible when num_samples > num_vars.
|
|
194
|
+
# If this is not the case, throw an error.
|
|
195
|
+
if num_samples == num_vars:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
"The number of samples and the number of design variables are "
|
|
198
|
+
"equal, i.e., there are no replicates to estimate the "
|
|
199
|
+
"dispersion. Please use a design with fewer variables."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
reg = LinearRegression(fit_intercept=False)
|
|
203
|
+
reg.fit(design_matrix, normed_counts)
|
|
204
|
+
y_hat = reg.predict(design_matrix)
|
|
205
|
+
y_hat = np.maximum(y_hat, 1)
|
|
206
|
+
alpha_rde = (
|
|
207
|
+
((normed_counts - y_hat) ** 2 - y_hat) / ((num_samples - num_vars) * y_hat**2)
|
|
208
|
+
).sum(0)
|
|
209
|
+
return np.maximum(alpha_rde, 0)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def fit_moments_dispersions2(
|
|
214
|
+
normed_counts: np.ndarray, size_factors: np.ndarray
|
|
215
|
+
) -> np.ndarray:
|
|
216
|
+
"""Dispersion estimates based on moments, as per the R code.
|
|
217
|
+
|
|
218
|
+
Used as initial estimates in :meth:`DeseqDataSet.fit_genewise_dispersions()
|
|
219
|
+
<pydeseq2.dds.DeseqDataSet.fit_genewise_dispersions>`.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
normed_counts : ndarray
|
|
224
|
+
Array of deseq2-normalized read counts. Rows: samples, columns: genes.
|
|
225
|
+
|
|
226
|
+
size_factors : ndarray
|
|
227
|
+
DESeq2 normalization factors.
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
ndarray
|
|
232
|
+
Estimated dispersion parameter for each gene.
|
|
233
|
+
"""
|
|
234
|
+
# Exclude genes with all zeroes
|
|
235
|
+
#normed_counts = normed_counts[:, ~(normed_counts == 0).all(axis=0)]
|
|
236
|
+
# mean inverse size factor
|
|
237
|
+
s_mean_inv = (1 /size_factors).mean()
|
|
238
|
+
mu = normed_counts.mean(0)
|
|
239
|
+
sigma = normed_counts.var(0, ddof=1)
|
|
240
|
+
# ddof=1 is to use an unbiased estimator, as in R
|
|
241
|
+
# NaN (variance = 0) are replaced with 0s
|
|
242
|
+
return np.nan_to_num((sigma - s_mean_inv * mu) / mu**2)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def nb_nll(
|
|
246
|
+
counts: np.ndarray, mu: np.ndarray, alpha: Union[float, np.ndarray]
|
|
247
|
+
) -> Union[float, np.ndarray]:
|
|
248
|
+
r"""Neg log-likelihood of a negative binomial of parameters ``mu`` and ``alpha``.
|
|
249
|
+
|
|
250
|
+
Mathematically, if ``counts`` is a vector of counting entries :math:`y_i`
|
|
251
|
+
then the likelihood of each entry :math:`y_i` to be drawn from a negative
|
|
252
|
+
binomial :math:`NB(\mu, \alpha)` is [1]
|
|
253
|
+
|
|
254
|
+
.. math::
|
|
255
|
+
p(y_i | \mu, \alpha) = \frac{\Gamma(y_i + \alpha^{-1})}{
|
|
256
|
+
\Gamma(y_i + 1)\Gamma(\alpha^{-1})
|
|
257
|
+
}
|
|
258
|
+
\left(\frac{1}{1 + \alpha \mu} \right)^{1/\alpha}
|
|
259
|
+
\left(\frac{\mu}{\alpha^{-1} + \mu} \right)^{y_i}
|
|
260
|
+
|
|
261
|
+
As a consequence, assuming there are :math:`n` entries,
|
|
262
|
+
the total negative log-likelihood for ``counts`` is
|
|
263
|
+
|
|
264
|
+
.. math::
|
|
265
|
+
\ell(\mu, \alpha) = \frac{n}{\alpha} \log(\alpha) +
|
|
266
|
+
\sum_i \left \lbrace
|
|
267
|
+
- \log \left( \frac{\Gamma(y_i + \alpha^{-1})}{
|
|
268
|
+
\Gamma(y_i + 1)\Gamma(\alpha^{-1})
|
|
269
|
+
} \right)
|
|
270
|
+
+ (\alpha^{-1} + y_i) \log (\alpha^{-1} + \mu)
|
|
271
|
+
- y_i \log \mu
|
|
272
|
+
\right \rbrace
|
|
273
|
+
|
|
274
|
+
This is implemented in this function.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
counts : ndarray
|
|
279
|
+
Observations.
|
|
280
|
+
|
|
281
|
+
mu : ndarray
|
|
282
|
+
Mean of the distribution :math:`\mu`.
|
|
283
|
+
|
|
284
|
+
alpha : float or ndarray
|
|
285
|
+
Dispersion of the distribution :math:`\alpha`,
|
|
286
|
+
s.t. the variance is :math:`\mu + \alpha \mu^2`.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
float or ndarray
|
|
291
|
+
Negative log likelihood of the observations counts
|
|
292
|
+
following :math:`NB(\mu, \alpha)`.
|
|
293
|
+
|
|
294
|
+
Notes
|
|
295
|
+
-----
|
|
296
|
+
[1] https://en.wikipedia.org/wiki/Negative_binomial_distribution
|
|
297
|
+
"""
|
|
298
|
+
n = len(counts)
|
|
299
|
+
alpha_neg1 = 1 / alpha
|
|
300
|
+
logbinom = gammaln(counts + alpha_neg1) - gammaln(counts + 1) - gammaln(alpha_neg1)
|
|
301
|
+
if hasattr(alpha, "__len__") and len(alpha) > 1:
|
|
302
|
+
return (
|
|
303
|
+
alpha_neg1 * np.log(alpha)
|
|
304
|
+
- logbinom
|
|
305
|
+
+ (counts + alpha_neg1) * np.log(mu + alpha_neg1)
|
|
306
|
+
- (counts * np.log(mu))
|
|
307
|
+
).sum(0)
|
|
308
|
+
else:
|
|
309
|
+
return (
|
|
310
|
+
n * alpha_neg1 * np.log(alpha)
|
|
311
|
+
+ (
|
|
312
|
+
-logbinom
|
|
313
|
+
+ (counts + alpha_neg1) * np.log(alpha_neg1 + mu)
|
|
314
|
+
- counts * np.log(mu)
|
|
315
|
+
).sum()
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def nbinomGLM(
|
|
320
|
+
design_matrix: np.ndarray,
|
|
321
|
+
counts: np.ndarray,
|
|
322
|
+
cnv: np.ndarray,
|
|
323
|
+
size: np.ndarray,
|
|
324
|
+
offset: np.ndarray,
|
|
325
|
+
prior_no_shrink_scale: float,
|
|
326
|
+
prior_scale: float,
|
|
327
|
+
optimizer="L-BFGS-B",
|
|
328
|
+
shrink_index: int = 1,
|
|
329
|
+
) -> Tuple[np.ndarray, np.ndarray, bool]:
|
|
330
|
+
"""Fit a negative binomial MAP LFC using an apeGLM prior.
|
|
331
|
+
|
|
332
|
+
Only the LFC is shrinked, and not the intercept.
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
design_matrix : ndarray
|
|
337
|
+
Design matrix.
|
|
338
|
+
|
|
339
|
+
counts : ndarray
|
|
340
|
+
Raw counts.
|
|
341
|
+
|
|
342
|
+
size : ndarray
|
|
343
|
+
Size parameter of NB family (inverse of dispersion).
|
|
344
|
+
|
|
345
|
+
offset : ndarray
|
|
346
|
+
Natural logarithm of size factor.
|
|
347
|
+
|
|
348
|
+
prior_no_shrink_scale : float
|
|
349
|
+
Prior variance for the intercept.
|
|
350
|
+
|
|
351
|
+
prior_scale : float
|
|
352
|
+
Prior variance for the LFC parameter.
|
|
353
|
+
|
|
354
|
+
optimizer : str
|
|
355
|
+
Optimizing method to use in case IRLS starts diverging.
|
|
356
|
+
Accepted values: 'L-BFGS-B', 'BFGS' or 'Newton-CG'. (default: ``'Newton-CG'``).
|
|
357
|
+
|
|
358
|
+
shrink_index : int
|
|
359
|
+
Index of the LFC coordinate to shrink. (default: ``1``).
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
beta: ndarray
|
|
364
|
+
2-element array, containing the intercept (first) and the LFC (second).
|
|
365
|
+
|
|
366
|
+
inv_hessian: ndarray
|
|
367
|
+
Inverse of the Hessian of the objective at the estimated MAP LFC.
|
|
368
|
+
|
|
369
|
+
converged: bool
|
|
370
|
+
Whether L-BFGS-B converged.
|
|
371
|
+
"""
|
|
372
|
+
num_vars = design_matrix.shape[-1]
|
|
373
|
+
|
|
374
|
+
shrink_mask = np.zeros(num_vars)
|
|
375
|
+
shrink_mask[shrink_index] = 1
|
|
376
|
+
no_shrink_mask = np.ones(num_vars) - shrink_mask
|
|
377
|
+
|
|
378
|
+
beta_init = np.ones(num_vars) * 0.1 * (-1) ** (np.arange(num_vars))
|
|
379
|
+
|
|
380
|
+
# Set optimization scale
|
|
381
|
+
scale_cnst = nbinomFn(
|
|
382
|
+
np.zeros(num_vars),
|
|
383
|
+
design_matrix,
|
|
384
|
+
counts,
|
|
385
|
+
cnv,
|
|
386
|
+
size,
|
|
387
|
+
offset,
|
|
388
|
+
prior_no_shrink_scale,
|
|
389
|
+
prior_scale,
|
|
390
|
+
shrink_index,
|
|
391
|
+
)
|
|
392
|
+
scale_cnst = np.maximum(scale_cnst, 1)
|
|
393
|
+
|
|
394
|
+
def f(beta: np.ndarray, cnst: float = scale_cnst) -> float:
|
|
395
|
+
# Function to optimize
|
|
396
|
+
return (
|
|
397
|
+
nbinomFn(
|
|
398
|
+
beta,
|
|
399
|
+
design_matrix,
|
|
400
|
+
counts,
|
|
401
|
+
cnv,
|
|
402
|
+
size,
|
|
403
|
+
offset,
|
|
404
|
+
prior_no_shrink_scale,
|
|
405
|
+
prior_scale,
|
|
406
|
+
shrink_index,
|
|
407
|
+
)
|
|
408
|
+
/ cnst
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
def df(beta: np.ndarray, cnst: float = scale_cnst) -> np.ndarray:
|
|
412
|
+
# Gradient of the function to optimize
|
|
413
|
+
xbeta = design_matrix @ beta
|
|
414
|
+
d_neg_prior = (
|
|
415
|
+
beta * no_shrink_mask / prior_no_shrink_scale**2
|
|
416
|
+
+ 2 * beta * shrink_mask / (prior_scale**2 + beta[shrink_index] ** 2),
|
|
417
|
+
)
|
|
418
|
+
d_nll = (
|
|
419
|
+
counts - (counts + size) / (1 + size * np.exp(-xbeta - offset - cnv))
|
|
420
|
+
) @ design_matrix
|
|
421
|
+
|
|
422
|
+
return (d_neg_prior - d_nll) / cnst
|
|
423
|
+
|
|
424
|
+
def ddf(beta: np.ndarray, cnst: float = scale_cnst) -> np.ndarray:
|
|
425
|
+
# Hessian of the function to optimize
|
|
426
|
+
# Note: will only work if there is a single shrink index
|
|
427
|
+
xbeta = design_matrix @ beta
|
|
428
|
+
exp_xbeta_off = np.exp(xbeta + offset + cnv)
|
|
429
|
+
frac = (counts + size) * size * exp_xbeta_off / (size + exp_xbeta_off) ** 2
|
|
430
|
+
# Build diagonal
|
|
431
|
+
h11 = 1 / prior_no_shrink_scale**2
|
|
432
|
+
h22 = (
|
|
433
|
+
2
|
|
434
|
+
* (prior_scale**2 - beta[shrink_index] ** 2)
|
|
435
|
+
/ (prior_scale**2 + beta[shrink_index] ** 2) ** 2
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
h = np.diag(no_shrink_mask * h11 + shrink_mask * h22)
|
|
439
|
+
|
|
440
|
+
return 1 / cnst * ((design_matrix.T * frac) @ design_matrix + np.diag(h))
|
|
441
|
+
|
|
442
|
+
res = minimize(
|
|
443
|
+
f,
|
|
444
|
+
beta_init,
|
|
445
|
+
jac=df,
|
|
446
|
+
hess=ddf if optimizer == "Newton-CG" else None,
|
|
447
|
+
method=optimizer,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
beta = res.x
|
|
451
|
+
converged = res.success
|
|
452
|
+
|
|
453
|
+
if not converged and num_vars == 2:
|
|
454
|
+
# If the solver failed, fit using grid search (slow)
|
|
455
|
+
# Only for single-factor analysis
|
|
456
|
+
beta = grid_fit_shrink_beta(
|
|
457
|
+
counts,
|
|
458
|
+
cnv,
|
|
459
|
+
offset,
|
|
460
|
+
design_matrix,
|
|
461
|
+
size,
|
|
462
|
+
prior_no_shrink_scale,
|
|
463
|
+
prior_scale,
|
|
464
|
+
scale_cnst,
|
|
465
|
+
grid_length=60,
|
|
466
|
+
min_beta=-30,
|
|
467
|
+
max_beta=30,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
inv_hessian = np.linalg.inv(ddf(beta, 1))
|
|
471
|
+
|
|
472
|
+
return beta, inv_hessian, converged
|
|
473
|
+
|
|
474
|
+
def nbinomFn(
|
|
475
|
+
beta: np.ndarray,
|
|
476
|
+
design_matrix: np.ndarray,
|
|
477
|
+
counts: np.ndarray,
|
|
478
|
+
cnv: np.ndarray,
|
|
479
|
+
size: np.ndarray,
|
|
480
|
+
offset: np.ndarray,
|
|
481
|
+
prior_no_shrink_scale: float,
|
|
482
|
+
prior_scale: float,
|
|
483
|
+
shrink_index: int = 1,
|
|
484
|
+
) -> float:
|
|
485
|
+
"""Return the NB negative likelihood with apeGLM prior.
|
|
486
|
+
|
|
487
|
+
Use for LFC shrinkage.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
beta : ndarray
|
|
492
|
+
2-element array: intercept and LFC coefficients.
|
|
493
|
+
|
|
494
|
+
design_matrix : ndarray
|
|
495
|
+
Design matrix.
|
|
496
|
+
|
|
497
|
+
counts : ndarray
|
|
498
|
+
Raw counts.
|
|
499
|
+
|
|
500
|
+
size : ndarray
|
|
501
|
+
Size parameter of NB family (inverse of dispersion).
|
|
502
|
+
|
|
503
|
+
offset : ndarray
|
|
504
|
+
Natural logarithm of size factor.
|
|
505
|
+
|
|
506
|
+
prior_no_shrink_scale : float
|
|
507
|
+
Prior variance for the intercept.
|
|
508
|
+
|
|
509
|
+
prior_scale : float
|
|
510
|
+
Prior variance for the intercept.
|
|
511
|
+
|
|
512
|
+
shrink_index : int
|
|
513
|
+
Index of the LFC coordinate to shrink. (default: ``1``).
|
|
514
|
+
|
|
515
|
+
Returns
|
|
516
|
+
-------
|
|
517
|
+
float
|
|
518
|
+
Sum of the NB negative likelihood and apeGLM prior.
|
|
519
|
+
"""
|
|
520
|
+
num_vars = design_matrix.shape[-1]
|
|
521
|
+
|
|
522
|
+
shrink_mask = np.zeros(num_vars)
|
|
523
|
+
shrink_mask[shrink_index] = 1
|
|
524
|
+
no_shrink_mask = np.ones(num_vars) - shrink_mask
|
|
525
|
+
|
|
526
|
+
xbeta = design_matrix @ beta
|
|
527
|
+
prior = (
|
|
528
|
+
(beta * no_shrink_mask) ** 2 / (2 * prior_no_shrink_scale**2)
|
|
529
|
+
).sum() + np.log1p((beta[shrink_index] / prior_scale) ** 2)
|
|
530
|
+
|
|
531
|
+
nll = (
|
|
532
|
+
counts * xbeta - (counts + size) * np.logaddexp(xbeta + offset + cnv, np.log(size))
|
|
533
|
+
).sum(0)
|
|
534
|
+
|
|
535
|
+
return prior - nll
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def process_results(file_path, method, lfc_cut = 1.0, pval_cut = 0.05):
|
|
540
|
+
df = pd.read_csv(file_path, index_col=0)
|
|
541
|
+
df['isDE'] = (np.abs(df['log2FoldChange']) >= lfc_cut) & (df['padj'] <= pval_cut)
|
|
542
|
+
df['DEtype'] = np.where(
|
|
543
|
+
~df['isDE'],
|
|
544
|
+
"n.s.",
|
|
545
|
+
np.where(df['log2FoldChange'] > 0, "Up-reg", "Down-reg")
|
|
546
|
+
)
|
|
547
|
+
df['method'] = method
|
|
548
|
+
return df[['log2FoldChange', 'padj', 'isDE', 'DEtype', 'method']]
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def define_gene_groups(res_joint):
|
|
552
|
+
DSGs = res_joint[
|
|
553
|
+
((res_joint['DEtype_naive'] == "Up-reg") & (res_joint['DEtype_aware'] == "n.s.")) |
|
|
554
|
+
((res_joint['DEtype_naive'] == "Down-reg") & (res_joint['DEtype_aware'] == "n.s."))
|
|
555
|
+
].assign(gene_category='DSGs')
|
|
556
|
+
|
|
557
|
+
DIGs = res_joint[
|
|
558
|
+
((res_joint['DEtype_naive'] == "Up-reg") & (res_joint['DEtype_aware'] == "Up-reg")) |
|
|
559
|
+
((res_joint['DEtype_naive'] == "Down-reg") & (res_joint['DEtype_aware'] == "Down-reg"))
|
|
560
|
+
].assign(gene_category='DIGs')
|
|
561
|
+
|
|
562
|
+
DCGs = res_joint[
|
|
563
|
+
((res_joint['DEtype_naive'] == "n.s.") & (res_joint['DEtype_aware'] == "Up-reg")) |
|
|
564
|
+
((res_joint['DEtype_naive'] == "n.s.") & (res_joint['DEtype_aware'] == "Down-reg"))
|
|
565
|
+
].assign(gene_category='DCGs')
|
|
566
|
+
|
|
567
|
+
non_DEGs = res_joint[
|
|
568
|
+
(res_joint['DEtype_naive'] == "n.s.") & (res_joint['DEtype_aware'] == "n.s.")
|
|
569
|
+
].assign(gene_category='non-DEGs')
|
|
570
|
+
|
|
571
|
+
return {
|
|
572
|
+
"DSGs": DSGs,
|
|
573
|
+
"DIGs": DIGs,
|
|
574
|
+
"DCGs": DCGs,
|
|
575
|
+
"non_DEGs": non_DEGs
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def generate_volcano_plot(plot_data, lfc_cut=1.0, pval_cut=0.05, xlim=None, ylim=None):
|
|
580
|
+
plot_data['gene_group'] = plot_data['gene_group'].astype('category')
|
|
581
|
+
|
|
582
|
+
# Define gene group colors
|
|
583
|
+
gene_group_colors = {
|
|
584
|
+
"DIGs": "#8F3931FF",
|
|
585
|
+
"DSGs": "#FFB977",
|
|
586
|
+
"DCGs": "#FFC300"
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
# Create a FacetGrid for faceted plots
|
|
590
|
+
g = sns.FacetGrid(
|
|
591
|
+
plot_data,
|
|
592
|
+
col="method",
|
|
593
|
+
margin_titles=True,
|
|
594
|
+
hue="gene_group",
|
|
595
|
+
palette=gene_group_colors,
|
|
596
|
+
sharey=False,
|
|
597
|
+
sharex=True
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
# Add points for "DIGs"
|
|
602
|
+
g.map_dataframe(
|
|
603
|
+
sns.scatterplot,
|
|
604
|
+
x="log2FC",
|
|
605
|
+
y="-log10(padj)",
|
|
606
|
+
alpha=0.1,
|
|
607
|
+
size=0.5,
|
|
608
|
+
legend=False,
|
|
609
|
+
data=plot_data[plot_data['gene_group'].isin(["DIGs"])]
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Add points for "DSGs" and "DCGs"
|
|
613
|
+
g.map_dataframe(
|
|
614
|
+
sns.scatterplot,
|
|
615
|
+
x="log2FC",
|
|
616
|
+
y="-log10(padj)",
|
|
617
|
+
alpha=1.0,
|
|
618
|
+
size=3.0,
|
|
619
|
+
legend=False,
|
|
620
|
+
data=plot_data[plot_data['gene_group'].isin(["DSGs", "DCGs"])]
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Add vertical and horizontal dashed lines
|
|
624
|
+
for ax in g.axes.flat:
|
|
625
|
+
ax.axvline(x=-lfc_cut, color="gray", linestyle="dashed")
|
|
626
|
+
ax.axvline(x=lfc_cut, color="gray", linestyle="dashed")
|
|
627
|
+
ax.axhline(y=-np.log10(pval_cut), color="gray", linestyle="dashed")
|
|
628
|
+
|
|
629
|
+
if xlim:
|
|
630
|
+
ax.set_xlim(xlim)
|
|
631
|
+
if ylim:
|
|
632
|
+
ax.set_ylim(ylim)
|
|
633
|
+
|
|
634
|
+
# Set axis labels
|
|
635
|
+
g.set_axis_labels("Log2 FC", "-Log10 P-value")
|
|
636
|
+
|
|
637
|
+
# Add titles, legends, and customize
|
|
638
|
+
g.add_legend(title="Gene category")
|
|
639
|
+
g.set_titles(row_template="{row_name}", col_template="{col_name}")
|
|
640
|
+
g.tight_layout()
|
|
641
|
+
|
|
642
|
+
# Adjust font sizes for better readability
|
|
643
|
+
for ax in g.axes.flat:
|
|
644
|
+
ax.tick_params(axis='both', labelsize=12)
|
|
645
|
+
ax.set_xlabel("Log2 FC", fontsize=14)
|
|
646
|
+
ax.set_ylabel("-Log10 P-value", fontsize=14)
|
|
647
|
+
|
|
648
|
+
# Save or display the plot
|
|
649
|
+
plt.show()
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def plot_cnv_hist(cnv_mean, binwidth=0.2):
|
|
653
|
+
"""
|
|
654
|
+
Plots a histogram of the CNV mean distribution.
|
|
655
|
+
|
|
656
|
+
Parameters:
|
|
657
|
+
cnv_mean (pd.Series or list): The CNV mean values to plot.
|
|
658
|
+
binwidth (float): The bin width for the histogram.
|
|
659
|
+
"""
|
|
660
|
+
# Convert to a DataFrame if it's not already
|
|
661
|
+
if isinstance(cnv_mean, list):
|
|
662
|
+
cnv_mean = pd.DataFrame({'cnv_mean': cnv_mean})
|
|
663
|
+
elif isinstance(cnv_mean, pd.Series):
|
|
664
|
+
cnv_mean = cnv_mean.to_frame(name='cnv_mean')
|
|
665
|
+
|
|
666
|
+
# Create the histogram plot
|
|
667
|
+
plt.figure(figsize=(5, 5))
|
|
668
|
+
sns.histplot(
|
|
669
|
+
cnv_mean['cnv_mean'],
|
|
670
|
+
bins=int((cnv_mean['cnv_mean'].max() - cnv_mean['cnv_mean'].min()) / binwidth),
|
|
671
|
+
kde=False,
|
|
672
|
+
color="#F39B7F",
|
|
673
|
+
edgecolor="black",
|
|
674
|
+
alpha=0.7
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# Add labels and titles
|
|
678
|
+
plt.title("", fontsize=14)
|
|
679
|
+
plt.xlabel("CN state", fontsize=14, labelpad=8)
|
|
680
|
+
plt.ylabel("Frequency", fontsize=14, labelpad=8)
|
|
681
|
+
|
|
682
|
+
# Customize the appearance of axes
|
|
683
|
+
plt.xticks(fontsize=12, color="black", rotation=45, ha="right")
|
|
684
|
+
plt.yticks(fontsize=12, color="black")
|
|
685
|
+
plt.gca().spines["top"].set_visible(False)
|
|
686
|
+
plt.gca().spines["right"].set_visible(False)
|
|
687
|
+
plt.gca().spines["left"].set_linewidth(1)
|
|
688
|
+
plt.gca().spines["bottom"].set_linewidth(1)
|
|
689
|
+
|
|
690
|
+
# Add a grid
|
|
691
|
+
plt.grid(visible=False)
|
|
692
|
+
|
|
693
|
+
# Show the plot
|
|
694
|
+
plt.tight_layout()
|
|
695
|
+
plt.show()
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def plot_stacked_bar(combined_data):
|
|
699
|
+
"""
|
|
700
|
+
Creates a stacked bar plot of gene counts by CNV group for each tumor type.
|
|
701
|
+
|
|
702
|
+
Parameters:
|
|
703
|
+
- combined_data: DataFrame containing the data to plot.
|
|
704
|
+
"""
|
|
705
|
+
# Define CNV colors inside the function
|
|
706
|
+
cnv_colors = {
|
|
707
|
+
"loss": "#0000FF",
|
|
708
|
+
"neutral": "#808080",
|
|
709
|
+
"gain": "#00FF00",
|
|
710
|
+
"amplification": "#FF0000"
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
tumor_types = combined_data['tumor_type'].unique()
|
|
714
|
+
|
|
715
|
+
# Create subplots for each tumor type
|
|
716
|
+
fig, axes = plt.subplots(1, len(tumor_types), figsize=(5, 5), sharey=True)
|
|
717
|
+
|
|
718
|
+
# If there's only one tumor type, axes will not be an array, so we convert it into a list
|
|
719
|
+
if len(tumor_types) == 1:
|
|
720
|
+
axes = [axes]
|
|
721
|
+
|
|
722
|
+
for idx, tumor_type in enumerate(tumor_types):
|
|
723
|
+
ax = axes[idx]
|
|
724
|
+
tumor_data = combined_data[combined_data['tumor_type'] == tumor_type]
|
|
725
|
+
|
|
726
|
+
# Create a table of counts for CNV group vs gene group
|
|
727
|
+
counts = pd.crosstab(tumor_data['gene_group'], tumor_data['cnv_group'])
|
|
728
|
+
|
|
729
|
+
# Plot stacked bars
|
|
730
|
+
counts.plot(kind='bar', stacked=True, ax=ax, color=[cnv_colors[group] for group in counts.columns], width=0.6)
|
|
731
|
+
|
|
732
|
+
ax.set_title(tumor_type, fontsize=16)
|
|
733
|
+
ax.set_xlabel("")
|
|
734
|
+
ax.set_ylabel("Gene Counts", fontsize=16)
|
|
735
|
+
|
|
736
|
+
# Customize axis labels and tick marks
|
|
737
|
+
ax.tick_params(axis='x', labelsize=16, labelcolor="black")
|
|
738
|
+
ax.tick_params(axis='y', labelsize=16, labelcolor="black")
|
|
739
|
+
|
|
740
|
+
# Overall settings for layout and labels
|
|
741
|
+
plt.xticks(fontsize=12, color="black", rotation=45, ha="right")
|
|
742
|
+
plt.tight_layout()
|
|
743
|
+
plt.show()
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def plot_percentage_bar(barplot_data):
|
|
747
|
+
"""
|
|
748
|
+
Creates a bar plot showing the percentage of genes for each gene group across tumor types.
|
|
749
|
+
|
|
750
|
+
Parameters:
|
|
751
|
+
- barplot_data: DataFrame containing 'gene_group', 'percentage', and 'Count' columns.
|
|
752
|
+
"""
|
|
753
|
+
# Define the gene group colors inside the function
|
|
754
|
+
gene_group_colors = {
|
|
755
|
+
"DIGs": "#8F3931FF",
|
|
756
|
+
"DSGs": "#FFB977",
|
|
757
|
+
"DCGs": "#FFC300"
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
tumor_types = barplot_data['tumor_type'].unique()
|
|
761
|
+
|
|
762
|
+
plt.figure(figsize=(5, 5))
|
|
763
|
+
sns.set(style="whitegrid")
|
|
764
|
+
|
|
765
|
+
# Create subplots for each tumor type
|
|
766
|
+
fig, axes = plt.subplots(1, len(tumor_types), figsize=(5, 5), sharey=True)
|
|
767
|
+
|
|
768
|
+
# If only one tumor type, ensure axes is a list
|
|
769
|
+
if len(tumor_types) == 1:
|
|
770
|
+
axes = [axes]
|
|
771
|
+
|
|
772
|
+
for idx, tumor_type in enumerate(tumor_types):
|
|
773
|
+
ax = axes[idx]
|
|
774
|
+
tumor_data = barplot_data[barplot_data['tumor_type'] == tumor_type]
|
|
775
|
+
|
|
776
|
+
# Plot the percentage bar plot
|
|
777
|
+
sns.barplot(data=tumor_data, x="gene_group", y="percentage", hue="gene_group",
|
|
778
|
+
palette=gene_group_colors, ax=ax, width=0.6)
|
|
779
|
+
|
|
780
|
+
# Add counts and percentages as labels
|
|
781
|
+
for p in ax.patches:
|
|
782
|
+
height = p.get_height()
|
|
783
|
+
gene_group = p.get_x() + p.get_width() / 2 # Get the x position of the patch (bar)
|
|
784
|
+
|
|
785
|
+
# Find the gene_group in the data based on its position
|
|
786
|
+
group_name = tumor_data.iloc[int(gene_group)]['gene_group']
|
|
787
|
+
count = tumor_data.loc[tumor_data['gene_group'] == group_name, 'Count'].values[0]
|
|
788
|
+
percentage = tumor_data.loc[tumor_data['gene_group'] == group_name, 'percentage'].values[0]
|
|
789
|
+
|
|
790
|
+
# Position the labels slightly above the bars
|
|
791
|
+
ax.text(p.get_x() + p.get_width() / 2, height + 0.5, f'{count} ({round(percentage, 1)}%)',
|
|
792
|
+
ha='center', va='bottom', fontsize=12, color="black")
|
|
793
|
+
|
|
794
|
+
ax.set_title(tumor_type, fontsize=16)
|
|
795
|
+
ax.set_xlabel("")
|
|
796
|
+
ax.set_ylabel("Percentage of Genes", fontsize=16)
|
|
797
|
+
|
|
798
|
+
# Customize axis labels and tick marks
|
|
799
|
+
ax.tick_params(axis='x', labelsize=16, labelcolor="black", rotation=45)
|
|
800
|
+
ax.tick_params(axis='y', labelsize=16, labelcolor="black")
|
|
801
|
+
|
|
802
|
+
# Explicitly set the x-tick labels with proper rotation and alignment
|
|
803
|
+
for tick in ax.get_xticklabels():
|
|
804
|
+
tick.set_horizontalalignment('right') # This ensures proper alignment for x-ticks
|
|
805
|
+
tick.set_rotation(45)
|
|
806
|
+
|
|
807
|
+
# Overall settings for layout and labels
|
|
808
|
+
plt.tight_layout()
|
|
809
|
+
plt.show()
|