combatlearn 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/combat.py CHANGED
@@ -8,28 +8,552 @@
8
8
  `ComBat` makes the model compatible with scikit-learn by stashing
9
9
  the batch (and optional covariates) at construction.
10
10
  """
11
+
11
12
  from __future__ import annotations
12
13
 
14
+ import warnings
15
+ from typing import Any, Literal
16
+
17
+ import matplotlib
18
+ import matplotlib.colors as mcolors
19
+ import matplotlib.pyplot as plt
13
20
  import numpy as np
14
21
  import numpy.linalg as la
22
+ import numpy.typing as npt
15
23
  import pandas as pd
24
+ import plotly.graph_objects as go
25
+ import umap
26
+ from plotly.subplots import make_subplots
27
+ from scipy.spatial.distance import pdist
28
+ from scipy.stats import chi2, levene, spearmanr
16
29
  from sklearn.base import BaseEstimator, TransformerMixin
17
30
  from sklearn.decomposition import PCA
18
31
  from sklearn.manifold import TSNE
19
- import matplotlib
20
- import matplotlib.pyplot as plt
21
- import matplotlib.colors as mcolors
22
- from typing import Literal, Optional, Union, Dict, Tuple, Any
23
- import numpy.typing as npt
24
- import warnings
25
- import umap
26
- import plotly.graph_objects as go
27
- from plotly.subplots import make_subplots
32
+ from sklearn.metrics import davies_bouldin_score, silhouette_score
33
+ from sklearn.neighbors import NearestNeighbors
28
34
 
29
- ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
35
+ ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
30
36
  FloatArray = npt.NDArray[np.float64]
31
37
 
32
38
 
39
+ def _compute_pca_embedding(
40
+ X_before: np.ndarray,
41
+ X_after: np.ndarray,
42
+ n_components: int,
43
+ ) -> tuple[np.ndarray, np.ndarray, PCA]:
44
+ """
45
+ Compute PCA embeddings for both datasets.
46
+
47
+ Fits PCA on X_before and applies to both datasets.
48
+
49
+ Parameters
50
+ ----------
51
+ X_before : np.ndarray
52
+ Original data before correction.
53
+ X_after : np.ndarray
54
+ Corrected data.
55
+ n_components : int
56
+ Number of PCA components.
57
+
58
+ Returns
59
+ -------
60
+ X_before_pca : np.ndarray
61
+ PCA-transformed original data.
62
+ X_after_pca : np.ndarray
63
+ PCA-transformed corrected data.
64
+ pca : PCA
65
+ Fitted PCA model.
66
+ """
67
+ n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
68
+ pca = PCA(n_components=n_components, random_state=42)
69
+ X_before_pca = pca.fit_transform(X_before)
70
+ X_after_pca = pca.transform(X_after)
71
+ return X_before_pca, X_after_pca, pca
72
+
73
+
74
+ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
75
+ """
76
+ Compute silhouette coefficient using batch as cluster labels.
77
+
78
+ Lower values after correction indicate better batch mixing.
79
+ Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
80
+
81
+ Parameters
82
+ ----------
83
+ X : np.ndarray
84
+ Data matrix.
85
+ batch_labels : np.ndarray
86
+ Batch labels for each sample.
87
+
88
+ Returns
89
+ -------
90
+ float
91
+ Silhouette coefficient.
92
+ """
93
+ unique_batches = np.unique(batch_labels)
94
+ if len(unique_batches) < 2:
95
+ return 0.0
96
+ try:
97
+ return silhouette_score(X, batch_labels, metric="euclidean")
98
+ except Exception:
99
+ return 0.0
100
+
101
+
102
+ def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
103
+ """
104
+ Compute Davies-Bouldin index using batch labels.
105
+
106
+ Lower values indicate better batch mixing.
107
+ Range: [0, inf), 0 = perfect batch overlap.
108
+
109
+ Parameters
110
+ ----------
111
+ X : np.ndarray
112
+ Data matrix.
113
+ batch_labels : np.ndarray
114
+ Batch labels for each sample.
115
+
116
+ Returns
117
+ -------
118
+ float
119
+ Davies-Bouldin index.
120
+ """
121
+ unique_batches = np.unique(batch_labels)
122
+ if len(unique_batches) < 2:
123
+ return 0.0
124
+ try:
125
+ return davies_bouldin_score(X, batch_labels)
126
+ except Exception:
127
+ return 0.0
128
+
129
+
130
+ def _kbet_score(
131
+ X: np.ndarray,
132
+ batch_labels: np.ndarray,
133
+ k0: int,
134
+ alpha: float = 0.05,
135
+ ) -> tuple[float, float]:
136
+ """
137
+ Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
138
+
139
+ Tests if local batch proportions match global batch proportions.
140
+ Higher acceptance rate = better batch mixing.
141
+
142
+ Reference: Buttner et al. (2019) Nature Methods
143
+
144
+ Parameters
145
+ ----------
146
+ X : np.ndarray
147
+ Data matrix.
148
+ batch_labels : np.ndarray
149
+ Batch labels for each sample.
150
+ k0 : int
151
+ Neighborhood size.
152
+ alpha : float
153
+ Significance level for chi-squared test.
154
+
155
+ Returns
156
+ -------
157
+ acceptance_rate : float
158
+ Fraction of samples where H0 (uniform mixing) is accepted.
159
+ mean_stat : float
160
+ Mean chi-squared statistic across samples.
161
+ """
162
+ n_samples = X.shape[0]
163
+ unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
164
+ n_batches = len(unique_batches)
165
+
166
+ if n_batches < 2:
167
+ return 1.0, 0.0
168
+
169
+ global_freq = batch_counts / n_samples
170
+ k0 = min(k0, n_samples - 1)
171
+
172
+ nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
173
+ nn.fit(X)
174
+ _, indices = nn.kneighbors(X)
175
+
176
+ chi2_stats = []
177
+ p_values = []
178
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
179
+
180
+ for i in range(n_samples):
181
+ neighbors = indices[i, 1 : k0 + 1]
182
+ neighbor_batches = batch_labels[neighbors]
183
+
184
+ observed = np.zeros(n_batches)
185
+ for nb in neighbor_batches:
186
+ observed[batch_to_idx[nb]] += 1
187
+
188
+ expected = global_freq * k0
189
+
190
+ mask = expected > 0
191
+ if mask.sum() < 2:
192
+ continue
193
+
194
+ stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
195
+ df = max(1, mask.sum() - 1)
196
+ p_val = 1 - chi2.cdf(stat, df)
197
+
198
+ chi2_stats.append(stat)
199
+ p_values.append(p_val)
200
+
201
+ if len(p_values) == 0:
202
+ return 1.0, 0.0
203
+
204
+ acceptance_rate = np.mean(np.array(p_values) > alpha)
205
+ mean_stat = np.mean(chi2_stats)
206
+
207
+ return acceptance_rate, mean_stat
208
+
209
+
210
+ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
211
+ """
212
+ Binary search for sigma to achieve target perplexity.
213
+
214
+ Used in LISI computation.
215
+
216
+ Parameters
217
+ ----------
218
+ distances : np.ndarray
219
+ Distances to neighbors.
220
+ target_perplexity : float
221
+ Target perplexity value.
222
+ tol : float
223
+ Tolerance for convergence.
224
+
225
+ Returns
226
+ -------
227
+ float
228
+ Sigma value.
229
+ """
230
+ target_H = np.log2(target_perplexity + 1e-10)
231
+
232
+ sigma_min, sigma_max = 1e-10, 1e10
233
+ sigma = 1.0
234
+
235
+ for _ in range(50):
236
+ P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
237
+ P_sum = P.sum()
238
+ if P_sum < 1e-10:
239
+ sigma = (sigma + sigma_max) / 2
240
+ continue
241
+ P = P / P_sum
242
+ P = np.clip(P, 1e-10, 1.0)
243
+ H = -np.sum(P * np.log2(P))
244
+
245
+ if abs(H - target_H) < tol:
246
+ break
247
+ elif target_H > H:
248
+ sigma_min = sigma
249
+ else:
250
+ sigma_max = sigma
251
+ sigma = (sigma_min + sigma_max) / 2
252
+
253
+ return sigma
254
+
255
+
256
+ def _lisi_score(
257
+ X: np.ndarray,
258
+ batch_labels: np.ndarray,
259
+ perplexity: int = 30,
260
+ ) -> float:
261
+ """
262
+ Compute mean Local Inverse Simpson's Index (LISI).
263
+
264
+ Range: [1, n_batches], where n_batches = perfect mixing.
265
+ Higher = better batch mixing.
266
+
267
+ Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
268
+
269
+ Parameters
270
+ ----------
271
+ X : np.ndarray
272
+ Data matrix.
273
+ batch_labels : np.ndarray
274
+ Batch labels for each sample.
275
+ perplexity : int
276
+ Perplexity for Gaussian kernel.
277
+
278
+ Returns
279
+ -------
280
+ float
281
+ Mean LISI score.
282
+ """
283
+ n_samples = X.shape[0]
284
+ unique_batches = np.unique(batch_labels)
285
+ n_batches = len(unique_batches)
286
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
287
+
288
+ if n_batches < 2:
289
+ return 1.0
290
+
291
+ k = min(3 * perplexity, n_samples - 1)
292
+
293
+ nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
294
+ nn.fit(X)
295
+ distances, indices = nn.kneighbors(X)
296
+
297
+ distances = distances[:, 1:]
298
+ indices = indices[:, 1:]
299
+
300
+ lisi_values = []
301
+
302
+ for i in range(n_samples):
303
+ sigma = _find_sigma(distances[i], perplexity)
304
+
305
+ P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
306
+ P_sum = P.sum()
307
+ if P_sum < 1e-10:
308
+ lisi_values.append(1.0)
309
+ continue
310
+ P = P / P_sum
311
+
312
+ neighbor_batches = batch_labels[indices[i]]
313
+ batch_probs = np.zeros(n_batches)
314
+ for j, nb in enumerate(neighbor_batches):
315
+ batch_probs[batch_to_idx[nb]] += P[j]
316
+
317
+ simpson = np.sum(batch_probs**2)
318
+ lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
319
+ lisi_values.append(lisi)
320
+
321
+ return np.mean(lisi_values)
322
+
323
+
324
+ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
325
+ """
326
+ Compute between-batch to within-batch variance ratio.
327
+
328
+ Similar to F-statistic in one-way ANOVA.
329
+ Lower ratio after correction = better batch effect removal.
330
+
331
+ Parameters
332
+ ----------
333
+ X : np.ndarray
334
+ Data matrix.
335
+ batch_labels : np.ndarray
336
+ Batch labels for each sample.
337
+
338
+ Returns
339
+ -------
340
+ float
341
+ Variance ratio (between/within).
342
+ """
343
+ unique_batches = np.unique(batch_labels)
344
+ n_batches = len(unique_batches)
345
+ n_samples = X.shape[0]
346
+
347
+ if n_batches < 2:
348
+ return 0.0
349
+
350
+ grand_mean = np.mean(X, axis=0)
351
+
352
+ between_var = 0.0
353
+ within_var = 0.0
354
+
355
+ for batch in unique_batches:
356
+ mask = batch_labels == batch
357
+ n_b = np.sum(mask)
358
+ X_batch = X[mask]
359
+ batch_mean = np.mean(X_batch, axis=0)
360
+
361
+ between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
362
+ within_var += np.sum((X_batch - batch_mean) ** 2)
363
+
364
+ between_var /= n_batches - 1
365
+ within_var /= n_samples - n_batches
366
+
367
+ if within_var < 1e-10:
368
+ return 0.0
369
+
370
+ return between_var / within_var
371
+
372
+
373
+ def _knn_preservation(
374
+ X_before: np.ndarray,
375
+ X_after: np.ndarray,
376
+ k_values: list[int],
377
+ n_jobs: int = 1,
378
+ ) -> dict[int, float]:
379
+ """
380
+ Compute fraction of k-nearest neighbors preserved after correction.
381
+
382
+ Range: [0, 1], where 1 = perfect preservation.
383
+ Higher = better biological structure preservation.
384
+
385
+ Parameters
386
+ ----------
387
+ X_before : np.ndarray
388
+ Original data.
389
+ X_after : np.ndarray
390
+ Corrected data.
391
+ k_values : list of int
392
+ Values of k for k-NN.
393
+ n_jobs : int
394
+ Number of parallel jobs.
395
+
396
+ Returns
397
+ -------
398
+ dict
399
+ Mapping from k to preservation fraction.
400
+ """
401
+ results = {}
402
+ max_k = max(k_values)
403
+ max_k = min(max_k, X_before.shape[0] - 1)
404
+
405
+ nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
406
+ nn_before.fit(X_before)
407
+ _, indices_before = nn_before.kneighbors(X_before)
408
+
409
+ nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
410
+ nn_after.fit(X_after)
411
+ _, indices_after = nn_after.kneighbors(X_after)
412
+
413
+ for k in k_values:
414
+ if k > max_k:
415
+ results[k] = 0.0
416
+ continue
417
+
418
+ overlaps = []
419
+ for i in range(X_before.shape[0]):
420
+ neighbors_before = set(indices_before[i, 1 : k + 1])
421
+ neighbors_after = set(indices_after[i, 1 : k + 1])
422
+ overlap = len(neighbors_before & neighbors_after) / k
423
+ overlaps.append(overlap)
424
+
425
+ results[k] = np.mean(overlaps)
426
+
427
+ return results
428
+
429
+
430
+ def _pairwise_distance_correlation(
431
+ X_before: np.ndarray,
432
+ X_after: np.ndarray,
433
+ subsample: int = 1000,
434
+ random_state: int = 42,
435
+ ) -> float:
436
+ """
437
+ Compute Spearman correlation of pairwise distances.
438
+
439
+ Range: [-1, 1], where 1 = perfect rank preservation.
440
+ Higher = better relative relationship preservation.
441
+
442
+ Parameters
443
+ ----------
444
+ X_before : np.ndarray
445
+ Original data.
446
+ X_after : np.ndarray
447
+ Corrected data.
448
+ subsample : int
449
+ Maximum samples to use (for efficiency).
450
+ random_state : int
451
+ Random seed for subsampling.
452
+
453
+ Returns
454
+ -------
455
+ float
456
+ Spearman correlation coefficient.
457
+ """
458
+ n_samples = X_before.shape[0]
459
+
460
+ if n_samples > subsample:
461
+ rng = np.random.default_rng(random_state)
462
+ idx = rng.choice(n_samples, subsample, replace=False)
463
+ X_before = X_before[idx]
464
+ X_after = X_after[idx]
465
+
466
+ dist_before = pdist(X_before, metric="euclidean")
467
+ dist_after = pdist(X_after, metric="euclidean")
468
+
469
+ if len(dist_before) == 0:
470
+ return 1.0
471
+
472
+ corr, _ = spearmanr(dist_before, dist_after)
473
+
474
+ if np.isnan(corr):
475
+ return 1.0
476
+
477
+ return corr
478
+
479
+
480
+ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
481
+ """
482
+ Compute mean pairwise Euclidean distance between batch centroids.
483
+
484
+ Lower after correction = better batch alignment.
485
+
486
+ Parameters
487
+ ----------
488
+ X : np.ndarray
489
+ Data matrix.
490
+ batch_labels : np.ndarray
491
+ Batch labels for each sample.
492
+
493
+ Returns
494
+ -------
495
+ float
496
+ Mean pairwise distance between centroids.
497
+ """
498
+ unique_batches = np.unique(batch_labels)
499
+ n_batches = len(unique_batches)
500
+
501
+ if n_batches < 2:
502
+ return 0.0
503
+
504
+ centroids = []
505
+ for batch in unique_batches:
506
+ mask = batch_labels == batch
507
+ centroid = np.mean(X[mask], axis=0)
508
+ centroids.append(centroid)
509
+
510
+ centroids = np.array(centroids)
511
+ distances = pdist(centroids, metric="euclidean")
512
+
513
+ return np.mean(distances)
514
+
515
+
516
+ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
517
+ """
518
+ Compute median Levene test statistic across features.
519
+
520
+ Lower statistic = more homogeneous variances across batches.
521
+
522
+ Parameters
523
+ ----------
524
+ X : np.ndarray
525
+ Data matrix.
526
+ batch_labels : np.ndarray
527
+ Batch labels for each sample.
528
+
529
+ Returns
530
+ -------
531
+ float
532
+ Median Levene test statistic.
533
+ """
534
+ unique_batches = np.unique(batch_labels)
535
+ if len(unique_batches) < 2:
536
+ return 0.0
537
+
538
+ levene_stats = []
539
+ for j in range(X.shape[1]):
540
+ groups = [X[batch_labels == b, j] for b in unique_batches]
541
+ groups = [g for g in groups if len(g) > 0]
542
+ if len(groups) < 2:
543
+ continue
544
+ try:
545
+ stat, _ = levene(*groups, center="median")
546
+ if not np.isnan(stat):
547
+ levene_stats.append(stat)
548
+ except Exception:
549
+ continue
550
+
551
+ if len(levene_stats) == 0:
552
+ return 0.0
553
+
554
+ return np.median(levene_stats)
555
+
556
+
33
557
  class ComBatModel:
34
558
  """ComBat algorithm.
