gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,857 @@
1
+ """
2
+ Accelerated metric computation with CPU parallelization and GPU support.
3
+
4
+ This module provides performance optimizations for metric computation:
5
+ - CPU parallelization via joblib for multi-core speedup
6
+ - GPU acceleration via PyTorch/geomloss for batch computation
7
+ - Vectorized operations for improved NumPy performance
8
+
9
+ Example usage:
10
+ >>> from geneval.metrics.accelerated import ParallelMetricComputer
11
+ >>> computer = ParallelMetricComputer(n_jobs=8, device="cuda")
12
+ >>> results = computer.compute_all(real, generated, metrics)
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import warnings
17
+ from typing import List, Optional, Dict, Any, Union, Literal
18
+ from dataclasses import dataclass
19
+ import numpy as np
20
+
21
+ from .base_metric import BaseMetric, MetricResult
22
+
23
+
24
+ # Check for optional dependencies
25
+ try:
26
+ from joblib import Parallel, delayed
27
+ HAS_JOBLIB = True
28
+ except ImportError:
29
+ HAS_JOBLIB = False
30
+
31
+ try:
32
+ import torch
33
+ HAS_TORCH = True
34
+ except ImportError:
35
+ HAS_TORCH = False
36
+
37
+ try:
38
+ from geomloss import SamplesLoss
39
+ HAS_GEOMLOSS = True
40
+ except ImportError:
41
+ HAS_GEOMLOSS = False
42
+
43
+
44
+ @dataclass
45
+ class AccelerationConfig:
46
+ """Configuration for accelerated metric computation.
47
+
48
+ Attributes
49
+ ----------
50
+ n_jobs : int
51
+ Number of CPU jobs for parallel computation.
52
+ -1 uses all available cores. Default is 1 (no parallelization).
53
+ device : str
54
+ Device for computation: "cpu", "cuda", "cuda:0", etc.
55
+ Default is "cpu".
56
+ batch_genes : bool
57
+ If True, batch all genes for GPU computation. Default is True.
58
+ gene_batch_size : int or None
59
+ If set, process genes in batches of this size to manage memory.
60
+ None means process all genes at once.
61
+ prefer_gpu : bool
62
+ If True and GPU is available, prefer GPU implementations.
63
+ Default is True.
64
+ verbose : bool
65
+ Print acceleration info. Default is False.
66
+ """
67
+ n_jobs: int = 1
68
+ device: str = "cpu"
69
+ batch_genes: bool = True
70
+ gene_batch_size: Optional[int] = None
71
+ prefer_gpu: bool = True
72
+ verbose: bool = False
73
+
74
+
75
+ def get_available_backends() -> Dict[str, bool]:
76
+ """Check which acceleration backends are available.
77
+
78
+ Returns
79
+ -------
80
+ Dict[str, bool]
81
+ Dictionary with backend availability.
82
+ """
83
+ backends = {
84
+ "joblib": HAS_JOBLIB,
85
+ "torch": HAS_TORCH,
86
+ "geomloss": HAS_GEOMLOSS,
87
+ "cuda": HAS_TORCH and torch.cuda.is_available(),
88
+ "mps": HAS_TORCH and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
89
+ }
90
+ return backends
91
+
92
+
93
+ def _get_device(device: str) -> "torch.device":
94
+ """Get PyTorch device, handling availability checks.
95
+
96
+ Parameters
97
+ ----------
98
+ device : str
99
+ Device string ("cpu", "cuda", "cuda:0", "mps", "auto")
100
+
101
+ Returns
102
+ -------
103
+ torch.device
104
+ PyTorch device object
105
+ """
106
+ if not HAS_TORCH:
107
+ raise ImportError("PyTorch is required for GPU acceleration")
108
+
109
+ if device == "auto":
110
+ if torch.cuda.is_available():
111
+ return torch.device("cuda")
112
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
113
+ return torch.device("mps")
114
+ else:
115
+ return torch.device("cpu")
116
+
117
+ return torch.device(device)
118
+
119
+
120
+ class ParallelMetricComputer:
121
+ """Parallel and GPU-accelerated metric computation.
122
+
123
+ This class wraps metric computation with parallelization and GPU
124
+ acceleration options for significant speedups on large datasets.
125
+
126
+ Parameters
127
+ ----------
128
+ n_jobs : int
129
+ Number of parallel jobs. -1 for all cores.
130
+ device : str
131
+ Compute device ("cpu", "cuda", "auto")
132
+ batch_genes : bool
133
+ Whether to batch genes for GPU computation.
134
+ gene_batch_size : int, optional
135
+ Process genes in chunks of this size.
136
+ verbose : bool
137
+ Print progress information.
138
+
139
+ Examples
140
+ --------
141
+ >>> computer = ParallelMetricComputer(n_jobs=8)
142
+ >>> results = computer.compute_metric(metric, real, generated)
143
+
144
+ >>> # GPU acceleration
145
+ >>> computer = ParallelMetricComputer(device="cuda")
146
+ >>> results = computer.compute_metric(metric, real, generated)
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ n_jobs: int = 1,
152
+ device: str = "cpu",
153
+ batch_genes: bool = True,
154
+ gene_batch_size: Optional[int] = None,
155
+ verbose: bool = False,
156
+ ):
157
+ self.n_jobs = n_jobs
158
+ self.device = device
159
+ self.batch_genes = batch_genes
160
+ self.gene_batch_size = gene_batch_size
161
+ self.verbose = verbose
162
+
163
+ # Validate configuration
164
+ if n_jobs != 1 and not HAS_JOBLIB:
165
+ warnings.warn("joblib not available, falling back to sequential processing")
166
+ self.n_jobs = 1
167
+
168
+ if device != "cpu" and not HAS_TORCH:
169
+ warnings.warn("PyTorch not available, falling back to CPU")
170
+ self.device = "cpu"
171
+
172
+ if self.verbose:
173
+ backends = get_available_backends()
174
+ print(f"Acceleration backends: {backends}")
175
+ print(f"Using n_jobs={self.n_jobs}, device={self.device}")
176
+
177
+ def compute_metric_parallel(
178
+ self,
179
+ metric: BaseMetric,
180
+ real: np.ndarray,
181
+ generated: np.ndarray,
182
+ gene_names: Optional[List[str]] = None,
183
+ ) -> MetricResult:
184
+ """Compute a metric with CPU parallelization.
185
+
186
+ Splits genes across multiple CPU cores for parallel computation.
187
+
188
+ Parameters
189
+ ----------
190
+ metric : BaseMetric
191
+ Metric to compute
192
+ real : np.ndarray
193
+ Real data, shape (n_samples, n_genes)
194
+ generated : np.ndarray
195
+ Generated data, shape (n_samples, n_genes)
196
+ gene_names : List[str], optional
197
+ Gene names
198
+
199
+ Returns
200
+ -------
201
+ MetricResult
202
+ Computed metric result
203
+ """
204
+ n_genes = real.shape[1]
205
+ if gene_names is None:
206
+ gene_names = [f"gene_{i}" for i in range(n_genes)]
207
+
208
+ if self.n_jobs == 1 or not HAS_JOBLIB:
209
+ # Sequential computation
210
+ per_gene = metric.compute_per_gene(real, generated)
211
+ else:
212
+ # Parallel computation across genes
213
+ if self.gene_batch_size:
214
+ # Process in batches
215
+ batches = [
216
+ (i, min(i + self.gene_batch_size, n_genes))
217
+ for i in range(0, n_genes, self.gene_batch_size)
218
+ ]
219
+ else:
220
+ # Split evenly across jobs
221
+ n_effective_jobs = min(self.n_jobs if self.n_jobs > 0 else 8, n_genes)
222
+ batch_size = max(1, n_genes // n_effective_jobs)
223
+ batches = [
224
+ (i, min(i + batch_size, n_genes))
225
+ for i in range(0, n_genes, batch_size)
226
+ ]
227
+
228
+ def compute_batch(start: int, end: int) -> np.ndarray:
229
+ return metric.compute_per_gene(
230
+ real[:, start:end],
231
+ generated[:, start:end]
232
+ )
233
+
234
+ results = Parallel(n_jobs=self.n_jobs, prefer="threads")(
235
+ delayed(compute_batch)(start, end) for start, end in batches
236
+ )
237
+
238
+ per_gene = np.concatenate(results)
239
+
240
+ aggregate = metric.compute_aggregate(per_gene, method="mean")
241
+
242
+ return MetricResult(
243
+ name=metric.name,
244
+ per_gene_values=per_gene,
245
+ gene_names=gene_names,
246
+ aggregate_value=aggregate,
247
+ aggregate_method="mean",
248
+ metadata={
249
+ "higher_is_better": metric.higher_is_better,
250
+ "accelerated": True,
251
+ "n_jobs": self.n_jobs,
252
+ }
253
+ )
254
+
255
+
256
+ # =============================================================================
257
+ # GPU-Accelerated Distance Metrics
258
+ # =============================================================================
259
+
260
+ class GPUWasserstein1:
261
+ """GPU-accelerated Wasserstein-1 distance computation.
262
+
263
+ Computes W1 distance for all genes in parallel on GPU using
264
+ vectorized sorting and quantile interpolation.
265
+ """
266
+
267
+ def __init__(self, device: str = "cuda"):
268
+ if not HAS_TORCH:
269
+ raise ImportError("PyTorch required for GPU acceleration")
270
+ self.device = _get_device(device)
271
+
272
+ def compute_batch(
273
+ self,
274
+ real: np.ndarray,
275
+ generated: np.ndarray,
276
+ ) -> np.ndarray:
277
+ """Compute W1 for all genes in batch on GPU.
278
+
279
+ Parameters
280
+ ----------
281
+ real : np.ndarray
282
+ Real data, shape (n_samples_real, n_genes)
283
+ generated : np.ndarray
284
+ Generated data, shape (n_samples_gen, n_genes)
285
+
286
+ Returns
287
+ -------
288
+ np.ndarray
289
+ W1 distance per gene
290
+ """
291
+ # Move to GPU
292
+ real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
293
+ gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
294
+
295
+ n_genes = real_t.shape[1]
296
+ n_quantiles = max(real_t.shape[0], gen_t.shape[0])
297
+
298
+ # Sort each gene column
299
+ real_sorted, _ = torch.sort(real_t, dim=0)
300
+ gen_sorted, _ = torch.sort(gen_t, dim=0)
301
+
302
+ # Interpolate to same number of quantiles
303
+ quantile_positions = torch.linspace(0, 1, n_quantiles, device=self.device)
304
+
305
+ # Interpolate real
306
+ real_indices = quantile_positions * (real_sorted.shape[0] - 1)
307
+ real_floor = real_indices.long().clamp(0, real_sorted.shape[0] - 2)
308
+ real_frac = (real_indices - real_floor.float()).unsqueeze(1)
309
+ real_interp = (
310
+ real_sorted[real_floor] * (1 - real_frac) +
311
+ real_sorted[real_floor + 1] * real_frac
312
+ )
313
+
314
+ # Interpolate generated
315
+ gen_indices = quantile_positions * (gen_sorted.shape[0] - 1)
316
+ gen_floor = gen_indices.long().clamp(0, gen_sorted.shape[0] - 2)
317
+ gen_frac = (gen_indices - gen_floor.float()).unsqueeze(1)
318
+ gen_interp = (
319
+ gen_sorted[gen_floor] * (1 - gen_frac) +
320
+ gen_sorted[gen_floor + 1] * gen_frac
321
+ )
322
+
323
+ # W1 = mean absolute difference
324
+ w1 = torch.mean(torch.abs(real_interp - gen_interp), dim=0)
325
+
326
+ return w1.cpu().numpy()
327
+
328
+
329
+ class GPUWasserstein2:
330
+ """GPU-accelerated Wasserstein-2 distance using geomloss.
331
+
332
+ Batches all genes together for efficient GPU computation.
333
+ """
334
+
335
+ def __init__(self, device: str = "cuda", blur: float = 0.01):
336
+ if not HAS_TORCH:
337
+ raise ImportError("PyTorch required for GPU acceleration")
338
+ if not HAS_GEOMLOSS:
339
+ raise ImportError("geomloss required for Wasserstein-2 GPU acceleration")
340
+
341
+ self.device = _get_device(device)
342
+ self.blur = blur
343
+ self.loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=blur, backend="tensorized")
344
+
345
+ def compute_batch(
346
+ self,
347
+ real: np.ndarray,
348
+ generated: np.ndarray,
349
+ ) -> np.ndarray:
350
+ """Compute W2 for all genes in batch on GPU.
351
+
352
+ Parameters
353
+ ----------
354
+ real : np.ndarray
355
+ Real data, shape (n_samples_real, n_genes)
356
+ generated : np.ndarray
357
+ Generated data, shape (n_samples_gen, n_genes)
358
+
359
+ Returns
360
+ -------
361
+ np.ndarray
362
+ W2 distance per gene
363
+ """
364
+ n_genes = real.shape[1]
365
+
366
+ # Move to GPU
367
+ real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
368
+ gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
369
+
370
+ distances = torch.zeros(n_genes, device=self.device)
371
+
372
+ # Process each gene (geomloss requires separate calls per distribution pair)
373
+ # But we can batch by treating genes as batch dimension
374
+ for i in range(n_genes):
375
+ r = real_t[:, i:i+1] # Keep 2D
376
+ g = gen_t[:, i:i+1]
377
+ distances[i] = self.loss_fn(r, g)
378
+
379
+ return distances.cpu().numpy()
380
+
381
+
382
+ class GPUMMD:
383
+ """GPU-accelerated MMD computation with RBF kernel.
384
+
385
+ Uses PyTorch for vectorized kernel computation across all genes.
386
+ """
387
+
388
+ def __init__(self, device: str = "cuda", sigma: Optional[float] = None):
389
+ if not HAS_TORCH:
390
+ raise ImportError("PyTorch required for GPU acceleration")
391
+
392
+ self.device = _get_device(device)
393
+ self.sigma = sigma
394
+
395
+ def compute_batch(
396
+ self,
397
+ real: np.ndarray,
398
+ generated: np.ndarray,
399
+ ) -> np.ndarray:
400
+ """Compute MMD for all genes in batch on GPU.
401
+
402
+ Parameters
403
+ ----------
404
+ real : np.ndarray
405
+ Real data, shape (n_samples_real, n_genes)
406
+ generated : np.ndarray
407
+ Generated data, shape (n_samples_gen, n_genes)
408
+
409
+ Returns
410
+ -------
411
+ np.ndarray
412
+ MMD per gene
413
+ """
414
+ real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
415
+ gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
416
+
417
+ n_genes = real_t.shape[1]
418
+ n_x, n_y = real_t.shape[0], gen_t.shape[0]
419
+
420
+ mmd_values = torch.zeros(n_genes, device=self.device)
421
+
422
+ for g in range(n_genes):
423
+ x = real_t[:, g:g+1]
424
+ y = gen_t[:, g:g+1]
425
+
426
+ # Median heuristic for sigma
427
+ if self.sigma is None:
428
+ combined = torch.cat([x, y], dim=0)
429
+ pairwise = torch.abs(combined - combined.T)
430
+ sigma = torch.median(pairwise[pairwise > 0]).item()
431
+ if sigma == 0:
432
+ sigma = 1.0
433
+ else:
434
+ sigma = self.sigma
435
+
436
+ # RBF kernel
437
+ def rbf(a, b, s):
438
+ sq_dist = (a - b.T) ** 2
439
+ return torch.exp(-sq_dist / (2 * s ** 2))
440
+
441
+ K_xx = rbf(x, x, sigma)
442
+ K_yy = rbf(y, y, sigma)
443
+ K_xy = rbf(x, y, sigma)
444
+
445
+ # Unbiased MMD
446
+ mmd = (
447
+ (K_xx.sum() - K_xx.trace()) / (n_x * (n_x - 1)) +
448
+ (K_yy.sum() - K_yy.trace()) / (n_y * (n_y - 1)) -
449
+ 2 * K_xy.sum() / (n_x * n_y)
450
+ )
451
+
452
+ mmd_values[g] = torch.clamp(mmd, min=0)
453
+
454
+ return mmd_values.cpu().numpy()
455
+
456
+
457
+ class GPUEnergyDistance:
458
+ """GPU-accelerated Energy distance computation."""
459
+
460
+ def __init__(self, device: str = "cuda"):
461
+ if not HAS_TORCH:
462
+ raise ImportError("PyTorch required for GPU acceleration")
463
+
464
+ self.device = _get_device(device)
465
+
466
+ def compute_batch(
467
+ self,
468
+ real: np.ndarray,
469
+ generated: np.ndarray,
470
+ ) -> np.ndarray:
471
+ """Compute Energy distance for all genes in batch on GPU.
472
+
473
+ Parameters
474
+ ----------
475
+ real : np.ndarray
476
+ Real data, shape (n_samples_real, n_genes)
477
+ generated : np.ndarray
478
+ Generated data, shape (n_samples_gen, n_genes)
479
+
480
+ Returns
481
+ -------
482
+ np.ndarray
483
+ Energy distance per gene
484
+ """
485
+ real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
486
+ gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
487
+
488
+ n_genes = real_t.shape[1]
489
+
490
+ energy_values = torch.zeros(n_genes, device=self.device)
491
+
492
+ for g in range(n_genes):
493
+ x = real_t[:, g]
494
+ y = gen_t[:, g]
495
+
496
+ # E[|X - Y|]
497
+ xy_dist = torch.mean(torch.abs(x.unsqueeze(1) - y.unsqueeze(0)))
498
+
499
+ # E[|X - X'|]
500
+ xx_dist = torch.mean(torch.abs(x.unsqueeze(1) - x.unsqueeze(0)))
501
+
502
+ # E[|Y - Y'|]
503
+ yy_dist = torch.mean(torch.abs(y.unsqueeze(1) - y.unsqueeze(0)))
504
+
505
+ energy = 2 * xy_dist - xx_dist - yy_dist
506
+ energy_values[g] = torch.clamp(energy, min=0)
507
+
508
+ return energy_values.cpu().numpy()
509
+
510
+
511
+ # =============================================================================
512
+ # Vectorized NumPy Implementations (for CPU speedup without joblib)
513
+ # =============================================================================
514
+
515
+ def vectorized_wasserstein1(
516
+ real: np.ndarray,
517
+ generated: np.ndarray,
518
+ ) -> np.ndarray:
519
+ """Compute W1 for all genes using vectorized NumPy.
520
+
521
+ This is faster than the loop-based scipy implementation.
522
+
523
+ Parameters
524
+ ----------
525
+ real : np.ndarray
526
+ Real data, shape (n_samples_real, n_genes)
527
+ generated : np.ndarray
528
+ Generated data, shape (n_samples_gen, n_genes)
529
+
530
+ Returns
531
+ -------
532
+ np.ndarray
533
+ W1 distance per gene
534
+ """
535
+ n_genes = real.shape[1]
536
+ n_quantiles = max(real.shape[0], generated.shape[0])
537
+
538
+ # Sort each column
539
+ real_sorted = np.sort(real, axis=0)
540
+ gen_sorted = np.sort(generated, axis=0)
541
+
542
+ # Interpolate to same number of quantiles
543
+ real_positions = np.linspace(0, 1, real_sorted.shape[0])
544
+ gen_positions = np.linspace(0, 1, gen_sorted.shape[0])
545
+ target_positions = np.linspace(0, 1, n_quantiles)
546
+
547
+ # Interpolate each gene column
548
+ real_interp = np.zeros((n_quantiles, n_genes))
549
+ gen_interp = np.zeros((n_quantiles, n_genes))
550
+
551
+ for g in range(n_genes):
552
+ real_interp[:, g] = np.interp(target_positions, real_positions, real_sorted[:, g])
553
+ gen_interp[:, g] = np.interp(target_positions, gen_positions, gen_sorted[:, g])
554
+
555
+ # W1 = mean absolute difference
556
+ return np.mean(np.abs(real_interp - gen_interp), axis=0)
557
+
558
+
559
+ def vectorized_mmd(
560
+ real: np.ndarray,
561
+ generated: np.ndarray,
562
+ sigma: Optional[float] = None,
563
+ ) -> np.ndarray:
564
+ """Compute MMD for all genes using vectorized NumPy.
565
+
566
+ Parameters
567
+ ----------
568
+ real : np.ndarray
569
+ Real data, shape (n_samples_real, n_genes)
570
+ generated : np.ndarray
571
+ Generated data, shape (n_samples_gen, n_genes)
572
+ sigma : float, optional
573
+ Kernel bandwidth. Uses median heuristic if None.
574
+
575
+ Returns
576
+ -------
577
+ np.ndarray
578
+ MMD per gene
579
+ """
580
+ n_genes = real.shape[1]
581
+ n_x, n_y = real.shape[0], generated.shape[0]
582
+
583
+ mmd_values = np.zeros(n_genes)
584
+
585
+ for g in range(n_genes):
586
+ x = real[:, g:g+1]
587
+ y = generated[:, g:g+1]
588
+
589
+ # Median heuristic
590
+ if sigma is None:
591
+ combined = np.vstack([x, y])
592
+ pairwise = np.abs(combined - combined.T)
593
+ s = float(np.median(pairwise[pairwise > 0]))
594
+ if s == 0:
595
+ s = 1.0
596
+ else:
597
+ s = sigma
598
+
599
+ # RBF kernel
600
+ K_xx = np.exp(-(x - x.T) ** 2 / (2 * s ** 2))
601
+ K_yy = np.exp(-(y - y.T) ** 2 / (2 * s ** 2))
602
+ K_xy = np.exp(-(x - y.T) ** 2 / (2 * s ** 2))
603
+
604
+ # Unbiased MMD
605
+ mmd = (
606
+ (np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
607
+ (np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
608
+ 2 * np.sum(K_xy) / (n_x * n_y)
609
+ )
610
+
611
+ mmd_values[g] = max(0, mmd)
612
+
613
+ return mmd_values
614
+
615
+
616
+ # =============================================================================
617
+ # High-Level Accelerated Evaluation Interface
618
+ # =============================================================================
619
+
620
+ def compute_metrics_accelerated(
621
+ real: np.ndarray,
622
+ generated: np.ndarray,
623
+ metrics: List[str] = ["wasserstein_1", "wasserstein_2", "mmd", "energy"],
624
+ n_jobs: int = 1,
625
+ device: str = "cpu",
626
+ gene_names: Optional[List[str]] = None,
627
+ verbose: bool = False,
628
+ ) -> Dict[str, MetricResult]:
629
+ """Compute multiple metrics with acceleration.
630
+
631
+ This is the main entry point for accelerated metric computation.
632
+ Automatically selects the best available backend.
633
+
634
+ Parameters
635
+ ----------
636
+ real : np.ndarray
637
+ Real data, shape (n_samples_real, n_genes)
638
+ generated : np.ndarray
639
+ Generated data, shape (n_samples_gen, n_genes)
640
+ metrics : List[str]
641
+ Metrics to compute: "wasserstein_1", "wasserstein_2", "mmd", "energy"
642
+ n_jobs : int
643
+ Number of CPU jobs (-1 for all cores)
644
+ device : str
645
+ Compute device ("cpu", "cuda", "auto")
646
+ gene_names : List[str], optional
647
+ Gene names
648
+ verbose : bool
649
+ Print progress
650
+
651
+ Returns
652
+ -------
653
+ Dict[str, MetricResult]
654
+ Dictionary of metric results
655
+ """
656
+ backends = get_available_backends()
657
+
658
+ if device == "auto":
659
+ if backends["cuda"]:
660
+ device = "cuda"
661
+ elif backends["mps"]:
662
+ device = "mps"
663
+ else:
664
+ device = "cpu"
665
+
666
+ if verbose:
667
+ print(f"Using device: {device}, n_jobs: {n_jobs}")
668
+ print(f"Available backends: {backends}")
669
+
670
+ n_genes = real.shape[1]
671
+ if gene_names is None:
672
+ gene_names = [f"gene_{i}" for i in range(n_genes)]
673
+
674
+ results = {}
675
+
676
+ for metric_name in metrics:
677
+ if verbose:
678
+ print(f"Computing {metric_name}...")
679
+
680
+ if device != "cpu" and backends["torch"]:
681
+ # GPU path
682
+ if metric_name == "wasserstein_1":
683
+ gpu_metric = GPUWasserstein1(device=device)
684
+ per_gene = gpu_metric.compute_batch(real, generated)
685
+ elif metric_name == "wasserstein_2" and backends["geomloss"]:
686
+ gpu_metric = GPUWasserstein2(device=device)
687
+ per_gene = gpu_metric.compute_batch(real, generated)
688
+ elif metric_name == "mmd":
689
+ gpu_metric = GPUMMD(device=device)
690
+ per_gene = gpu_metric.compute_batch(real, generated)
691
+ elif metric_name == "energy":
692
+ gpu_metric = GPUEnergyDistance(device=device)
693
+ per_gene = gpu_metric.compute_batch(real, generated)
694
+ else:
695
+ # Fallback to vectorized CPU
696
+ per_gene = _compute_cpu_metric(metric_name, real, generated, n_jobs)
697
+ else:
698
+ # CPU path
699
+ per_gene = _compute_cpu_metric(metric_name, real, generated, n_jobs)
700
+
701
+ results[metric_name] = MetricResult(
702
+ name=metric_name,
703
+ per_gene_values=per_gene,
704
+ gene_names=gene_names,
705
+ aggregate_value=float(np.nanmean(per_gene)),
706
+ aggregate_method="mean",
707
+ metadata={
708
+ "device": device,
709
+ "n_jobs": n_jobs,
710
+ "accelerated": True,
711
+ }
712
+ )
713
+
714
+ return results
715
+
716
+
717
+ def _compute_cpu_metric(
718
+ metric_name: str,
719
+ real: np.ndarray,
720
+ generated: np.ndarray,
721
+ n_jobs: int,
722
+ ) -> np.ndarray:
723
+ """Compute metric on CPU with optional parallelization."""
724
+ if metric_name == "wasserstein_1":
725
+ if n_jobs != 1 and HAS_JOBLIB:
726
+ return _parallel_w1(real, generated, n_jobs)
727
+ else:
728
+ return vectorized_wasserstein1(real, generated)
729
+ elif metric_name == "wasserstein_2":
730
+ return _compute_w2_cpu(real, generated, n_jobs)
731
+ elif metric_name == "mmd":
732
+ if n_jobs != 1 and HAS_JOBLIB:
733
+ return _parallel_mmd(real, generated, n_jobs)
734
+ else:
735
+ return vectorized_mmd(real, generated)
736
+ elif metric_name == "energy":
737
+ return _compute_energy_cpu(real, generated, n_jobs)
738
+ else:
739
+ raise ValueError(f"Unknown metric: {metric_name}")
740
+
741
+
742
+ def _parallel_w1(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
743
+ """Parallel W1 computation."""
744
+ from scipy.stats import wasserstein_distance
745
+
746
+ n_genes = real.shape[1]
747
+
748
+ def compute_single(g):
749
+ r = real[:, g]
750
+ gen = generated[:, g]
751
+ r = r[~np.isnan(r)]
752
+ gen = gen[~np.isnan(gen)]
753
+ if len(r) == 0 or len(gen) == 0:
754
+ return np.nan
755
+ return wasserstein_distance(r, gen)
756
+
757
+ results = Parallel(n_jobs=n_jobs)(
758
+ delayed(compute_single)(g) for g in range(n_genes)
759
+ )
760
+
761
+ return np.array(results)
762
+
763
+
764
+ def _parallel_mmd(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
765
+ """Parallel MMD computation."""
766
+ n_genes = real.shape[1]
767
+
768
+ def compute_single(g):
769
+ x = real[:, g:g+1]
770
+ y = generated[:, g:g+1]
771
+
772
+ combined = np.vstack([x, y])
773
+ pairwise = np.abs(combined - combined.T)
774
+ sigma = float(np.median(pairwise[pairwise > 0]))
775
+ if sigma == 0:
776
+ sigma = 1.0
777
+
778
+ n_x, n_y = len(x), len(y)
779
+
780
+ K_xx = np.exp(-(x - x.T) ** 2 / (2 * sigma ** 2))
781
+ K_yy = np.exp(-(y - y.T) ** 2 / (2 * sigma ** 2))
782
+ K_xy = np.exp(-(x - y.T) ** 2 / (2 * sigma ** 2))
783
+
784
+ mmd = (
785
+ (np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
786
+ (np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
787
+ 2 * np.sum(K_xy) / (n_x * n_y)
788
+ )
789
+
790
+ return max(0, mmd)
791
+
792
+ results = Parallel(n_jobs=n_jobs)(
793
+ delayed(compute_single)(g) for g in range(n_genes)
794
+ )
795
+
796
+ return np.array(results)
797
+
798
+
799
+ def _compute_w2_cpu(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
800
+ """CPU W2 computation (quantile-based)."""
801
+ n_genes = real.shape[1]
802
+
803
+ def compute_single(g):
804
+ r = real[:, g]
805
+ gen = generated[:, g]
806
+
807
+ r = r[~np.isnan(r)]
808
+ gen = gen[~np.isnan(gen)]
809
+
810
+ if len(r) == 0 or len(gen) == 0:
811
+ return np.nan
812
+
813
+ r_sorted = np.sort(r)
814
+ g_sorted = np.sort(gen)
815
+
816
+ n = max(len(r_sorted), len(g_sorted))
817
+ r_q = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(r_sorted)), r_sorted)
818
+ g_q = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(g_sorted)), g_sorted)
819
+
820
+ return np.sqrt(np.mean((r_q - g_q) ** 2))
821
+
822
+ if n_jobs != 1 and HAS_JOBLIB:
823
+ results = Parallel(n_jobs=n_jobs)(
824
+ delayed(compute_single)(g) for g in range(n_genes)
825
+ )
826
+ return np.array(results)
827
+ else:
828
+ return np.array([compute_single(g) for g in range(n_genes)])
829
+
830
+
831
+ def _compute_energy_cpu(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
832
+ """CPU Energy distance computation."""
833
+ n_genes = real.shape[1]
834
+
835
+ def compute_single(g):
836
+ x = real[:, g]
837
+ y = generated[:, g]
838
+
839
+ x = x[~np.isnan(x)]
840
+ y = y[~np.isnan(y)]
841
+
842
+ if len(x) < 2 or len(y) < 2:
843
+ return np.nan
844
+
845
+ xy_dist = np.mean(np.abs(x[:, np.newaxis] - y[np.newaxis, :]))
846
+ xx_dist = np.mean(np.abs(x[:, np.newaxis] - x[np.newaxis, :]))
847
+ yy_dist = np.mean(np.abs(y[:, np.newaxis] - y[np.newaxis, :]))
848
+
849
+ return max(0, 2 * xy_dist - xx_dist - yy_dist)
850
+
851
+ if n_jobs != 1 and HAS_JOBLIB:
852
+ results = Parallel(n_jobs=n_jobs)(
853
+ delayed(compute_single)(g) for g in range(n_genes)
854
+ )
855
+ return np.array(results)
856
+ else:
857
+ return np.array([compute_single(g) for g in range(n_genes)])