combatlearn 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/combat.py CHANGED
@@ -8,36 +8,39 @@
8
8
  `ComBat` makes the model compatible with scikit-learn by stashing
9
9
  the batch (and optional covariates) at construction.
10
10
  """
11
+
11
12
  from __future__ import annotations
12
13
 
14
+ import warnings
15
+ from typing import Any, Literal
16
+
17
+ import matplotlib
18
+ import matplotlib.colors as mcolors
19
+ import matplotlib.pyplot as plt
13
20
  import numpy as np
14
21
  import numpy.linalg as la
22
+ import numpy.typing as npt
15
23
  import pandas as pd
24
+ import plotly.graph_objects as go
25
+ import umap
26
+ from plotly.subplots import make_subplots
27
+ from scipy.spatial.distance import pdist
28
+ from scipy.stats import chi2, levene, spearmanr
16
29
  from sklearn.base import BaseEstimator, TransformerMixin
17
30
  from sklearn.decomposition import PCA
18
31
  from sklearn.manifold import TSNE
32
+ from sklearn.metrics import davies_bouldin_score, silhouette_score
19
33
  from sklearn.neighbors import NearestNeighbors
20
- from sklearn.metrics import silhouette_score, davies_bouldin_score
21
- from scipy.stats import levene, spearmanr, chi2
22
- from scipy.spatial.distance import pdist
23
- import matplotlib
24
- import matplotlib.pyplot as plt
25
- import matplotlib.colors as mcolors
26
- from typing import Literal, Optional, Union, Dict, Tuple, Any, List
27
- import numpy.typing as npt
28
- import warnings
29
- import umap
30
- import plotly.graph_objects as go
31
- from plotly.subplots import make_subplots
32
34
 
33
- ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
35
+ ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
34
36
  FloatArray = npt.NDArray[np.float64]
35
37
 
38
+
36
39
  def _compute_pca_embedding(
37
40
  X_before: np.ndarray,
38
41
  X_after: np.ndarray,
39
42
  n_components: int,
40
- ) -> Tuple[np.ndarray, np.ndarray, PCA]:
43
+ ) -> tuple[np.ndarray, np.ndarray, PCA]:
41
44
  """
42
45
  Compute PCA embeddings for both datasets.
43
46
 
@@ -91,7 +94,7 @@ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
91
94
  if len(unique_batches) < 2:
92
95
  return 0.0
93
96
  try:
94
- return silhouette_score(X, batch_labels, metric='euclidean')
97
+ return silhouette_score(X, batch_labels, metric="euclidean")
95
98
  except Exception:
96
99
  return 0.0
97
100
 
@@ -129,7 +132,7 @@ def _kbet_score(
129
132
  batch_labels: np.ndarray,
130
133
  k0: int,
131
134
  alpha: float = 0.05,
132
- ) -> Tuple[float, float]:
135
+ ) -> tuple[float, float]:
133
136
  """
134
137
  Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
135
138
 
@@ -166,7 +169,7 @@ def _kbet_score(
166
169
  global_freq = batch_counts / n_samples
167
170
  k0 = min(k0, n_samples - 1)
168
171
 
169
- nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm='auto')
172
+ nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
170
173
  nn.fit(X)
171
174
  _, indices = nn.kneighbors(X)
172
175
 
@@ -175,7 +178,7 @@ def _kbet_score(
175
178
  batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
176
179
 
177
180
  for i in range(n_samples):
178
- neighbors = indices[i, 1:k0+1]
181
+ neighbors = indices[i, 1 : k0 + 1]
179
182
  neighbor_batches = batch_labels[neighbors]
180
183
 
181
184
  observed = np.zeros(n_batches)
@@ -188,7 +191,7 @@ def _kbet_score(
188
191
  if mask.sum() < 2:
189
192
  continue
190
193
 
191
- stat = np.sum((observed[mask] - expected[mask])**2 / expected[mask])
194
+ stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
192
195
  df = max(1, mask.sum() - 1)
193
196
  p_val = 1 - chi2.cdf(stat, df)
194
197
 
@@ -230,7 +233,7 @@ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e
230
233
  sigma = 1.0
231
234
 
232
235
  for _ in range(50):
233
- P = np.exp(-distances**2 / (2 * sigma**2 + 1e-10))
236
+ P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
234
237
  P_sum = P.sum()
235
238
  if P_sum < 1e-10:
236
239
  sigma = (sigma + sigma_max) / 2
@@ -241,7 +244,7 @@ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e
241
244
 
242
245
  if abs(H - target_H) < tol:
243
246
  break
244
- elif H < target_H:
247
+ elif target_H > H:
245
248
  sigma_min = sigma
246
249
  else:
247
250
  sigma_max = sigma
@@ -287,7 +290,7 @@ def _lisi_score(
287
290
 
288
291
  k = min(3 * perplexity, n_samples - 1)
289
292
 
290
- nn = NearestNeighbors(n_neighbors=k + 1, algorithm='auto')
293
+ nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
291
294
  nn.fit(X)
292
295
  distances, indices = nn.kneighbors(X)
293
296
 
@@ -299,7 +302,7 @@ def _lisi_score(
299
302
  for i in range(n_samples):
300
303
  sigma = _find_sigma(distances[i], perplexity)
301
304
 
302
- P = np.exp(-distances[i]**2 / (2 * sigma**2 + 1e-10))
305
+ P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
303
306
  P_sum = P.sum()
304
307
  if P_sum < 1e-10:
305
308
  lisi_values.append(1.0)
@@ -312,10 +315,7 @@ def _lisi_score(
312
315
  batch_probs[batch_to_idx[nb]] += P[j]
313
316
 
314
317
  simpson = np.sum(batch_probs**2)
315
- if simpson < 1e-10:
316
- lisi = n_batches
317
- else:
318
- lisi = 1.0 / simpson
318
+ lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
319
319
  lisi_values.append(lisi)
320
320
 
321
321
  return np.mean(lisi_values)
@@ -358,11 +358,11 @@ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
358
358
  X_batch = X[mask]
359
359
  batch_mean = np.mean(X_batch, axis=0)
360
360
 
361
- between_var += n_b * np.sum((batch_mean - grand_mean)**2)
362
- within_var += np.sum((X_batch - batch_mean)**2)
361
+ between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
362
+ within_var += np.sum((X_batch - batch_mean) ** 2)
363
363
 
364
- between_var /= (n_batches - 1)
365
- within_var /= (n_samples - n_batches)
364
+ between_var /= n_batches - 1
365
+ within_var /= n_samples - n_batches
366
366
 
367
367
  if within_var < 1e-10:
368
368
  return 0.0
@@ -373,9 +373,9 @@ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
373
373
  def _knn_preservation(
374
374
  X_before: np.ndarray,
375
375
  X_after: np.ndarray,
376
- k_values: List[int],
376
+ k_values: list[int],
377
377
  n_jobs: int = 1,
378
- ) -> Dict[int, float]:
378
+ ) -> dict[int, float]:
379
379
  """
