combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/metrics.py ADDED
@@ -0,0 +1,788 @@
1
+ """Batch effect metrics and diagnostics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from scipy.spatial.distance import pdist
10
+ from scipy.stats import chi2, levene, spearmanr
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.metrics import davies_bouldin_score, silhouette_score
13
+ from sklearn.neighbors import NearestNeighbors
14
+
15
+ from .core import ArrayLike
16
+
17
+
18
+ def _compute_pca_embedding(
19
+ X_before: np.ndarray,
20
+ X_after: np.ndarray,
21
+ n_components: int,
22
+ ) -> tuple[np.ndarray, np.ndarray, PCA]:
23
+ """
24
+ Compute PCA embeddings for both datasets.
25
+
26
+ Fits PCA on X_before and applies to both datasets.
27
+
28
+ Parameters
29
+ ----------
30
+ X_before : np.ndarray
31
+ Original data before correction.
32
+ X_after : np.ndarray
33
+ Corrected data.
34
+ n_components : int
35
+ Number of PCA components.
36
+
37
+ Returns
38
+ -------
39
+ X_before_pca : np.ndarray
40
+ PCA-transformed original data.
41
+ X_after_pca : np.ndarray
42
+ PCA-transformed corrected data.
43
+ pca : PCA
44
+ Fitted PCA model.
45
+ """
46
+ n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
47
+ pca = PCA(n_components=n_components, random_state=42)
48
+ X_before_pca = pca.fit_transform(X_before)
49
+ X_after_pca = pca.transform(X_after)
50
+ return X_before_pca, X_after_pca, pca
51
+
52
+
53
+ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
54
+ """
55
+ Compute silhouette coefficient using batch as cluster labels.
56
+
57
+ Lower values after correction indicate better batch mixing.
58
+ Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
59
+
60
+ Parameters
61
+ ----------
62
+ X : np.ndarray
63
+ Data matrix.
64
+ batch_labels : np.ndarray
65
+ Batch labels for each sample.
66
+
67
+ Returns
68
+ -------
69
+ float
70
+ Silhouette coefficient.
71
+ """
72
+ unique_batches = np.unique(batch_labels)
73
+ if len(unique_batches) < 2:
74
+ return 0.0
75
+ try:
76
+ return silhouette_score(X, batch_labels, metric="euclidean")
77
+ except Exception:
78
+ return 0.0
79
+
80
+
81
+ def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
82
+ """
83
+ Compute Davies-Bouldin index using batch labels.
84
+
85
+ Lower values indicate better batch mixing.
86
+ Range: [0, inf), 0 = perfect batch overlap.
87
+
88
+ Parameters
89
+ ----------
90
+ X : np.ndarray
91
+ Data matrix.
92
+ batch_labels : np.ndarray
93
+ Batch labels for each sample.
94
+
95
+ Returns
96
+ -------
97
+ float
98
+ Davies-Bouldin index.
99
+ """
100
+ unique_batches = np.unique(batch_labels)
101
+ if len(unique_batches) < 2:
102
+ return 0.0
103
+ try:
104
+ return davies_bouldin_score(X, batch_labels)
105
+ except Exception:
106
+ return 0.0
107
+
108
+
109
+ def _kbet_score(
110
+ X: np.ndarray,
111
+ batch_labels: np.ndarray,
112
+ k0: int,
113
+ alpha: float = 0.05,
114
+ ) -> tuple[float, float]:
115
+ """
116
+ Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
117
+
118
+ Tests if local batch proportions match global batch proportions.
119
+ Higher acceptance rate = better batch mixing.
120
+
121
+ Reference: Buttner et al. (2019) Nature Methods
122
+
123
+ Parameters
124
+ ----------
125
+ X : np.ndarray
126
+ Data matrix.
127
+ batch_labels : np.ndarray
128
+ Batch labels for each sample.
129
+ k0 : int
130
+ Neighborhood size.
131
+ alpha : float
132
+ Significance level for chi-squared test.
133
+
134
+ Returns
135
+ -------
136
+ acceptance_rate : float
137
+ Fraction of samples where H0 (uniform mixing) is accepted.
138
+ mean_stat : float
139
+ Mean chi-squared statistic across samples.
140
+ """
141
+ n_samples = X.shape[0]
142
+ unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
143
+ n_batches = len(unique_batches)
144
+
145
+ if n_batches < 2:
146
+ return 1.0, 0.0
147
+
148
+ global_freq = batch_counts / n_samples
149
+ k0 = min(k0, n_samples - 1)
150
+
151
+ nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
152
+ nn.fit(X)
153
+ _, indices = nn.kneighbors(X)
154
+
155
+ chi2_stats = []
156
+ p_values = []
157
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
158
+
159
+ for i in range(n_samples):
160
+ neighbors = indices[i, 1 : k0 + 1]
161
+ neighbor_batches = batch_labels[neighbors]
162
+
163
+ observed = np.zeros(n_batches)
164
+ for nb in neighbor_batches:
165
+ observed[batch_to_idx[nb]] += 1
166
+
167
+ expected = global_freq * k0
168
+
169
+ mask = expected > 0
170
+ if mask.sum() < 2:
171
+ continue
172
+
173
+ stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
174
+ df = max(1, mask.sum() - 1)
175
+ p_val = 1 - chi2.cdf(stat, df)
176
+
177
+ chi2_stats.append(stat)
178
+ p_values.append(p_val)
179
+
180
+ if len(p_values) == 0:
181
+ return 1.0, 0.0
182
+
183
+ acceptance_rate = np.mean(np.array(p_values) > alpha)
184
+ mean_stat = np.mean(chi2_stats)
185
+
186
+ return acceptance_rate, mean_stat
187
+
188
+
189
+ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
190
+ """
191
+ Binary search for sigma to achieve target perplexity.
192
+
193
+ Used in LISI computation.
194
+
195
+ Parameters
196
+ ----------
197
+ distances : np.ndarray
198
+ Distances to neighbors.
199
+ target_perplexity : float
200
+ Target perplexity value.
201
+ tol : float
202
+ Tolerance for convergence.
203
+
204
+ Returns
205
+ -------
206
+ float
207
+ Sigma value.
208
+ """
209
+ target_H = np.log2(target_perplexity + 1e-10)
210
+
211
+ sigma_min, sigma_max = 1e-10, 1e10
212
+ sigma = 1.0
213
+
214
+ for _ in range(50):
215
+ P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
216
+ P_sum = P.sum()
217
+ if P_sum < 1e-10:
218
+ sigma = (sigma + sigma_max) / 2
219
+ continue
220
+ P = P / P_sum
221
+ P = np.clip(P, 1e-10, 1.0)
222
+ H = -np.sum(P * np.log2(P))
223
+
224
+ if abs(H - target_H) < tol:
225
+ break
226
+ elif target_H > H:
227
+ sigma_min = sigma
228
+ else:
229
+ sigma_max = sigma
230
+ sigma = (sigma_min + sigma_max) / 2
231
+
232
+ return sigma
233
+
234
+
235
+ def _lisi_score(
236
+ X: np.ndarray,
237
+ batch_labels: np.ndarray,
238
+ perplexity: int = 30,
239
+ ) -> float:
240
+ """
241
+ Compute mean Local Inverse Simpson's Index (LISI).
242
+
243
+ Range: [1, n_batches], where n_batches = perfect mixing.
244
+ Higher = better batch mixing.
245
+
246
+ Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
247
+
248
+ Parameters
249
+ ----------
250
+ X : np.ndarray
251
+ Data matrix.
252
+ batch_labels : np.ndarray
253
+ Batch labels for each sample.
254
+ perplexity : int
255
+ Perplexity for Gaussian kernel.
256
+
257
+ Returns
258
+ -------
259
+ float
260
+ Mean LISI score.
261
+ """
262
+ n_samples = X.shape[0]
263
+ unique_batches = np.unique(batch_labels)
264
+ n_batches = len(unique_batches)
265
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
266
+
267
+ if n_batches < 2:
268
+ return 1.0
269
+
270
+ k = min(3 * perplexity, n_samples - 1)
271
+
272
+ nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
273
+ nn.fit(X)
274
+ distances, indices = nn.kneighbors(X)
275
+
276
+ distances = distances[:, 1:]
277
+ indices = indices[:, 1:]
278
+
279
+ lisi_values = []
280
+
281
+ for i in range(n_samples):
282
+ sigma = _find_sigma(distances[i], perplexity)
283
+
284
+ P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
285
+ P_sum = P.sum()
286
+ if P_sum < 1e-10:
287
+ lisi_values.append(1.0)
288
+ continue
289
+ P = P / P_sum
290
+
291
+ neighbor_batches = batch_labels[indices[i]]
292
+ batch_probs = np.zeros(n_batches)
293
+ for j, nb in enumerate(neighbor_batches):
294
+ batch_probs[batch_to_idx[nb]] += P[j]
295
+
296
+ simpson = np.sum(batch_probs**2)
297
+ lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
298
+ lisi_values.append(lisi)
299
+
300
+ return np.mean(lisi_values)
301
+
302
+
303
+ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
304
+ """
305
+ Compute between-batch to within-batch variance ratio.
306
+
307
+ Similar to F-statistic in one-way ANOVA.
308
+ Lower ratio after correction = better batch effect removal.
309
+
310
+ Parameters
311
+ ----------
312
+ X : np.ndarray
313
+ Data matrix.
314
+ batch_labels : np.ndarray
315
+ Batch labels for each sample.
316
+
317
+ Returns
318
+ -------
319
+ float
320
+ Variance ratio (between/within).
321
+ """
322
+ unique_batches = np.unique(batch_labels)
323
+ n_batches = len(unique_batches)
324
+ n_samples = X.shape[0]
325
+
326
+ if n_batches < 2:
327
+ return 0.0
328
+
329
+ grand_mean = np.mean(X, axis=0)
330
+
331
+ between_var = 0.0
332
+ within_var = 0.0
333
+
334
+ for batch in unique_batches:
335
+ mask = batch_labels == batch
336
+ n_b = np.sum(mask)
337
+ X_batch = X[mask]
338
+ batch_mean = np.mean(X_batch, axis=0)
339
+
340
+ between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
341
+ within_var += np.sum((X_batch - batch_mean) ** 2)
342
+
343
+ between_var /= n_batches - 1
344
+ within_var /= n_samples - n_batches
345
+
346
+ if within_var < 1e-10:
347
+ return 0.0
348
+
349
+ return between_var / within_var
350
+
351
+
352
+ def _knn_preservation(
353
+ X_before: np.ndarray,
354
+ X_after: np.ndarray,
355
+ k_values: list[int],
356
+ n_jobs: int = 1,
357
+ ) -> dict[int, float]:
358
+ """
359
+ Compute fraction of k-nearest neighbors preserved after correction.
360
+
361
+ Range: [0, 1], where 1 = perfect preservation.
362
+ Higher = better biological structure preservation.
363
+
364
+ Parameters
365
+ ----------
366
+ X_before : np.ndarray
367
+ Original data.
368
+ X_after : np.ndarray
369
+ Corrected data.
370
+ k_values : list of int
371
+ Values of k for k-NN.
372
+ n_jobs : int
373
+ Number of parallel jobs.
374
+
375
+ Returns
376
+ -------
377
+ dict
378
+ Mapping from k to preservation fraction.
379
+ """
380
+ results = {}
381
+ max_k = max(k_values)
382
+ max_k = min(max_k, X_before.shape[0] - 1)
383
+
384
+ nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
385
+ nn_before.fit(X_before)
386
+ _, indices_before = nn_before.kneighbors(X_before)
387
+
388
+ nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
389
+ nn_after.fit(X_after)
390
+ _, indices_after = nn_after.kneighbors(X_after)
391
+
392
+ for k in k_values:
393
+ if k > max_k:
394
+ results[k] = 0.0
395
+ continue
396
+
397
+ overlaps = []
398
+ for i in range(X_before.shape[0]):
399
+ neighbors_before = set(indices_before[i, 1 : k + 1])
400
+ neighbors_after = set(indices_after[i, 1 : k + 1])
401
+ overlap = len(neighbors_before & neighbors_after) / k
402
+ overlaps.append(overlap)
403
+
404
+ results[k] = np.mean(overlaps)
405
+
406
+ return results
407
+
408
+
409
+ def _pairwise_distance_correlation(
410
+ X_before: np.ndarray,
411
+ X_after: np.ndarray,
412
+ subsample: int = 1000,
413
+ random_state: int = 42,
414
+ ) -> float:
415
+ """
416
+ Compute Spearman correlation of pairwise distances.
417
+
418
+ Range: [-1, 1], where 1 = perfect rank preservation.
419
+ Higher = better relative relationship preservation.
420
+
421
+ Parameters
422
+ ----------
423
+ X_before : np.ndarray
424
+ Original data.
425
+ X_after : np.ndarray
426
+ Corrected data.
427
+ subsample : int
428
+ Maximum samples to use (for efficiency).
429
+ random_state : int
430
+ Random seed for subsampling.
431
+
432
+ Returns
433
+ -------
434
+ float
435
+ Spearman correlation coefficient.
436
+ """
437
+ n_samples = X_before.shape[0]
438
+
439
+ if n_samples > subsample:
440
+ rng = np.random.default_rng(random_state)
441
+ idx = rng.choice(n_samples, subsample, replace=False)
442
+ X_before = X_before[idx]
443
+ X_after = X_after[idx]
444
+
445
+ dist_before = pdist(X_before, metric="euclidean")
446
+ dist_after = pdist(X_after, metric="euclidean")
447
+
448
+ if len(dist_before) == 0:
449
+ return 1.0
450
+
451
+ corr, _ = spearmanr(dist_before, dist_after)
452
+
453
+ if np.isnan(corr):
454
+ return 1.0
455
+
456
+ return corr
457
+
458
+
459
+ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
460
+ """
461
+ Compute mean pairwise Euclidean distance between batch centroids.
462
+
463
+ Lower after correction = better batch alignment.
464
+
465
+ Parameters
466
+ ----------
467
+ X : np.ndarray
468
+ Data matrix.
469
+ batch_labels : np.ndarray
470
+ Batch labels for each sample.
471
+
472
+ Returns
473
+ -------
474
+ float
475
+ Mean pairwise distance between centroids.
476
+ """
477
+ unique_batches = np.unique(batch_labels)
478
+ n_batches = len(unique_batches)
479
+
480
+ if n_batches < 2:
481
+ return 0.0
482
+
483
+ centroids = []
484
+ for batch in unique_batches:
485
+ mask = batch_labels == batch
486
+ centroid = np.mean(X[mask], axis=0)
487
+ centroids.append(centroid)
488
+
489
+ centroids = np.array(centroids)
490
+ distances = pdist(centroids, metric="euclidean")
491
+
492
+ return np.mean(distances)
493
+
494
+
495
+ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
496
+ """
497
+ Compute median Levene test statistic across features.
498
+
499
+ Lower statistic = more homogeneous variances across batches.
500
+
501
+ Parameters
502
+ ----------
503
+ X : np.ndarray
504
+ Data matrix.
505
+ batch_labels : np.ndarray
506
+ Batch labels for each sample.
507
+
508
+ Returns
509
+ -------
510
+ float
511
+ Median Levene test statistic.
512
+ """
513
+ unique_batches = np.unique(batch_labels)
514
+ if len(unique_batches) < 2:
515
+ return 0.0
516
+
517
+ levene_stats = []
518
+ for j in range(X.shape[1]):
519
+ groups = [X[batch_labels == b, j] for b in unique_batches]
520
+ groups = [g for g in groups if len(g) > 0]
521
+ if len(groups) < 2:
522
+ continue
523
+ try:
524
+ stat, _ = levene(*groups, center="median")
525
+ if not np.isnan(stat):
526
+ levene_stats.append(stat)
527
+ except Exception:
528
+ continue
529
+
530
+ if len(levene_stats) == 0:
531
+ return 0.0
532
+
533
+ return np.median(levene_stats)
534
+
535
+
536
+ class ComBatMetricsMixin:
537
+ """Mixin providing batch effect metrics for the ComBat wrapper."""
538
+
539
+ @property
540
+ def metrics_(self) -> dict[str, Any] | None:
541
+ """Return cached metrics from last fit_transform with compute_metrics=True.
542
+
543
+ Returns
544
+ -------
545
+ dict or None
546
+ Cached metrics dictionary, or None if no metrics have been computed.
547
+ """
548
+ return getattr(self, "_metrics_cache", None)
549
+
550
+ def compute_batch_metrics(
551
+ self,
552
+ X: ArrayLike,
553
+ batch: ArrayLike | None = None,
554
+ *,
555
+ pca_components: int | None = None,
556
+ k_neighbors: list[int] | None = None,
557
+ kbet_k0: int | None = None,
558
+ lisi_perplexity: int = 30,
559
+ n_jobs: int = 1,
560
+ ) -> dict[str, Any]:
561
+ """
562
+ Compute batch effect metrics before and after ComBat correction.
563
+
564
+ Parameters
565
+ ----------
566
+ X : array-like of shape (n_samples, n_features)
567
+ Input data to evaluate.
568
+ batch : array-like of shape (n_samples,), optional
569
+ Batch labels. If None, uses the batch stored at construction.
570
+ pca_components : int, optional
571
+ Number of PCA components for dimensionality reduction before
572
+ computing metrics. If None (default), metrics are computed in
573
+ the original feature space. Must be less than min(n_samples, n_features).
574
+ k_neighbors : list of int, default=[5, 10, 50]
575
+ Values of k for k-NN preservation metric.
576
+ kbet_k0 : int, optional
577
+ Neighborhood size for kBET. Default is 10% of samples.
578
+ lisi_perplexity : int, default=30
579
+ Perplexity for LISI computation.
580
+ n_jobs : int, default=1
581
+ Number of parallel jobs for neighbor computations.
582
+
583
+ Returns
584
+ -------
585
+ dict
586
+ Dictionary with three main keys:
587
+
588
+ - ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
589
+ (each with 'before' and 'after' values)
590
+ - ``preservation``: k-NN preservation fractions, distance correlation
591
+ - ``alignment``: Centroid distance, Levene statistic (each with
592
+ 'before' and 'after' values)
593
+
594
+ Raises
595
+ ------
596
+ ValueError
597
+ If the model is not fitted or if pca_components is invalid.
598
+ """
599
+ if not hasattr(self._model, "_gamma_star"):
600
+ raise ValueError(
601
+ "This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
602
+ )
603
+
604
+ if not isinstance(X, pd.DataFrame):
605
+ X = pd.DataFrame(X)
606
+
607
+ idx = X.index
608
+
609
+ if batch is None:
610
+ batch_vec = self._subset(self.batch, idx)
611
+ else:
612
+ if isinstance(batch, (pd.Series, pd.DataFrame)):
613
+ batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
614
+ elif isinstance(batch, np.ndarray):
615
+ batch_vec = pd.Series(batch, index=idx)
616
+ else:
617
+ batch_vec = pd.Series(batch, index=idx)
618
+
619
+ batch_labels = np.array(batch_vec)
620
+
621
+ X_before = X.values
622
+ X_after = self.transform(X).values
623
+
624
+ n_samples, n_features = X_before.shape
625
+ if kbet_k0 is None:
626
+ kbet_k0 = max(10, int(0.10 * n_samples))
627
+ if k_neighbors is None:
628
+ k_neighbors = [5, 10, 50]
629
+
630
+ # Validate and apply PCA if requested
631
+ if pca_components is not None:
632
+ max_components = min(n_samples, n_features)
633
+ if pca_components >= max_components:
634
+ raise ValueError(
635
+ f"pca_components={pca_components} must be less than "
636
+ f"min(n_samples, n_features)={max_components}."
637
+ )
638
+ X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
639
+ else:
640
+ X_before_pca = X_before
641
+ X_after_pca = X_after
642
+
643
+ silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
644
+ silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
645
+
646
+ db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
647
+ db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
648
+
649
+ kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
650
+ kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
651
+
652
+ lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
653
+ lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
654
+
655
+ var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
656
+ var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
657
+
658
+ knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
659
+ dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
660
+
661
+ centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
662
+ centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
663
+
664
+ levene_before = _levene_median_statistic(X_before, batch_labels)
665
+ levene_after = _levene_median_statistic(X_after, batch_labels)
666
+
667
+ n_batches = len(np.unique(batch_labels))
668
+
669
+ metrics = {
670
+ "batch_effect": {
671
+ "silhouette": {
672
+ "before": silhouette_before,
673
+ "after": silhouette_after,
674
+ },
675
+ "davies_bouldin": {
676
+ "before": db_before,
677
+ "after": db_after,
678
+ },
679
+ "kbet": {
680
+ "before": kbet_before,
681
+ "after": kbet_after,
682
+ },
683
+ "lisi": {
684
+ "before": lisi_before,
685
+ "after": lisi_after,
686
+ "max_value": n_batches,
687
+ },
688
+ "variance_ratio": {
689
+ "before": var_ratio_before,
690
+ "after": var_ratio_after,
691
+ },
692
+ },
693
+ "preservation": {
694
+ "knn": knn_results,
695
+ "distance_correlation": dist_corr,
696
+ },
697
+ "alignment": {
698
+ "centroid_distance": {
699
+ "before": centroid_before,
700
+ "after": centroid_after,
701
+ },
702
+ "levene_statistic": {
703
+ "before": levene_before,
704
+ "after": levene_after,
705
+ },
706
+ },
707
+ }
708
+
709
+ return metrics
710
+
711
+ def feature_batch_importance(
712
+ self,
713
+ mode: Literal["magnitude", "distribution"] = "magnitude",
714
+ ) -> pd.DataFrame:
715
+ """Compute per-feature batch effect magnitude.
716
+
717
+ Returns a DataFrame with columns ``location``, ``scale``, and
718
+ ``combined``. Location is the RMS of gamma across batches
719
+ (standardized mean shifts). Scale is the RMS of log-delta across
720
+ batches (log-fold variance change). Combined is the Euclidean norm
721
+ sqrt(location**2 + scale**2). Using RMS provides L2-consistent
722
+ aggregation; using log(delta) ensures symmetry.
723
+
724
+ Parameters
725
+ ----------
726
+ mode : {'magnitude', 'distribution'}, default='magnitude'
727
+ - 'magnitude': Returns L2-consistent absolute batch effect magnitudes.
728
+ Suitable for ranking, thresholding, and cross-dataset comparison.
729
+ - 'distribution': Returns column-wise normalized proportions (each column
730
+ sums to 1, values in range [0, 1]), representing the relative contribution
731
+ of each feature to the total location, scale, or combined batch effect.
732
+ Note: normalization is applied independently to each column, so the
733
+ Euclidean relationship (combined**2 = location**2 + scale**2) no longer holds.
734
+
735
+ Returns
736
+ -------
737
+ pd.DataFrame
738
+ DataFrame with index=feature names, columns=['location', 'scale', 'combined'],
739
+ sorted by 'combined' descending.
740
+
741
+ Raises
742
+ ------
743
+ ValueError
744
+ If the model is not fitted or if mode is invalid.
745
+ """
746
+ if not hasattr(self._model, "_gamma_star"):
747
+ raise ValueError(
748
+ "This ComBat instance is not fitted yet. "
749
+ "Call 'fit' before 'feature_batch_importance'."
750
+ )
751
+
752
+ if mode not in ["magnitude", "distribution"]:
753
+ raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
754
+
755
+ feature_names = self._model._grand_mean.index
756
+ gamma_star = self._model._gamma_star
757
+ delta_star = self._model._delta_star
758
+
759
+ # Location effect: RMS of gamma across batches (L2 aggregation)
760
+ location = np.sqrt((gamma_star**2).mean(axis=0))
761
+
762
+ # Scale effect: RMS of log(delta) across batches
763
+ if not self.mean_only:
764
+ scale = np.sqrt((np.log(delta_star) ** 2).mean(axis=0))
765
+ else:
766
+ scale = np.zeros_like(location)
767
+
768
+ # Euclidean to treat location and scale as orthogonal dimensions
769
+ combined = np.sqrt(location**2 + scale**2)
770
+
771
+ if mode == "distribution":
772
+ # Normalize each column independently to sum to 1
773
+ location_sum = location.sum()
774
+ scale_sum = scale.sum()
775
+ combined_sum = combined.sum()
776
+
777
+ location = location / location_sum if location_sum > 0 else location
778
+ scale = scale / scale_sum if scale_sum > 0 else scale
779
+ combined = combined / combined_sum if combined_sum > 0 else combined
780
+
781
+ return pd.DataFrame(
782
+ {
783
+ "location": location,
784
+ "scale": scale,
785
+ "combined": combined,
786
+ },
787
+ index=feature_names,
788
+ ).sort_values("combined", ascending=False)