35
559
 
@@ -59,24 +583,24 @@ class ComBatModel:
59
583
  method: Literal["johnson", "fortin", "chen"] = "johnson",
60
584
  parametric: bool = True,
61
585
  mean_only: bool = False,
62
- reference_batch: Optional[str] = None,
586
+ reference_batch: str | None = None,
63
587
  eps: float = 1e-8,
64
- covbat_cov_thresh: Union[float, int] = 0.9,
588
+ covbat_cov_thresh: float | int = 0.9,
65
589
  ) -> None:
66
590
  self.method: str = method
67
591
  self.parametric: bool = parametric
68
592
  self.mean_only: bool = bool(mean_only)
69
- self.reference_batch: Optional[str] = reference_batch
593
+ self.reference_batch: str | None = reference_batch
70
594
  self.eps: float = float(eps)
71
- self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
595
+ self.covbat_cov_thresh: float | int = covbat_cov_thresh
72
596
 
73
597
  self._batch_levels: pd.Index
74
598
  self._grand_mean: pd.Series
75
599
  self._pooled_var: pd.Series
76
600
  self._gamma_star: FloatArray
77
601
  self._delta_star: FloatArray
78
- self._n_per_batch: Dict[str, int]
79
- self._reference_batch_idx: Optional[int]
602
+ self._n_per_batch: dict[str, int]
603
+ self._reference_batch_idx: int | None
80
604
  self._beta_hat_nonbatch: FloatArray