380
380
  Compute fraction of k-nearest neighbors preserved after correction.
381
381
 
@@ -402,11 +402,11 @@ def _knn_preservation(
402
402
  max_k = max(k_values)
403
403
  max_k = min(max_k, X_before.shape[0] - 1)
404
404
 
405
- nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
405
+ nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
406
406
  nn_before.fit(X_before)
407
407
  _, indices_before = nn_before.kneighbors(X_before)
408
408
 
409
- nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
409
+ nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
410
410
  nn_after.fit(X_after)
411
411
  _, indices_after = nn_after.kneighbors(X_after)
412
412
 
@@ -417,8 +417,8 @@ def _knn_preservation(
417
417
 
418
418
  overlaps = []
419
419
  for i in range(X_before.shape[0]):
420
- neighbors_before = set(indices_before[i, 1:k+1])
421
- neighbors_after = set(indices_after[i, 1:k+1])
420
+ neighbors_before = set(indices_before[i, 1 : k + 1])
421
+ neighbors_after = set(indices_after[i, 1 : k + 1])
422
422
  overlap = len(neighbors_before & neighbors_after) / k
423
423
  overlaps.append(overlap)
424
424
 
@@ -463,8 +463,8 @@ def _pairwise_distance_correlation(
463
463
  X_before = X_before[idx]
464
464
  X_after = X_after[idx]
465
465
 
466
- dist_before = pdist(X_before, metric='euclidean')
467
- dist_after = pdist(X_after, metric='euclidean')
466
+ dist_before = pdist(X_before, metric="euclidean")
467
+ dist_after = pdist(X_after, metric="euclidean")
468
468
 
469
469
  if len(dist_before) == 0:
470
470
  return 1.0
@@ -508,7 +508,7 @@ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
508
508
  centroids.append(centroid)
509
509
 
510
510
  centroids = np.array(centroids)
511
- distances = pdist(centroids, metric='euclidean')
511
+ distances = pdist(centroids, metric="euclidean")
512
512
 
513
513
  return np.mean(distances)
514
514
 
@@ -542,7 +542,7 @@ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
542
542
  if len(groups) < 2:
543
543
  continue
544
544
  try:
545
- stat, _ = levene(*groups, center='median')
545
+ stat, _ = levene(*groups, center="median")
546
546
  if not np.isnan(stat):
547
547
  levene_stats.append(stat)
548
548
  except Exception:
@@ -583,24 +583,24 @@ class ComBatModel:
583
583
  method: Literal["johnson", "fortin", "chen"] = "johnson",
584
584
  parametric: bool = True,
585
585
  mean_only: bool = False,
586
- reference_batch: Optional[str] = None,
586
+ reference_batch: str | None = None,
587
587
  eps: float = 1e-8,
588
- covbat_cov_thresh: Union[float, int] = 0.9,
588
+ covbat_cov_thresh: float | int = 0.9,
589
589
  ) -> None:
590
590
  self.method: str = method
591
591
  self.parametric: bool = parametric
592
592
  self.mean_only: bool = bool(mean_only)
593
- self.reference_batch: Optional[str] = reference_batch
593
+ self.reference_batch: str | None = reference_batch
594
594
  self.eps: float = float(eps)
595
- self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
595
+ self.covbat_cov_thresh: float | int = covbat_cov_thresh
596
596
 
597
597
  self._batch_levels: pd.Index
598
598
  self._grand_mean: pd.Series
599
599
  self._pooled_var: pd.Series
600
600
  self._gamma_star: FloatArray
601
601
  self._delta_star: FloatArray
602
- self._n_per_batch: Dict[str, int]
603
- self._reference_batch_idx: Optional[int]
602
+ self._n_per_batch: dict[str, int]
603
+ self._reference_batch_idx: int | None
604
604
  self._beta_hat_nonbatch: FloatArray
605
605
  self._n_batch: int
606
606
  self._p_design: int
@@ -621,26 +621,15 @@ class ComBatModel:
621
621
  raise TypeError("covbat_cov_thresh must be float or int.")
622
622
 
623
623
  @staticmethod
624
- def _as_series(
625
- arr: ArrayLike,
626
- index: pd.Index,
627
- name: str
628
- ) -> pd.Series:
624
+ def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
629
625
  """Convert array-like to categorical Series with validation."""
630
- if isinstance(arr, pd.Series):
631
- ser = arr.copy()
632
- else:
633
- ser = pd.Series(arr, index=index, name=name)
626
+ ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
634
627
  if not ser.index.equals(index):
635
628
  raise ValueError(f"`{name}` index mismatch with `X`.")
636
629
  return ser.astype("category")
637
630
 
638
631
  @staticmethod
639
- def _to_df(
640
- arr: Optional[ArrayLike],
641
- index: pd.Index,
642
- name: str
643
- ) -> Optional[pd.DataFrame]:
632
+ def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
644
633
  """Convert array-like to DataFrame."""
645
634
  if arr is None:
646
635
  return None
@@ -655,11 +644,11 @@ class ComBatModel:
655
644
  def fit(
656
645
  self,
657
646
  X: ArrayLike,
658
- y: Optional[ArrayLike] = None,
647
+ y: ArrayLike | None = None,
659
648
  *,
660
649
  batch: ArrayLike,
661
- discrete_covariates: Optional[ArrayLike] = None,
662
- continuous_covariates: Optional[ArrayLike] = None,
650
+ discrete_covariates: ArrayLike | None = None,
651
+ continuous_covariates: ArrayLike | None = None,
663
652
  ) -> ComBatModel:
664
653
  """Fit the ComBat model."""
665
654
  method = self.method.lower()
@@ -681,9 +670,7 @@ class ComBatModel:
681
670
 
682
671
  if method == "johnson":
683
672
  if disc is not None or cont is not None:
684
- warnings.warn(
685
- "Covariates are ignored when using method='johnson'."
686
- )
673
+ warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
687
674
  self._fit_johnson(X, batch)
688
675
  elif method == "fortin":
689
676
  self._fit_fortin(X, batch, disc, cont)
@@ -691,11 +678,7 @@ class ComBatModel:
691
678
  self._fit_chen(X, batch, disc, cont)
692
679
  return self
693
680
 
694
- def _fit_johnson(
695
- self,
696
- X: pd.DataFrame,
697
- batch: pd.Series
698
- ) -> None:
681
+ def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
699
682
  """Johnson et al. (2007) ComBat."""
700
683
  self._batch_levels = batch.cat.categories
701
684
  pooled_var = X.var(axis=0, ddof=1) + self.eps
@@ -703,10 +686,10 @@ class ComBatModel:
703
686
 
704
687
  Xs = (X - grand_mean) / np.sqrt(pooled_var)
705
688
 
706
- n_per_batch: Dict[str, int] = {}
689
+ n_per_batch: dict[str, int] = {}
707
690
  gamma_hat: list[npt.NDArray[np.float64]] = []
708
691
  delta_hat: list[npt.NDArray[np.float64]] = []
709
-
692
+
710
693
  for lvl in self._batch_levels:
711
694
  idx = batch == lvl
712
695
  n_b = int(idx.sum())
@@ -751,8 +734,8 @@ class ComBatModel:
751
734
  self,
752
735
  X: pd.DataFrame,
753
736
  batch: pd.Series,
754
- disc: Optional[pd.DataFrame],
755
- cont: Optional[pd.DataFrame],
737
+ disc: pd.DataFrame | None,
738
+ cont: pd.DataFrame | None,
756
739
  ) -> None:
757
740
  """Fortin et al. (2018) neuroComBat."""
758
741
  self._batch_levels = batch.cat.categories
@@ -770,11 +753,7 @@ class ComBatModel:
770
753
 
771
754
  parts: list[pd.DataFrame] = [batch_dummies]
772
755
  if disc is not None:
773
- parts.append(
774
- pd.get_dummies(
775
- disc.astype("category"), drop_first=True
776
- ).astype(float)
777
- )
756
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
778
757
 
779
758
  if cont is not None:
780
759
  parts.append(cont.astype(float))
@@ -789,7 +768,7 @@ class ComBatModel:
789
768
  self._beta_hat_nonbatch = beta_hat[n_batch:]
790
769
 
791
770
  n_per_batch = batch.value_counts().sort_index().astype(int).values
792
- self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
771
+ self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
793
772
 
794
773
  if self.reference_batch is not None:
795
774
  ref_idx = list(self._batch_levels).index(self.reference_batch)
@@ -807,30 +786,25 @@ class ComBatModel:
807
786
  else:
808
787
  resid = X_np - design @ beta_hat
809
788
  denom = n_samples
810
- var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
789
+ var_pooled = (resid**2).sum(axis=0) / denom + self.eps
811
790
  self._pooled_var = pd.Series(var_pooled, index=X.columns)
812
791
 
813
792
  stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
814
793
  Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
815
794
 
816
- gamma_hat = np.vstack(
817
- [Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
818
- )
795
+ gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
819
796
  delta_hat = np.vstack(
820
- [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
821
- for lvl in self._batch_levels]
797
+ [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
822
798
  )
823
799
 
824
800
  if self.mean_only:
825
801
  gamma_star = self._shrink_gamma(
826
- gamma_hat, delta_hat, n_per_batch,
827
- parametric = self.parametric
802
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
828
803
  )
829
804
  delta_star = np.ones_like(delta_hat)
830
805
  else:
831
806
  gamma_star, delta_star = self._shrink_gamma_delta(
832
- gamma_hat, delta_hat, n_per_batch,
833
- parametric = self.parametric
807
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
834
808
  )
835
809
 
836
810
  if ref_idx is not None:
@@ -841,15 +815,15 @@ class ComBatModel:
841
815
 
842
816
  self._gamma_star = gamma_star
843
817
  self._delta_star = delta_star
844
- self._n_batch = n_batch
818
+ self._n_batch = n_batch
845
819
  self._p_design = p_design
846
-
820
+
847
821
  def _fit_chen(
848
822
  self,
849
823
  X: pd.DataFrame,
850
824
  batch: pd.Series,
851
- disc: Optional[pd.DataFrame],
852
- cont: Optional[pd.DataFrame],
825
+ disc: pd.DataFrame | None,
826
+ cont: pd.DataFrame | None,
853
827
  ) -> None:
854
828
  """Chen et al. (2022) CovBat."""
855
829
  self._fit_fortin(X, batch, disc, cont)
@@ -868,7 +842,7 @@ class ComBatModel:
868
842
  self._covbat_n_pc = n_pc
869
843
 
870
844
  scores = pca.transform(X_centered)[:, :n_pc]
871
- scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
845
+ scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
872
846
  self._batch_levels_pc = self._batch_levels
873
847
  n_per_batch = self._n_per_batch
874
848
 
@@ -907,12 +881,12 @@ class ComBatModel:
907
881
  self,
908
882
  gamma_hat: FloatArray,
909
883
  delta_hat: FloatArray,
910
- n_per_batch: Union[Dict[str, int], FloatArray],
884
+ n_per_batch: dict[str, int] | FloatArray,
911
885
  *,
912
886
  parametric: bool,
913
887
  max_iter: int = 100,
914
888
  tol: float = 1e-4,
915
- ) -> Tuple[FloatArray, FloatArray]:
889
+ ) -> tuple[FloatArray, FloatArray]:
916
890
  """Empirical Bayes shrinkage estimation."""
917
891
  if parametric:
918
892
  gamma_bar = gamma_hat.mean(axis=0)
@@ -920,10 +894,14 @@ class ComBatModel:
920
894
  a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
921
895
  b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
922
896
 
923
- B, p = gamma_hat.shape
897
+ B, _p = gamma_hat.shape
924
898
  gamma_star = np.empty_like(gamma_hat)
925
899
  delta_star = np.empty_like(delta_hat)
926
- n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
900
+ n_vec = (
901
+ np.array(list(n_per_batch.values()))
902
+ if isinstance(n_per_batch, dict)
903
+ else n_per_batch
904
+ )
927
905
 
928
906
  for i in range(B):
929
907
  n_i = n_vec[i]
@@ -937,8 +915,12 @@ class ComBatModel:
937
915
  return gamma_star, delta_star
938
916
 
939
917
  else:
940
- B, p = gamma_hat.shape
941
- n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
918
+ B, _p = gamma_hat.shape
919
+ n_vec = (
920
+ np.array(list(n_per_batch.values()))
921
+ if isinstance(n_per_batch, dict)
922
+ else n_per_batch
923
+ )
942
924
  gamma_bar = gamma_hat.mean(axis=0)
943
925
  t2 = gamma_hat.var(axis=0, ddof=1)
944
926
 
@@ -947,27 +929,22 @@ class ComBatModel:
947
929
  g_bar: FloatArray,
948
930
  n: float,
949
931
  d_star: FloatArray,
950
- t2_: FloatArray
932
+ t2_: FloatArray,
951
933
  ) -> FloatArray:
