gengeneeval 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,578 @@
1
+ """
2
+ Fast DEG detection with CPU/GPU acceleration.
3
+
4
+ This module provides vectorized statistical tests for DEG detection:
5
+ - Welch's t-test (default, robust to unequal variance)
6
+ - Student's t-test
7
+ - Wilcoxon rank-sum test
8
+ - Log-fold change thresholding
9
+
10
+ All methods are accelerated using:
11
+ - Vectorized NumPy operations
12
+ - Optional GPU acceleration via PyTorch
13
+ - Parallel computation via joblib
14
+ """
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional, Literal, Dict, Union, Tuple, List
18
+ from dataclasses import dataclass, field
19
+ import numpy as np
20
+ import warnings
21
+
22
+ # Optional dependencies
23
+ try:
24
+ import torch
25
+ HAS_TORCH = True
26
+ except ImportError:
27
+ HAS_TORCH = False
28
+
29
+ try:
30
+ from joblib import Parallel, delayed
31
+ HAS_JOBLIB = True
32
+ except ImportError:
33
+ HAS_JOBLIB = False
34
+
35
+ try:
36
+ from scipy import stats
37
+ from scipy.stats import ttest_ind, mannwhitneyu
38
+ HAS_SCIPY = True
39
+ except ImportError:
40
+ HAS_SCIPY = False
41
+
42
+
43
+ # Type alias for DEG methods
44
+ DEGMethod = Literal["welch", "student", "wilcoxon", "logfc"]
45
+
46
+
47
+ @dataclass
48
+ class DEGResult:
49
+ """Results from DEG detection.
50
+
51
+ Attributes
52
+ ----------
53
+ gene_names : np.ndarray
54
+ Names of all genes
55
+ pvalues : np.ndarray
56
+ P-values for each gene (NaN for logfc method)
57
+ pvalues_adj : np.ndarray
58
+ Adjusted p-values (Benjamini-Hochberg)
59
+ log_fold_changes : np.ndarray
60
+ Log2 fold changes (mean_perturbed / mean_control)
61
+ mean_control : np.ndarray
62
+ Mean expression in control
63
+ mean_perturbed : np.ndarray
64
+ Mean expression in perturbed
65
+ is_deg : np.ndarray
66
+ Boolean mask of significant DEGs
67
+ n_degs : int
68
+ Number of significant DEGs
69
+ method : str
70
+ Method used for detection
71
+ pval_threshold : float
72
+ P-value threshold used
73
+ lfc_threshold : float
74
+ Log fold change threshold used
75
+ """
76
+ gene_names: np.ndarray
77
+ pvalues: np.ndarray
78
+ pvalues_adj: np.ndarray
79
+ log_fold_changes: np.ndarray
80
+ mean_control: np.ndarray
81
+ mean_perturbed: np.ndarray
82
+ is_deg: np.ndarray
83
+ n_degs: int
84
+ method: str
85
+ pval_threshold: float
86
+ lfc_threshold: float
87
+
88
+ # Optional: indices of DEGs for fast slicing
89
+ deg_indices: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
90
+
91
+ def __post_init__(self):
92
+ """Compute DEG indices after initialization."""
93
+ if len(self.deg_indices) == 0:
94
+ self.deg_indices = np.where(self.is_deg)[0]
95
+
96
+ def get_deg_names(self) -> np.ndarray:
97
+ """Get names of significant DEGs."""
98
+ return self.gene_names[self.is_deg]
99
+
100
+ def to_dataframe(self):
101
+ """Convert to pandas DataFrame."""
102
+ import pandas as pd
103
+ return pd.DataFrame({
104
+ "gene": self.gene_names,
105
+ "pvalue": self.pvalues,
106
+ "pvalue_adj": self.pvalues_adj,
107
+ "log2fc": self.log_fold_changes,
108
+ "mean_control": self.mean_control,
109
+ "mean_perturbed": self.mean_perturbed,
110
+ "is_deg": self.is_deg,
111
+ }).set_index("gene")
112
+
113
+ def __repr__(self) -> str:
114
+ return (
115
+ f"DEGResult(n_genes={len(self.gene_names)}, n_degs={self.n_degs}, "
116
+ f"method='{self.method}', pval<{self.pval_threshold}, |lfc|>{self.lfc_threshold})"
117
+ )
118
+
119
+
120
+ def _benjamini_hochberg(pvalues: np.ndarray) -> np.ndarray:
121
+ """Apply Benjamini-Hochberg correction for multiple testing.
122
+
123
+ Parameters
124
+ ----------
125
+ pvalues : np.ndarray
126
+ Raw p-values
127
+
128
+ Returns
129
+ -------
130
+ np.ndarray
131
+ Adjusted p-values (FDR)
132
+ """
133
+ n = len(pvalues)
134
+ if n == 0:
135
+ return pvalues
136
+
137
+ # Handle NaN values
138
+ valid_mask = ~np.isnan(pvalues)
139
+ pvalues_adj = np.full_like(pvalues, np.nan)
140
+
141
+ if not np.any(valid_mask):
142
+ return pvalues_adj
143
+
144
+ valid_pvals = pvalues[valid_mask]
145
+
146
+ # Sort p-values
147
+ sorted_idx = np.argsort(valid_pvals)
148
+ sorted_pvals = valid_pvals[sorted_idx]
149
+
150
+ # BH correction
151
+ n_valid = len(sorted_pvals)
152
+ rank = np.arange(1, n_valid + 1)
153
+ adjusted = sorted_pvals * n_valid / rank
154
+
155
+ # Ensure monotonicity (cumulative minimum from right)
156
+ adjusted = np.minimum.accumulate(adjusted[::-1])[::-1]
157
+
158
+ # Clip to [0, 1]
159
+ adjusted = np.clip(adjusted, 0, 1)
160
+
161
+ # Restore original order
162
+ unsorted_adj = np.empty_like(adjusted)
163
+ unsorted_adj[sorted_idx] = adjusted
164
+
165
+ pvalues_adj[valid_mask] = unsorted_adj
166
+
167
+ return pvalues_adj
168
+
169
+
170
+ def compute_degs_fast(
171
+ control: np.ndarray,
172
+ perturbed: np.ndarray,
173
+ gene_names: Optional[np.ndarray] = None,
174
+ method: DEGMethod = "welch",
175
+ pval_threshold: float = 0.05,
176
+ lfc_threshold: float = 0.5,
177
+ use_adjusted_pval: bool = True,
178
+ n_jobs: int = 1,
179
+ ) -> DEGResult:
180
+ """
181
+ Fast DEG detection using vectorized statistical tests.
182
+
183
+ Parameters
184
+ ----------
185
+ control : np.ndarray
186
+ Control expression matrix (n_samples_control, n_genes)
187
+ perturbed : np.ndarray
188
+ Perturbed expression matrix (n_samples_perturbed, n_genes)
189
+ gene_names : np.ndarray, optional
190
+ Gene names. If None, uses indices.
191
+ method : str
192
+ Statistical test: "welch", "student", "wilcoxon", "logfc"
193
+ pval_threshold : float
194
+ P-value threshold for significance
195
+ lfc_threshold : float
196
+ Absolute log2 fold change threshold
197
+ use_adjusted_pval : bool
198
+ If True, use adjusted p-values (BH correction)
199
+ n_jobs : int
200
+ Number of parallel jobs (only for wilcoxon)
201
+
202
+ Returns
203
+ -------
204
+ DEGResult
205
+ DEG detection results
206
+
207
+ Examples
208
+ --------
209
+ >>> control = np.random.randn(100, 1000) # 100 control cells, 1000 genes
210
+ >>> perturbed = control + np.random.randn(100, 1000) * 0.5 # Add noise
211
+ >>> perturbed[:, :50] += 2 # Make first 50 genes differentially expressed
212
+ >>> result = compute_degs_fast(control, perturbed, method="welch")
213
+ >>> print(f"Found {result.n_degs} DEGs")
214
+ """
215
+ n_genes = control.shape[1]
216
+
217
+ # Gene names
218
+ if gene_names is None:
219
+ gene_names = np.array([f"Gene_{i}" for i in range(n_genes)])
220
+
221
+ # Compute means
222
+ mean_control = np.mean(control, axis=0)
223
+ mean_perturbed = np.mean(perturbed, axis=0)
224
+
225
+ # Compute log fold change (add pseudocount for stability)
226
+ # Use pseudocount of 1 for log normalization (common in RNA-seq)
227
+ eps = 1.0 # pseudocount
228
+ log_fold_changes = np.log2((mean_perturbed + eps) / (mean_control + eps))
229
+
230
+ # Compute p-values based on method
231
+ if method == "logfc":
232
+ # No statistical test, just fold change thresholding
233
+ pvalues = np.full(n_genes, np.nan)
234
+ pvalues_adj = pvalues.copy()
235
+ elif method == "welch":
236
+ pvalues = _welch_ttest_vectorized(control, perturbed)
237
+ pvalues_adj = _benjamini_hochberg(pvalues)
238
+ elif method == "student":
239
+ pvalues = _student_ttest_vectorized(control, perturbed)
240
+ pvalues_adj = _benjamini_hochberg(pvalues)
241
+ elif method == "wilcoxon":
242
+ pvalues = _wilcoxon_vectorized(control, perturbed, n_jobs=n_jobs)
243
+ pvalues_adj = _benjamini_hochberg(pvalues)
244
+ else:
245
+ raise ValueError(f"Unknown method: {method}. Use 'welch', 'student', 'wilcoxon', or 'logfc'")
246
+
247
+ # Determine significant DEGs
248
+ if method == "logfc":
249
+ is_deg = np.abs(log_fold_changes) > lfc_threshold
250
+ else:
251
+ pval_test = pvalues_adj if use_adjusted_pval else pvalues
252
+ is_deg = (pval_test < pval_threshold) & (np.abs(log_fold_changes) > lfc_threshold)
253
+
254
+ return DEGResult(
255
+ gene_names=gene_names,
256
+ pvalues=pvalues,
257
+ pvalues_adj=pvalues_adj,
258
+ log_fold_changes=log_fold_changes,
259
+ mean_control=mean_control,
260
+ mean_perturbed=mean_perturbed,
261
+ is_deg=is_deg,
262
+ n_degs=int(np.sum(is_deg)),
263
+ method=method,
264
+ pval_threshold=pval_threshold,
265
+ lfc_threshold=lfc_threshold,
266
+ )
267
+
268
+
269
+ def _welch_ttest_vectorized(x: np.ndarray, y: np.ndarray) -> np.ndarray:
270
+ """
271
+ Vectorized Welch's t-test across all genes simultaneously.
272
+
273
+ Much faster than scipy.stats.ttest_ind for many genes.
274
+ """
275
+ n1, n2 = x.shape[0], y.shape[0]
276
+
277
+ # Sample means
278
+ mean1 = np.mean(x, axis=0)
279
+ mean2 = np.mean(y, axis=0)
280
+
281
+ # Sample variances (unbiased)
282
+ var1 = np.var(x, axis=0, ddof=1)
283
+ var2 = np.var(y, axis=0, ddof=1)
284
+
285
+ # Standard error
286
+ se = np.sqrt(var1 / n1 + var2 / n2)
287
+
288
+ # T-statistic
289
+ with np.errstate(divide='ignore', invalid='ignore'):
290
+ t_stat = (mean1 - mean2) / se
291
+
292
+ # Welch-Satterthwaite degrees of freedom
293
+ with np.errstate(divide='ignore', invalid='ignore'):
294
+ num = (var1 / n1 + var2 / n2) ** 2
295
+ denom = (var1 / n1) ** 2 / (n1 - 1) + (var2 / n2) ** 2 / (n2 - 1)
296
+ df = num / denom
297
+
298
+ # Handle edge cases
299
+ df = np.clip(df, 1, np.inf)
300
+ df = np.nan_to_num(df, nan=1.0)
301
+
302
+ # Two-tailed p-value using scipy (still fast for vectorized computation)
303
+ if HAS_SCIPY:
304
+ pvalues = 2 * stats.t.sf(np.abs(t_stat), df)
305
+ else:
306
+ # Fallback: approximate p-value using normal distribution for large df
307
+ pvalues = 2 * (1 - _normal_cdf(np.abs(t_stat)))
308
+
309
+ return np.nan_to_num(pvalues, nan=1.0)
310
+
311
+
312
+ def _student_ttest_vectorized(x: np.ndarray, y: np.ndarray) -> np.ndarray:
313
+ """
314
+ Vectorized Student's t-test (equal variance assumption).
315
+ """
316
+ n1, n2 = x.shape[0], y.shape[0]
317
+
318
+ # Sample means
319
+ mean1 = np.mean(x, axis=0)
320
+ mean2 = np.mean(y, axis=0)
321
+
322
+ # Pooled variance
323
+ var1 = np.var(x, axis=0, ddof=1)
324
+ var2 = np.var(y, axis=0, ddof=1)
325
+ pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2)
326
+
327
+ # Standard error
328
+ se = np.sqrt(pooled_var * (1/n1 + 1/n2))
329
+
330
+ # T-statistic
331
+ with np.errstate(divide='ignore', invalid='ignore'):
332
+ t_stat = (mean1 - mean2) / se
333
+
334
+ # Degrees of freedom
335
+ df = n1 + n2 - 2
336
+
337
+ # Two-tailed p-value
338
+ if HAS_SCIPY:
339
+ pvalues = 2 * stats.t.sf(np.abs(t_stat), df)
340
+ else:
341
+ pvalues = 2 * (1 - _normal_cdf(np.abs(t_stat)))
342
+
343
+ return np.nan_to_num(pvalues, nan=1.0)
344
+
345
+
346
+ def _wilcoxon_vectorized(
347
+ x: np.ndarray,
348
+ y: np.ndarray,
349
+ n_jobs: int = 1
350
+ ) -> np.ndarray:
351
+ """
352
+ Wilcoxon rank-sum test with optional parallelization.
353
+
354
+ Note: This is slower than t-tests but more robust for non-normal data.
355
+ """
356
+ if not HAS_SCIPY:
357
+ raise ImportError("scipy is required for Wilcoxon test")
358
+
359
+ n_genes = x.shape[1]
360
+
361
+ if HAS_JOBLIB and n_jobs != 1:
362
+ # Parallel computation
363
+ def _compute_pval(i):
364
+ try:
365
+ _, pval = mannwhitneyu(x[:, i], y[:, i], alternative='two-sided')
366
+ return pval
367
+ except Exception:
368
+ return 1.0
369
+
370
+ pvalues = Parallel(n_jobs=n_jobs)(
371
+ delayed(_compute_pval)(i) for i in range(n_genes)
372
+ )
373
+ return np.array(pvalues)
374
+ else:
375
+ # Sequential computation
376
+ pvalues = np.zeros(n_genes)
377
+ for i in range(n_genes):
378
+ try:
379
+ _, pvalues[i] = mannwhitneyu(x[:, i], y[:, i], alternative='two-sided')
380
+ except Exception:
381
+ pvalues[i] = 1.0
382
+ return pvalues
383
+
384
+
385
+ def _normal_cdf(x: np.ndarray) -> np.ndarray:
386
+ """Approximate normal CDF without scipy."""
387
+ return 0.5 * (1 + np.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))
388
+
389
+
390
+ def compute_degs_gpu(
391
+ control: np.ndarray,
392
+ perturbed: np.ndarray,
393
+ gene_names: Optional[np.ndarray] = None,
394
+ method: DEGMethod = "welch",
395
+ pval_threshold: float = 0.05,
396
+ lfc_threshold: float = 0.5,
397
+ use_adjusted_pval: bool = True,
398
+ device: str = "cuda",
399
+ ) -> DEGResult:
400
+ """
401
+ GPU-accelerated DEG detection using PyTorch.
402
+
403
+ Parameters
404
+ ----------
405
+ control : np.ndarray
406
+ Control expression matrix (n_samples_control, n_genes)
407
+ perturbed : np.ndarray
408
+ Perturbed expression matrix (n_samples_perturbed, n_genes)
409
+ gene_names : np.ndarray, optional
410
+ Gene names. If None, uses indices.
411
+ method : str
412
+ Statistical test: "welch" or "student" (wilcoxon not supported on GPU)
413
+ pval_threshold : float
414
+ P-value threshold for significance
415
+ lfc_threshold : float
416
+ Absolute log2 fold change threshold
417
+ use_adjusted_pval : bool
418
+ If True, use adjusted p-values (BH correction)
419
+ device : str
420
+ GPU device: "cuda", "cuda:0", "mps", etc.
421
+
422
+ Returns
423
+ -------
424
+ DEGResult
425
+ DEG detection results
426
+ """
427
+ if not HAS_TORCH:
428
+ warnings.warn("PyTorch not available, falling back to CPU")
429
+ return compute_degs_fast(
430
+ control, perturbed, gene_names, method,
431
+ pval_threshold, lfc_threshold, use_adjusted_pval
432
+ )
433
+
434
+ if method == "wilcoxon":
435
+ warnings.warn("Wilcoxon test not supported on GPU, falling back to CPU")
436
+ return compute_degs_fast(
437
+ control, perturbed, gene_names, method,
438
+ pval_threshold, lfc_threshold, use_adjusted_pval
439
+ )
440
+
441
+ # Get device
442
+ if device == "auto":
443
+ if torch.cuda.is_available():
444
+ device = "cuda"
445
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
446
+ device = "mps"
447
+ else:
448
+ device = "cpu"
449
+
450
+ torch_device = torch.device(device)
451
+ n_genes = control.shape[1]
452
+
453
+ # Gene names
454
+ if gene_names is None:
455
+ gene_names = np.array([f"Gene_{i}" for i in range(n_genes)])
456
+
457
+ # Move data to GPU
458
+ x = torch.tensor(control, dtype=torch.float32, device=torch_device)
459
+ y = torch.tensor(perturbed, dtype=torch.float32, device=torch_device)
460
+
461
+ n1, n2 = x.shape[0], y.shape[0]
462
+
463
+ # Compute means
464
+ mean_control = x.mean(dim=0)
465
+ mean_perturbed = y.mean(dim=0)
466
+
467
+ # Log fold change (use pseudocount of 1 for stability)
468
+ eps = 1.0
469
+ log_fold_changes = torch.log2((mean_perturbed + eps) / (mean_control + eps))
470
+
471
+ if method == "logfc":
472
+ pvalues = torch.full((n_genes,), float('nan'), device=torch_device)
473
+ pvalues_adj = pvalues.clone()
474
+ else:
475
+ # Compute variances
476
+ var1 = x.var(dim=0, unbiased=True)
477
+ var2 = y.var(dim=0, unbiased=True)
478
+
479
+ if method == "welch":
480
+ # Welch's t-test
481
+ se = torch.sqrt(var1 / n1 + var2 / n2)
482
+ t_stat = (mean_control - mean_perturbed) / (se + 1e-10)
483
+
484
+ # Welch-Satterthwaite degrees of freedom
485
+ num = (var1 / n1 + var2 / n2) ** 2
486
+ denom = (var1 / n1) ** 2 / (n1 - 1) + (var2 / n2) ** 2 / (n2 - 1)
487
+ df = num / (denom + 1e-10)
488
+ df = torch.clamp(df, min=1.0)
489
+
490
+ else: # student
491
+ # Student's t-test
492
+ pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2)
493
+ se = torch.sqrt(pooled_var * (1/n1 + 1/n2))
494
+ t_stat = (mean_control - mean_perturbed) / (se + 1e-10)
495
+ df = torch.full((n_genes,), n1 + n2 - 2, device=torch_device, dtype=torch.float32)
496
+
497
+ # Move to CPU for p-value computation (scipy needed for t-distribution)
498
+ t_stat_np = torch.abs(t_stat).cpu().numpy()
499
+ df_np = df.cpu().numpy()
500
+
501
+ if HAS_SCIPY:
502
+ pvalues_np = 2 * stats.t.sf(t_stat_np, df_np)
503
+ else:
504
+ pvalues_np = 2 * (1 - _normal_cdf(t_stat_np))
505
+
506
+ pvalues_np = np.nan_to_num(pvalues_np, nan=1.0).astype(np.float32)
507
+ pvalues_adj_np = _benjamini_hochberg(pvalues_np).astype(np.float32)
508
+
509
+ pvalues = torch.tensor(pvalues_np, device=torch_device, dtype=torch.float32)
510
+ pvalues_adj = torch.tensor(pvalues_adj_np, device=torch_device, dtype=torch.float32)
511
+
512
+ # Determine significant DEGs
513
+ lfc_abs = torch.abs(log_fold_changes)
514
+ if method == "logfc":
515
+ is_deg = lfc_abs > lfc_threshold
516
+ else:
517
+ pval_test = pvalues_adj if use_adjusted_pval else pvalues
518
+ is_deg = (pval_test < pval_threshold) & (lfc_abs > lfc_threshold)
519
+
520
+ # Move results to CPU
521
+ return DEGResult(
522
+ gene_names=gene_names,
523
+ pvalues=pvalues.cpu().numpy(),
524
+ pvalues_adj=pvalues_adj.cpu().numpy(),
525
+ log_fold_changes=log_fold_changes.cpu().numpy(),
526
+ mean_control=mean_control.cpu().numpy(),
527
+ mean_perturbed=mean_perturbed.cpu().numpy(),
528
+ is_deg=is_deg.cpu().numpy(),
529
+ n_degs=int(is_deg.sum().item()),
530
+ method=method,
531
+ pval_threshold=pval_threshold,
532
+ lfc_threshold=lfc_threshold,
533
+ )
534
+
535
+
536
+ def compute_degs_auto(
537
+ control: np.ndarray,
538
+ perturbed: np.ndarray,
539
+ gene_names: Optional[np.ndarray] = None,
540
+ method: DEGMethod = "welch",
541
+ pval_threshold: float = 0.05,
542
+ lfc_threshold: float = 0.5,
543
+ use_adjusted_pval: bool = True,
544
+ n_jobs: int = 1,
545
+ device: str = "auto",
546
+ ) -> DEGResult:
547
+ """
548
+ Automatically select the fastest DEG computation method.
549
+
550
+ Chooses GPU if available and data is large enough to benefit,
551
+ otherwise uses CPU with optional parallelization.
552
+ """
553
+ n_genes = control.shape[1]
554
+ n_samples = control.shape[0] + perturbed.shape[0]
555
+
556
+ # Use GPU for large datasets
557
+ use_gpu = False
558
+ if device != "cpu" and HAS_TORCH:
559
+ if device == "auto":
560
+ if torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
561
+ # GPU worthwhile for >1000 genes or >1000 samples
562
+ if n_genes > 1000 or n_samples > 1000:
563
+ use_gpu = True
564
+ else:
565
+ use_gpu = True
566
+
567
+ if use_gpu:
568
+ return compute_degs_gpu(
569
+ control, perturbed, gene_names, method,
570
+ pval_threshold, lfc_threshold, use_adjusted_pval,
571
+ device=device if device != "auto" else "auto",
572
+ )
573
+ else:
574
+ return compute_degs_fast(
575
+ control, perturbed, gene_names, method,
576
+ pval_threshold, lfc_threshold, use_adjusted_pval,
577
+ n_jobs=n_jobs,
578
+ )