81
605
  self._n_batch: int
82
606
  self._p_design: int
@@ -97,26 +621,15 @@ class ComBatModel:
97
621
  raise TypeError("covbat_cov_thresh must be float or int.")
98
622
 
99
623
  @staticmethod
100
- def _as_series(
101
- arr: ArrayLike,
102
- index: pd.Index,
103
- name: str
104
- ) -> pd.Series:
624
+ def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
105
625
  """Convert array-like to categorical Series with validation."""
106
- if isinstance(arr, pd.Series):
107
- ser = arr.copy()
108
- else:
109
- ser = pd.Series(arr, index=index, name=name)
626
+ ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
110
627
  if not ser.index.equals(index):
111
628
  raise ValueError(f"`{name}` index mismatch with `X`.")
112
629
  return ser.astype("category")
113
630
 
114
631
  @staticmethod
115
- def _to_df(
116
- arr: Optional[ArrayLike],
117
- index: pd.Index,
118
- name: str
119
- ) -> Optional[pd.DataFrame]:
632
+ def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
120
633
  """Convert array-like to DataFrame."""
121
634
  if arr is None:
122
635
  return None
@@ -131,11 +644,11 @@ class ComBatModel:
131
644
  def fit(
132
645
  self,
133
646
  X: ArrayLike,
134
- y: Optional[ArrayLike] = None,
647
+ y: ArrayLike | None = None,
135
648
  *,
136
649
  batch: ArrayLike,
137
- discrete_covariates: Optional[ArrayLike] = None,
138
- continuous_covariates: Optional[ArrayLike] = None,
650
+ discrete_covariates: ArrayLike | None = None,
651
+ continuous_covariates: ArrayLike | None = None,
139
652
  ) -> ComBatModel:
140
653
  """Fit the ComBat model."""
141
654
  method = self.method.lower()
@@ -157,9 +670,7 @@ class ComBatModel:
157
670
 
158
671
  if method == "johnson":
159
672
  if disc is not None or cont is not None:
160
- warnings.warn(
161
- "Covariates are ignored when using method='johnson'."
162
- )
673
+ warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
163
674
  self._fit_johnson(X, batch)
164
675
  elif method == "fortin":
165
676
  self._fit_fortin(X, batch, disc, cont)
@@ -167,11 +678,7 @@ class ComBatModel:
167
678
  self._fit_chen(X, batch, disc, cont)
168
679
  return self
169
680
 
170
- def _fit_johnson(
171
- self,
172
- X: pd.DataFrame,
173
- batch: pd.Series
174
- ) -> None:
681
+ def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
175
682
  """Johnson et al. (2007) ComBat."""
176
683
  self._batch_levels = batch.cat.categories
177
684
  pooled_var = X.var(axis=0, ddof=1) + self.eps
@@ -179,10 +686,10 @@ class ComBatModel:
179
686
 
180
687
  Xs = (X - grand_mean) / np.sqrt(pooled_var)
181
688
 
182
- n_per_batch: Dict[str, int] = {}
689
+ n_per_batch: dict[str, int] = {}
183
690
  gamma_hat: list[npt.NDArray[np.float64]] = []
184
691
  delta_hat: list[npt.NDArray[np.float64]] = []
185
-
692
+
186
693
  for lvl in self._batch_levels:
187
694
  idx = batch == lvl
188
695
  n_b = int(idx.sum())
@@ -227,8 +734,8 @@ class ComBatModel:
227
734
  self,
228
735
  X: pd.DataFrame,
229
736
  batch: pd.Series,
230
- disc: Optional[pd.DataFrame],
231
- cont: Optional[pd.DataFrame],
737
+ disc: pd.DataFrame | None,
738
+ cont: pd.DataFrame | None,
232
739
  ) -> None:
233
740
  """Fortin et al. (2018) neuroComBat."""
234
741
  self._batch_levels = batch.cat.categories
@@ -246,11 +753,7 @@ class ComBatModel:
246
753
 
247
754
  parts: list[pd.DataFrame] = [batch_dummies]
248
755
  if disc is not None:
249
- parts.append(
250
- pd.get_dummies(
251
- disc.astype("category"), drop_first=True
252
- ).astype(float)
253
- )
756
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
254
757
 
255
758
  if cont is not None:
256
759
  parts.append(cont.astype(float))
@@ -265,7 +768,7 @@ class ComBatModel:
265
768
  self._beta_hat_nonbatch = beta_hat[n_batch:]
266
769
 
267
770
  n_per_batch = batch.value_counts().sort_index().astype(int).values