952
934
  return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
953
935
 
954
- def postvar(
955
- sum2: FloatArray,
956
- n: float,
957
- a: FloatArray,
958
- b: FloatArray
959
- ) -> FloatArray:
936
+ def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
960
937
  return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
961
938
 
962
939
  def aprior(delta: FloatArray) -> FloatArray:
963
940
  m, s2 = delta.mean(), delta.var()
964
941
  s2 = max(s2, self.eps)
965
- return (2 * s2 + m ** 2) / s2
942
+ return (2 * s2 + m**2) / s2
966
943
 
967
944
  def bprior(delta: FloatArray) -> FloatArray:
968
945
  m, s2 = delta.mean(), delta.var()
969
946
  s2 = max(s2, self.eps)
970
- return (m * s2 + m ** 3) / s2
947
+ return (m * s2 + m**3) / s2
971
948
 
972
949
  gamma_star = np.empty_like(gamma_hat)
973
950
  delta_star = np.empty_like(delta_hat)
@@ -986,7 +963,8 @@ class ComBatModel:
986
963
  sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
987
964
  d_new = postvar(sum2, n_i, a_i, b_i)
988
965
  if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
989
- self.mean_only or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
966
+ self.mean_only
967
+ or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
990
968
  ):
991
969
  break
992
970
  gamma_star[i] = g_new
