combatlearn 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/combat.py DELETED
@@ -1,1772 +0,0 @@
1
- """ComBat algorithm.
2
-
3
- `ComBatModel` implements both:
4
- * Johnson et al. (2007) vanilla ComBat (method="johnson")
5
- * Fortin et al. (2018) extension with covariates (method="fortin")
6
- * Chen et al. (2022) CovBat (method="chen")
7
-
8
- `ComBat` makes the model compatible with scikit-learn by stashing
9
- the batch (and optional covariates) at construction.
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import warnings
15
- from typing import Any, Literal
16
-
17
- import matplotlib
18
- import matplotlib.colors as mcolors
19
- import matplotlib.pyplot as plt
20
- import numpy as np
21
- import numpy.linalg as la
22
- import numpy.typing as npt
23
- import pandas as pd
24
- import plotly.graph_objects as go
25
- import umap
26
- from plotly.subplots import make_subplots
27
- from scipy.spatial.distance import pdist
28
- from scipy.stats import chi2, levene, spearmanr
29
- from sklearn.base import BaseEstimator, TransformerMixin
30
- from sklearn.decomposition import PCA
31
- from sklearn.manifold import TSNE
32
- from sklearn.metrics import davies_bouldin_score, silhouette_score
33
- from sklearn.neighbors import NearestNeighbors
34
-
35
- ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
36
- FloatArray = npt.NDArray[np.float64]
37
-
38
-
39
- def _compute_pca_embedding(
40
- X_before: np.ndarray,
41
- X_after: np.ndarray,
42
- n_components: int,
43
- ) -> tuple[np.ndarray, np.ndarray, PCA]:
44
- """
45
- Compute PCA embeddings for both datasets.
46
-
47
- Fits PCA on X_before and applies to both datasets.
48
-
49
- Parameters
50
- ----------
51
- X_before : np.ndarray
52
- Original data before correction.
53
- X_after : np.ndarray
54
- Corrected data.
55
- n_components : int
56
- Number of PCA components.
57
-
58
- Returns
59
- -------
60
- X_before_pca : np.ndarray
61
- PCA-transformed original data.
62
- X_after_pca : np.ndarray
63
- PCA-transformed corrected data.
64
- pca : PCA
65
- Fitted PCA model.
66
- """
67
- n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
68
- pca = PCA(n_components=n_components, random_state=42)
69
- X_before_pca = pca.fit_transform(X_before)
70
- X_after_pca = pca.transform(X_after)
71
- return X_before_pca, X_after_pca, pca
72
-
73
-
74
- def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
75
- """
76
- Compute silhouette coefficient using batch as cluster labels.
77
-
78
- Lower values after correction indicate better batch mixing.
79
- Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
80
-
81
- Parameters
82
- ----------
83
- X : np.ndarray
84
- Data matrix.
85
- batch_labels : np.ndarray
86
- Batch labels for each sample.
87
-
88
- Returns
89
- -------
90
- float
91
- Silhouette coefficient.
92
- """
93
- unique_batches = np.unique(batch_labels)
94
- if len(unique_batches) < 2:
95
- return 0.0
96
- try:
97
- return silhouette_score(X, batch_labels, metric="euclidean")
98
- except Exception:
99
- return 0.0
100
-
101
-
102
- def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
103
- """
104
- Compute Davies-Bouldin index using batch labels.
105
-
106
- Lower values indicate better batch mixing.
107
- Range: [0, inf), 0 = perfect batch overlap.
108
-
109
- Parameters
110
- ----------
111
- X : np.ndarray
112
- Data matrix.
113
- batch_labels : np.ndarray
114
- Batch labels for each sample.
115
-
116
- Returns
117
- -------
118
- float
119
- Davies-Bouldin index.
120
- """
121
- unique_batches = np.unique(batch_labels)
122
- if len(unique_batches) < 2:
123
- return 0.0
124
- try:
125
- return davies_bouldin_score(X, batch_labels)
126
- except Exception:
127
- return 0.0
128
-
129
-
130
- def _kbet_score(
131
- X: np.ndarray,
132
- batch_labels: np.ndarray,
133
- k0: int,
134
- alpha: float = 0.05,
135
- ) -> tuple[float, float]:
136
- """
137
- Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
138
-
139
- Tests if local batch proportions match global batch proportions.
140
- Higher acceptance rate = better batch mixing.
141
-
142
- Reference: Buttner et al. (2019) Nature Methods
143
-
144
- Parameters
145
- ----------
146
- X : np.ndarray
147
- Data matrix.
148
- batch_labels : np.ndarray
149
- Batch labels for each sample.
150
- k0 : int
151
- Neighborhood size.
152
- alpha : float
153
- Significance level for chi-squared test.
154
-
155
- Returns
156
- -------
157
- acceptance_rate : float
158
- Fraction of samples where H0 (uniform mixing) is accepted.
159
- mean_stat : float
160
- Mean chi-squared statistic across samples.
161
- """
162
- n_samples = X.shape[0]
163
- unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
164
- n_batches = len(unique_batches)
165
-
166
- if n_batches < 2:
167
- return 1.0, 0.0
168
-
169
- global_freq = batch_counts / n_samples
170
- k0 = min(k0, n_samples - 1)
171
-
172
- nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
173
- nn.fit(X)
174
- _, indices = nn.kneighbors(X)
175
-
176
- chi2_stats = []
177
- p_values = []
178
- batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
179
-
180
- for i in range(n_samples):
181
- neighbors = indices[i, 1 : k0 + 1]
182
- neighbor_batches = batch_labels[neighbors]
183
-
184
- observed = np.zeros(n_batches)
185
- for nb in neighbor_batches:
186
- observed[batch_to_idx[nb]] += 1
187
-
188
- expected = global_freq * k0
189
-
190
- mask = expected > 0
191
- if mask.sum() < 2:
192
- continue
193
-
194
- stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
195
- df = max(1, mask.sum() - 1)
196
- p_val = 1 - chi2.cdf(stat, df)
197
-
198
- chi2_stats.append(stat)
199
- p_values.append(p_val)
200
-
201
- if len(p_values) == 0:
202
- return 1.0, 0.0
203
-
204
- acceptance_rate = np.mean(np.array(p_values) > alpha)
205
- mean_stat = np.mean(chi2_stats)
206
-
207
- return acceptance_rate, mean_stat
208
-
209
-
210
- def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
211
- """
212
- Binary search for sigma to achieve target perplexity.
213
-
214
- Used in LISI computation.
215
-
216
- Parameters
217
- ----------
218
- distances : np.ndarray
219
- Distances to neighbors.
220
- target_perplexity : float
221
- Target perplexity value.
222
- tol : float
223
- Tolerance for convergence.
224
-
225
- Returns
226
- -------
227
- float
228
- Sigma value.
229
- """
230
- target_H = np.log2(target_perplexity + 1e-10)
231
-
232
- sigma_min, sigma_max = 1e-10, 1e10
233
- sigma = 1.0
234
-
235
- for _ in range(50):
236
- P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
237
- P_sum = P.sum()
238
- if P_sum < 1e-10:
239
- sigma = (sigma + sigma_max) / 2
240
- continue
241
- P = P / P_sum
242
- P = np.clip(P, 1e-10, 1.0)
243
- H = -np.sum(P * np.log2(P))
244
-
245
- if abs(H - target_H) < tol:
246
- break
247
- elif target_H > H:
248
- sigma_min = sigma
249
- else:
250
- sigma_max = sigma
251
- sigma = (sigma_min + sigma_max) / 2
252
-
253
- return sigma
254
-
255
-
256
- def _lisi_score(
257
- X: np.ndarray,
258
- batch_labels: np.ndarray,
259
- perplexity: int = 30,
260
- ) -> float:
261
- """
262
- Compute mean Local Inverse Simpson's Index (LISI).
263
-
264
- Range: [1, n_batches], where n_batches = perfect mixing.
265
- Higher = better batch mixing.
266
-
267
- Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
268
-
269
- Parameters
270
- ----------
271
- X : np.ndarray
272
- Data matrix.
273
- batch_labels : np.ndarray
274
- Batch labels for each sample.
275
- perplexity : int
276
- Perplexity for Gaussian kernel.
277
-
278
- Returns
279
- -------
280
- float
281
- Mean LISI score.
282
- """
283
- n_samples = X.shape[0]
284
- unique_batches = np.unique(batch_labels)
285
- n_batches = len(unique_batches)
286
- batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
287
-
288
- if n_batches < 2:
289
- return 1.0
290
-
291
- k = min(3 * perplexity, n_samples - 1)
292
-
293
- nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
294
- nn.fit(X)
295
- distances, indices = nn.kneighbors(X)
296
-
297
- distances = distances[:, 1:]
298
- indices = indices[:, 1:]
299
-
300
- lisi_values = []
301
-
302
- for i in range(n_samples):
303
- sigma = _find_sigma(distances[i], perplexity)
304
-
305
- P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
306
- P_sum = P.sum()
307
- if P_sum < 1e-10:
308
- lisi_values.append(1.0)
309
- continue
310
- P = P / P_sum
311
-
312
- neighbor_batches = batch_labels[indices[i]]
313
- batch_probs = np.zeros(n_batches)
314
- for j, nb in enumerate(neighbor_batches):
315
- batch_probs[batch_to_idx[nb]] += P[j]
316
-
317
- simpson = np.sum(batch_probs**2)
318
- lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
319
- lisi_values.append(lisi)
320
-
321
- return np.mean(lisi_values)
322
-
323
-
324
- def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
325
- """
326
- Compute between-batch to within-batch variance ratio.
327
-
328
- Similar to F-statistic in one-way ANOVA.
329
- Lower ratio after correction = better batch effect removal.
330
-
331
- Parameters
332
- ----------
333
- X : np.ndarray
334
- Data matrix.
335
- batch_labels : np.ndarray
336
- Batch labels for each sample.
337
-
338
- Returns
339
- -------
340
- float
341
- Variance ratio (between/within).
342
- """
343
- unique_batches = np.unique(batch_labels)
344
- n_batches = len(unique_batches)
345
- n_samples = X.shape[0]
346
-
347
- if n_batches < 2:
348
- return 0.0
349
-
350
- grand_mean = np.mean(X, axis=0)
351
-
352
- between_var = 0.0
353
- within_var = 0.0
354
-
355
- for batch in unique_batches:
356
- mask = batch_labels == batch
357
- n_b = np.sum(mask)
358
- X_batch = X[mask]
359
- batch_mean = np.mean(X_batch, axis=0)
360
-
361
- between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
362
- within_var += np.sum((X_batch - batch_mean) ** 2)
363
-
364
- between_var /= n_batches - 1
365
- within_var /= n_samples - n_batches
366
-
367
- if within_var < 1e-10:
368
- return 0.0
369
-
370
- return between_var / within_var
371
-
372
-
373
- def _knn_preservation(
374
- X_before: np.ndarray,
375
- X_after: np.ndarray,
376
- k_values: list[int],
377
- n_jobs: int = 1,
378
- ) -> dict[int, float]:
379
- """
380
- Compute fraction of k-nearest neighbors preserved after correction.
381
-
382
- Range: [0, 1], where 1 = perfect preservation.
383
- Higher = better biological structure preservation.
384
-
385
- Parameters
386
- ----------
387
- X_before : np.ndarray
388
- Original data.
389
- X_after : np.ndarray
390
- Corrected data.
391
- k_values : list of int
392
- Values of k for k-NN.
393
- n_jobs : int
394
- Number of parallel jobs.
395
-
396
- Returns
397
- -------
398
- dict
399
- Mapping from k to preservation fraction.
400
- """
401
- results = {}
402
- max_k = max(k_values)
403
- max_k = min(max_k, X_before.shape[0] - 1)
404
-
405
- nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
406
- nn_before.fit(X_before)
407
- _, indices_before = nn_before.kneighbors(X_before)
408
-
409
- nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
410
- nn_after.fit(X_after)
411
- _, indices_after = nn_after.kneighbors(X_after)
412
-
413
- for k in k_values:
414
- if k > max_k:
415
- results[k] = 0.0
416
- continue
417
-
418
- overlaps = []
419
- for i in range(X_before.shape[0]):
420
- neighbors_before = set(indices_before[i, 1 : k + 1])
421
- neighbors_after = set(indices_after[i, 1 : k + 1])
422
- overlap = len(neighbors_before & neighbors_after) / k
423
- overlaps.append(overlap)
424
-
425
- results[k] = np.mean(overlaps)
426
-
427
- return results
428
-
429
-
430
- def _pairwise_distance_correlation(
431
- X_before: np.ndarray,
432
- X_after: np.ndarray,
433
- subsample: int = 1000,
434
- random_state: int = 42,
435
- ) -> float:
436
- """
437
- Compute Spearman correlation of pairwise distances.
438
-
439
- Range: [-1, 1], where 1 = perfect rank preservation.
440
- Higher = better relative relationship preservation.
441
-
442
- Parameters
443
- ----------
444
- X_before : np.ndarray
445
- Original data.
446
- X_after : np.ndarray
447
- Corrected data.
448
- subsample : int
449
- Maximum samples to use (for efficiency).
450
- random_state : int
451
- Random seed for subsampling.
452
-
453
- Returns
454
- -------
455
- float
456
- Spearman correlation coefficient.
457
- """
458
- n_samples = X_before.shape[0]
459
-
460
- if n_samples > subsample:
461
- rng = np.random.default_rng(random_state)
462
- idx = rng.choice(n_samples, subsample, replace=False)
463
- X_before = X_before[idx]
464
- X_after = X_after[idx]
465
-
466
- dist_before = pdist(X_before, metric="euclidean")
467
- dist_after = pdist(X_after, metric="euclidean")
468
-
469
- if len(dist_before) == 0:
470
- return 1.0
471
-
472
- corr, _ = spearmanr(dist_before, dist_after)
473
-
474
- if np.isnan(corr):
475
- return 1.0
476
-
477
- return corr
478
-
479
-
480
- def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
481
- """
482
- Compute mean pairwise Euclidean distance between batch centroids.
483
-
484
- Lower after correction = better batch alignment.
485
-
486
- Parameters
487
- ----------
488
- X : np.ndarray
489
- Data matrix.
490
- batch_labels : np.ndarray
491
- Batch labels for each sample.
492
-
493
- Returns
494
- -------
495
- float
496
- Mean pairwise distance between centroids.
497
- """
498
- unique_batches = np.unique(batch_labels)
499
- n_batches = len(unique_batches)
500
-
501
- if n_batches < 2:
502
- return 0.0
503
-
504
- centroids = []
505
- for batch in unique_batches:
506
- mask = batch_labels == batch
507
- centroid = np.mean(X[mask], axis=0)
508
- centroids.append(centroid)
509
-
510
- centroids = np.array(centroids)
511
- distances = pdist(centroids, metric="euclidean")
512
-
513
- return np.mean(distances)
514
-
515
-
516
- def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
517
- """
518
- Compute median Levene test statistic across features.
519
-
520
- Lower statistic = more homogeneous variances across batches.
521
-
522
- Parameters
523
- ----------
524
- X : np.ndarray
525
- Data matrix.
526
- batch_labels : np.ndarray
527
- Batch labels for each sample.
528
-
529
- Returns
530
- -------
531
- float
532
- Median Levene test statistic.
533
- """
534
- unique_batches = np.unique(batch_labels)
535
- if len(unique_batches) < 2:
536
- return 0.0
537
-
538
- levene_stats = []
539
- for j in range(X.shape[1]):
540
- groups = [X[batch_labels == b, j] for b in unique_batches]
541
- groups = [g for g in groups if len(g) > 0]
542
- if len(groups) < 2:
543
- continue
544
- try:
545
- stat, _ = levene(*groups, center="median")
546
- if not np.isnan(stat):
547
- levene_stats.append(stat)
548
- except Exception:
549
- continue
550
-
551
- if len(levene_stats) == 0:
552
- return 0.0
553
-
554
- return np.median(levene_stats)
555
-
556
-
557
- class ComBatModel:
558
- """ComBat algorithm.
559
-
560
- Parameters
561
- ----------
562
- method : {'johnson', 'fortin', 'chen'}, default='johnson'
563
- * 'johnson' - classic ComBat.
564
- * 'fortin' - covariate-aware ComBat.
565
- * 'chen' - CovBat, PCA-based ComBat.
566
- parametric : bool, default=True
567
- Use the parametric empirical Bayes variant.
568
- mean_only : bool, default=False
569
- If True, only the mean is adjusted (`gamma_star`),
570
- ignoring the variance (`delta_star`).
571
- reference_batch : str, optional
572
- If specified, the batch level to use as reference.
573
- covbat_cov_thresh : float or int, default=0.9
574
- CovBat: cumulative variance threshold (0, 1] to retain PCs, or
575
- integer >= 1 specifying the number of components directly.
576
- eps : float, default=1e-8
577
- Numerical jitter to avoid division-by-zero.
578
- """
579
-
580
- def __init__(
581
- self,
582
- *,
583
- method: Literal["johnson", "fortin", "chen"] = "johnson",
584
- parametric: bool = True,
585
- mean_only: bool = False,
586
- reference_batch: str | None = None,
587
- eps: float = 1e-8,
588
- covbat_cov_thresh: float | int = 0.9,
589
- ) -> None:
590
- self.method: str = method
591
- self.parametric: bool = parametric
592
- self.mean_only: bool = bool(mean_only)
593
- self.reference_batch: str | None = reference_batch
594
- self.eps: float = float(eps)
595
- self.covbat_cov_thresh: float | int = covbat_cov_thresh
596
-
597
- self._batch_levels: pd.Index
598
- self._grand_mean: pd.Series
599
- self._pooled_var: pd.Series
600
- self._gamma_star: FloatArray
601
- self._delta_star: FloatArray
602
- self._n_per_batch: dict[str, int]
603
- self._reference_batch_idx: int | None
604
- self._beta_hat_nonbatch: FloatArray
605
- self._n_batch: int
606
- self._p_design: int
607
- self._covbat_pca: PCA
608
- self._covbat_n_pc: int
609
- self._batch_levels_pc: pd.Index
610
- self._pc_gamma_star: FloatArray
611
- self._pc_delta_star: FloatArray
612
-
613
- # Validate covbat_cov_thresh
614
- if isinstance(self.covbat_cov_thresh, float):
615
- if not (0.0 < self.covbat_cov_thresh <= 1.0):
616
- raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
617
- elif isinstance(self.covbat_cov_thresh, int):
618
- if self.covbat_cov_thresh < 1:
619
- raise ValueError("covbat_cov_thresh must be >= 1 when int.")
620
- else:
621
- raise TypeError("covbat_cov_thresh must be float or int.")
622
-
623
- @staticmethod
624
- def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
625
- """Convert array-like to categorical Series with validation."""
626
- ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
627
- if not ser.index.equals(index):
628
- raise ValueError(f"`{name}` index mismatch with `X`.")
629
- return ser.astype("category")
630
-
631
- @staticmethod
632
- def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
633
- """Convert array-like to DataFrame."""
634
- if arr is None:
635
- return None
636
- if isinstance(arr, pd.Series):
637
- arr = arr.to_frame()
638
- if not isinstance(arr, pd.DataFrame):
639
- arr = pd.DataFrame(arr, index=index)
640
- if not arr.index.equals(index):
641
- raise ValueError(f"`{name}` index mismatch with `X`.")
642
- return arr
643
-
644
- def fit(
645
- self,
646
- X: ArrayLike,
647
- y: ArrayLike | None = None,
648
- *,
649
- batch: ArrayLike,
650
- discrete_covariates: ArrayLike | None = None,
651
- continuous_covariates: ArrayLike | None = None,
652
- ) -> ComBatModel:
653
- """Fit the ComBat model."""
654
- method = self.method.lower()
655
- if method not in {"johnson", "fortin", "chen"}:
656
- raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
657
- if not isinstance(X, pd.DataFrame):
658
- X = pd.DataFrame(X)
659
- idx = X.index
660
- batch = self._as_series(batch, idx, "batch")
661
-
662
- disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
663
- cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
664
-
665
- if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
666
- raise ValueError(
667
- f"reference_batch={self.reference_batch!r} not present in the data batches."
668
- f"{list(batch.cat.categories)}"
669
- )
670
-
671
- if method == "johnson":
672
- if disc is not None or cont is not None:
673
- warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
674
- self._fit_johnson(X, batch)
675
- elif method == "fortin":
676
- self._fit_fortin(X, batch, disc, cont)
677
- elif method == "chen":
678
- self._fit_chen(X, batch, disc, cont)
679
- return self
680
-
681
- def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
682
- """Johnson et al. (2007) ComBat."""
683
- self._batch_levels = batch.cat.categories
684
- pooled_var = X.var(axis=0, ddof=1) + self.eps
685
- grand_mean = X.mean(axis=0)
686
-
687
- Xs = (X - grand_mean) / np.sqrt(pooled_var)
688
-
689
- n_per_batch: dict[str, int] = {}
690
- gamma_hat: list[npt.NDArray[np.float64]] = []
691
- delta_hat: list[npt.NDArray[np.float64]] = []
692
-
693
- for lvl in self._batch_levels:
694
- idx = batch == lvl
695
- n_b = int(idx.sum())
696
- if n_b < 2:
697
- raise ValueError(f"Batch '{lvl}' has <2 samples.")
698
- n_per_batch[str(lvl)] = n_b
699
- xb = Xs.loc[idx]
700
- gamma_hat.append(xb.mean(axis=0).values)
701
- delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
702
-
703
- gamma_hat_arr = np.vstack(gamma_hat)
704
- delta_hat_arr = np.vstack(delta_hat)
705
-
706
- if self.mean_only:
707
- gamma_star = self._shrink_gamma(
708
- gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
709
- )
710
- delta_star = np.ones_like(delta_hat_arr)
711
- else:
712
- gamma_star, delta_star = self._shrink_gamma_delta(
713
- gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
714
- )
715
-
716
- if self.reference_batch is not None:
717
- ref_idx = list(self._batch_levels).index(self.reference_batch)
718
- gamma_ref = gamma_star[ref_idx]
719
- delta_ref = delta_star[ref_idx]
720
- gamma_star = gamma_star - gamma_ref
721
- if not self.mean_only:
722
- delta_star = delta_star / delta_ref
723
- self._reference_batch_idx = ref_idx
724
- else:
725
- self._reference_batch_idx = None
726
-
727
- self._grand_mean = grand_mean
728
- self._pooled_var = pooled_var
729
- self._gamma_star = gamma_star
730
- self._delta_star = delta_star
731
- self._n_per_batch = n_per_batch
732
-
733
- def _fit_fortin(
734
- self,
735
- X: pd.DataFrame,
736
- batch: pd.Series,
737
- disc: pd.DataFrame | None,
738
- cont: pd.DataFrame | None,
739
- ) -> None:
740
- """Fortin et al. (2018) neuroComBat."""
741
- self._batch_levels = batch.cat.categories
742
- n_batch = len(self._batch_levels)
743
- n_samples = len(X)
744
-
745
- batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
746
- if self.reference_batch is not None:
747
- if self.reference_batch not in self._batch_levels:
748
- raise ValueError(
749
- f"reference_batch={self.reference_batch!r} not present in batches."
750
- f"{list(self._batch_levels)}"
751
- )
752
- batch_dummies.loc[:, self.reference_batch] = 1.0
753
-
754
- parts: list[pd.DataFrame] = [batch_dummies]
755
- if disc is not None:
756
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
757
-
758
- if cont is not None:
759
- parts.append(cont.astype(float))
760
-
761
- design = pd.concat(parts, axis=1).values
762
- p_design = design.shape[1]
763
-
764
- X_np = X.values
765
- beta_hat = la.lstsq(design, X_np, rcond=None)[0]
766
-
767
- beta_hat_batch = beta_hat[:n_batch]
768
- self._beta_hat_nonbatch = beta_hat[n_batch:]
769
-
770
- n_per_batch = batch.value_counts().sort_index().astype(int).values
771
- self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
772
-
773
- if self.reference_batch is not None:
774
- ref_idx = list(self._batch_levels).index(self.reference_batch)
775
- grand_mean = beta_hat_batch[ref_idx]
776
- else:
777
- grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
778
- ref_idx = None
779
-
780
- self._grand_mean = pd.Series(grand_mean, index=X.columns)
781
-
782
- if self.reference_batch is not None:
783
- ref_mask = (batch == self.reference_batch).values
784
- resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
785
- denom = int(ref_mask.sum())
786
- else:
787
- resid = X_np - design @ beta_hat
788
- denom = n_samples
789
- var_pooled = (resid**2).sum(axis=0) / denom + self.eps
790
- self._pooled_var = pd.Series(var_pooled, index=X.columns)
791
-
792
- stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
793
- Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
794
-
795
- gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
796
- delta_hat = np.vstack(
797
- [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
798
- )
799
-
800
- if self.mean_only:
801
- gamma_star = self._shrink_gamma(
802
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
803
- )
804
- delta_star = np.ones_like(delta_hat)
805
- else:
806
- gamma_star, delta_star = self._shrink_gamma_delta(
807
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
808
- )
809
-
810
- if ref_idx is not None:
811
- gamma_star[ref_idx] = 0.0
812
- if not self.mean_only:
813
- delta_star[ref_idx] = 1.0
814
- self._reference_batch_idx = ref_idx
815
-
816
- self._gamma_star = gamma_star
817
- self._delta_star = delta_star
818
- self._n_batch = n_batch
819
- self._p_design = p_design
820
-
821
- def _fit_chen(
822
- self,
823
- X: pd.DataFrame,
824
- batch: pd.Series,
825
- disc: pd.DataFrame | None,
826
- cont: pd.DataFrame | None,
827
- ) -> None:
828
- """Chen et al. (2022) CovBat."""
829
- self._fit_fortin(X, batch, disc, cont)
830
- X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
831
- X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
832
- pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
833
-
834
- # Determine number of components based on threshold type
835
- if isinstance(self.covbat_cov_thresh, int):
836
- n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
837
- else:
838
- cumulative = np.cumsum(pca.explained_variance_ratio_)
839
- n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
840
-
841
- self._covbat_pca = pca
842
- self._covbat_n_pc = n_pc
843
-
844
- scores = pca.transform(X_centered)[:, :n_pc]
845
- scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
846
- self._batch_levels_pc = self._batch_levels
847
- n_per_batch = self._n_per_batch
848
-
849
- gamma_hat: list[npt.NDArray[np.float64]] = []
850
- delta_hat: list[npt.NDArray[np.float64]] = []
851
- for lvl in self._batch_levels_pc:
852
- idx = batch == lvl
853
- xb = scores_df.loc[idx]
854
- gamma_hat.append(xb.mean(axis=0).values)
855
- delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
856
- gamma_hat_arr = np.vstack(gamma_hat)
857
- delta_hat_arr = np.vstack(delta_hat)
858
-
859
- if self.mean_only:
860
- gamma_star = self._shrink_gamma(
861
- gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
862
- )
863
- delta_star = np.ones_like(delta_hat_arr)
864
- else:
865
- gamma_star, delta_star = self._shrink_gamma_delta(
866
- gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
867
- )
868
-
869
- if self.reference_batch is not None:
870
- ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
871
- gamma_ref = gamma_star[ref_idx]
872
- delta_ref = delta_star[ref_idx]
873
- gamma_star = gamma_star - gamma_ref
874
- if not self.mean_only:
875
- delta_star = delta_star / delta_ref
876
-
877
- self._pc_gamma_star = gamma_star
878
- self._pc_delta_star = delta_star
879
-
880
- def _shrink_gamma_delta(
881
- self,
882
- gamma_hat: FloatArray,
883
- delta_hat: FloatArray,
884
- n_per_batch: dict[str, int] | FloatArray,
885
- *,
886
- parametric: bool,
887
- max_iter: int = 100,
888
- tol: float = 1e-4,
889
- ) -> tuple[FloatArray, FloatArray]:
890
- """Empirical Bayes shrinkage estimation."""
891
- if parametric:
892
- gamma_bar = gamma_hat.mean(axis=0)
893
- t2 = gamma_hat.var(axis=0, ddof=1)
894
- a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
895
- b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
896
-
897
- B, _p = gamma_hat.shape
898
- gamma_star = np.empty_like(gamma_hat)
899
- delta_star = np.empty_like(delta_hat)
900
- n_vec = (
901
- np.array(list(n_per_batch.values()))
902
- if isinstance(n_per_batch, dict)
903
- else n_per_batch
904
- )
905
-
906
- for i in range(B):
907
- n_i = n_vec[i]
908
- g, d = gamma_hat[i], delta_hat[i]
909
- gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
910
- gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
911
-
912
- a_post = a_prior + n_i / 2.0
913
- b_post = b_prior + 0.5 * n_i * d
914
- delta_star[i] = b_post / (a_post - 1)
915
- return gamma_star, delta_star
916
-
917
- else:
918
- B, _p = gamma_hat.shape
919
- n_vec = (
920
- np.array(list(n_per_batch.values()))
921
- if isinstance(n_per_batch, dict)
922
- else n_per_batch
923
- )
924
- gamma_bar = gamma_hat.mean(axis=0)
925
- t2 = gamma_hat.var(axis=0, ddof=1)
926
-
927
- def postmean(
928
- g_hat: FloatArray,
929
- g_bar: FloatArray,
930
- n: float,
931
- d_star: FloatArray,
932
- t2_: FloatArray,
933
- ) -> FloatArray:
934
- return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
935
-
936
- def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
937
- return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
938
-
939
- def aprior(delta: FloatArray) -> FloatArray:
940
- m, s2 = delta.mean(), delta.var()
941
- s2 = max(s2, self.eps)
942
- return (2 * s2 + m**2) / s2
943
-
944
- def bprior(delta: FloatArray) -> FloatArray:
945
- m, s2 = delta.mean(), delta.var()
946
- s2 = max(s2, self.eps)
947
- return (m * s2 + m**3) / s2
948
-
949
- gamma_star = np.empty_like(gamma_hat)
950
- delta_star = np.empty_like(delta_hat)
951
-
952
- for i in range(B):
953
- n_i = n_vec[i]
954
- g_hat_i = gamma_hat[i]
955
- d_hat_i = delta_hat[i]
956
- a_i = aprior(d_hat_i)
957
- b_i = bprior(d_hat_i)
958
-
959
- g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
960
- for _ in range(max_iter):
961
- g_prev, d_prev = g_new, d_new
962
- g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
963
- sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
964
- d_new = postvar(sum2, n_i, a_i, b_i)
965
- if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
966
- self.mean_only
967
- or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
968
- ):
969
- break
970
- gamma_star[i] = g_new
971
- delta_star[i] = 1.0 if self.mean_only else d_new
972
- return gamma_star, delta_star
973
-
974
- def _shrink_gamma(
975
- self,
976
- gamma_hat: FloatArray,
977
- delta_hat: FloatArray,
978
- n_per_batch: dict[str, int] | FloatArray,
979
- *,
980
- parametric: bool,
981
- ) -> FloatArray:
982
- """Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
983
- gamma, _ = self._shrink_gamma_delta(
984
- gamma_hat, delta_hat, n_per_batch, parametric=parametric
985
- )
986
- return gamma
987
-
988
- def transform(
989
- self,
990
- X: ArrayLike,
991
- *,
992
- batch: ArrayLike,
993
- discrete_covariates: ArrayLike | None = None,
994
- continuous_covariates: ArrayLike | None = None,
995
- ) -> pd.DataFrame:
996
- """Transform the data using fitted ComBat parameters."""
997
- if not hasattr(self, "_gamma_star"):
998
- raise ValueError(
999
- "This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
1000
- )
1001
- if not isinstance(X, pd.DataFrame):
1002
- X = pd.DataFrame(X)
1003
- idx = X.index
1004
- batch = self._as_series(batch, idx, "batch")
1005
- unseen = set(batch.cat.categories) - set(self._batch_levels)
1006
- if unseen:
1007
- raise ValueError(f"Unseen batch levels during transform: {unseen}.")
1008
- disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
1009
- cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
1010
-
1011
- method = self.method.lower()
1012
- if method == "johnson":
1013
- return self._transform_johnson(X, batch)
1014
- elif method == "fortin":
1015
- return self._transform_fortin(X, batch, disc, cont)
1016
- elif method == "chen":
1017
- return self._transform_chen(X, batch, disc, cont)
1018
- else:
1019
- raise ValueError(f"Unknown method: {method}.")
1020
-
1021
- def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
1022
- """Johnson transform implementation."""
1023
- pooled = self._pooled_var
1024
- grand = self._grand_mean
1025
-
1026
- Xs = (X - grand) / np.sqrt(pooled)
1027
- X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
1028
-
1029
- for i, lvl in enumerate(self._batch_levels):
1030
- idx = batch == lvl
1031
- if not idx.any():
1032
- continue
1033
- if self.reference_batch is not None and lvl == self.reference_batch:
1034
- X_adj.loc[idx] = X.loc[idx].values
1035
- continue
1036
-
1037
- g = self._gamma_star[i]
1038
- d = self._delta_star[i]
1039
- Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
1040
- X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
1041
- return X_adj
1042
-
1043
- def _transform_fortin(
1044
- self,
1045
- X: pd.DataFrame,
1046
- batch: pd.Series,
1047
- disc: pd.DataFrame | None,
1048
- cont: pd.DataFrame | None,
1049
- ) -> pd.DataFrame:
1050
- """Fortin transform implementation."""
1051
- batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
1052
- if self.reference_batch is not None:
1053
- batch_dummies.loc[:, self.reference_batch] = 1.0
1054
-
1055
- parts = [batch_dummies]
1056
- if disc is not None:
1057
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
1058
- if cont is not None:
1059
- parts.append(cont.astype(float))
1060
-
1061
- design = pd.concat(parts, axis=1).values
1062
-
1063
- X_np = X.values
1064
- stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
1065
- Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
1066
-
1067
- for i, lvl in enumerate(self._batch_levels):
1068
- idx = batch == lvl
1069
- if not idx.any():
1070
- continue
1071
- if self.reference_batch is not None and lvl == self.reference_batch:
1072
- # leave reference samples unchanged
1073
- continue
1074
-
1075
- g = self._gamma_star[i]
1076
- d = self._delta_star[i]
1077
- if self.mean_only:
1078
- Xs[idx] = Xs[idx] - g
1079
- else:
1080
- Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
1081
-
1082
- X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
1083
- return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
1084
-
1085
- def _transform_chen(
1086
- self,
1087
- X: pd.DataFrame,
1088
- batch: pd.Series,
1089
- disc: pd.DataFrame | None,
1090
- cont: pd.DataFrame | None,
1091
- ) -> pd.DataFrame:
1092
- """Chen transform implementation."""
1093
- X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
1094
- X_centered = X_meanvar_adj - self._covbat_pca.mean_
1095
- scores = self._covbat_pca.transform(X_centered)
1096
- n_pc = self._covbat_n_pc
1097
- scores_adj = scores.copy()
1098
-
1099
- for i, lvl in enumerate(self._batch_levels_pc):
1100
- idx = batch == lvl
1101
- if not idx.any():
1102
- continue
1103
- if self.reference_batch is not None and lvl == self.reference_batch:
1104
- continue
1105
- g = self._pc_gamma_star[i]
1106
- d = self._pc_delta_star[i]
1107
- if self.mean_only:
1108
- scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
1109
- else:
1110
- scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
1111
-
1112
- X_recon = self._covbat_pca.inverse_transform(scores_adj) + self._covbat_pca.mean_
1113
- return pd.DataFrame(X_recon, index=X.index, columns=X.columns)
1114
-
1115
-
1116
- class ComBat(BaseEstimator, TransformerMixin):
1117
- """Pipeline-friendly wrapper around `ComBatModel`.
1118
-
1119
- Stores batch (and optional covariates) passed at construction and
1120
- appropriately uses them for separate `fit` and `transform`.
1121
- """
1122
-
1123
- def __init__(
1124
- self,
1125
- batch: ArrayLike,
1126
- *,
1127
- discrete_covariates: ArrayLike | None = None,
1128
- continuous_covariates: ArrayLike | None = None,
1129
- method: str = "johnson",
1130
- parametric: bool = True,
1131
- mean_only: bool = False,
1132
- reference_batch: str | None = None,
1133
- eps: float = 1e-8,
1134
- covbat_cov_thresh: float | int = 0.9,
1135
- compute_metrics: bool = False,
1136
- ) -> None:
1137
- self.batch = batch
1138
- self.discrete_covariates = discrete_covariates
1139
- self.continuous_covariates = continuous_covariates
1140
- self.method = method
1141
- self.parametric = parametric
1142
- self.mean_only = mean_only
1143
- self.reference_batch = reference_batch
1144
- self.eps = eps
1145
- self.covbat_cov_thresh = covbat_cov_thresh
1146
- self.compute_metrics = compute_metrics
1147
- self._model = ComBatModel(
1148
- method=method,
1149
- parametric=parametric,
1150
- mean_only=mean_only,
1151
- reference_batch=reference_batch,
1152
- eps=eps,
1153
- covbat_cov_thresh=covbat_cov_thresh,
1154
- )
1155
-
1156
- def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
1157
- """Fit the ComBat model."""
1158
- idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
1159
- batch_vec = self._subset(self.batch, idx)
1160
- disc = self._subset(self.discrete_covariates, idx)
1161
- cont = self._subset(self.continuous_covariates, idx)
1162
- self._model.fit(
1163
- X,
1164
- batch=batch_vec,
1165
- discrete_covariates=disc,
1166
- continuous_covariates=cont,
1167
- )
1168
- self._fitted_batch = batch_vec
1169
- return self
1170
-
1171
- def transform(self, X: ArrayLike) -> pd.DataFrame:
1172
- """Transform the data using fitted ComBat parameters."""
1173
- idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
1174
- batch_vec = self._subset(self.batch, idx)
1175
- disc = self._subset(self.discrete_covariates, idx)
1176
- cont = self._subset(self.continuous_covariates, idx)
1177
- return self._model.transform(
1178
- X,
1179
- batch=batch_vec,
1180
- discrete_covariates=disc,
1181
- continuous_covariates=cont,
1182
- )
1183
-
1184
- @staticmethod
1185
- def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
1186
- """Subset array-like object by index."""
1187
- if obj is None:
1188
- return None
1189
- if isinstance(obj, (pd.Series, pd.DataFrame)):
1190
- return obj.loc[idx]
1191
- else:
1192
- if isinstance(obj, np.ndarray) and obj.ndim == 1:
1193
- return pd.Series(obj, index=idx)
1194
- else:
1195
- return pd.DataFrame(obj, index=idx)
1196
-
1197
- @property
1198
- def metrics_(self) -> dict[str, Any] | None:
1199
- """Return cached metrics from last fit_transform with compute_metrics=True.
1200
-
1201
- Returns
1202
- -------
1203
- dict or None
1204
- Cached metrics dictionary, or None if no metrics have been computed.
1205
- """
1206
- return getattr(self, "_metrics_cache", None)
1207
-
1208
- def compute_batch_metrics(
1209
- self,
1210
- X: ArrayLike,
1211
- batch: ArrayLike | None = None,
1212
- *,
1213
- pca_components: int | None = None,
1214
- k_neighbors: list[int] | None = None,
1215
- kbet_k0: int | None = None,
1216
- lisi_perplexity: int = 30,
1217
- n_jobs: int = 1,
1218
- ) -> dict[str, Any]:
1219
- """
1220
- Compute batch effect metrics before and after ComBat correction.
1221
-
1222
- Parameters
1223
- ----------
1224
- X : array-like of shape (n_samples, n_features)
1225
- Input data to evaluate.
1226
- batch : array-like of shape (n_samples,), optional
1227
- Batch labels. If None, uses the batch stored at construction.
1228
- pca_components : int, optional
1229
- Number of PCA components for dimensionality reduction before
1230
- computing metrics. If None (default), metrics are computed in
1231
- the original feature space. Must be less than min(n_samples, n_features).
1232
- k_neighbors : list of int, default=[5, 10, 50]
1233
- Values of k for k-NN preservation metric.
1234
- kbet_k0 : int, optional
1235
- Neighborhood size for kBET. Default is 10% of samples.
1236
- lisi_perplexity : int, default=30
1237
- Perplexity for LISI computation.
1238
- n_jobs : int, default=1
1239
- Number of parallel jobs for neighbor computations.
1240
-
1241
- Returns
1242
- -------
1243
- dict
1244
- Dictionary with three main keys:
1245
-
1246
- - ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
1247
- (each with 'before' and 'after' values)
1248
- - ``preservation``: k-NN preservation fractions, distance correlation
1249
- - ``alignment``: Centroid distance, Levene statistic (each with
1250
- 'before' and 'after' values)
1251
-
1252
- Raises
1253
- ------
1254
- ValueError
1255
- If the model is not fitted or if pca_components is invalid.
1256
- """
1257
- if not hasattr(self._model, "_gamma_star"):
1258
- raise ValueError(
1259
- "This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
1260
- )
1261
-
1262
- if not isinstance(X, pd.DataFrame):
1263
- X = pd.DataFrame(X)
1264
-
1265
- idx = X.index
1266
-
1267
- if batch is None:
1268
- batch_vec = self._subset(self.batch, idx)
1269
- else:
1270
- if isinstance(batch, (pd.Series, pd.DataFrame)):
1271
- batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
1272
- elif isinstance(batch, np.ndarray):
1273
- batch_vec = pd.Series(batch, index=idx)
1274
- else:
1275
- batch_vec = pd.Series(batch, index=idx)
1276
-
1277
- batch_labels = np.array(batch_vec)
1278
-
1279
- X_before = X.values
1280
- X_after = self.transform(X).values
1281
-
1282
- n_samples, n_features = X_before.shape
1283
- if kbet_k0 is None:
1284
- kbet_k0 = max(10, int(0.10 * n_samples))
1285
- if k_neighbors is None:
1286
- k_neighbors = [5, 10, 50]
1287
-
1288
- # Validate and apply PCA if requested
1289
- if pca_components is not None:
1290
- max_components = min(n_samples, n_features)
1291
- if pca_components >= max_components:
1292
- raise ValueError(
1293
- f"pca_components={pca_components} must be less than "
1294
- f"min(n_samples, n_features)={max_components}."
1295
- )
1296
- X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
1297
- else:
1298
- X_before_pca = X_before
1299
- X_after_pca = X_after
1300
-
1301
- silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
1302
- silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
1303
-
1304
- db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
1305
- db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
1306
-
1307
- kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
1308
- kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
1309
-
1310
- lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
1311
- lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
1312
-
1313
- var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
1314
- var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
1315
-
1316
- knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
1317
- dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
1318
-
1319
- centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
1320
- centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
1321
-
1322
- levene_before = _levene_median_statistic(X_before, batch_labels)
1323
- levene_after = _levene_median_statistic(X_after, batch_labels)
1324
-
1325
- n_batches = len(np.unique(batch_labels))
1326
-
1327
- metrics = {
1328
- "batch_effect": {
1329
- "silhouette": {
1330
- "before": silhouette_before,
1331
- "after": silhouette_after,
1332
- },
1333
- "davies_bouldin": {
1334
- "before": db_before,
1335
- "after": db_after,
1336
- },
1337
- "kbet": {
1338
- "before": kbet_before,
1339
- "after": kbet_after,
1340
- },
1341
- "lisi": {
1342
- "before": lisi_before,
1343
- "after": lisi_after,
1344
- "max_value": n_batches,
1345
- },
1346
- "variance_ratio": {
1347
- "before": var_ratio_before,
1348
- "after": var_ratio_after,
1349
- },
1350
- },
1351
- "preservation": {
1352
- "knn": knn_results,
1353
- "distance_correlation": dist_corr,
1354
- },
1355
- "alignment": {
1356
- "centroid_distance": {
1357
- "before": centroid_before,
1358
- "after": centroid_after,
1359
- },
1360
- "levene_statistic": {
1361
- "before": levene_before,
1362
- "after": levene_after,
1363
- },
1364
- },
1365
- }
1366
-
1367
- return metrics
1368
-
1369
- def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
1370
- """
1371
- Fit and transform the data, optionally computing metrics.
1372
-
1373
- If ``compute_metrics=True`` was set at construction, batch effect
1374
- metrics are computed and cached in the ``metrics_`` property.
1375
-
1376
- Parameters
1377
- ----------
1378
- X : array-like of shape (n_samples, n_features)
1379
- Input data to fit and transform.
1380
- y : None
1381
- Ignored. Present for API compatibility.
1382
-
1383
- Returns
1384
- -------
1385
- X_transformed : pd.DataFrame
1386
- Batch-corrected data.
1387
- """
1388
- self.fit(X, y)
1389
- X_transformed = self.transform(X)
1390
-
1391
- if self.compute_metrics:
1392
- self._metrics_cache = self.compute_batch_metrics(X)
1393
-
1394
- return X_transformed
1395
-
1396
- def plot_transformation(
1397
- self,
1398
- X: ArrayLike,
1399
- *,
1400
- reduction_method: Literal["pca", "tsne", "umap"] = "pca",
1401
- n_components: Literal[2, 3] = 2,
1402
- plot_type: Literal["static", "interactive"] = "static",
1403
- figsize: tuple[int, int] = (12, 5),
1404
- alpha: float = 0.7,
1405
- point_size: int = 50,
1406
- cmap: str = "Set1",
1407
- title: str | None = None,
1408
- show_legend: bool = True,
1409
- return_embeddings: bool = False,
1410
- **reduction_kwargs,
1411
- ) -> Any | tuple[Any, dict[str, FloatArray]]:
1412
- """
1413
- Visualize the ComBat transformation effect using dimensionality reduction.
1414
-
1415
- It shows a before/after comparison of data transformed by `ComBat` using
1416
- PCA, t-SNE, or UMAP to reduce dimensions for visualization.
1417
-
1418
- Parameters
1419
- ----------
1420
- X : array-like of shape (n_samples, n_features)
1421
- Input data to transform and visualize.
1422
-
1423
- reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
1424
- Dimensionality reduction method.
1425
-
1426
- n_components : {2, 3}, default=2
1427
- Number of components for dimensionality reduction.
1428
-
1429
- plot_type : {`'static'`, `'interactive'`}, default=`'static'`
1430
- Visualization type:
1431
- - `'static'`: matplotlib plots (can be saved as images)
1432
- - `'interactive'`: plotly plots (explorable, requires plotly)
1433
-
1434
- return_embeddings : bool, default=False
1435
- If `True`, return embeddings along with the plot.
1436
-
1437
- **reduction_kwargs : dict
1438
- Additional parameters for reduction methods.
1439
-
1440
- Returns
1441
- -------
1442
- fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
1443
- The figure object containing the plots.
1444
-
1445
- embeddings : dict, optional
1446
- If `return_embeddings=True`, dictionary with:
1447
- - `'original'`: embedding of original data
1448
- - `'transformed'`: embedding of ComBat-transformed data
1449
- """
1450
- if not hasattr(self._model, "_gamma_star"):
1451
- raise ValueError(
1452
- "This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
1453
- )
1454
-
1455
- if n_components not in [2, 3]:
1456
- raise ValueError(f"n_components must be 2 or 3, got {n_components}")
1457
- if reduction_method not in ["pca", "tsne", "umap"]:
1458
- raise ValueError(
1459
- f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
1460
- )
1461
- if plot_type not in ["static", "interactive"]:
1462
- raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
1463
-
1464
- if not isinstance(X, pd.DataFrame):
1465
- X = pd.DataFrame(X)
1466
-
1467
- idx = X.index
1468
- batch_vec = self._subset(self.batch, idx)
1469
- if batch_vec is None:
1470
- raise ValueError("Batch information is required for visualization")
1471
-
1472
- X_transformed = self.transform(X)
1473
-
1474
- X_np = X.values
1475
- X_trans_np = X_transformed.values
1476
-
1477
- if reduction_method == "pca":
1478
- reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
1479
- reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
1480
- elif reduction_method == "tsne":
1481
- tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
1482
- tsne_params.update(reduction_kwargs)
1483
- reducer_orig = TSNE(n_components=n_components, **tsne_params)
1484
- reducer_trans = TSNE(n_components=n_components, **tsne_params)
1485
- else:
1486
- umap_params = {"random_state": 42}
1487
- umap_params.update(reduction_kwargs)
1488
- reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
1489
- reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
1490
-
1491
- X_embedded_orig = reducer_orig.fit_transform(X_np)
1492
- X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
1493
-
1494
- if plot_type == "static":
1495
- fig = self._create_static_plot(
1496
- X_embedded_orig,
1497
- X_embedded_trans,
1498
- batch_vec,
1499
- reduction_method,
1500
- n_components,
1501
- figsize,
1502
- alpha,
1503
- point_size,
1504
- cmap,
1505
- title,
1506
- show_legend,
1507
- )
1508
- else:
1509
- fig = self._create_interactive_plot(
1510
- X_embedded_orig,
1511
- X_embedded_trans,
1512
- batch_vec,
1513
- reduction_method,
1514
- n_components,
1515
- cmap,
1516
- title,
1517
- show_legend,
1518
- )
1519
-
1520
- if return_embeddings:
1521
- embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
1522
- return fig, embeddings
1523
- else:
1524
- return fig
1525
-
1526
- def _create_static_plot(
1527
- self,
1528
- X_orig: FloatArray,
1529
- X_trans: FloatArray,
1530
- batch_labels: pd.Series,
1531
- method: str,
1532
- n_components: int,
1533
- figsize: tuple[int, int],
1534
- alpha: float,
1535
- point_size: int,
1536
- cmap: str,
1537
- title: str | None,
1538
- show_legend: bool,
1539
- ) -> Any:
1540
- """Create static plots using matplotlib."""
1541
-
1542
- fig = plt.figure(figsize=figsize)
1543
-
1544
- unique_batches = batch_labels.drop_duplicates()
1545
- n_batches = len(unique_batches)
1546
-
1547
- if n_batches <= 10:
1548
- colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
1549
- else:
1550
- colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
1551
-
1552
- if n_components == 2:
1553
- ax1 = plt.subplot(1, 2, 1)
1554
- ax2 = plt.subplot(1, 2, 2)
1555
- else:
1556
- ax1 = fig.add_subplot(121, projection="3d")
1557
- ax2 = fig.add_subplot(122, projection="3d")
1558
-
1559
- for i, batch in enumerate(unique_batches):
1560
- mask = batch_labels == batch
1561
- if n_components == 2:
1562
- ax1.scatter(
1563
- X_orig[mask, 0],
1564
- X_orig[mask, 1],
1565
- c=[colors[i]],
1566
- s=point_size,
1567
- alpha=alpha,
1568
- label=f"Batch {batch}",
1569
- edgecolors="black",
1570
- linewidth=0.5,
1571
- )
1572
- else:
1573
- ax1.scatter(
1574
- X_orig[mask, 0],
1575
- X_orig[mask, 1],
1576
- X_orig[mask, 2],
1577
- c=[colors[i]],
1578
- s=point_size,
1579
- alpha=alpha,
1580
- label=f"Batch {batch}",
1581
- edgecolors="black",
1582
- linewidth=0.5,
1583
- )
1584
-
1585
- ax1.set_title(f"Before ComBat correction\n({method.upper()})")
1586
- ax1.set_xlabel(f"{method.upper()}1")
1587
- ax1.set_ylabel(f"{method.upper()}2")
1588
- if n_components == 3:
1589
- ax1.set_zlabel(f"{method.upper()}3")
1590
-
1591
- for i, batch in enumerate(unique_batches):
1592
- mask = batch_labels == batch
1593
- if n_components == 2:
1594
- ax2.scatter(
1595
- X_trans[mask, 0],
1596
- X_trans[mask, 1],
1597
- c=[colors[i]],
1598
- s=point_size,
1599
- alpha=alpha,
1600
- label=f"Batch {batch}",
1601
- edgecolors="black",
1602
- linewidth=0.5,
1603
- )
1604
- else:
1605
- ax2.scatter(
1606
- X_trans[mask, 0],
1607
- X_trans[mask, 1],
1608
- X_trans[mask, 2],
1609
- c=[colors[i]],
1610
- s=point_size,
1611
- alpha=alpha,
1612
- label=f"Batch {batch}",
1613
- edgecolors="black",
1614
- linewidth=0.5,
1615
- )
1616
-
1617
- ax2.set_title(f"After ComBat correction\n({method.upper()})")
1618
- ax2.set_xlabel(f"{method.upper()}1")
1619
- ax2.set_ylabel(f"{method.upper()}2")
1620
- if n_components == 3:
1621
- ax2.set_zlabel(f"{method.upper()}3")
1622
-
1623
- if show_legend and n_batches <= 20:
1624
- ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
1625
-
1626
- if title is None:
1627
- title = f"ComBat correction effect visualized with {method.upper()}"
1628
- fig.suptitle(title, fontsize=14, fontweight="bold")
1629
-
1630
- plt.tight_layout()
1631
- return fig
1632
-
1633
- def _create_interactive_plot(
1634
- self,
1635
- X_orig: FloatArray,
1636
- X_trans: FloatArray,
1637
- batch_labels: pd.Series,
1638
- method: str,
1639
- n_components: int,
1640
- cmap: str,
1641
- title: str | None,
1642
- show_legend: bool,
1643
- ) -> Any:
1644
- """Create interactive plots using plotly."""
1645
- if n_components == 2:
1646
- fig = make_subplots(
1647
- rows=1,
1648
- cols=2,
1649
- subplot_titles=(
1650
- f"Before ComBat correction ({method.upper()})",
1651
- f"After ComBat correction ({method.upper()})",
1652
- ),
1653
- )
1654
- else:
1655
- fig = make_subplots(
1656
- rows=1,
1657
- cols=2,
1658
- specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
1659
- subplot_titles=(
1660
- f"Before ComBat correction ({method.upper()})",
1661
- f"After ComBat correction ({method.upper()})",
1662
- ),
1663
- )
1664
-
1665
- unique_batches = batch_labels.drop_duplicates()
1666
-
1667
- n_batches = len(unique_batches)
1668
- cmap_func = matplotlib.colormaps.get_cmap(cmap)
1669
- color_list = [
1670
- mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
1671
- ]
1672
-
1673
- batch_to_color = dict(zip(unique_batches, color_list, strict=True))
1674
-
1675
- for batch in unique_batches:
1676
- mask = batch_labels == batch
1677
-
1678
- if n_components == 2:
1679
- fig.add_trace(
1680
- go.Scatter(
1681
- x=X_orig[mask, 0],
1682
- y=X_orig[mask, 1],
1683
- mode="markers",
1684
- name=f"Batch {batch}",
1685
- marker={
1686
- "size": 8,
1687
- "color": batch_to_color[batch],
1688
- "line": {"width": 1, "color": "black"},
1689
- },
1690
- showlegend=False,
1691
- ),
1692
- row=1,
1693
- col=1,
1694
- )
1695
-
1696
- fig.add_trace(
1697
- go.Scatter(
1698
- x=X_trans[mask, 0],
1699
- y=X_trans[mask, 1],
1700
- mode="markers",
1701
- name=f"Batch {batch}",
1702
- marker={
1703
- "size": 8,
1704
- "color": batch_to_color[batch],
1705
- "line": {"width": 1, "color": "black"},
1706
- },
1707
- showlegend=show_legend,
1708
- ),
1709
- row=1,
1710
- col=2,
1711
- )
1712
- else:
1713
- fig.add_trace(
1714
- go.Scatter3d(
1715
- x=X_orig[mask, 0],
1716
- y=X_orig[mask, 1],
1717
- z=X_orig[mask, 2],
1718
- mode="markers",
1719
- name=f"Batch {batch}",
1720
- marker={
1721
- "size": 5,
1722
- "color": batch_to_color[batch],
1723
- "line": {"width": 0.5, "color": "black"},
1724
- },
1725
- showlegend=False,
1726
- ),
1727
- row=1,
1728
- col=1,
1729
- )
1730
-
1731
- fig.add_trace(
1732
- go.Scatter3d(
1733
- x=X_trans[mask, 0],
1734
- y=X_trans[mask, 1],
1735
- z=X_trans[mask, 2],
1736
- mode="markers",
1737
- name=f"Batch {batch}",
1738
- marker={
1739
- "size": 5,
1740
- "color": batch_to_color[batch],
1741
- "line": {"width": 0.5, "color": "black"},
1742
- },
1743
- showlegend=show_legend,
1744
- ),
1745
- row=1,
1746
- col=2,
1747
- )
1748
-
1749
- if title is None:
1750
- title = f"ComBat correction effect visualized with {method.upper()}"
1751
-
1752
- fig.update_layout(
1753
- title=title,
1754
- title_font_size=16,
1755
- height=600,
1756
- showlegend=show_legend,
1757
- hovermode="closest",
1758
- )
1759
-
1760
- axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
1761
-
1762
- if n_components == 2:
1763
- fig.update_xaxes(title_text=axis_labels[0])
1764
- fig.update_yaxes(title_text=axis_labels[1])
1765
- else:
1766
- fig.update_scenes(
1767
- xaxis_title=axis_labels[0],
1768
- yaxis_title=axis_labels[1],
1769
- zaxis_title=axis_labels[2],
1770
- )
1771
-
1772
- return fig