268
- self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
771
+ self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
269
772
 
270
773
  if self.reference_batch is not None:
271
774
  ref_idx = list(self._batch_levels).index(self.reference_batch)
@@ -283,30 +786,25 @@ class ComBatModel:
283
786
  else:
284
787
  resid = X_np - design @ beta_hat
285
788
  denom = n_samples
286
- var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
789
+ var_pooled = (resid**2).sum(axis=0) / denom + self.eps
287
790
  self._pooled_var = pd.Series(var_pooled, index=X.columns)
288
791
 
289
792
  stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
290
793
  Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
291
794
 
292
- gamma_hat = np.vstack(
293
- [Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
294
- )
795
+ gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
295
796
  delta_hat = np.vstack(
296
- [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
297
- for lvl in self._batch_levels]
797
+ [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
298
798
  )
299
799
 
300
800
  if self.mean_only:
301
801
  gamma_star = self._shrink_gamma(
302
- gamma_hat, delta_hat, n_per_batch,
303
- parametric = self.parametric
802
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
304
803
  )
305
804
  delta_star = np.ones_like(delta_hat)
306
805
  else:
307
806
  gamma_star, delta_star = self._shrink_gamma_delta(
308
- gamma_hat, delta_hat, n_per_batch,
309
- parametric = self.parametric
807
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
310
808
  )
311
809
 
312
810
  if ref_idx is not None:
@@ -317,15 +815,15 @@ class ComBatModel:
317
815
 
318
816
  self._gamma_star = gamma_star
319
817
  self._delta_star = delta_star
320
- self._n_batch = n_batch
818
+ self._n_batch = n_batch
321
819
  self._p_design = p_design
322
-
820
+
323
821
  def _fit_chen(
324
822
  self,
325
823
  X: pd.DataFrame,
326
824
  batch: pd.Series,
327
- disc: Optional[pd.DataFrame],
328
- cont: Optional[pd.DataFrame],
825
+ disc: pd.DataFrame | None,
826
+ cont: pd.DataFrame | None,
329
827
  ) -> None:
330
828
  """Chen et al. (2022) CovBat."""
331
829
  self._fit_fortin(X, batch, disc, cont)
@@ -344,7 +842,7 @@ class ComBatModel:
344
842
  self._covbat_n_pc = n_pc
345
843
 
346
844
  scores = pca.transform(X_centered)[:, :n_pc]
347
- scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
845
+ scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
348
846
  self._batch_levels_pc = self._batch_levels
349
847
  n_per_batch = self._n_per_batch
350
848
 
@@ -383,12 +881,12 @@ class ComBatModel:
383
881
  self,
384
882
  gamma_hat: FloatArray,
385
883
  delta_hat: FloatArray,
386
- n_per_batch: Union[Dict[str, int], FloatArray],
884
+ n_per_batch: dict[str, int] | FloatArray,
387
885
  *,
388
886
  parametric: bool,
389
887
  max_iter: int = 100,
390
888
  tol: float = 1e-4,
391
- ) -> Tuple[FloatArray, FloatArray]:
889
+ ) -> tuple[FloatArray, FloatArray]:
392
890
  """Empirical Bayes shrinkage estimation."""
393
891
  if parametric:
394
892
  gamma_bar = gamma_hat.mean(axis=0)
@@ -396,10 +894,14 @@ class ComBatModel:
396
894
  a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
397
895
  b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
398
896
 
399
- B, p = gamma_hat.shape
897
+ B, _p = gamma_hat.shape
400
898
  gamma_star = np.empty_like(gamma_hat)
401
899
  delta_star = np.empty_like(delta_hat)
402
- n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
900
+ n_vec = (
901
+ np.array(list(n_per_batch.values()))
902
+ if isinstance(n_per_batch, dict)
903
+ else n_per_batch
904
+ )
403
905
 
404
906
  for i in range(B):
405
907
  n_i = n_vec[i]
@@ -413,8 +915,12 @@ class ComBatModel:
413
915
  return gamma_star, delta_star
414
916
 
415
917
  else:
416
- B, p = gamma_hat.shape
417
- n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
918
+ B, _p = gamma_hat.shape
919
+ n_vec = (
920
+ np.array(list(n_per_batch.values()))
921
+ if isinstance(n_per_batch, dict)
922
+ else n_per_batch
923
+ )
418
924
  gamma_bar = gamma_hat.mean(axis=0)
419
925
  t2 = gamma_hat.var(axis=0, ddof=1)
420
926
 
@@ -423,27 +929,22 @@ class ComBatModel:
423
929
  g_bar: FloatArray,
424
930
  n: float,
425
931
  d_star: FloatArray,
426
- t2_: FloatArray
932
+ t2_: FloatArray,
427
933
  ) -> FloatArray:
428
934
  return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
429
935
 
430
- def postvar(
431
- sum2: FloatArray,
432
- n: float,
433
- a: FloatArray,
434
- b: FloatArray
435
- ) -> FloatArray:
936
+ def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
436
937
  return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
437
938
 
438
939
  def aprior(delta: FloatArray) -> FloatArray:
439
940
  m, s2 = delta.mean(), delta.var()
440
941
  s2 = max(s2, self.eps)
441
- return (2 * s2 + m ** 2) / s2
942
+ return (2 * s2 + m**2) / s2
442
943
 
443
944
  def bprior(delta: FloatArray) -> FloatArray:
444
945
  m, s2 = delta.mean(), delta.var()
445
946
  s2 = max(s2, self.eps)
446
- return (m * s2 + m ** 3) / s2
947
+ return (m * s2 + m**3) / s2
447
948
 
448
949
  gamma_star = np.empty_like(gamma_hat)
449
950
  delta_star = np.empty_like(delta_hat)
@@ -462,7 +963,8 @@ class ComBatModel:
462
963
  sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
463
964
  d_new = postvar(sum2, n_i, a_i, b_i)
464
965
  if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
465
- self.mean_only or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
966
+ self.mean_only
967
+ or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
466
968
  ):
467
969
  break
468
970
  gamma_star[i] = g_new
@@ -473,12 +975,14 @@ class ComBatModel:
473
975
  self,
474
976
  gamma_hat: FloatArray,
475
977
  delta_hat: FloatArray,
476
- n_per_batch: Union[Dict[str, int], FloatArray],
978
+ n_per_batch: dict[str, int] | FloatArray,
477
979
  *,
478
980
  parametric: bool,
479
981
  ) -> FloatArray:
480
- """Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
481
- gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
982
+ """Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
983
+ gamma, _ = self._shrink_gamma_delta(
984
+ gamma_hat, delta_hat, n_per_batch, parametric=parametric
985
+ )
482
986
  return gamma
483
987
 
484
988
  def transform(
@@ -486,12 +990,14 @@ class ComBatModel:
486
990
  X: ArrayLike,
487
991
  *,
488
992
  batch: ArrayLike,
489
- discrete_covariates: Optional[ArrayLike] = None,
490
- continuous_covariates: Optional[ArrayLike] = None,
993
+ discrete_covariates: ArrayLike | None = None,
994
+ continuous_covariates: ArrayLike | None = None,
491
995
  ) -> pd.DataFrame:
492
996
  """Transform the data using fitted ComBat parameters."""
493
997
  if not hasattr(self, "_gamma_star"):
494
- raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
998
+ raise ValueError(
999
+ "This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
1000
+ )
495
1001
  if not isinstance(X, pd.DataFrame):
496
1002
  X = pd.DataFrame(X)
497
1003
  idx = X.index
@@ -512,11 +1018,7 @@ class ComBatModel:
512
1018
  else:
513
1019
  raise ValueError(f"Unknown method: {method}.")
514
1020
 
515
- def _transform_johnson(
516
- self,
517
- X: pd.DataFrame,
518
- batch: pd.Series
519
- ) -> pd.DataFrame:
1021
+ def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
520
1022
  """Johnson transform implementation."""