@@ -997,12 +975,14 @@ class ComBatModel:
997
975
  self,
998
976
  gamma_hat: FloatArray,
999
977
  delta_hat: FloatArray,
1000
- n_per_batch: Union[Dict[str, int], FloatArray],
978
+ n_per_batch: dict[str, int] | FloatArray,
1001
979
  *,
1002
980
  parametric: bool,
1003
981
  ) -> FloatArray:
1004
- """Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
1005
- gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
982
+ """Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
983
+ gamma, _ = self._shrink_gamma_delta(
984
+ gamma_hat, delta_hat, n_per_batch, parametric=parametric
985
+ )
1006
986
  return gamma
1007
987
 
1008
988
  def transform(
@@ -1010,12 +990,14 @@ class ComBatModel:
1010
990
  X: ArrayLike,
1011
991
  *,
1012
992
  batch: ArrayLike,
1013
- discrete_covariates: Optional[ArrayLike] = None,
1014
- continuous_covariates: Optional[ArrayLike] = None,
993
+ discrete_covariates: ArrayLike | None = None,
994
+ continuous_covariates: ArrayLike | None = None,
1015
995
  ) -> pd.DataFrame:
1016
996
  """Transform the data using fitted ComBat parameters."""
1017
997
  if not hasattr(self, "_gamma_star"):
1018
- raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
998
+ raise ValueError(
999
+ "This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
1000
+ )
1019
1001
  if not isinstance(X, pd.DataFrame):
1020
1002
  X = pd.DataFrame(X)
1021
1003
  idx = X.index
@@ -1036,11 +1018,7 @@ class ComBatModel:
1036
1018
  else:
1037
1019
  raise ValueError(f"Unknown method: {method}.")
1038
1020
 
1039
- def _transform_johnson(
1040
- self,
1041
- X: pd.DataFrame,
1042
- batch: pd.Series
1043
- ) -> pd.DataFrame:
1021
+ def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
1044
1022
  """Johnson transform implementation."""
1045
1023
  pooled = self._pooled_var
1046
1024
  grand = self._grand_mean
@@ -1058,10 +1036,7 @@ class ComBatModel:
1058
1036
 
1059
1037
  g = self._gamma_star[i]
1060
1038
  d = self._delta_star[i]
1061
- if self.mean_only:
1062
- Xb = Xs.loc[idx] - g
1063
- else:
1064
- Xb = (Xs.loc[idx] - g) / np.sqrt(d)
1039
+ Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
1065
1040
  X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
1066
1041
  return X_adj
1067
1042
 
@@ -1069,8 +1044,8 @@ class ComBatModel:
1069
1044
  self,
1070
1045
  X: pd.DataFrame,
1071
1046
  batch: pd.Series,
1072
- disc: Optional[pd.DataFrame],
1073
- cont: Optional[pd.DataFrame],
1047
+ disc: pd.DataFrame | None,
1048
+ cont: pd.DataFrame | None,
1074
1049
  ) -> pd.DataFrame:
1075
1050
  """Fortin transform implementation."""
1076
1051
  batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
@@ -1079,21 +1054,14 @@ class ComBatModel:
1079
1054
 
1080
1055
  parts = [batch_dummies]
1081
1056
  if disc is not None:
1082
- parts.append(
1083
- pd.get_dummies(
1084
- disc.astype("category"), drop_first=True
1085
- ).astype(float)
1086
- )
1057
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
1087
1058
  if cont is not None:
1088
1059
  parts.append(cont.astype(float))
1089
1060
 
1090
1061
  design = pd.concat(parts, axis=1).values
1091
1062
 
1092
1063
  X_np = X.values
1093
- stand_mu = (
1094
- self._grand_mean.values +
1095
- design[:, self._n_batch:] @ self._beta_hat_nonbatch
1096
- )
1064
+ stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
1097
1065
  Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
1098
1066
 
1099
1067
  for i, lvl in enumerate(self._batch_levels):
@@ -1111,18 +1079,15 @@ class ComBatModel:
1111
1079
  else:
1112
1080
  Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
1113
1081
 
1114
- X_adj = (
1115
- Xs * np.sqrt(self._pooled_var.values) +
1116
- stand_mu
1117
- )
1082
+ X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
1118
1083
  return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
1119
-
1084
+
1120
1085
  def _transform_chen(
1121
1086
  self,
1122
1087
  X: pd.DataFrame,
1123
1088
  batch: pd.Series,
1124
- disc: Optional[pd.DataFrame],
1125
- cont: Optional[pd.DataFrame],
1089
+ disc: pd.DataFrame | None,
1090
+ cont: pd.DataFrame | None,
1126
1091
  ) -> pd.DataFrame:
1127
1092
  """Chen transform implementation."""
1128
1093
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
@@ -1159,14 +1124,14 @@ class ComBat(BaseEstimator, TransformerMixin):
1159
1124
  self,
1160
1125
  batch: ArrayLike,
1161
1126
  *,
1162
- discrete_covariates: Optional[ArrayLike] = None,
1163
- continuous_covariates: Optional[ArrayLike] = None,
1127
+ discrete_covariates: ArrayLike | None = None,
1128
+ continuous_covariates: ArrayLike | None = None,
1164
1129
  method: str = "johnson",
1165
1130
  parametric: bool = True,
1166
1131
  mean_only: bool = False,
1167
- reference_batch: Optional[str] = None,
1132
+ reference_batch: str | None = None,
1168
1133
  eps: float = 1e-8,
1169
- covbat_cov_thresh: Union[float, int] = 0.9,
1134
+ covbat_cov_thresh: float | int = 0.9,
1170
1135
  compute_metrics: bool = False,
1171
1136
  ) -> None:
1172
1137
  self.batch = batch
@@ -1188,11 +1153,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1188
1153
  covbat_cov_thresh=covbat_cov_thresh,
1189
1154
  )
1190
1155
 
1191
- def fit(
1192
- self,
1193
- X: ArrayLike,
1194
- y: Optional[ArrayLike] = None
1195
- ) -> "ComBat":
1156
+ def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
1196
1157
  """Fit the ComBat model."""
1197
1158
  idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
1198
1159
  batch_vec = self._subset(self.batch, idx)
@@ -1221,10 +1182,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1221
1182
  )
1222
1183
 
1223
1184
  @staticmethod
1224
- def _subset(
1225
- obj: Optional[ArrayLike],
1226
- idx: pd.Index
1227
- ) -> Optional[Union[pd.DataFrame, pd.Series]]:
1185
+ def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
1228
1186
  """Subset array-like object by index."""
1229
1187
  if obj is None:
