gengeneeval 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +43 -1
- geneval/deg/__init__.py +65 -0
- geneval/deg/context.py +271 -0
- geneval/deg/detection.py +578 -0
- geneval/deg/evaluator.py +538 -0
- geneval/deg/visualization.py +376 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.0.dist-info}/METADATA +90 -3
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.0.dist-info}/RECORD +11 -6
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.0.dist-info}/WHEEL +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.0.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.0.dist-info}/licenses/LICENSE +0 -0
geneval/deg/detection.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fast DEG detection with CPU/GPU acceleration.
|
|
3
|
+
|
|
4
|
+
This module provides vectorized statistical tests for DEG detection:
|
|
5
|
+
- Welch's t-test (default, robust to unequal variance)
|
|
6
|
+
- Student's t-test
|
|
7
|
+
- Wilcoxon rank-sum test
|
|
8
|
+
- Log-fold change thresholding
|
|
9
|
+
|
|
10
|
+
All methods are accelerated using:
|
|
11
|
+
- Vectorized NumPy operations
|
|
12
|
+
- Optional GPU acceleration via PyTorch
|
|
13
|
+
- Parallel computation via joblib
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Optional, Literal, Dict, Union, Tuple, List
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
import numpy as np
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
# Optional dependencies
|
|
23
|
+
try:
|
|
24
|
+
import torch
|
|
25
|
+
HAS_TORCH = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
HAS_TORCH = False
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from joblib import Parallel, delayed
|
|
31
|
+
HAS_JOBLIB = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
HAS_JOBLIB = False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from scipy import stats
|
|
37
|
+
from scipy.stats import ttest_ind, mannwhitneyu
|
|
38
|
+
HAS_SCIPY = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
HAS_SCIPY = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Type alias for DEG methods
|
|
44
|
+
DEGMethod = Literal["welch", "student", "wilcoxon", "logfc"]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class DEGResult:
|
|
49
|
+
"""Results from DEG detection.
|
|
50
|
+
|
|
51
|
+
Attributes
|
|
52
|
+
----------
|
|
53
|
+
gene_names : np.ndarray
|
|
54
|
+
Names of all genes
|
|
55
|
+
pvalues : np.ndarray
|
|
56
|
+
P-values for each gene (NaN for logfc method)
|
|
57
|
+
pvalues_adj : np.ndarray
|
|
58
|
+
Adjusted p-values (Benjamini-Hochberg)
|
|
59
|
+
log_fold_changes : np.ndarray
|
|
60
|
+
Log2 fold changes (mean_perturbed / mean_control)
|
|
61
|
+
mean_control : np.ndarray
|
|
62
|
+
Mean expression in control
|
|
63
|
+
mean_perturbed : np.ndarray
|
|
64
|
+
Mean expression in perturbed
|
|
65
|
+
is_deg : np.ndarray
|
|
66
|
+
Boolean mask of significant DEGs
|
|
67
|
+
n_degs : int
|
|
68
|
+
Number of significant DEGs
|
|
69
|
+
method : str
|
|
70
|
+
Method used for detection
|
|
71
|
+
pval_threshold : float
|
|
72
|
+
P-value threshold used
|
|
73
|
+
lfc_threshold : float
|
|
74
|
+
Log fold change threshold used
|
|
75
|
+
"""
|
|
76
|
+
gene_names: np.ndarray
|
|
77
|
+
pvalues: np.ndarray
|
|
78
|
+
pvalues_adj: np.ndarray
|
|
79
|
+
log_fold_changes: np.ndarray
|
|
80
|
+
mean_control: np.ndarray
|
|
81
|
+
mean_perturbed: np.ndarray
|
|
82
|
+
is_deg: np.ndarray
|
|
83
|
+
n_degs: int
|
|
84
|
+
method: str
|
|
85
|
+
pval_threshold: float
|
|
86
|
+
lfc_threshold: float
|
|
87
|
+
|
|
88
|
+
# Optional: indices of DEGs for fast slicing
|
|
89
|
+
deg_indices: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
|
|
90
|
+
|
|
91
|
+
def __post_init__(self):
|
|
92
|
+
"""Compute DEG indices after initialization."""
|
|
93
|
+
if len(self.deg_indices) == 0:
|
|
94
|
+
self.deg_indices = np.where(self.is_deg)[0]
|
|
95
|
+
|
|
96
|
+
def get_deg_names(self) -> np.ndarray:
|
|
97
|
+
"""Get names of significant DEGs."""
|
|
98
|
+
return self.gene_names[self.is_deg]
|
|
99
|
+
|
|
100
|
+
def to_dataframe(self):
|
|
101
|
+
"""Convert to pandas DataFrame."""
|
|
102
|
+
import pandas as pd
|
|
103
|
+
return pd.DataFrame({
|
|
104
|
+
"gene": self.gene_names,
|
|
105
|
+
"pvalue": self.pvalues,
|
|
106
|
+
"pvalue_adj": self.pvalues_adj,
|
|
107
|
+
"log2fc": self.log_fold_changes,
|
|
108
|
+
"mean_control": self.mean_control,
|
|
109
|
+
"mean_perturbed": self.mean_perturbed,
|
|
110
|
+
"is_deg": self.is_deg,
|
|
111
|
+
}).set_index("gene")
|
|
112
|
+
|
|
113
|
+
def __repr__(self) -> str:
|
|
114
|
+
return (
|
|
115
|
+
f"DEGResult(n_genes={len(self.gene_names)}, n_degs={self.n_degs}, "
|
|
116
|
+
f"method='{self.method}', pval<{self.pval_threshold}, |lfc|>{self.lfc_threshold})"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _benjamini_hochberg(pvalues: np.ndarray) -> np.ndarray:
|
|
121
|
+
"""Apply Benjamini-Hochberg correction for multiple testing.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
pvalues : np.ndarray
|
|
126
|
+
Raw p-values
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
np.ndarray
|
|
131
|
+
Adjusted p-values (FDR)
|
|
132
|
+
"""
|
|
133
|
+
n = len(pvalues)
|
|
134
|
+
if n == 0:
|
|
135
|
+
return pvalues
|
|
136
|
+
|
|
137
|
+
# Handle NaN values
|
|
138
|
+
valid_mask = ~np.isnan(pvalues)
|
|
139
|
+
pvalues_adj = np.full_like(pvalues, np.nan)
|
|
140
|
+
|
|
141
|
+
if not np.any(valid_mask):
|
|
142
|
+
return pvalues_adj
|
|
143
|
+
|
|
144
|
+
valid_pvals = pvalues[valid_mask]
|
|
145
|
+
|
|
146
|
+
# Sort p-values
|
|
147
|
+
sorted_idx = np.argsort(valid_pvals)
|
|
148
|
+
sorted_pvals = valid_pvals[sorted_idx]
|
|
149
|
+
|
|
150
|
+
# BH correction
|
|
151
|
+
n_valid = len(sorted_pvals)
|
|
152
|
+
rank = np.arange(1, n_valid + 1)
|
|
153
|
+
adjusted = sorted_pvals * n_valid / rank
|
|
154
|
+
|
|
155
|
+
# Ensure monotonicity (cumulative minimum from right)
|
|
156
|
+
adjusted = np.minimum.accumulate(adjusted[::-1])[::-1]
|
|
157
|
+
|
|
158
|
+
# Clip to [0, 1]
|
|
159
|
+
adjusted = np.clip(adjusted, 0, 1)
|
|
160
|
+
|
|
161
|
+
# Restore original order
|
|
162
|
+
unsorted_adj = np.empty_like(adjusted)
|
|
163
|
+
unsorted_adj[sorted_idx] = adjusted
|
|
164
|
+
|
|
165
|
+
pvalues_adj[valid_mask] = unsorted_adj
|
|
166
|
+
|
|
167
|
+
return pvalues_adj
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def compute_degs_fast(
|
|
171
|
+
control: np.ndarray,
|
|
172
|
+
perturbed: np.ndarray,
|
|
173
|
+
gene_names: Optional[np.ndarray] = None,
|
|
174
|
+
method: DEGMethod = "welch",
|
|
175
|
+
pval_threshold: float = 0.05,
|
|
176
|
+
lfc_threshold: float = 0.5,
|
|
177
|
+
use_adjusted_pval: bool = True,
|
|
178
|
+
n_jobs: int = 1,
|
|
179
|
+
) -> DEGResult:
|
|
180
|
+
"""
|
|
181
|
+
Fast DEG detection using vectorized statistical tests.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
control : np.ndarray
|
|
186
|
+
Control expression matrix (n_samples_control, n_genes)
|
|
187
|
+
perturbed : np.ndarray
|
|
188
|
+
Perturbed expression matrix (n_samples_perturbed, n_genes)
|
|
189
|
+
gene_names : np.ndarray, optional
|
|
190
|
+
Gene names. If None, uses indices.
|
|
191
|
+
method : str
|
|
192
|
+
Statistical test: "welch", "student", "wilcoxon", "logfc"
|
|
193
|
+
pval_threshold : float
|
|
194
|
+
P-value threshold for significance
|
|
195
|
+
lfc_threshold : float
|
|
196
|
+
Absolute log2 fold change threshold
|
|
197
|
+
use_adjusted_pval : bool
|
|
198
|
+
If True, use adjusted p-values (BH correction)
|
|
199
|
+
n_jobs : int
|
|
200
|
+
Number of parallel jobs (only for wilcoxon)
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
DEGResult
|
|
205
|
+
DEG detection results
|
|
206
|
+
|
|
207
|
+
Examples
|
|
208
|
+
--------
|
|
209
|
+
>>> control = np.random.randn(100, 1000) # 100 control cells, 1000 genes
|
|
210
|
+
>>> perturbed = control + np.random.randn(100, 1000) * 0.5 # Add noise
|
|
211
|
+
>>> perturbed[:, :50] += 2 # Make first 50 genes differentially expressed
|
|
212
|
+
>>> result = compute_degs_fast(control, perturbed, method="welch")
|
|
213
|
+
>>> print(f"Found {result.n_degs} DEGs")
|
|
214
|
+
"""
|
|
215
|
+
n_genes = control.shape[1]
|
|
216
|
+
|
|
217
|
+
# Gene names
|
|
218
|
+
if gene_names is None:
|
|
219
|
+
gene_names = np.array([f"Gene_{i}" for i in range(n_genes)])
|
|
220
|
+
|
|
221
|
+
# Compute means
|
|
222
|
+
mean_control = np.mean(control, axis=0)
|
|
223
|
+
mean_perturbed = np.mean(perturbed, axis=0)
|
|
224
|
+
|
|
225
|
+
# Compute log fold change (add pseudocount for stability)
|
|
226
|
+
# Use pseudocount of 1 for log normalization (common in RNA-seq)
|
|
227
|
+
eps = 1.0 # pseudocount
|
|
228
|
+
log_fold_changes = np.log2((mean_perturbed + eps) / (mean_control + eps))
|
|
229
|
+
|
|
230
|
+
# Compute p-values based on method
|
|
231
|
+
if method == "logfc":
|
|
232
|
+
# No statistical test, just fold change thresholding
|
|
233
|
+
pvalues = np.full(n_genes, np.nan)
|
|
234
|
+
pvalues_adj = pvalues.copy()
|
|
235
|
+
elif method == "welch":
|
|
236
|
+
pvalues = _welch_ttest_vectorized(control, perturbed)
|
|
237
|
+
pvalues_adj = _benjamini_hochberg(pvalues)
|
|
238
|
+
elif method == "student":
|
|
239
|
+
pvalues = _student_ttest_vectorized(control, perturbed)
|
|
240
|
+
pvalues_adj = _benjamini_hochberg(pvalues)
|
|
241
|
+
elif method == "wilcoxon":
|
|
242
|
+
pvalues = _wilcoxon_vectorized(control, perturbed, n_jobs=n_jobs)
|
|
243
|
+
pvalues_adj = _benjamini_hochberg(pvalues)
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError(f"Unknown method: {method}. Use 'welch', 'student', 'wilcoxon', or 'logfc'")
|
|
246
|
+
|
|
247
|
+
# Determine significant DEGs
|
|
248
|
+
if method == "logfc":
|
|
249
|
+
is_deg = np.abs(log_fold_changes) > lfc_threshold
|
|
250
|
+
else:
|
|
251
|
+
pval_test = pvalues_adj if use_adjusted_pval else pvalues
|
|
252
|
+
is_deg = (pval_test < pval_threshold) & (np.abs(log_fold_changes) > lfc_threshold)
|
|
253
|
+
|
|
254
|
+
return DEGResult(
|
|
255
|
+
gene_names=gene_names,
|
|
256
|
+
pvalues=pvalues,
|
|
257
|
+
pvalues_adj=pvalues_adj,
|
|
258
|
+
log_fold_changes=log_fold_changes,
|
|
259
|
+
mean_control=mean_control,
|
|
260
|
+
mean_perturbed=mean_perturbed,
|
|
261
|
+
is_deg=is_deg,
|
|
262
|
+
n_degs=int(np.sum(is_deg)),
|
|
263
|
+
method=method,
|
|
264
|
+
pval_threshold=pval_threshold,
|
|
265
|
+
lfc_threshold=lfc_threshold,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _welch_ttest_vectorized(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
270
|
+
"""
|
|
271
|
+
Vectorized Welch's t-test across all genes simultaneously.
|
|
272
|
+
|
|
273
|
+
Much faster than scipy.stats.ttest_ind for many genes.
|
|
274
|
+
"""
|
|
275
|
+
n1, n2 = x.shape[0], y.shape[0]
|
|
276
|
+
|
|
277
|
+
# Sample means
|
|
278
|
+
mean1 = np.mean(x, axis=0)
|
|
279
|
+
mean2 = np.mean(y, axis=0)
|
|
280
|
+
|
|
281
|
+
# Sample variances (unbiased)
|
|
282
|
+
var1 = np.var(x, axis=0, ddof=1)
|
|
283
|
+
var2 = np.var(y, axis=0, ddof=1)
|
|
284
|
+
|
|
285
|
+
# Standard error
|
|
286
|
+
se = np.sqrt(var1 / n1 + var2 / n2)
|
|
287
|
+
|
|
288
|
+
# T-statistic
|
|
289
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
290
|
+
t_stat = (mean1 - mean2) / se
|
|
291
|
+
|
|
292
|
+
# Welch-Satterthwaite degrees of freedom
|
|
293
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
294
|
+
num = (var1 / n1 + var2 / n2) ** 2
|
|
295
|
+
denom = (var1 / n1) ** 2 / (n1 - 1) + (var2 / n2) ** 2 / (n2 - 1)
|
|
296
|
+
df = num / denom
|
|
297
|
+
|
|
298
|
+
# Handle edge cases
|
|
299
|
+
df = np.clip(df, 1, np.inf)
|
|
300
|
+
df = np.nan_to_num(df, nan=1.0)
|
|
301
|
+
|
|
302
|
+
# Two-tailed p-value using scipy (still fast for vectorized computation)
|
|
303
|
+
if HAS_SCIPY:
|
|
304
|
+
pvalues = 2 * stats.t.sf(np.abs(t_stat), df)
|
|
305
|
+
else:
|
|
306
|
+
# Fallback: approximate p-value using normal distribution for large df
|
|
307
|
+
pvalues = 2 * (1 - _normal_cdf(np.abs(t_stat)))
|
|
308
|
+
|
|
309
|
+
return np.nan_to_num(pvalues, nan=1.0)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _student_ttest_vectorized(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
313
|
+
"""
|
|
314
|
+
Vectorized Student's t-test (equal variance assumption).
|
|
315
|
+
"""
|
|
316
|
+
n1, n2 = x.shape[0], y.shape[0]
|
|
317
|
+
|
|
318
|
+
# Sample means
|
|
319
|
+
mean1 = np.mean(x, axis=0)
|
|
320
|
+
mean2 = np.mean(y, axis=0)
|
|
321
|
+
|
|
322
|
+
# Pooled variance
|
|
323
|
+
var1 = np.var(x, axis=0, ddof=1)
|
|
324
|
+
var2 = np.var(y, axis=0, ddof=1)
|
|
325
|
+
pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2)
|
|
326
|
+
|
|
327
|
+
# Standard error
|
|
328
|
+
se = np.sqrt(pooled_var * (1/n1 + 1/n2))
|
|
329
|
+
|
|
330
|
+
# T-statistic
|
|
331
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
332
|
+
t_stat = (mean1 - mean2) / se
|
|
333
|
+
|
|
334
|
+
# Degrees of freedom
|
|
335
|
+
df = n1 + n2 - 2
|
|
336
|
+
|
|
337
|
+
# Two-tailed p-value
|
|
338
|
+
if HAS_SCIPY:
|
|
339
|
+
pvalues = 2 * stats.t.sf(np.abs(t_stat), df)
|
|
340
|
+
else:
|
|
341
|
+
pvalues = 2 * (1 - _normal_cdf(np.abs(t_stat)))
|
|
342
|
+
|
|
343
|
+
return np.nan_to_num(pvalues, nan=1.0)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _wilcoxon_vectorized(
|
|
347
|
+
x: np.ndarray,
|
|
348
|
+
y: np.ndarray,
|
|
349
|
+
n_jobs: int = 1
|
|
350
|
+
) -> np.ndarray:
|
|
351
|
+
"""
|
|
352
|
+
Wilcoxon rank-sum test with optional parallelization.
|
|
353
|
+
|
|
354
|
+
Note: This is slower than t-tests but more robust for non-normal data.
|
|
355
|
+
"""
|
|
356
|
+
if not HAS_SCIPY:
|
|
357
|
+
raise ImportError("scipy is required for Wilcoxon test")
|
|
358
|
+
|
|
359
|
+
n_genes = x.shape[1]
|
|
360
|
+
|
|
361
|
+
if HAS_JOBLIB and n_jobs != 1:
|
|
362
|
+
# Parallel computation
|
|
363
|
+
def _compute_pval(i):
|
|
364
|
+
try:
|
|
365
|
+
_, pval = mannwhitneyu(x[:, i], y[:, i], alternative='two-sided')
|
|
366
|
+
return pval
|
|
367
|
+
except Exception:
|
|
368
|
+
return 1.0
|
|
369
|
+
|
|
370
|
+
pvalues = Parallel(n_jobs=n_jobs)(
|
|
371
|
+
delayed(_compute_pval)(i) for i in range(n_genes)
|
|
372
|
+
)
|
|
373
|
+
return np.array(pvalues)
|
|
374
|
+
else:
|
|
375
|
+
# Sequential computation
|
|
376
|
+
pvalues = np.zeros(n_genes)
|
|
377
|
+
for i in range(n_genes):
|
|
378
|
+
try:
|
|
379
|
+
_, pvalues[i] = mannwhitneyu(x[:, i], y[:, i], alternative='two-sided')
|
|
380
|
+
except Exception:
|
|
381
|
+
pvalues[i] = 1.0
|
|
382
|
+
return pvalues
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _normal_cdf(x: np.ndarray) -> np.ndarray:
|
|
386
|
+
"""Approximate normal CDF without scipy."""
|
|
387
|
+
return 0.5 * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def compute_degs_gpu(
|
|
391
|
+
control: np.ndarray,
|
|
392
|
+
perturbed: np.ndarray,
|
|
393
|
+
gene_names: Optional[np.ndarray] = None,
|
|
394
|
+
method: DEGMethod = "welch",
|
|
395
|
+
pval_threshold: float = 0.05,
|
|
396
|
+
lfc_threshold: float = 0.5,
|
|
397
|
+
use_adjusted_pval: bool = True,
|
|
398
|
+
device: str = "cuda",
|
|
399
|
+
) -> DEGResult:
|
|
400
|
+
"""
|
|
401
|
+
GPU-accelerated DEG detection using PyTorch.
|
|
402
|
+
|
|
403
|
+
Parameters
|
|
404
|
+
----------
|
|
405
|
+
control : np.ndarray
|
|
406
|
+
Control expression matrix (n_samples_control, n_genes)
|
|
407
|
+
perturbed : np.ndarray
|
|
408
|
+
Perturbed expression matrix (n_samples_perturbed, n_genes)
|
|
409
|
+
gene_names : np.ndarray, optional
|
|
410
|
+
Gene names. If None, uses indices.
|
|
411
|
+
method : str
|
|
412
|
+
Statistical test: "welch" or "student" (wilcoxon not supported on GPU)
|
|
413
|
+
pval_threshold : float
|
|
414
|
+
P-value threshold for significance
|
|
415
|
+
lfc_threshold : float
|
|
416
|
+
Absolute log2 fold change threshold
|
|
417
|
+
use_adjusted_pval : bool
|
|
418
|
+
If True, use adjusted p-values (BH correction)
|
|
419
|
+
device : str
|
|
420
|
+
GPU device: "cuda", "cuda:0", "mps", etc.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
DEGResult
|
|
425
|
+
DEG detection results
|
|
426
|
+
"""
|
|
427
|
+
if not HAS_TORCH:
|
|
428
|
+
warnings.warn("PyTorch not available, falling back to CPU")
|
|
429
|
+
return compute_degs_fast(
|
|
430
|
+
control, perturbed, gene_names, method,
|
|
431
|
+
pval_threshold, lfc_threshold, use_adjusted_pval
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
if method == "wilcoxon":
|
|
435
|
+
warnings.warn("Wilcoxon test not supported on GPU, falling back to CPU")
|
|
436
|
+
return compute_degs_fast(
|
|
437
|
+
control, perturbed, gene_names, method,
|
|
438
|
+
pval_threshold, lfc_threshold, use_adjusted_pval
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Get device
|
|
442
|
+
if device == "auto":
|
|
443
|
+
if torch.cuda.is_available():
|
|
444
|
+
device = "cuda"
|
|
445
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
446
|
+
device = "mps"
|
|
447
|
+
else:
|
|
448
|
+
device = "cpu"
|
|
449
|
+
|
|
450
|
+
torch_device = torch.device(device)
|
|
451
|
+
n_genes = control.shape[1]
|
|
452
|
+
|
|
453
|
+
# Gene names
|
|
454
|
+
if gene_names is None:
|
|
455
|
+
gene_names = np.array([f"Gene_{i}" for i in range(n_genes)])
|
|
456
|
+
|
|
457
|
+
# Move data to GPU
|
|
458
|
+
x = torch.tensor(control, dtype=torch.float32, device=torch_device)
|
|
459
|
+
y = torch.tensor(perturbed, dtype=torch.float32, device=torch_device)
|
|
460
|
+
|
|
461
|
+
n1, n2 = x.shape[0], y.shape[0]
|
|
462
|
+
|
|
463
|
+
# Compute means
|
|
464
|
+
mean_control = x.mean(dim=0)
|
|
465
|
+
mean_perturbed = y.mean(dim=0)
|
|
466
|
+
|
|
467
|
+
# Log fold change (use pseudocount of 1 for stability)
|
|
468
|
+
eps = 1.0
|
|
469
|
+
log_fold_changes = torch.log2((mean_perturbed + eps) / (mean_control + eps))
|
|
470
|
+
|
|
471
|
+
if method == "logfc":
|
|
472
|
+
pvalues = torch.full((n_genes,), float('nan'), device=torch_device)
|
|
473
|
+
pvalues_adj = pvalues.clone()
|
|
474
|
+
else:
|
|
475
|
+
# Compute variances
|
|
476
|
+
var1 = x.var(dim=0, unbiased=True)
|
|
477
|
+
var2 = y.var(dim=0, unbiased=True)
|
|
478
|
+
|
|
479
|
+
if method == "welch":
|
|
480
|
+
# Welch's t-test
|
|
481
|
+
se = torch.sqrt(var1 / n1 + var2 / n2)
|
|
482
|
+
t_stat = (mean_control - mean_perturbed) / (se + 1e-10)
|
|
483
|
+
|
|
484
|
+
# Welch-Satterthwaite degrees of freedom
|
|
485
|
+
num = (var1 / n1 + var2 / n2) ** 2
|
|
486
|
+
denom = (var1 / n1) ** 2 / (n1 - 1) + (var2 / n2) ** 2 / (n2 - 1)
|
|
487
|
+
df = num / (denom + 1e-10)
|
|
488
|
+
df = torch.clamp(df, min=1.0)
|
|
489
|
+
|
|
490
|
+
else: # student
|
|
491
|
+
# Student's t-test
|
|
492
|
+
pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2)
|
|
493
|
+
se = torch.sqrt(pooled_var * (1/n1 + 1/n2))
|
|
494
|
+
t_stat = (mean_control - mean_perturbed) / (se + 1e-10)
|
|
495
|
+
df = torch.full((n_genes,), n1 + n2 - 2, device=torch_device, dtype=torch.float32)
|
|
496
|
+
|
|
497
|
+
# Move to CPU for p-value computation (scipy needed for t-distribution)
|
|
498
|
+
t_stat_np = torch.abs(t_stat).cpu().numpy()
|
|
499
|
+
df_np = df.cpu().numpy()
|
|
500
|
+
|
|
501
|
+
if HAS_SCIPY:
|
|
502
|
+
pvalues_np = 2 * stats.t.sf(t_stat_np, df_np)
|
|
503
|
+
else:
|
|
504
|
+
pvalues_np = 2 * (1 - _normal_cdf(t_stat_np))
|
|
505
|
+
|
|
506
|
+
pvalues_np = np.nan_to_num(pvalues_np, nan=1.0).astype(np.float32)
|
|
507
|
+
pvalues_adj_np = _benjamini_hochberg(pvalues_np).astype(np.float32)
|
|
508
|
+
|
|
509
|
+
pvalues = torch.tensor(pvalues_np, device=torch_device, dtype=torch.float32)
|
|
510
|
+
pvalues_adj = torch.tensor(pvalues_adj_np, device=torch_device, dtype=torch.float32)
|
|
511
|
+
|
|
512
|
+
# Determine significant DEGs
|
|
513
|
+
lfc_abs = torch.abs(log_fold_changes)
|
|
514
|
+
if method == "logfc":
|
|
515
|
+
is_deg = lfc_abs > lfc_threshold
|
|
516
|
+
else:
|
|
517
|
+
pval_test = pvalues_adj if use_adjusted_pval else pvalues
|
|
518
|
+
is_deg = (pval_test < pval_threshold) & (lfc_abs > lfc_threshold)
|
|
519
|
+
|
|
520
|
+
# Move results to CPU
|
|
521
|
+
return DEGResult(
|
|
522
|
+
gene_names=gene_names,
|
|
523
|
+
pvalues=pvalues.cpu().numpy(),
|
|
524
|
+
pvalues_adj=pvalues_adj.cpu().numpy(),
|
|
525
|
+
log_fold_changes=log_fold_changes.cpu().numpy(),
|
|
526
|
+
mean_control=mean_control.cpu().numpy(),
|
|
527
|
+
mean_perturbed=mean_perturbed.cpu().numpy(),
|
|
528
|
+
is_deg=is_deg.cpu().numpy(),
|
|
529
|
+
n_degs=int(is_deg.sum().item()),
|
|
530
|
+
method=method,
|
|
531
|
+
pval_threshold=pval_threshold,
|
|
532
|
+
lfc_threshold=lfc_threshold,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def compute_degs_auto(
|
|
537
|
+
control: np.ndarray,
|
|
538
|
+
perturbed: np.ndarray,
|
|
539
|
+
gene_names: Optional[np.ndarray] = None,
|
|
540
|
+
method: DEGMethod = "welch",
|
|
541
|
+
pval_threshold: float = 0.05,
|
|
542
|
+
lfc_threshold: float = 0.5,
|
|
543
|
+
use_adjusted_pval: bool = True,
|
|
544
|
+
n_jobs: int = 1,
|
|
545
|
+
device: str = "auto",
|
|
546
|
+
) -> DEGResult:
|
|
547
|
+
"""
|
|
548
|
+
Automatically select the fastest DEG computation method.
|
|
549
|
+
|
|
550
|
+
Chooses GPU if available and data is large enough to benefit,
|
|
551
|
+
otherwise uses CPU with optional parallelization.
|
|
552
|
+
"""
|
|
553
|
+
n_genes = control.shape[1]
|
|
554
|
+
n_samples = control.shape[0] + perturbed.shape[0]
|
|
555
|
+
|
|
556
|
+
# Use GPU for large datasets
|
|
557
|
+
use_gpu = False
|
|
558
|
+
if device != "cpu" and HAS_TORCH:
|
|
559
|
+
if device == "auto":
|
|
560
|
+
if torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
|
|
561
|
+
# GPU worthwhile for >1000 genes or >1000 samples
|
|
562
|
+
if n_genes > 1000 or n_samples > 1000:
|
|
563
|
+
use_gpu = True
|
|
564
|
+
else:
|
|
565
|
+
use_gpu = True
|
|
566
|
+
|
|
567
|
+
if use_gpu:
|
|
568
|
+
return compute_degs_gpu(
|
|
569
|
+
control, perturbed, gene_names, method,
|
|
570
|
+
pval_threshold, lfc_threshold, use_adjusted_pval,
|
|
571
|
+
device=device if device != "auto" else "auto",
|
|
572
|
+
)
|
|
573
|
+
else:
|
|
574
|
+
return compute_degs_fast(
|
|
575
|
+
control, perturbed, gene_names, method,
|
|
576
|
+
pval_threshold, lfc_threshold, use_adjusted_pval,
|
|
577
|
+
n_jobs=n_jobs,
|
|
578
|
+
)
|