521
1023
  pooled = self._pooled_var
522
1024
  grand = self._grand_mean
@@ -534,10 +1036,7 @@ class ComBatModel:
534
1036
 
535
1037
  g = self._gamma_star[i]
536
1038
  d = self._delta_star[i]
537
- if self.mean_only:
538
- Xb = Xs.loc[idx] - g
539
- else:
540
- Xb = (Xs.loc[idx] - g) / np.sqrt(d)
1039
+ Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
541
1040
  X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
542
1041
  return X_adj
543
1042
 
@@ -545,8 +1044,8 @@ class ComBatModel:
545
1044
  self,
546
1045
  X: pd.DataFrame,
547
1046
  batch: pd.Series,
548
- disc: Optional[pd.DataFrame],
549
- cont: Optional[pd.DataFrame],
1047
+ disc: pd.DataFrame | None,
1048
+ cont: pd.DataFrame | None,
550
1049
  ) -> pd.DataFrame:
551
1050
  """Fortin transform implementation."""
552
1051
  batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
@@ -555,21 +1054,14 @@ class ComBatModel:
555
1054
 
556
1055
  parts = [batch_dummies]
557
1056
  if disc is not None:
558
- parts.append(
559
- pd.get_dummies(
560
- disc.astype("category"), drop_first=True
561
- ).astype(float)
562
- )
1057
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
563
1058
  if cont is not None:
564
1059
  parts.append(cont.astype(float))
565
1060
 
566
1061
  design = pd.concat(parts, axis=1).values
567
1062
 
568
1063
  X_np = X.values
569
- stand_mu = (
570
- self._grand_mean.values +
571
- design[:, self._n_batch:] @ self._beta_hat_nonbatch
572
- )
1064
+ stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
573
1065
  Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
574
1066
 
575
1067
  for i, lvl in enumerate(self._batch_levels):
@@ -587,18 +1079,15 @@ class ComBatModel:
587
1079
  else:
588
1080
  Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
589
1081
 
590
- X_adj = (
591
- Xs * np.sqrt(self._pooled_var.values) +
592
- stand_mu
593
- )
1082
+ X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
594
1083
  return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
595
-
1084
+
596
1085
  def _transform_chen(
597
1086
  self,
598
1087
  X: pd.DataFrame,
599
1088
  batch: pd.Series,
600
- disc: Optional[pd.DataFrame],
601
- cont: Optional[pd.DataFrame],
1089
+ disc: pd.DataFrame | None,
1090
+ cont: pd.DataFrame | None,
602
1091
  ) -> pd.DataFrame:
603
1092
  """Chen transform implementation."""
604
1093
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
@@ -635,14 +1124,15 @@ class ComBat(BaseEstimator, TransformerMixin):
635
1124
  self,
636
1125
  batch: ArrayLike,
637
1126
  *,
638
- discrete_covariates: Optional[ArrayLike] = None,
639
- continuous_covariates: Optional[ArrayLike] = None,
1127
+ discrete_covariates: ArrayLike | None = None,
1128
+ continuous_covariates: ArrayLike | None = None,
640
1129
  method: str = "johnson",
641
1130
  parametric: bool = True,
642
1131
  mean_only: bool = False,
643
- reference_batch: Optional[str] = None,
1132
+ reference_batch: str | None = None,
644
1133
  eps: float = 1e-8,
645
- covbat_cov_thresh: Union[float, int] = 0.9,
1134
+ covbat_cov_thresh: float | int = 0.9,
1135
+ compute_metrics: bool = False,
646
1136
  ) -> None:
647
1137
  self.batch = batch
648
1138
  self.discrete_covariates = discrete_covariates
@@ -653,6 +1143,7 @@ class ComBat(BaseEstimator, TransformerMixin):
653
1143
  self.reference_batch = reference_batch
654
1144
  self.eps = eps
655
1145
  self.covbat_cov_thresh = covbat_cov_thresh
1146
+ self.compute_metrics = compute_metrics
656
1147
  self._model = ComBatModel(
657
1148
  method=method,
658
1149
  parametric=parametric,
@@ -662,11 +1153,7 @@ class ComBat(BaseEstimator, TransformerMixin):
662
1153
  covbat_cov_thresh=covbat_cov_thresh,
663
1154
  )
664
1155
 
665
- def fit(
666
- self,
667
- X: ArrayLike,
668
- y: Optional[ArrayLike] = None
669
- ) -> "ComBat":
1156
+ def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
670
1157
  """Fit the ComBat model."""
671
1158
  idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
672
1159
  batch_vec = self._subset(self.batch, idx)
@@ -695,10 +1182,7 @@ class ComBat(BaseEstimator, TransformerMixin):
695
1182
  )
696
1183
 
697
1184
  @staticmethod
698
- def _subset(
699
- obj: Optional[ArrayLike],
700
- idx: pd.Index
701
- ) -> Optional[Union[pd.DataFrame, pd.Series]]:
1185
+ def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
702
1186
  """Subset array-like object by index."""
703
1187
  if obj is None:
704
1188
  return None
@@ -710,20 +1194,221 @@ class ComBat(BaseEstimator, TransformerMixin):
710
1194
  else:
711
1195
  return pd.DataFrame(obj, index=idx)
712
1196
 