1230
1188
  return None
@@ -1237,7 +1195,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1237
1195
  return pd.DataFrame(obj, index=idx)
1238
1196
 
1239
1197
  @property
1240
- def metrics_(self) -> Optional[Dict[str, Any]]:
1198
+ def metrics_(self) -> dict[str, Any] | None:
1241
1199
  """Return cached metrics from last fit_transform with compute_metrics=True.
1242
1200
 
1243
1201
  Returns
@@ -1245,19 +1203,19 @@ class ComBat(BaseEstimator, TransformerMixin):
1245
1203
  dict or None
1246
1204
  Cached metrics dictionary, or None if no metrics have been computed.
1247
1205
  """
1248
- return getattr(self, '_metrics_cache', None)
1206
+ return getattr(self, "_metrics_cache", None)
1249
1207
 
1250
1208
  def compute_batch_metrics(
1251
1209
  self,
1252
1210
  X: ArrayLike,
1253
- batch: Optional[ArrayLike] = None,
1211
+ batch: ArrayLike | None = None,
1254
1212
  *,
1255
- pca_components: Optional[int] = None,
1256
- k_neighbors: List[int] = [5, 10, 50],
1257
- kbet_k0: Optional[int] = None,
1213
+ pca_components: int | None = None,
1214
+ k_neighbors: list[int] | None = None,
1215
+ kbet_k0: int | None = None,
1258
1216
  lisi_perplexity: int = 30,
1259
1217
  n_jobs: int = 1,
1260
- ) -> Dict[str, Any]:
1218
+ ) -> dict[str, Any]:
1261
1219
  """
1262
1220
  Compute batch effect metrics before and after ComBat correction.
1263
1221
 
@@ -1282,25 +1240,14 @@ class ComBat(BaseEstimator, TransformerMixin):
1282
1240
 
1283
1241
  Returns
1284
1242
  -------
1285
- metrics : dict
1286
- Dictionary with structure:
1287
- {
1288
- 'batch_effect': {
1289
- 'silhouette': {'before': float, 'after': float},
1290
- 'davies_bouldin': {...},
1291
- 'kbet': {...},
1292
- 'lisi': {..., 'max_value': n_batches},
1293
- 'variance_ratio': {...},
1294
- },
1295
- 'preservation': {
1296
- 'knn': {k: fraction for k in k_neighbors},
1297
- 'distance_correlation': float,
1298
- },
1299
- 'alignment': {
1300
- 'centroid_distance': {...},
1301
- 'levene_statistic': {...},
1302
- },
1303
- }
1243
+ dict
1244
+ Dictionary with three main keys:
1245
+
1246
+ - ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
1247
+ (each with 'before' and 'after' values)
1248
+ - ``preservation``: k-NN preservation fractions, distance correlation
1249
+ - ``alignment``: Centroid distance, Levene statistic (each with
1250
+ 'before' and 'after' values)
1304
1251
 
1305
1252
  Raises
1306
1253
  ------
@@ -1309,8 +1256,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1309
1256
  """
1310
1257
  if not hasattr(self._model, "_gamma_star"):
1311
1258
  raise ValueError(
1312
- "This ComBat instance is not fitted yet. "
1313
- "Call 'fit' before 'compute_batch_metrics'."
1259
+ "This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
1314
1260
  )
1315
1261
 
1316
1262
  if not isinstance(X, pd.DataFrame):
@@ -1322,7 +1268,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1322
1268
  batch_vec = self._subset(self.batch, idx)
1323
1269
  else:
1324
1270
  if isinstance(batch, (pd.Series, pd.DataFrame)):
1325
- batch_vec = batch.loc[idx] if hasattr(batch, 'loc') else batch
1271
+ batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
1326
1272
  elif isinstance(batch, np.ndarray):
1327
1273
  batch_vec = pd.Series(batch, index=idx)
1328
1274
  else:
@@ -1336,6 +1282,8 @@ class ComBat(BaseEstimator, TransformerMixin):
1336
1282
  n_samples, n_features = X_before.shape
1337
1283
  if kbet_k0 is None:
1338
1284
  kbet_k0 = max(10, int(0.10 * n_samples))
1285
+ if k_neighbors is None:
1286
+ k_neighbors = [5, 10, 50]
1339
1287
 
1340
1288
  # Validate and apply PCA if requested
1341
1289
  if pca_components is not None:
@@ -1345,9 +1293,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1345
1293
  f"pca_components={pca_components} must be less than "
1346
1294
  f"min(n_samples, n_features)={max_components}."
1347
1295
  )
1348
- X_before_pca, X_after_pca, _ = _compute_pca_embedding(
1349
- X_before, X_after, pca_components
1350
- )
1296
+ X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
1351
1297
  else:
1352
1298
  X_before_pca = X_before
1353
1299
  X_after_pca = X_after
@@ -1379,57 +1325,53 @@ class ComBat(BaseEstimator, TransformerMixin):
1379
1325
  n_batches = len(np.unique(batch_labels))
1380
1326
 