1197
+ @property
1198
+ def metrics_(self) -> dict[str, Any] | None:
1199
+ """Return cached metrics from last fit_transform with compute_metrics=True.
1200
+
1201
+ Returns
1202
+ -------
1203
+ dict or None
1204
+ Cached metrics dictionary, or None if no metrics have been computed.
1205
+ """
1206
+ return getattr(self, "_metrics_cache", None)
1207
+
1208
+ def compute_batch_metrics(
1209
+ self,
1210
+ X: ArrayLike,
1211
+ batch: ArrayLike | None = None,
1212
+ *,
1213
+ pca_components: int | None = None,
1214
+ k_neighbors: list[int] | None = None,
1215
+ kbet_k0: int | None = None,
1216
+ lisi_perplexity: int = 30,
1217
+ n_jobs: int = 1,
1218
+ ) -> dict[str, Any]:
1219
+ """
1220
+ Compute batch effect metrics before and after ComBat correction.
1221
+
1222
+ Parameters
1223
+ ----------
1224
+ X : array-like of shape (n_samples, n_features)
1225
+ Input data to evaluate.
1226
+ batch : array-like of shape (n_samples,), optional
1227
+ Batch labels. If None, uses the batch stored at construction.
1228
+ pca_components : int, optional
1229
+ Number of PCA components for dimensionality reduction before
1230
+ computing metrics. If None (default), metrics are computed in
1231
+ the original feature space. Must be less than min(n_samples, n_features).
1232
+ k_neighbors : list of int, default=[5, 10, 50]
1233
+ Values of k for k-NN preservation metric.
1234
+ kbet_k0 : int, optional
1235
+ Neighborhood size for kBET. Default is 10% of samples.
1236
+ lisi_perplexity : int, default=30
1237
+ Perplexity for LISI computation.
1238
+ n_jobs : int, default=1
1239
+ Number of parallel jobs for neighbor computations.
1240
+
1241
+ Returns
1242
+ -------
1243
+ dict
1244
+ Dictionary with three main keys:
1245
+
1246
+ - ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
1247
+ (each with 'before' and 'after' values)
1248
+ - ``preservation``: k-NN preservation fractions, distance correlation
1249
+ - ``alignment``: Centroid distance, Levene statistic (each with
1250
+ 'before' and 'after' values)
1251
+
1252
+ Raises
1253
+ ------
1254
+ ValueError
1255
+ If the model is not fitted or if pca_components is invalid.
1256
+ """
1257
+ if not hasattr(self._model, "_gamma_star"):
1258
+ raise ValueError(
1259
+ "This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
1260
+ )
1261
+
1262
+ if not isinstance(X, pd.DataFrame):
1263
+ X = pd.DataFrame(X)
1264
+
1265
+ idx = X.index
1266
+
1267
+ if batch is None:
1268
+ batch_vec = self._subset(self.batch, idx)
1269
+ else:
1270
+ if isinstance(batch, (pd.Series, pd.DataFrame)):
1271
+ batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
1272
+ elif isinstance(batch, np.ndarray):
1273
+ batch_vec = pd.Series(batch, index=idx)
1274
+ else:
1275
+ batch_vec = pd.Series(batch, index=idx)
1276
+
1277
+ batch_labels = np.array(batch_vec)
1278
+
1279
+ X_before = X.values
1280
+ X_after = self.transform(X).values
1281
+
1282
+ n_samples, n_features = X_before.shape
1283
+ if kbet_k0 is None:
1284
+ kbet_k0 = max(10, int(0.10 * n_samples))
1285
+ if k_neighbors is None:
1286
+ k_neighbors = [5, 10, 50]
1287
+
1288
+ # Validate and apply PCA if requested
1289
+ if pca_components is not None:
1290
+ max_components = min(n_samples, n_features)
1291
+ if pca_components >= max_components:
1292
+ raise ValueError(
1293
+ f"pca_components={pca_components} must be less than "
1294
+ f"min(n_samples, n_features)={max_components}."
1295
+ )
1296
+ X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
1297
+ else:
1298
+ X_before_pca = X_before
1299
+ X_after_pca = X_after
1300
+
1301
+ silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
1302
+ silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
1303
+
1304
+ db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
1305
+ db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
1306
+
1307
+ kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
1308
+ kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
1309
+
1310
+ lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
1311
+ lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
1312
+
1313
+ var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
1314
+ var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
1315
+
1316
+ knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
1317
+ dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
1318
+
1319
+ centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
1320
+ centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
1321
+
1322
+ levene_before = _levene_median_statistic(X_before, batch_labels)
1323
+ levene_after = _levene_median_statistic(X_after, batch_labels)
1324
+
1325
+ n_batches = len(np.unique(batch_labels))
1326
+
1327
+ metrics = {
1328
+ "batch_effect": {
1329
+ "silhouette": {
1330
+ "before": silhouette_before,
1331
+ "after": silhouette_after,
1332
+ },
1333
+ "davies_bouldin": {
1334
+ "before": db_before,
1335
+ "after": db_after,
1336
+ },
1337
+ "kbet": {
1338
+ "before": kbet_before,
1339
+ "after": kbet_after,
1340
+ },
1341
+ "lisi": {
1342
+ "before": lisi_before,
1343
+ "after": lisi_after,
1344
+ "max_value": n_batches,
1345
+ },
1346
+ "variance_ratio": {
1347
+ "before": var_ratio_before,
1348
+ "after": var_ratio_after,
1349
+ },
1350
+ },
1351
+ "preservation": {
1352
+ "knn": knn_results,
1353
+ "distance_correlation": dist_corr,
1354
+ },
1355
+ "alignment": {
1356
+ "centroid_distance": {
1357
+ "before": centroid_before,
1358
+ "after": centroid_after,
1359
+ },
1360
+ "levene_statistic": {
1361
+ "before": levene_before,
1362
+ "after": levene_after,
1363
+ },
1364
+ },
1365
+ }
1366
+
1367
+ return metrics
1368
+
1369
+ def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
1370
+ """
1371
+ Fit and transform the data, optionally computing metrics.
1372
+
1373
+ If ``compute_metrics=True`` was set at construction, batch effect
1374
+ metrics are computed and cached in the ``metrics_`` property.
1375
+
1376
+ Parameters
1377
+ ----------
1378
+ X : array-like of shape (n_samples, n_features)
1379
+ Input data to fit and transform.
1380
+ y : None
1381
+ Ignored. Present for API compatibility.
1382
+
1383
+ Returns
1384
+ -------
1385
+ X_transformed : pd.DataFrame
1386
+ Batch-corrected data.
1387
+ """
1388
+ self.fit(X, y)
1389
+ X_transformed = self.transform(X)
1390
+
1391
+ if self.compute_metrics:
1392
+ self._metrics_cache = self.compute_batch_metrics(X)
1393
+
1394
+ return X_transformed
1395
+
713
1396
  def plot_transformation(
714
- self,
715
- X: ArrayLike, *,
716
- reduction_method: Literal['pca', 'tsne', 'umap'] = 'pca',
717
- n_components: Literal[2, 3] = 2,
718
- plot_type: Literal['static', 'interactive'] = 'static',
719
- figsize: Tuple[int, int] = (12, 5),
720
- alpha: float = 0.7,
721
- point_size: int = 50,
722
- cmap: str = 'Set1',
723
- title: Optional[str] = None,
724
- show_legend: bool = True,
725
- return_embeddings: bool = False,
726
- **reduction_kwargs) -> Union[Any, Tuple[Any, Dict[str, FloatArray]]]:
1397
+ self,
1398
+ X: ArrayLike,
1399
+ *,
1400
+ reduction_method: Literal["pca", "tsne", "umap"] = "pca",
1401
+ n_components: Literal[2, 3] = 2,
1402
+ plot_type: Literal["static", "interactive"] = "static",
1403
+ figsize: tuple[int, int] = (12, 5),
1404
+ alpha: float = 0.7,
1405
+ point_size: int = 50,
1406
+ cmap: str = "Set1",
1407
+ title: str | None = None,
1408
+ show_legend: bool = True,
1409
+ return_embeddings: bool = False,
1410
+ **reduction_kwargs,
1411
+ ) -> Any | tuple[Any, dict[str, FloatArray]]:
727
1412
  """
728
1413
  Visualize the ComBat transformation effect using dimensionality reduction.
729
1414
 
@@ -748,28 +1433,32 @@ class ComBat(BaseEstimator, TransformerMixin):
748
1433
 
749
1434
  return_embeddings : bool, default=False
750
1435
  If `True`, return embeddings along with the plot.
751
-
1436
+
752
1437
  **reduction_kwargs : dict
753
1438
  Additional parameters for reduction methods.
754
-
1439
+
755
1440
  Returns
756
1441
  -------
757
1442
  fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
758
1443
  The figure object containing the plots.
759
-
1444
+
760
1445
  embeddings : dict, optional
761
1446
  If `return_embeddings=True`, dictionary with:
762
1447
  - `'original'`: embedding of original data
763
1448
  - `'transformed'`: embedding of ComBat-transformed data
764
1449
  """
765
1450
  if not hasattr(self._model, "_gamma_star"):
766
- raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
1451
+ raise ValueError(
1452
+ "This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
1453
+ )
767
1454
 
768
1455
  if n_components not in [2, 3]:
769
1456
  raise ValueError(f"n_components must be 2 or 3, got {n_components}")