1381
1327
  metrics = {
1382
- 'batch_effect': {
1383
- 'silhouette': {
1384
- 'before': silhouette_before,
1385
- 'after': silhouette_after,
1328
+ "batch_effect": {
1329
+ "silhouette": {
1330
+ "before": silhouette_before,
1331
+ "after": silhouette_after,
1386
1332
  },
1387
- 'davies_bouldin': {
1388
- 'before': db_before,
1389
- 'after': db_after,
1333
+ "davies_bouldin": {
1334
+ "before": db_before,
1335
+ "after": db_after,
1390
1336
  },
1391
- 'kbet': {
1392
- 'before': kbet_before,
1393
- 'after': kbet_after,
1337
+ "kbet": {
1338
+ "before": kbet_before,
1339
+ "after": kbet_after,
1394
1340
  },
1395
- 'lisi': {
1396
- 'before': lisi_before,
1397
- 'after': lisi_after,
1398
- 'max_value': n_batches,
1341
+ "lisi": {
1342
+ "before": lisi_before,
1343
+ "after": lisi_after,
1344
+ "max_value": n_batches,
1399
1345
  },
1400
- 'variance_ratio': {
1401
- 'before': var_ratio_before,
1402
- 'after': var_ratio_after,
1346
+ "variance_ratio": {
1347
+ "before": var_ratio_before,
1348
+ "after": var_ratio_after,
1403
1349
  },
1404
1350
  },
1405
- 'preservation': {
1406
- 'knn': knn_results,
1407
- 'distance_correlation': dist_corr,
1351
+ "preservation": {
1352
+ "knn": knn_results,
1353
+ "distance_correlation": dist_corr,
1408
1354
  },
1409
- 'alignment': {
1410
- 'centroid_distance': {
1411
- 'before': centroid_before,
1412
- 'after': centroid_after,
1355
+ "alignment": {
1356
+ "centroid_distance": {
1357
+ "before": centroid_before,
1358
+ "after": centroid_after,
1413
1359
  },
1414
- 'levene_statistic': {
1415
- 'before': levene_before,
1416
- 'after': levene_after,
1360
+ "levene_statistic": {
1361
+ "before": levene_before,
1362
+ "after": levene_after,
1417
1363
  },
1418
1364
  },
1419
1365
  }
1420
1366
 
1421
1367
  return metrics
1422
1368
 
1423
- def fit_transform(
1424
- self,
1425
- X: ArrayLike,
1426
- y: Optional[ArrayLike] = None
1427
- ) -> pd.DataFrame:
1369
+ def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
1428
1370
  """
1429
1371
  Fit and transform the data, optionally computing metrics.
1430
1372
 
1431
- If compute_metrics=True was set at construction, batch effect
1432
- metrics are computed and cached in metrics_ property.
1373
+ If ``compute_metrics=True`` was set at construction, batch effect
1374
+ metrics are computed and cached in the ``metrics_`` property.
1433
1375
 
1434
1376
  Parameters
1435
1377
  ----------
@@ -1452,19 +1394,21 @@ class ComBat(BaseEstimator, TransformerMixin):
1452
1394
  return X_transformed
1453
1395
 
1454
1396
  def plot_transformation(
1455
- self,
1456
- X: ArrayLike, *,
1457
- reduction_method: Literal['pca', 'tsne', 'umap'] = 'pca',
1458
- n_components: Literal[2, 3] = 2,
1459
- plot_type: Literal['static', 'interactive'] = 'static',
1460
- figsize: Tuple[int, int] = (12, 5),
1461
- alpha: float = 0.7,
1462
- point_size: int = 50,
1463
- cmap: str = 'Set1',
1464
- title: Optional[str] = None,
1465
- show_legend: bool = True,
1466
- return_embeddings: bool = False,
1467
- **reduction_kwargs) -> Union[Any, Tuple[Any, Dict[str, FloatArray]]]:
1397
+ self,
1398
+ X: ArrayLike,
1399
+ *,
1400
+ reduction_method: Literal["pca", "tsne", "umap"] = "pca",
1401
+ n_components: Literal[2, 3] = 2,
1402
+ plot_type: Literal["static", "interactive"] = "static",
1403
+ figsize: tuple[int, int] = (12, 5),
1404
+ alpha: float = 0.7,
1405
+ point_size: int = 50,
1406
+ cmap: str = "Set1",
1407
+ title: str | None = None,
1408
+ show_legend: bool = True,
1409
+ return_embeddings: bool = False,
1410
+ **reduction_kwargs,
1411
+ ) -> Any | tuple[Any, dict[str, FloatArray]]:
1468
1412
  """
1469
1413
  Visualize the ComBat transformation effect using dimensionality reduction.
1470
1414
 
@@ -1489,28 +1433,32 @@ class ComBat(BaseEstimator, TransformerMixin):
1489
1433
 
1490
1434
  return_embeddings : bool, default=False
1491
1435
  If `True`, return embeddings along with the plot.
1492
-
1436
+
1493
1437
  **reduction_kwargs : dict
1494
1438
  Additional parameters for reduction methods.
1495
-
1439
+
1496
1440
  Returns
1497
1441
  -------
1498
1442
  fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
1499
1443
  The figure object containing the plots.
1500
-
1444
+
1501
1445
  embeddings : dict, optional
1502
1446
  If `return_embeddings=True`, dictionary with:
1503
1447
  - `'original'`: embedding of original data
1504
1448
  - `'transformed'`: embedding of ComBat-transformed data
1505
1449
  """
1506
1450
  if not hasattr(self._model, "_gamma_star"):
1507
- raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
1451
+ raise ValueError(
1452
+ "This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
1453
+ )
1508
1454
 
1509
1455
  if n_components not in [2, 3]:
1510
1456
  raise ValueError(f"n_components must be 2 or 3, got {n_components}")
1511
- if reduction_method not in ['pca', 'tsne', 'umap']:
1512
- raise ValueError(f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'")
1513
- if plot_type not in ['static', 'interactive']:
1457
+ if reduction_method not in ["pca", "tsne", "umap"]:
1458
+ raise ValueError(
1459
+ f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
1460
+ )
1461
+ if plot_type not in ["static", "interactive"]:
1514
1462
  raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
1515
1463
 
1516
1464
  if not isinstance(X, pd.DataFrame):
@@ -1526,16 +1474,16 @@ class ComBat(BaseEstimator, TransformerMixin):
1526
1474
  X_np = X.values
1527
1475
  X_trans_np = X_transformed.values
1528
1476
 
1529
- if reduction_method == 'pca':
1477
+ if reduction_method == "pca":
1530
1478
  reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
1531
1479
  reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
1532
- elif reduction_method == 'tsne':
1533
- tsne_params = {'perplexity': 30, 'max_iter': 1000, 'random_state': 42}
1480
+ elif reduction_method == "tsne":
1481
+ tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
1534
1482
  tsne_params.update(reduction_kwargs)
1535
1483
  reducer_orig = TSNE(n_components=n_components, **tsne_params)
1536
1484
  reducer_trans = TSNE(n_components=n_components, **tsne_params)
1537
1485
  else:
1538
- umap_params = {'random_state': 42}
1486
+ umap_params = {"random_state": 42}
1539
1487
  umap_params.update(reduction_kwargs)
1540
1488
  reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
1541
1489
  reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
@@ -1543,40 +1491,52 @@ class ComBat(BaseEstimator, TransformerMixin):
1543
1491
  X_embedded_orig = reducer_orig.fit_transform(X_np)
1544
1492
  X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
1545
1493
 
1546
- if plot_type == 'static':
1494
+ if plot_type == "static":
1547
1495
  fig = self._create_static_plot(
1548
- X_embedded_orig, X_embedded_trans, batch_vec,
1549
- reduction_method, n_components, figsize, alpha,
1550
- point_size, cmap, title, show_legend
1496
+ X_embedded_orig,
1497
+ X_embedded_trans,
1498
+ batch_vec,
1499
+ reduction_method,
1500
+ n_components,
1501
+ figsize,
1502
+ alpha,
1503
+ point_size,
1504
+ cmap,
1505
+ title,
1506
+ show_legend,
1551
1507
  )
1552
1508
  else:
1553
1509
  fig = self._create_interactive_plot(
1554
- X_embedded_orig, X_embedded_trans, batch_vec,
1555
- reduction_method, n_components, cmap, title, show_legend
1510
+ X_embedded_orig,
1511
+ X_embedded_trans,
1512
+ batch_vec,
1513
+ reduction_method,
1514
+ n_components,
1515
+ cmap,
1516
+ title,
1517
+ show_legend,
1556
1518
  )
1557
1519
 
1558
1520
  if return_embeddings:
1559
- embeddings = {
1560
- 'original': X_embedded_orig,
1561
- 'transformed': X_embedded_trans
1562
- }
1521
+ embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
1563
1522
  return fig, embeddings
1564
1523
  else:
1565
1524
  return fig
1566
1525
 
1567
1526
  def _create_static_plot(
1568
- self,
1569
- X_orig: FloatArray,
1570
- X_trans: FloatArray,
1571
- batch_labels: pd.Series,
1572
- method: str,
1573
- n_components: int,
1574
- figsize: Tuple[int, int],
1575
- alpha: float,
1576
- point_size: int,
1577
- cmap: str,
1578
- title: Optional[str],
1579
- show_legend: bool) -> Any:
1527
+ self,
1528
+ X_orig: FloatArray,
1529
+ X_trans: FloatArray,
1530
+ batch_labels: pd.Series,
1531
+ method: str,
1532
+ n_components: int,
1533
+ figsize: tuple[int, int],
1534
+ alpha: float,
1535
+ point_size: int,
1536
+ cmap: str,
1537
+ title: str | None,
1538
+ show_legend: bool,
1539
+ ) -> Any:
1580
1540
  """Create static plots using matplotlib."""
1581
1541
 
1582
1542
  fig = plt.figure(figsize=figsize)
@@ -1587,119 +1547,130 @@ class ComBat(BaseEstimator, TransformerMixin):
1587
1547
  if n_batches <= 10:
1588
1548
  colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
1589
1549
  else:
1590
- colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
1550
+ colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
1591
1551
 
1592
1552
  if n_components == 2:
1593
1553
  ax1 = plt.subplot(1, 2, 1)
1594
1554
  ax2 = plt.subplot(1, 2, 2)
1595
1555
  else:
1596
- ax1 = fig.add_subplot(121, projection='3d')
1597
- ax2 = fig.add_subplot(122, projection='3d')
1556
+ ax1 = fig.add_subplot(121, projection="3d")
1557
+ ax2 = fig.add_subplot(122, projection="3d")
1598
1558
 
1599
1559
  for i, batch in enumerate(unique_batches):
1600
1560
  mask = batch_labels == batch
1601
1561
  if n_components == 2:
1602
1562
  ax1.scatter(
1603
- X_orig[mask, 0], X_orig[mask, 1],
1563
+ X_orig[mask, 0],
1564
+ X_orig[mask, 1],
1604
1565
  c=[colors[i]],
1605
1566
  s=point_size,
1606
- alpha=alpha,
1607
- label=f'Batch {batch}',
1608
- edgecolors='black',
1609
- linewidth=0.5
1567
+ alpha=alpha,
1568
+ label=f"Batch {batch}",
1569
+ edgecolors="black",
1570
+ linewidth=0.5,
1610
1571
  )
1611
1572
  else:
1612
1573
  ax1.scatter(
1613
- X_orig[mask, 0], X_orig[mask, 1], X_orig[mask, 2],
1574
+ X_orig[mask, 0],
1575
+ X_orig[mask, 1],
1576
+ X_orig[mask, 2],
1614
1577
  c=[colors[i]],
1615
1578
  s=point_size,
1616
- alpha=alpha,
1617
- label=f'Batch {batch}',
1618
- edgecolors='black',
1619
- linewidth=0.5
1579
+ alpha=alpha,
1580
+ label=f"Batch {batch}",
1581
+ edgecolors="black",
1582
+ linewidth=0.5,
1620
1583
  )
1621
1584
 
1622
- ax1.set_title(f'Before ComBat correction\n({method.upper()})')
1623
- ax1.set_xlabel(f'{method.upper()}1')
1624
- ax1.set_ylabel(f'{method.upper()}2')
1585
+ ax1.set_title(f"Before ComBat correction\n({method.upper()})")
1586
+ ax1.set_xlabel(f"{method.upper()}1")
1587
+ ax1.set_ylabel(f"{method.upper()}2")
1625
1588
  if n_components == 3:
1626
- ax1.set_zlabel(f'{method.upper()}3')
1589
+ ax1.set_zlabel(f"{method.upper()}3")
1627
1590
 
1628
1591
  for i, batch in enumerate(unique_batches):
1629
1592
  mask = batch_labels == batch
1630
1593
  if n_components == 2:
1631
1594
  ax2.scatter(
1632
- X_trans[mask, 0], X_trans[mask, 1],
1595
+ X_trans[mask, 0],
1596
+ X_trans[mask, 1],
1633
1597
  c=[colors[i]],
1634
1598
  s=point_size,
1635
- alpha=alpha,
1636
- label=f'Batch {batch}',
1637
- edgecolors='black',
1638
- linewidth=0.5
1599
+ alpha=alpha,
1600
+ label=f"Batch {batch}",
1601
+ edgecolors="black",
1602
+ linewidth=0.5,
1639
1603
  )
1640
1604
  else:
1641
1605
  ax2.scatter(
1642
- X_trans[mask, 0], X_trans[mask, 1], X_trans[mask, 2],
1606
+ X_trans[mask, 0],
1607
+ X_trans[mask, 1],
1608
+ X_trans[mask, 2],
1643
1609
  c=[colors[i]],
1644
1610
  s=point_size,
1645
- alpha=alpha,
1646
- label=f'Batch {batch}',
1647
- edgecolors='black',
1648
- linewidth=0.5
1611
+ alpha=alpha,
1612
+ label=f"Batch {batch}",
1613
+ edgecolors="black",
1614
+ linewidth=0.5,
1649
1615
  )
1650
1616
 
1651
- ax2.set_title(f'After ComBat correction\n({method.upper()})')
1652
- ax2.set_xlabel(f'{method.upper()}1')
1653
- ax2.set_ylabel(f'{method.upper()}2')
1617
+ ax2.set_title(f"After ComBat correction\n({method.upper()})")
1618
+ ax2.set_xlabel(f"{method.upper()}1")
1619
+ ax2.set_ylabel(f"{method.upper()}2")
1654
1620
  if n_components == 3:
1655
- ax2.set_zlabel(f'{method.upper()}3')
1621
+ ax2.set_zlabel(f"{method.upper()}3")
1656
1622
 
1657
1623
  if show_legend and n_batches <= 20:
1658
- ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
1624
+ ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
1659
1625
 
1660
1626
  if title is None:
1661
- title = f'ComBat correction effect visualized with {method.upper()}'
1662
- fig.suptitle(title, fontsize=14, fontweight='bold')
1627
+ title = f"ComBat correction effect visualized with {method.upper()}"
1628
+ fig.suptitle(title, fontsize=14, fontweight="bold")
1663
1629
 
1664
1630
  plt.tight_layout()
1665
1631
  return fig
1666
1632
 
1667
1633
  def _create_interactive_plot(
1668
- self,
1669
- X_orig: FloatArray,
1670
- X_trans: FloatArray,
1671
- batch_labels: pd.Series,
1672
- method: str,
1673
- n_components: int,
1674
- cmap: str,
1675
- title: Optional[str],
1676
- show_legend: bool) -> Any:
1634
+ self,
1635
+ X_orig: FloatArray,
1636
+ X_trans: FloatArray,
1637
+ batch_labels: pd.Series,
1638
+ method: str,
1639
+ n_components: int,
1640
+ cmap: str,
1641
+ title: str | None,
1642
+ show_legend: bool,
1643
+ ) -> Any:
1677
1644
  """Create interactive plots using plotly."""
1678
1645
  if n_components == 2:
1679
1646
  fig = make_subplots(
1680
- rows=1, cols=2,
1647
+ rows=1,
1648
+ cols=2,
1681
1649
  subplot_titles=(
1682
- f'Before ComBat correction ({method.upper()})',
1683
- f'After ComBat correction ({method.upper()})'
1684
- )
1650
+ f"Before ComBat correction ({method.upper()})",
1651
+ f"After ComBat correction ({method.upper()})",
1652
+ ),
1685
1653
  )
1686
1654
  else:
1687
1655
  fig = make_subplots(
1688
- rows=1, cols=2,
1689
- specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
1656
+ rows=1,
1657
+ cols=2,
1658
+ specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
1690
1659
  subplot_titles=(
1691
- f'Before ComBat correction ({method.upper()})',
1692
- f'After ComBat correction ({method.upper()})'
1693
- )
1660
+ f"Before ComBat correction ({method.upper()})",
1661
+ f"After ComBat correction ({method.upper()})",
1662
+ ),
1694
1663
  )
1695
1664
 
1696
1665
  unique_batches = batch_labels.drop_duplicates()
1697
1666
 
1698
1667
  n_batches = len(unique_batches)
1699
1668
  cmap_func = matplotlib.colormaps.get_cmap(cmap)
1700
- color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
1669
+ color_list = [
1670
+ mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
1671
+ ]
1701
1672
 
1702
- batch_to_color = dict(zip(unique_batches, color_list))
1673
+ batch_to_color = dict(zip(unique_batches, color_list, strict=True))
1703
1674
 
1704
1675
  for batch in unique_batches:
1705
1676
  mask = batch_labels == batch
@@ -1707,72 +1678,86 @@ class ComBat(BaseEstimator, TransformerMixin):
1707
1678
  if n_components == 2:
1708
1679
  fig.add_trace(
1709
1680
  go.Scatter(
1710
- x=X_orig[mask, 0], y=X_orig[mask, 1],
1711
- mode='markers',
1712
- name=f'Batch {batch}',
1713
- marker=dict(
1714
- size=8,
1715
- color=batch_to_color[batch],
1716
- line=dict(width=1, color='black')
1717
- ),
1718
- showlegend=False),
1719
- row=1, col=1
1681
+ x=X_orig[mask, 0],
1682
+ y=X_orig[mask, 1],
1683
+ mode="markers",
1684
+ name=f"Batch {batch}",
1685
+ marker={
1686
+ "size": 8,
1687
+ "color": batch_to_color[batch],
1688
+ "line": {"width": 1, "color": "black"},
1689
+ },
1690
+ showlegend=False,
1691
+ ),
1692
+ row=1,
1693
+ col=1,
1720
1694
  )
1721
1695
 
1722
1696
  fig.add_trace(
1723
1697
  go.Scatter(
1724
- x=X_trans[mask, 0], y=X_trans[mask, 1],
1725
- mode='markers',
1726
- name=f'Batch {batch}',
1727
- marker=dict(
1728
- size=8,
1729
- color=batch_to_color[batch],
1730
- line=dict(width=1, color='black')
1731
- ),
1732
- showlegend=show_legend),
1733
- row=1, col=2
1698
+ x=X_trans[mask, 0],
1699
+ y=X_trans[mask, 1],
1700
+ mode="markers",
1701
+ name=f"Batch {batch}",
1702
+ marker={
1703
+ "size": 8,
1704
+ "color": batch_to_color[batch],
1705
+ "line": {"width": 1, "color": "black"},
1706
+ },
1707
+ showlegend=show_legend,
1708
+ ),
1709
+ row=1,
1710
+ col=2,
1734
1711
  )
1735
1712
  else:
1736
1713
  fig.add_trace(
1737
1714
  go.Scatter3d(
1738
- x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
1739
- mode='markers',
1740
- name=f'Batch {batch}',
1741
- marker=dict(
1742
- size=5,
1743
- color=batch_to_color[batch],
1744
- line=dict(width=0.5, color='black')
1745
- ),
1746
- showlegend=False),
1747
- row=1, col=1
1715
+ x=X_orig[mask, 0],
1716
+ y=X_orig[mask, 1],
1717
+ z=X_orig[mask, 2],
1718
+ mode="markers",
1719
+ name=f"Batch {batch}",
1720
+ marker={
1721
+ "size": 5,
1722
+ "color": batch_to_color[batch],
1723
+ "line": {"width": 0.5, "color": "black"},
1724
+ },
1725
+ showlegend=False,
1726
+ ),
1727
+ row=1,
1728
+ col=1,
1748
1729
  )
1749
1730
 
1750
1731
  fig.add_trace(
1751
1732
  go.Scatter3d(
1752
- x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
1753
- mode='markers',
1754
- name=f'Batch {batch}',
1755
- marker=dict(
1756
- size=5,
1757
- color=batch_to_color[batch],
1758
- line=dict(width=0.5, color='black')
1759
- ),
1760
- showlegend=show_legend),
1761
- row=1, col=2
1733
+ x=X_trans[mask, 0],
1734
+ y=X_trans[mask, 1],
1735
+ z=X_trans[mask, 2],
1736
+ mode="markers",
1737
+ name=f"Batch {batch}",
1738
+ marker={
1739
+ "size": 5,
1740
+ "color": batch_to_color[batch],
1741
+ "line": {"width": 0.5, "color": "black"},
1742
+ },
1743
+ showlegend=show_legend,
1744
+ ),
1745
+ row=1,
1746
+ col=2,
1762
1747
  )
1763
1748
 
1764
1749
  if title is None:
1765
- title = f'ComBat correction effect visualized with {method.upper()}'
1750
+ title = f"ComBat correction effect visualized with {method.upper()}"
1766
1751
 
1767
1752
  fig.update_layout(
1768
1753
  title=title,
1769
1754
  title_font_size=16,
1770
1755
  height=600,
1771
1756
  showlegend=show_legend,
1772
- hovermode='closest'
1757
+ hovermode="closest",
1773
1758
  )
1774
1759
 
1775
- axis_labels = [f'{method.upper()}{i+1}' for i in range(n_components)]
1760
+ axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
1776
1761
 
1777
1762
  if n_components == 2:
1778
1763
  fig.update_xaxes(title_text=axis_labels[0])
@@ -1781,7 +1766,7 @@ class ComBat(BaseEstimator, TransformerMixin):
1781
1766
  fig.update_scenes(
1782
1767
  xaxis_title=axis_labels[0],
1783
1768
  yaxis_title=axis_labels[1],
1784
- zaxis_title=axis_labels[2]
1769
+ zaxis_title=axis_labels[2],
1785
1770
  )
1786
1771
 
1787
1772
  return fig