770
- if reduction_method not in ['pca', 'tsne', 'umap']:
771
- raise ValueError(f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'")
772
- if plot_type not in ['static', 'interactive']:
1457
+ if reduction_method not in ["pca", "tsne", "umap"]:
1458
+ raise ValueError(
1459
+ f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
1460
+ )
1461
+ if plot_type not in ["static", "interactive"]:
773
1462
  raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
774
1463
 
775
1464
  if not isinstance(X, pd.DataFrame):
@@ -785,16 +1474,16 @@ class ComBat(BaseEstimator, TransformerMixin):
785
1474
  X_np = X.values
786
1475
  X_trans_np = X_transformed.values
787
1476
 
788
- if reduction_method == 'pca':
1477
+ if reduction_method == "pca":
789
1478
  reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
790
1479
  reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
791
- elif reduction_method == 'tsne':
792
- tsne_params = {'perplexity': 30, 'max_iter': 1000, 'random_state': 42}
1480
+ elif reduction_method == "tsne":
1481
+ tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
793
1482
  tsne_params.update(reduction_kwargs)
794
1483
  reducer_orig = TSNE(n_components=n_components, **tsne_params)
795
1484
  reducer_trans = TSNE(n_components=n_components, **tsne_params)
796
1485
  else:
797
- umap_params = {'random_state': 42}
1486
+ umap_params = {"random_state": 42}
798
1487
  umap_params.update(reduction_kwargs)
799
1488
  reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
800
1489
  reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
@@ -802,40 +1491,52 @@ class ComBat(BaseEstimator, TransformerMixin):
802
1491
  X_embedded_orig = reducer_orig.fit_transform(X_np)
803
1492
  X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
804
1493
 
805
- if plot_type == 'static':
1494
+ if plot_type == "static":
806
1495
  fig = self._create_static_plot(
807
- X_embedded_orig, X_embedded_trans, batch_vec,
808
- reduction_method, n_components, figsize, alpha,
809
- point_size, cmap, title, show_legend
1496
+ X_embedded_orig,
1497
+ X_embedded_trans,
1498
+ batch_vec,
1499
+ reduction_method,
1500
+ n_components,
1501
+ figsize,
1502
+ alpha,
1503
+ point_size,
1504
+ cmap,
1505
+ title,
1506
+ show_legend,
810
1507
  )
811
1508
  else:
812
1509
  fig = self._create_interactive_plot(
813
- X_embedded_orig, X_embedded_trans, batch_vec,
814
- reduction_method, n_components, cmap, title, show_legend
1510
+ X_embedded_orig,
1511
+ X_embedded_trans,
1512
+ batch_vec,
1513
+ reduction_method,
1514
+ n_components,
1515
+ cmap,
1516
+ title,
1517
+ show_legend,
815
1518
  )
816
1519
 
817
1520
  if return_embeddings:
818
- embeddings = {
819
- 'original': X_embedded_orig,
820
- 'transformed': X_embedded_trans
821
- }
1521
+ embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
822
1522
  return fig, embeddings
823
1523
  else:
824
1524
  return fig
825
1525
 
826
1526
  def _create_static_plot(
827
- self,
828
- X_orig: FloatArray,
829
- X_trans: FloatArray,
830
- batch_labels: pd.Series,
831
- method: str,
832
- n_components: int,
833
- figsize: Tuple[int, int],
834
- alpha: float,
835
- point_size: int,
836
- cmap: str,
837
- title: Optional[str],
838
- show_legend: bool) -> Any:
1527
+ self,
1528
+ X_orig: FloatArray,
1529
+ X_trans: FloatArray,
1530
+ batch_labels: pd.Series,
1531
+ method: str,
1532
+ n_components: int,
1533
+ figsize: tuple[int, int],
1534
+ alpha: float,
1535
+ point_size: int,
1536
+ cmap: str,
1537
+ title: str | None,
1538
+ show_legend: bool,
1539
+ ) -> Any:
839
1540
  """Create static plots using matplotlib."""
840
1541
 
841
1542
  fig = plt.figure(figsize=figsize)
@@ -846,119 +1547,130 @@ class ComBat(BaseEstimator, TransformerMixin):
846
1547
  if n_batches <= 10:
847
1548
  colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
848
1549
  else:
849
- colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
1550
+ colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
850
1551
 
851
1552
  if n_components == 2:
852
1553
  ax1 = plt.subplot(1, 2, 1)
853
1554
  ax2 = plt.subplot(1, 2, 2)
854
1555
  else:
855
- ax1 = fig.add_subplot(121, projection='3d')
856
- ax2 = fig.add_subplot(122, projection='3d')
1556
+ ax1 = fig.add_subplot(121, projection="3d")
1557
+ ax2 = fig.add_subplot(122, projection="3d")
857
1558
 
858
1559
  for i, batch in enumerate(unique_batches):
859
1560
  mask = batch_labels == batch
860
1561
  if n_components == 2:
861
1562
  ax1.scatter(
862
- X_orig[mask, 0], X_orig[mask, 1],
1563
+ X_orig[mask, 0],
1564
+ X_orig[mask, 1],
863
1565
  c=[colors[i]],
864
1566
  s=point_size,
865
- alpha=alpha,
866
- label=f'Batch {batch}',
867
- edgecolors='black',
868
- linewidth=0.5
1567
+ alpha=alpha,
1568
+ label=f"Batch {batch}",
1569
+ edgecolors="black",
1570
+ linewidth=0.5,
869
1571
  )
870
1572
  else:
871
1573
  ax1.scatter(
872
- X_orig[mask, 0], X_orig[mask, 1], X_orig[mask, 2],
1574
+ X_orig[mask, 0],
1575
+ X_orig[mask, 1],
1576
+ X_orig[mask, 2],
873
1577
  c=[colors[i]],
874
1578
  s=point_size,
875
- alpha=alpha,
876
- label=f'Batch {batch}',
877
- edgecolors='black',
878
- linewidth=0.5
1579
+ alpha=alpha,
1580
+ label=f"Batch {batch}",
1581
+ edgecolors="black",
1582
+ linewidth=0.5,
879
1583
  )
880
1584
 
881
- ax1.set_title(f'Before ComBat correction\n({method.upper()})')
882
- ax1.set_xlabel(f'{method.upper()}1')
883
- ax1.set_ylabel(f'{method.upper()}2')
1585
+ ax1.set_title(f"Before ComBat correction\n({method.upper()})")
1586
+ ax1.set_xlabel(f"{method.upper()}1")
1587
+ ax1.set_ylabel(f"{method.upper()}2")
884
1588
  if n_components == 3:
885
- ax1.set_zlabel(f'{method.upper()}3')
1589
+ ax1.set_zlabel(f"{method.upper()}3")
886
1590
 
887
1591
  for i, batch in enumerate(unique_batches):
888
1592
  mask = batch_labels == batch
889
1593
  if n_components == 2:
890
1594
  ax2.scatter(
891
- X_trans[mask, 0], X_trans[mask, 1],
1595
+ X_trans[mask, 0],
1596
+ X_trans[mask, 1],
892
1597
  c=[colors[i]],
893
1598
  s=point_size,
894
- alpha=alpha,
895
- label=f'Batch {batch}',
896
- edgecolors='black',
897
- linewidth=0.5
1599
+ alpha=alpha,
1600
+ label=f"Batch {batch}",
1601
+ edgecolors="black",
1602
+ linewidth=0.5,
898
1603
  )
899
1604
  else:
900
1605
  ax2.scatter(
901
- X_trans[mask, 0], X_trans[mask, 1], X_trans[mask, 2],
1606
+ X_trans[mask, 0],
1607
+ X_trans[mask, 1],
1608
+ X_trans[mask, 2],
902
1609
  c=[colors[i]],
903
1610
  s=point_size,
904
- alpha=alpha,
905
- label=f'Batch {batch}',
906
- edgecolors='black',
907
- linewidth=0.5
1611
+ alpha=alpha,
1612
+ label=f"Batch {batch}",
1613
+ edgecolors="black",
1614
+ linewidth=0.5,
908
1615
  )
909
1616
 
910
- ax2.set_title(f'After ComBat correction\n({method.upper()})')
911
- ax2.set_xlabel(f'{method.upper()}1')
912
- ax2.set_ylabel(f'{method.upper()}2')
1617
+ ax2.set_title(f"After ComBat correction\n({method.upper()})")
1618
+ ax2.set_xlabel(f"{method.upper()}1")
1619
+ ax2.set_ylabel(f"{method.upper()}2")
913
1620
  if n_components == 3:
914
- ax2.set_zlabel(f'{method.upper()}3')
1621
+ ax2.set_zlabel(f"{method.upper()}3")
915
1622
 
916
1623
  if show_legend and n_batches <= 20:
917
- ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
1624
+ ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
918
1625
 
919
1626
  if title is None:
920
- title = f'ComBat correction effect visualized with {method.upper()}'
921
- fig.suptitle(title, fontsize=14, fontweight='bold')
1627
+ title = f"ComBat correction effect visualized with {method.upper()}"
1628
+ fig.suptitle(title, fontsize=14, fontweight="bold")
922
1629
 
923
1630
  plt.tight_layout()
924
1631
  return fig
925
1632
 
926
1633
  def _create_interactive_plot(
927
- self,
928
- X_orig: FloatArray,
929
- X_trans: FloatArray,
930
- batch_labels: pd.Series,
931
- method: str,
932
- n_components: int,
933
- cmap: str,
934
- title: Optional[str],
935
- show_legend: bool) -> Any:
1634
+ self,
1635
+ X_orig: FloatArray,
1636
+ X_trans: FloatArray,
1637
+ batch_labels: pd.Series,
1638
+ method: str,
1639
+ n_components: int,
1640
+ cmap: str,
1641
+ title: str | None,
1642
+ show_legend: bool,
1643
+ ) -> Any:
936
1644
  """Create interactive plots using plotly."""
937
1645
  if n_components == 2:
938
1646
  fig = make_subplots(
939
- rows=1, cols=2,
1647
+ rows=1,
1648
+ cols=2,
940
1649
  subplot_titles=(
941
- f'Before ComBat correction ({method.upper()})',
942
- f'After ComBat correction ({method.upper()})'
943
- )
1650
+ f"Before ComBat correction ({method.upper()})",
1651
+ f"After ComBat correction ({method.upper()})",
1652
+ ),
944
1653
  )
945
1654
  else:
946
1655
  fig = make_subplots(
947
- rows=1, cols=2,
948
- specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
1656
+ rows=1,
1657
+ cols=2,
1658
+ specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
949
1659
  subplot_titles=(
950
- f'Before ComBat correction ({method.upper()})',
951
- f'After ComBat correction ({method.upper()})'
952
- )
1660
+ f"Before ComBat correction ({method.upper()})",
1661
+ f"After ComBat correction ({method.upper()})",
1662
+ ),
953
1663
  )
954
1664
 
955
1665
  unique_batches = batch_labels.drop_duplicates()
956
1666
 
957
1667
  n_batches = len(unique_batches)
958
1668
  cmap_func = matplotlib.colormaps.get_cmap(cmap)
959
- color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
1669
+ color_list = [
1670
+ mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
1671
+ ]
960
1672
 
961
- batch_to_color = dict(zip(unique_batches, color_list))
1673
+ batch_to_color = dict(zip(unique_batches, color_list, strict=True))
962
1674
 
963
1675
  for batch in unique_batches:
964
1676
  mask = batch_labels == batch
@@ -966,72 +1678,86 @@ class ComBat(BaseEstimator, TransformerMixin):
966
1678
  if n_components == 2:
967
1679
  fig.add_trace(
968
1680
  go.Scatter(
969
- x=X_orig[mask, 0], y=X_orig[mask, 1],
970
- mode='markers',
971
- name=f'Batch {batch}',
972
- marker=dict(
973
- size=8,
974
- color=batch_to_color[batch],
975
- line=dict(width=1, color='black')
976
- ),
977
- showlegend=False),
978
- row=1, col=1
1681
+ x=X_orig[mask, 0],
1682
+ y=X_orig[mask, 1],
1683
+ mode="markers",
1684
+ name=f"Batch {batch}",
1685
+ marker={
1686
+ "size": 8,
1687
+ "color": batch_to_color[batch],
1688
+ "line": {"width": 1, "color": "black"},
1689
+ },
1690
+ showlegend=False,
1691
+ ),
1692
+ row=1,
1693
+ col=1,
979
1694
  )
980
1695
 
981
1696
  fig.add_trace(
982
1697
  go.Scatter(
983
- x=X_trans[mask, 0], y=X_trans[mask, 1],
984
- mode='markers',
985
- name=f'Batch {batch}',
986
- marker=dict(
987
- size=8,
988
- color=batch_to_color[batch],
989
- line=dict(width=1, color='black')
990
- ),
991
- showlegend=show_legend),
992
- row=1, col=2
1698
+ x=X_trans[mask, 0],
1699
+ y=X_trans[mask, 1],
1700
+ mode="markers",
1701
+ name=f"Batch {batch}",
1702
+ marker={
1703
+ "size": 8,
1704
+ "color": batch_to_color[batch],
1705
+ "line": {"width": 1, "color": "black"},
1706
+ },
1707
+ showlegend=show_legend,
1708
+ ),
1709
+ row=1,
1710
+ col=2,
993
1711
  )
994
1712
  else:
995
1713
  fig.add_trace(
996
1714
  go.Scatter3d(
997
- x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
998
- mode='markers',
999
- name=f'Batch {batch}',
1000
- marker=dict(
1001
- size=5,
1002
- color=batch_to_color[batch],
1003
- line=dict(width=0.5, color='black')
1004
- ),
1005
- showlegend=False),
1006
- row=1, col=1
1715
+ x=X_orig[mask, 0],
1716
+ y=X_orig[mask, 1],
1717
+ z=X_orig[mask, 2],
1718
+ mode="markers",
1719
+ name=f"Batch {batch}",
1720
+ marker={
1721
+ "size": 5,
1722
+ "color": batch_to_color[batch],
1723
+ "line": {"width": 0.5, "color": "black"},
1724
+ },
1725
+ showlegend=False,
1726
+ ),
1727
+ row=1,
1728
+ col=1,
1007
1729
  )
1008
1730
 
1009
1731
  fig.add_trace(
1010
1732
  go.Scatter3d(
1011
- x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
1012
- mode='markers',
1013
- name=f'Batch {batch}',
1014
- marker=dict(
1015
- size=5,
1016
- color=batch_to_color[batch],
1017
- line=dict(width=0.5, color='black')
1018
- ),
1019
- showlegend=show_legend),
1020
- row=1, col=2
1733
+ x=X_trans[mask, 0],
1734
+ y=X_trans[mask, 1],
1735
+ z=X_trans[mask, 2],
1736
+ mode="markers",
1737
+ name=f"Batch {batch}",
1738
+ marker={
1739
+ "size": 5,
1740
+ "color": batch_to_color[batch],
1741
+ "line": {"width": 0.5, "color": "black"},
1742
+ },
1743
+ showlegend=show_legend,
1744
+ ),
1745
+ row=1,
1746
+ col=2,
1021
1747
  )
1022
1748
 
1023
1749
  if title is None:
1024
- title = f'ComBat correction effect visualized with {method.upper()}'
1750
+ title = f"ComBat correction effect visualized with {method.upper()}"
1025
1751
 
1026
1752
  fig.update_layout(
1027
1753
  title=title,
1028
1754
  title_font_size=16,
1029
1755
  height=600,
1030
1756
  showlegend=show_legend,
1031
- hovermode='closest'
1757
+ hovermode="closest",
1032
1758
  )
1033
1759
 
1034
- axis_labels = [f'{method.upper()}{i+1}' for i in range(n_components)]
1760
+ axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
1035
1761
 
1036
1762
  if n_components == 2:
1037
1763
  fig.update_xaxes(title_text=axis_labels[0])
@@ -1040,7 +1766,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1040
1766
  fig.update_scenes(
1041
1767
  xaxis_title=axis_labels[0],
1042
1768
  yaxis_title=axis_labels[1],
1043
- zaxis_title=axis_labels[2]
1769
+ zaxis_title=axis_labels[2],
1044
1770
  )
1045
1771
 
1046
1772
  return fig