combatlearn 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/combat.py +354 -369
- {combatlearn-1.1.0.dist-info → combatlearn-1.1.1.dist-info}/METADATA +13 -18
- combatlearn-1.1.1.dist-info/RECORD +7 -0
- combatlearn-1.1.0.dist-info/RECORD +0 -7
- {combatlearn-1.1.0.dist-info → combatlearn-1.1.1.dist-info}/WHEEL +0 -0
- {combatlearn-1.1.0.dist-info → combatlearn-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.1.0.dist-info → combatlearn-1.1.1.dist-info}/top_level.txt +0 -0
combatlearn/combat.py
CHANGED
|
@@ -8,36 +8,39 @@
|
|
|
8
8
|
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
9
9
|
the batch (and optional covariates) at construction.
|
|
10
10
|
"""
|
|
11
|
+
|
|
11
12
|
from __future__ import annotations
|
|
12
13
|
|
|
14
|
+
import warnings
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
import matplotlib
|
|
18
|
+
import matplotlib.colors as mcolors
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
13
20
|
import numpy as np
|
|
14
21
|
import numpy.linalg as la
|
|
22
|
+
import numpy.typing as npt
|
|
15
23
|
import pandas as pd
|
|
24
|
+
import plotly.graph_objects as go
|
|
25
|
+
import umap
|
|
26
|
+
from plotly.subplots import make_subplots
|
|
27
|
+
from scipy.spatial.distance import pdist
|
|
28
|
+
from scipy.stats import chi2, levene, spearmanr
|
|
16
29
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
17
30
|
from sklearn.decomposition import PCA
|
|
18
31
|
from sklearn.manifold import TSNE
|
|
32
|
+
from sklearn.metrics import davies_bouldin_score, silhouette_score
|
|
19
33
|
from sklearn.neighbors import NearestNeighbors
|
|
20
|
-
from sklearn.metrics import silhouette_score, davies_bouldin_score
|
|
21
|
-
from scipy.stats import levene, spearmanr, chi2
|
|
22
|
-
from scipy.spatial.distance import pdist
|
|
23
|
-
import matplotlib
|
|
24
|
-
import matplotlib.pyplot as plt
|
|
25
|
-
import matplotlib.colors as mcolors
|
|
26
|
-
from typing import Literal, Optional, Union, Dict, Tuple, Any, List
|
|
27
|
-
import numpy.typing as npt
|
|
28
|
-
import warnings
|
|
29
|
-
import umap
|
|
30
|
-
import plotly.graph_objects as go
|
|
31
|
-
from plotly.subplots import make_subplots
|
|
32
34
|
|
|
33
|
-
ArrayLike =
|
|
35
|
+
ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
|
|
34
36
|
FloatArray = npt.NDArray[np.float64]
|
|
35
37
|
|
|
38
|
+
|
|
36
39
|
def _compute_pca_embedding(
|
|
37
40
|
X_before: np.ndarray,
|
|
38
41
|
X_after: np.ndarray,
|
|
39
42
|
n_components: int,
|
|
40
|
-
) ->
|
|
43
|
+
) -> tuple[np.ndarray, np.ndarray, PCA]:
|
|
41
44
|
"""
|
|
42
45
|
Compute PCA embeddings for both datasets.
|
|
43
46
|
|
|
@@ -91,7 +94,7 @@ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
|
91
94
|
if len(unique_batches) < 2:
|
|
92
95
|
return 0.0
|
|
93
96
|
try:
|
|
94
|
-
return silhouette_score(X, batch_labels, metric=
|
|
97
|
+
return silhouette_score(X, batch_labels, metric="euclidean")
|
|
95
98
|
except Exception:
|
|
96
99
|
return 0.0
|
|
97
100
|
|
|
@@ -129,7 +132,7 @@ def _kbet_score(
|
|
|
129
132
|
batch_labels: np.ndarray,
|
|
130
133
|
k0: int,
|
|
131
134
|
alpha: float = 0.05,
|
|
132
|
-
) ->
|
|
135
|
+
) -> tuple[float, float]:
|
|
133
136
|
"""
|
|
134
137
|
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
|
|
135
138
|
|
|
@@ -166,7 +169,7 @@ def _kbet_score(
|
|
|
166
169
|
global_freq = batch_counts / n_samples
|
|
167
170
|
k0 = min(k0, n_samples - 1)
|
|
168
171
|
|
|
169
|
-
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm=
|
|
172
|
+
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
|
|
170
173
|
nn.fit(X)
|
|
171
174
|
_, indices = nn.kneighbors(X)
|
|
172
175
|
|
|
@@ -175,7 +178,7 @@ def _kbet_score(
|
|
|
175
178
|
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
176
179
|
|
|
177
180
|
for i in range(n_samples):
|
|
178
|
-
neighbors = indices[i, 1:k0+1]
|
|
181
|
+
neighbors = indices[i, 1 : k0 + 1]
|
|
179
182
|
neighbor_batches = batch_labels[neighbors]
|
|
180
183
|
|
|
181
184
|
observed = np.zeros(n_batches)
|
|
@@ -188,7 +191,7 @@ def _kbet_score(
|
|
|
188
191
|
if mask.sum() < 2:
|
|
189
192
|
continue
|
|
190
193
|
|
|
191
|
-
stat = np.sum((observed[mask] - expected[mask])**2 / expected[mask])
|
|
194
|
+
stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
|
|
192
195
|
df = max(1, mask.sum() - 1)
|
|
193
196
|
p_val = 1 - chi2.cdf(stat, df)
|
|
194
197
|
|
|
@@ -230,7 +233,7 @@ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e
|
|
|
230
233
|
sigma = 1.0
|
|
231
234
|
|
|
232
235
|
for _ in range(50):
|
|
233
|
-
P = np.exp(-distances**2 / (2 * sigma**2 + 1e-10))
|
|
236
|
+
P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
|
|
234
237
|
P_sum = P.sum()
|
|
235
238
|
if P_sum < 1e-10:
|
|
236
239
|
sigma = (sigma + sigma_max) / 2
|
|
@@ -241,7 +244,7 @@ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e
|
|
|
241
244
|
|
|
242
245
|
if abs(H - target_H) < tol:
|
|
243
246
|
break
|
|
244
|
-
elif
|
|
247
|
+
elif target_H > H:
|
|
245
248
|
sigma_min = sigma
|
|
246
249
|
else:
|
|
247
250
|
sigma_max = sigma
|
|
@@ -287,7 +290,7 @@ def _lisi_score(
|
|
|
287
290
|
|
|
288
291
|
k = min(3 * perplexity, n_samples - 1)
|
|
289
292
|
|
|
290
|
-
nn = NearestNeighbors(n_neighbors=k + 1, algorithm=
|
|
293
|
+
nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
|
|
291
294
|
nn.fit(X)
|
|
292
295
|
distances, indices = nn.kneighbors(X)
|
|
293
296
|
|
|
@@ -299,7 +302,7 @@ def _lisi_score(
|
|
|
299
302
|
for i in range(n_samples):
|
|
300
303
|
sigma = _find_sigma(distances[i], perplexity)
|
|
301
304
|
|
|
302
|
-
P = np.exp(-distances[i]**2 / (2 * sigma**2 + 1e-10))
|
|
305
|
+
P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
|
|
303
306
|
P_sum = P.sum()
|
|
304
307
|
if P_sum < 1e-10:
|
|
305
308
|
lisi_values.append(1.0)
|
|
@@ -312,10 +315,7 @@ def _lisi_score(
|
|
|
312
315
|
batch_probs[batch_to_idx[nb]] += P[j]
|
|
313
316
|
|
|
314
317
|
simpson = np.sum(batch_probs**2)
|
|
315
|
-
if simpson < 1e-10
|
|
316
|
-
lisi = n_batches
|
|
317
|
-
else:
|
|
318
|
-
lisi = 1.0 / simpson
|
|
318
|
+
lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
|
|
319
319
|
lisi_values.append(lisi)
|
|
320
320
|
|
|
321
321
|
return np.mean(lisi_values)
|
|
@@ -358,11 +358,11 @@ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
|
358
358
|
X_batch = X[mask]
|
|
359
359
|
batch_mean = np.mean(X_batch, axis=0)
|
|
360
360
|
|
|
361
|
-
between_var += n_b * np.sum((batch_mean - grand_mean)**2)
|
|
362
|
-
within_var += np.sum((X_batch - batch_mean)**2)
|
|
361
|
+
between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
|
|
362
|
+
within_var += np.sum((X_batch - batch_mean) ** 2)
|
|
363
363
|
|
|
364
|
-
between_var /=
|
|
365
|
-
within_var /=
|
|
364
|
+
between_var /= n_batches - 1
|
|
365
|
+
within_var /= n_samples - n_batches
|
|
366
366
|
|
|
367
367
|
if within_var < 1e-10:
|
|
368
368
|
return 0.0
|
|
@@ -373,9 +373,9 @@ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
|
373
373
|
def _knn_preservation(
|
|
374
374
|
X_before: np.ndarray,
|
|
375
375
|
X_after: np.ndarray,
|
|
376
|
-
k_values:
|
|
376
|
+
k_values: list[int],
|
|
377
377
|
n_jobs: int = 1,
|
|
378
|
-
) ->
|
|
378
|
+
) -> dict[int, float]:
|
|
379
379
|
"""
|
|
380
380
|
Compute fraction of k-nearest neighbors preserved after correction.
|
|
381
381
|
|
|
@@ -402,11 +402,11 @@ def _knn_preservation(
|
|
|
402
402
|
max_k = max(k_values)
|
|
403
403
|
max_k = min(max_k, X_before.shape[0] - 1)
|
|
404
404
|
|
|
405
|
-
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm=
|
|
405
|
+
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
406
406
|
nn_before.fit(X_before)
|
|
407
407
|
_, indices_before = nn_before.kneighbors(X_before)
|
|
408
408
|
|
|
409
|
-
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm=
|
|
409
|
+
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
410
410
|
nn_after.fit(X_after)
|
|
411
411
|
_, indices_after = nn_after.kneighbors(X_after)
|
|
412
412
|
|
|
@@ -417,8 +417,8 @@ def _knn_preservation(
|
|
|
417
417
|
|
|
418
418
|
overlaps = []
|
|
419
419
|
for i in range(X_before.shape[0]):
|
|
420
|
-
neighbors_before = set(indices_before[i, 1:k+1])
|
|
421
|
-
neighbors_after = set(indices_after[i, 1:k+1])
|
|
420
|
+
neighbors_before = set(indices_before[i, 1 : k + 1])
|
|
421
|
+
neighbors_after = set(indices_after[i, 1 : k + 1])
|
|
422
422
|
overlap = len(neighbors_before & neighbors_after) / k
|
|
423
423
|
overlaps.append(overlap)
|
|
424
424
|
|
|
@@ -463,8 +463,8 @@ def _pairwise_distance_correlation(
|
|
|
463
463
|
X_before = X_before[idx]
|
|
464
464
|
X_after = X_after[idx]
|
|
465
465
|
|
|
466
|
-
dist_before = pdist(X_before, metric=
|
|
467
|
-
dist_after = pdist(X_after, metric=
|
|
466
|
+
dist_before = pdist(X_before, metric="euclidean")
|
|
467
|
+
dist_after = pdist(X_after, metric="euclidean")
|
|
468
468
|
|
|
469
469
|
if len(dist_before) == 0:
|
|
470
470
|
return 1.0
|
|
@@ -508,7 +508,7 @@ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
|
508
508
|
centroids.append(centroid)
|
|
509
509
|
|
|
510
510
|
centroids = np.array(centroids)
|
|
511
|
-
distances = pdist(centroids, metric=
|
|
511
|
+
distances = pdist(centroids, metric="euclidean")
|
|
512
512
|
|
|
513
513
|
return np.mean(distances)
|
|
514
514
|
|
|
@@ -542,7 +542,7 @@ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
|
542
542
|
if len(groups) < 2:
|
|
543
543
|
continue
|
|
544
544
|
try:
|
|
545
|
-
stat, _ = levene(*groups, center=
|
|
545
|
+
stat, _ = levene(*groups, center="median")
|
|
546
546
|
if not np.isnan(stat):
|
|
547
547
|
levene_stats.append(stat)
|
|
548
548
|
except Exception:
|
|
@@ -583,24 +583,24 @@ class ComBatModel:
|
|
|
583
583
|
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
584
584
|
parametric: bool = True,
|
|
585
585
|
mean_only: bool = False,
|
|
586
|
-
reference_batch:
|
|
586
|
+
reference_batch: str | None = None,
|
|
587
587
|
eps: float = 1e-8,
|
|
588
|
-
covbat_cov_thresh:
|
|
588
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
589
589
|
) -> None:
|
|
590
590
|
self.method: str = method
|
|
591
591
|
self.parametric: bool = parametric
|
|
592
592
|
self.mean_only: bool = bool(mean_only)
|
|
593
|
-
self.reference_batch:
|
|
593
|
+
self.reference_batch: str | None = reference_batch
|
|
594
594
|
self.eps: float = float(eps)
|
|
595
|
-
self.covbat_cov_thresh:
|
|
595
|
+
self.covbat_cov_thresh: float | int = covbat_cov_thresh
|
|
596
596
|
|
|
597
597
|
self._batch_levels: pd.Index
|
|
598
598
|
self._grand_mean: pd.Series
|
|
599
599
|
self._pooled_var: pd.Series
|
|
600
600
|
self._gamma_star: FloatArray
|
|
601
601
|
self._delta_star: FloatArray
|
|
602
|
-
self._n_per_batch:
|
|
603
|
-
self._reference_batch_idx:
|
|
602
|
+
self._n_per_batch: dict[str, int]
|
|
603
|
+
self._reference_batch_idx: int | None
|
|
604
604
|
self._beta_hat_nonbatch: FloatArray
|
|
605
605
|
self._n_batch: int
|
|
606
606
|
self._p_design: int
|
|
@@ -621,26 +621,15 @@ class ComBatModel:
|
|
|
621
621
|
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
622
622
|
|
|
623
623
|
@staticmethod
|
|
624
|
-
def _as_series(
|
|
625
|
-
arr: ArrayLike,
|
|
626
|
-
index: pd.Index,
|
|
627
|
-
name: str
|
|
628
|
-
) -> pd.Series:
|
|
624
|
+
def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
|
|
629
625
|
"""Convert array-like to categorical Series with validation."""
|
|
630
|
-
if isinstance(arr, pd.Series)
|
|
631
|
-
ser = arr.copy()
|
|
632
|
-
else:
|
|
633
|
-
ser = pd.Series(arr, index=index, name=name)
|
|
626
|
+
ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
|
|
634
627
|
if not ser.index.equals(index):
|
|
635
628
|
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
636
629
|
return ser.astype("category")
|
|
637
630
|
|
|
638
631
|
@staticmethod
|
|
639
|
-
def _to_df(
|
|
640
|
-
arr: Optional[ArrayLike],
|
|
641
|
-
index: pd.Index,
|
|
642
|
-
name: str
|
|
643
|
-
) -> Optional[pd.DataFrame]:
|
|
632
|
+
def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
|
|
644
633
|
"""Convert array-like to DataFrame."""
|
|
645
634
|
if arr is None:
|
|
646
635
|
return None
|
|
@@ -655,11 +644,11 @@ class ComBatModel:
|
|
|
655
644
|
def fit(
|
|
656
645
|
self,
|
|
657
646
|
X: ArrayLike,
|
|
658
|
-
y:
|
|
647
|
+
y: ArrayLike | None = None,
|
|
659
648
|
*,
|
|
660
649
|
batch: ArrayLike,
|
|
661
|
-
discrete_covariates:
|
|
662
|
-
continuous_covariates:
|
|
650
|
+
discrete_covariates: ArrayLike | None = None,
|
|
651
|
+
continuous_covariates: ArrayLike | None = None,
|
|
663
652
|
) -> ComBatModel:
|
|
664
653
|
"""Fit the ComBat model."""
|
|
665
654
|
method = self.method.lower()
|
|
@@ -681,9 +670,7 @@ class ComBatModel:
|
|
|
681
670
|
|
|
682
671
|
if method == "johnson":
|
|
683
672
|
if disc is not None or cont is not None:
|
|
684
|
-
warnings.warn(
|
|
685
|
-
"Covariates are ignored when using method='johnson'."
|
|
686
|
-
)
|
|
673
|
+
warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
|
|
687
674
|
self._fit_johnson(X, batch)
|
|
688
675
|
elif method == "fortin":
|
|
689
676
|
self._fit_fortin(X, batch, disc, cont)
|
|
@@ -691,11 +678,7 @@ class ComBatModel:
|
|
|
691
678
|
self._fit_chen(X, batch, disc, cont)
|
|
692
679
|
return self
|
|
693
680
|
|
|
694
|
-
def _fit_johnson(
|
|
695
|
-
self,
|
|
696
|
-
X: pd.DataFrame,
|
|
697
|
-
batch: pd.Series
|
|
698
|
-
) -> None:
|
|
681
|
+
def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
|
|
699
682
|
"""Johnson et al. (2007) ComBat."""
|
|
700
683
|
self._batch_levels = batch.cat.categories
|
|
701
684
|
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
@@ -703,10 +686,10 @@ class ComBatModel:
|
|
|
703
686
|
|
|
704
687
|
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
705
688
|
|
|
706
|
-
n_per_batch:
|
|
689
|
+
n_per_batch: dict[str, int] = {}
|
|
707
690
|
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
708
691
|
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
709
|
-
|
|
692
|
+
|
|
710
693
|
for lvl in self._batch_levels:
|
|
711
694
|
idx = batch == lvl
|
|
712
695
|
n_b = int(idx.sum())
|
|
@@ -751,8 +734,8 @@ class ComBatModel:
|
|
|
751
734
|
self,
|
|
752
735
|
X: pd.DataFrame,
|
|
753
736
|
batch: pd.Series,
|
|
754
|
-
disc:
|
|
755
|
-
cont:
|
|
737
|
+
disc: pd.DataFrame | None,
|
|
738
|
+
cont: pd.DataFrame | None,
|
|
756
739
|
) -> None:
|
|
757
740
|
"""Fortin et al. (2018) neuroComBat."""
|
|
758
741
|
self._batch_levels = batch.cat.categories
|
|
@@ -770,11 +753,7 @@ class ComBatModel:
|
|
|
770
753
|
|
|
771
754
|
parts: list[pd.DataFrame] = [batch_dummies]
|
|
772
755
|
if disc is not None:
|
|
773
|
-
parts.append(
|
|
774
|
-
pd.get_dummies(
|
|
775
|
-
disc.astype("category"), drop_first=True
|
|
776
|
-
).astype(float)
|
|
777
|
-
)
|
|
756
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
778
757
|
|
|
779
758
|
if cont is not None:
|
|
780
759
|
parts.append(cont.astype(float))
|
|
@@ -789,7 +768,7 @@ class ComBatModel:
|
|
|
789
768
|
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
790
769
|
|
|
791
770
|
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
792
|
-
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
|
|
771
|
+
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
|
|
793
772
|
|
|
794
773
|
if self.reference_batch is not None:
|
|
795
774
|
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
@@ -807,30 +786,25 @@ class ComBatModel:
|
|
|
807
786
|
else:
|
|
808
787
|
resid = X_np - design @ beta_hat
|
|
809
788
|
denom = n_samples
|
|
810
|
-
var_pooled = (resid
|
|
789
|
+
var_pooled = (resid**2).sum(axis=0) / denom + self.eps
|
|
811
790
|
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
812
791
|
|
|
813
792
|
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
814
793
|
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
815
794
|
|
|
816
|
-
gamma_hat = np.vstack(
|
|
817
|
-
[Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
|
|
818
|
-
)
|
|
795
|
+
gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
|
|
819
796
|
delta_hat = np.vstack(
|
|
820
|
-
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
|
|
821
|
-
for lvl in self._batch_levels]
|
|
797
|
+
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
|
|
822
798
|
)
|
|
823
799
|
|
|
824
800
|
if self.mean_only:
|
|
825
801
|
gamma_star = self._shrink_gamma(
|
|
826
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
827
|
-
parametric = self.parametric
|
|
802
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
828
803
|
)
|
|
829
804
|
delta_star = np.ones_like(delta_hat)
|
|
830
805
|
else:
|
|
831
806
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
832
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
833
|
-
parametric = self.parametric
|
|
807
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
834
808
|
)
|
|
835
809
|
|
|
836
810
|
if ref_idx is not None:
|
|
@@ -841,15 +815,15 @@ class ComBatModel:
|
|
|
841
815
|
|
|
842
816
|
self._gamma_star = gamma_star
|
|
843
817
|
self._delta_star = delta_star
|
|
844
|
-
self._n_batch
|
|
818
|
+
self._n_batch = n_batch
|
|
845
819
|
self._p_design = p_design
|
|
846
|
-
|
|
820
|
+
|
|
847
821
|
def _fit_chen(
|
|
848
822
|
self,
|
|
849
823
|
X: pd.DataFrame,
|
|
850
824
|
batch: pd.Series,
|
|
851
|
-
disc:
|
|
852
|
-
cont:
|
|
825
|
+
disc: pd.DataFrame | None,
|
|
826
|
+
cont: pd.DataFrame | None,
|
|
853
827
|
) -> None:
|
|
854
828
|
"""Chen et al. (2022) CovBat."""
|
|
855
829
|
self._fit_fortin(X, batch, disc, cont)
|
|
@@ -868,7 +842,7 @@ class ComBatModel:
|
|
|
868
842
|
self._covbat_n_pc = n_pc
|
|
869
843
|
|
|
870
844
|
scores = pca.transform(X_centered)[:, :n_pc]
|
|
871
|
-
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
|
|
845
|
+
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
|
|
872
846
|
self._batch_levels_pc = self._batch_levels
|
|
873
847
|
n_per_batch = self._n_per_batch
|
|
874
848
|
|
|
@@ -907,12 +881,12 @@ class ComBatModel:
|
|
|
907
881
|
self,
|
|
908
882
|
gamma_hat: FloatArray,
|
|
909
883
|
delta_hat: FloatArray,
|
|
910
|
-
n_per_batch:
|
|
884
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
911
885
|
*,
|
|
912
886
|
parametric: bool,
|
|
913
887
|
max_iter: int = 100,
|
|
914
888
|
tol: float = 1e-4,
|
|
915
|
-
) ->
|
|
889
|
+
) -> tuple[FloatArray, FloatArray]:
|
|
916
890
|
"""Empirical Bayes shrinkage estimation."""
|
|
917
891
|
if parametric:
|
|
918
892
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
@@ -920,10 +894,14 @@ class ComBatModel:
|
|
|
920
894
|
a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
|
|
921
895
|
b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
|
|
922
896
|
|
|
923
|
-
B,
|
|
897
|
+
B, _p = gamma_hat.shape
|
|
924
898
|
gamma_star = np.empty_like(gamma_hat)
|
|
925
899
|
delta_star = np.empty_like(delta_hat)
|
|
926
|
-
n_vec =
|
|
900
|
+
n_vec = (
|
|
901
|
+
np.array(list(n_per_batch.values()))
|
|
902
|
+
if isinstance(n_per_batch, dict)
|
|
903
|
+
else n_per_batch
|
|
904
|
+
)
|
|
927
905
|
|
|
928
906
|
for i in range(B):
|
|
929
907
|
n_i = n_vec[i]
|
|
@@ -937,8 +915,12 @@ class ComBatModel:
|
|
|
937
915
|
return gamma_star, delta_star
|
|
938
916
|
|
|
939
917
|
else:
|
|
940
|
-
B,
|
|
941
|
-
n_vec =
|
|
918
|
+
B, _p = gamma_hat.shape
|
|
919
|
+
n_vec = (
|
|
920
|
+
np.array(list(n_per_batch.values()))
|
|
921
|
+
if isinstance(n_per_batch, dict)
|
|
922
|
+
else n_per_batch
|
|
923
|
+
)
|
|
942
924
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
943
925
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
944
926
|
|
|
@@ -947,27 +929,22 @@ class ComBatModel:
|
|
|
947
929
|
g_bar: FloatArray,
|
|
948
930
|
n: float,
|
|
949
931
|
d_star: FloatArray,
|
|
950
|
-
t2_: FloatArray
|
|
932
|
+
t2_: FloatArray,
|
|
951
933
|
) -> FloatArray:
|
|
952
934
|
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
953
935
|
|
|
954
|
-
def postvar(
|
|
955
|
-
sum2: FloatArray,
|
|
956
|
-
n: float,
|
|
957
|
-
a: FloatArray,
|
|
958
|
-
b: FloatArray
|
|
959
|
-
) -> FloatArray:
|
|
936
|
+
def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
|
|
960
937
|
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
961
938
|
|
|
962
939
|
def aprior(delta: FloatArray) -> FloatArray:
|
|
963
940
|
m, s2 = delta.mean(), delta.var()
|
|
964
941
|
s2 = max(s2, self.eps)
|
|
965
|
-
return (2 * s2 + m
|
|
942
|
+
return (2 * s2 + m**2) / s2
|
|
966
943
|
|
|
967
944
|
def bprior(delta: FloatArray) -> FloatArray:
|
|
968
945
|
m, s2 = delta.mean(), delta.var()
|
|
969
946
|
s2 = max(s2, self.eps)
|
|
970
|
-
return (m * s2 + m
|
|
947
|
+
return (m * s2 + m**3) / s2
|
|
971
948
|
|
|
972
949
|
gamma_star = np.empty_like(gamma_hat)
|
|
973
950
|
delta_star = np.empty_like(delta_hat)
|
|
@@ -986,7 +963,8 @@ class ComBatModel:
|
|
|
986
963
|
sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
|
|
987
964
|
d_new = postvar(sum2, n_i, a_i, b_i)
|
|
988
965
|
if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
|
|
989
|
-
self.mean_only
|
|
966
|
+
self.mean_only
|
|
967
|
+
or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
|
|
990
968
|
):
|
|
991
969
|
break
|
|
992
970
|
gamma_star[i] = g_new
|
|
@@ -997,12 +975,14 @@ class ComBatModel:
|
|
|
997
975
|
self,
|
|
998
976
|
gamma_hat: FloatArray,
|
|
999
977
|
delta_hat: FloatArray,
|
|
1000
|
-
n_per_batch:
|
|
978
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
1001
979
|
*,
|
|
1002
980
|
parametric: bool,
|
|
1003
981
|
) -> FloatArray:
|
|
1004
|
-
"""Convenience wrapper that returns only
|
|
1005
|
-
gamma, _ = self._shrink_gamma_delta(
|
|
982
|
+
"""Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
|
|
983
|
+
gamma, _ = self._shrink_gamma_delta(
|
|
984
|
+
gamma_hat, delta_hat, n_per_batch, parametric=parametric
|
|
985
|
+
)
|
|
1006
986
|
return gamma
|
|
1007
987
|
|
|
1008
988
|
def transform(
|
|
@@ -1010,12 +990,14 @@ class ComBatModel:
|
|
|
1010
990
|
X: ArrayLike,
|
|
1011
991
|
*,
|
|
1012
992
|
batch: ArrayLike,
|
|
1013
|
-
discrete_covariates:
|
|
1014
|
-
continuous_covariates:
|
|
993
|
+
discrete_covariates: ArrayLike | None = None,
|
|
994
|
+
continuous_covariates: ArrayLike | None = None,
|
|
1015
995
|
) -> pd.DataFrame:
|
|
1016
996
|
"""Transform the data using fitted ComBat parameters."""
|
|
1017
997
|
if not hasattr(self, "_gamma_star"):
|
|
1018
|
-
raise ValueError(
|
|
998
|
+
raise ValueError(
|
|
999
|
+
"This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
|
|
1000
|
+
)
|
|
1019
1001
|
if not isinstance(X, pd.DataFrame):
|
|
1020
1002
|
X = pd.DataFrame(X)
|
|
1021
1003
|
idx = X.index
|
|
@@ -1036,11 +1018,7 @@ class ComBatModel:
|
|
|
1036
1018
|
else:
|
|
1037
1019
|
raise ValueError(f"Unknown method: {method}.")
|
|
1038
1020
|
|
|
1039
|
-
def _transform_johnson(
|
|
1040
|
-
self,
|
|
1041
|
-
X: pd.DataFrame,
|
|
1042
|
-
batch: pd.Series
|
|
1043
|
-
) -> pd.DataFrame:
|
|
1021
|
+
def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
|
|
1044
1022
|
"""Johnson transform implementation."""
|
|
1045
1023
|
pooled = self._pooled_var
|
|
1046
1024
|
grand = self._grand_mean
|
|
@@ -1058,10 +1036,7 @@ class ComBatModel:
|
|
|
1058
1036
|
|
|
1059
1037
|
g = self._gamma_star[i]
|
|
1060
1038
|
d = self._delta_star[i]
|
|
1061
|
-
if self.mean_only
|
|
1062
|
-
Xb = Xs.loc[idx] - g
|
|
1063
|
-
else:
|
|
1064
|
-
Xb = (Xs.loc[idx] - g) / np.sqrt(d)
|
|
1039
|
+
Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
|
|
1065
1040
|
X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
|
|
1066
1041
|
return X_adj
|
|
1067
1042
|
|
|
@@ -1069,8 +1044,8 @@ class ComBatModel:
|
|
|
1069
1044
|
self,
|
|
1070
1045
|
X: pd.DataFrame,
|
|
1071
1046
|
batch: pd.Series,
|
|
1072
|
-
disc:
|
|
1073
|
-
cont:
|
|
1047
|
+
disc: pd.DataFrame | None,
|
|
1048
|
+
cont: pd.DataFrame | None,
|
|
1074
1049
|
) -> pd.DataFrame:
|
|
1075
1050
|
"""Fortin transform implementation."""
|
|
1076
1051
|
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
@@ -1079,21 +1054,14 @@ class ComBatModel:
|
|
|
1079
1054
|
|
|
1080
1055
|
parts = [batch_dummies]
|
|
1081
1056
|
if disc is not None:
|
|
1082
|
-
parts.append(
|
|
1083
|
-
pd.get_dummies(
|
|
1084
|
-
disc.astype("category"), drop_first=True
|
|
1085
|
-
).astype(float)
|
|
1086
|
-
)
|
|
1057
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
1087
1058
|
if cont is not None:
|
|
1088
1059
|
parts.append(cont.astype(float))
|
|
1089
1060
|
|
|
1090
1061
|
design = pd.concat(parts, axis=1).values
|
|
1091
1062
|
|
|
1092
1063
|
X_np = X.values
|
|
1093
|
-
stand_mu =
|
|
1094
|
-
self._grand_mean.values +
|
|
1095
|
-
design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
1096
|
-
)
|
|
1064
|
+
stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
|
|
1097
1065
|
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
1098
1066
|
|
|
1099
1067
|
for i, lvl in enumerate(self._batch_levels):
|
|
@@ -1111,18 +1079,15 @@ class ComBatModel:
|
|
|
1111
1079
|
else:
|
|
1112
1080
|
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
1113
1081
|
|
|
1114
|
-
X_adj = (
|
|
1115
|
-
Xs * np.sqrt(self._pooled_var.values) +
|
|
1116
|
-
stand_mu
|
|
1117
|
-
)
|
|
1082
|
+
X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
|
|
1118
1083
|
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
1119
|
-
|
|
1084
|
+
|
|
1120
1085
|
def _transform_chen(
|
|
1121
1086
|
self,
|
|
1122
1087
|
X: pd.DataFrame,
|
|
1123
1088
|
batch: pd.Series,
|
|
1124
|
-
disc:
|
|
1125
|
-
cont:
|
|
1089
|
+
disc: pd.DataFrame | None,
|
|
1090
|
+
cont: pd.DataFrame | None,
|
|
1126
1091
|
) -> pd.DataFrame:
|
|
1127
1092
|
"""Chen transform implementation."""
|
|
1128
1093
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
@@ -1159,14 +1124,14 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1159
1124
|
self,
|
|
1160
1125
|
batch: ArrayLike,
|
|
1161
1126
|
*,
|
|
1162
|
-
discrete_covariates:
|
|
1163
|
-
continuous_covariates:
|
|
1127
|
+
discrete_covariates: ArrayLike | None = None,
|
|
1128
|
+
continuous_covariates: ArrayLike | None = None,
|
|
1164
1129
|
method: str = "johnson",
|
|
1165
1130
|
parametric: bool = True,
|
|
1166
1131
|
mean_only: bool = False,
|
|
1167
|
-
reference_batch:
|
|
1132
|
+
reference_batch: str | None = None,
|
|
1168
1133
|
eps: float = 1e-8,
|
|
1169
|
-
covbat_cov_thresh:
|
|
1134
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
1170
1135
|
compute_metrics: bool = False,
|
|
1171
1136
|
) -> None:
|
|
1172
1137
|
self.batch = batch
|
|
@@ -1188,11 +1153,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1188
1153
|
covbat_cov_thresh=covbat_cov_thresh,
|
|
1189
1154
|
)
|
|
1190
1155
|
|
|
1191
|
-
def fit(
|
|
1192
|
-
self,
|
|
1193
|
-
X: ArrayLike,
|
|
1194
|
-
y: Optional[ArrayLike] = None
|
|
1195
|
-
) -> "ComBat":
|
|
1156
|
+
def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
|
|
1196
1157
|
"""Fit the ComBat model."""
|
|
1197
1158
|
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
1198
1159
|
batch_vec = self._subset(self.batch, idx)
|
|
@@ -1221,10 +1182,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1221
1182
|
)
|
|
1222
1183
|
|
|
1223
1184
|
@staticmethod
|
|
1224
|
-
def _subset(
|
|
1225
|
-
obj: Optional[ArrayLike],
|
|
1226
|
-
idx: pd.Index
|
|
1227
|
-
) -> Optional[Union[pd.DataFrame, pd.Series]]:
|
|
1185
|
+
def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
|
|
1228
1186
|
"""Subset array-like object by index."""
|
|
1229
1187
|
if obj is None:
|
|
1230
1188
|
return None
|
|
@@ -1237,7 +1195,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1237
1195
|
return pd.DataFrame(obj, index=idx)
|
|
1238
1196
|
|
|
1239
1197
|
@property
|
|
1240
|
-
def metrics_(self) ->
|
|
1198
|
+
def metrics_(self) -> dict[str, Any] | None:
|
|
1241
1199
|
"""Return cached metrics from last fit_transform with compute_metrics=True.
|
|
1242
1200
|
|
|
1243
1201
|
Returns
|
|
@@ -1245,19 +1203,19 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1245
1203
|
dict or None
|
|
1246
1204
|
Cached metrics dictionary, or None if no metrics have been computed.
|
|
1247
1205
|
"""
|
|
1248
|
-
return getattr(self,
|
|
1206
|
+
return getattr(self, "_metrics_cache", None)
|
|
1249
1207
|
|
|
1250
1208
|
def compute_batch_metrics(
|
|
1251
1209
|
self,
|
|
1252
1210
|
X: ArrayLike,
|
|
1253
|
-
batch:
|
|
1211
|
+
batch: ArrayLike | None = None,
|
|
1254
1212
|
*,
|
|
1255
|
-
pca_components:
|
|
1256
|
-
k_neighbors:
|
|
1257
|
-
kbet_k0:
|
|
1213
|
+
pca_components: int | None = None,
|
|
1214
|
+
k_neighbors: list[int] | None = None,
|
|
1215
|
+
kbet_k0: int | None = None,
|
|
1258
1216
|
lisi_perplexity: int = 30,
|
|
1259
1217
|
n_jobs: int = 1,
|
|
1260
|
-
) ->
|
|
1218
|
+
) -> dict[str, Any]:
|
|
1261
1219
|
"""
|
|
1262
1220
|
Compute batch effect metrics before and after ComBat correction.
|
|
1263
1221
|
|
|
@@ -1282,25 +1240,14 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1282
1240
|
|
|
1283
1241
|
Returns
|
|
1284
1242
|
-------
|
|
1285
|
-
|
|
1286
|
-
Dictionary with
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
'variance_ratio': {...},
|
|
1294
|
-
},
|
|
1295
|
-
'preservation': {
|
|
1296
|
-
'knn': {k: fraction for k in k_neighbors},
|
|
1297
|
-
'distance_correlation': float,
|
|
1298
|
-
},
|
|
1299
|
-
'alignment': {
|
|
1300
|
-
'centroid_distance': {...},
|
|
1301
|
-
'levene_statistic': {...},
|
|
1302
|
-
},
|
|
1303
|
-
}
|
|
1243
|
+
dict
|
|
1244
|
+
Dictionary with three main keys:
|
|
1245
|
+
|
|
1246
|
+
- ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
|
|
1247
|
+
(each with 'before' and 'after' values)
|
|
1248
|
+
- ``preservation``: k-NN preservation fractions, distance correlation
|
|
1249
|
+
- ``alignment``: Centroid distance, Levene statistic (each with
|
|
1250
|
+
'before' and 'after' values)
|
|
1304
1251
|
|
|
1305
1252
|
Raises
|
|
1306
1253
|
------
|
|
@@ -1309,8 +1256,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1309
1256
|
"""
|
|
1310
1257
|
if not hasattr(self._model, "_gamma_star"):
|
|
1311
1258
|
raise ValueError(
|
|
1312
|
-
"This ComBat instance is not fitted yet. "
|
|
1313
|
-
"Call 'fit' before 'compute_batch_metrics'."
|
|
1259
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
|
|
1314
1260
|
)
|
|
1315
1261
|
|
|
1316
1262
|
if not isinstance(X, pd.DataFrame):
|
|
@@ -1322,7 +1268,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1322
1268
|
batch_vec = self._subset(self.batch, idx)
|
|
1323
1269
|
else:
|
|
1324
1270
|
if isinstance(batch, (pd.Series, pd.DataFrame)):
|
|
1325
|
-
batch_vec = batch.loc[idx] if hasattr(batch,
|
|
1271
|
+
batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
|
|
1326
1272
|
elif isinstance(batch, np.ndarray):
|
|
1327
1273
|
batch_vec = pd.Series(batch, index=idx)
|
|
1328
1274
|
else:
|
|
@@ -1336,6 +1282,8 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1336
1282
|
n_samples, n_features = X_before.shape
|
|
1337
1283
|
if kbet_k0 is None:
|
|
1338
1284
|
kbet_k0 = max(10, int(0.10 * n_samples))
|
|
1285
|
+
if k_neighbors is None:
|
|
1286
|
+
k_neighbors = [5, 10, 50]
|
|
1339
1287
|
|
|
1340
1288
|
# Validate and apply PCA if requested
|
|
1341
1289
|
if pca_components is not None:
|
|
@@ -1345,9 +1293,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1345
1293
|
f"pca_components={pca_components} must be less than "
|
|
1346
1294
|
f"min(n_samples, n_features)={max_components}."
|
|
1347
1295
|
)
|
|
1348
|
-
X_before_pca, X_after_pca, _ = _compute_pca_embedding(
|
|
1349
|
-
X_before, X_after, pca_components
|
|
1350
|
-
)
|
|
1296
|
+
X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
|
|
1351
1297
|
else:
|
|
1352
1298
|
X_before_pca = X_before
|
|
1353
1299
|
X_after_pca = X_after
|
|
@@ -1379,57 +1325,53 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1379
1325
|
n_batches = len(np.unique(batch_labels))
|
|
1380
1326
|
|
|
1381
1327
|
metrics = {
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1328
|
+
"batch_effect": {
|
|
1329
|
+
"silhouette": {
|
|
1330
|
+
"before": silhouette_before,
|
|
1331
|
+
"after": silhouette_after,
|
|
1386
1332
|
},
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1333
|
+
"davies_bouldin": {
|
|
1334
|
+
"before": db_before,
|
|
1335
|
+
"after": db_after,
|
|
1390
1336
|
},
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1337
|
+
"kbet": {
|
|
1338
|
+
"before": kbet_before,
|
|
1339
|
+
"after": kbet_after,
|
|
1394
1340
|
},
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1341
|
+
"lisi": {
|
|
1342
|
+
"before": lisi_before,
|
|
1343
|
+
"after": lisi_after,
|
|
1344
|
+
"max_value": n_batches,
|
|
1399
1345
|
},
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1346
|
+
"variance_ratio": {
|
|
1347
|
+
"before": var_ratio_before,
|
|
1348
|
+
"after": var_ratio_after,
|
|
1403
1349
|
},
|
|
1404
1350
|
},
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1351
|
+
"preservation": {
|
|
1352
|
+
"knn": knn_results,
|
|
1353
|
+
"distance_correlation": dist_corr,
|
|
1408
1354
|
},
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1355
|
+
"alignment": {
|
|
1356
|
+
"centroid_distance": {
|
|
1357
|
+
"before": centroid_before,
|
|
1358
|
+
"after": centroid_after,
|
|
1413
1359
|
},
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1360
|
+
"levene_statistic": {
|
|
1361
|
+
"before": levene_before,
|
|
1362
|
+
"after": levene_after,
|
|
1417
1363
|
},
|
|
1418
1364
|
},
|
|
1419
1365
|
}
|
|
1420
1366
|
|
|
1421
1367
|
return metrics
|
|
1422
1368
|
|
|
1423
|
-
def fit_transform(
|
|
1424
|
-
self,
|
|
1425
|
-
X: ArrayLike,
|
|
1426
|
-
y: Optional[ArrayLike] = None
|
|
1427
|
-
) -> pd.DataFrame:
|
|
1369
|
+
def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
|
|
1428
1370
|
"""
|
|
1429
1371
|
Fit and transform the data, optionally computing metrics.
|
|
1430
1372
|
|
|
1431
|
-
If compute_metrics=True was set at construction, batch effect
|
|
1432
|
-
metrics are computed and cached in metrics_ property.
|
|
1373
|
+
If ``compute_metrics=True`` was set at construction, batch effect
|
|
1374
|
+
metrics are computed and cached in the ``metrics_`` property.
|
|
1433
1375
|
|
|
1434
1376
|
Parameters
|
|
1435
1377
|
----------
|
|
@@ -1452,19 +1394,21 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1452
1394
|
return X_transformed
|
|
1453
1395
|
|
|
1454
1396
|
def plot_transformation(
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1397
|
+
self,
|
|
1398
|
+
X: ArrayLike,
|
|
1399
|
+
*,
|
|
1400
|
+
reduction_method: Literal["pca", "tsne", "umap"] = "pca",
|
|
1401
|
+
n_components: Literal[2, 3] = 2,
|
|
1402
|
+
plot_type: Literal["static", "interactive"] = "static",
|
|
1403
|
+
figsize: tuple[int, int] = (12, 5),
|
|
1404
|
+
alpha: float = 0.7,
|
|
1405
|
+
point_size: int = 50,
|
|
1406
|
+
cmap: str = "Set1",
|
|
1407
|
+
title: str | None = None,
|
|
1408
|
+
show_legend: bool = True,
|
|
1409
|
+
return_embeddings: bool = False,
|
|
1410
|
+
**reduction_kwargs,
|
|
1411
|
+
) -> Any | tuple[Any, dict[str, FloatArray]]:
|
|
1468
1412
|
"""
|
|
1469
1413
|
Visualize the ComBat transformation effect using dimensionality reduction.
|
|
1470
1414
|
|
|
@@ -1489,28 +1433,32 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1489
1433
|
|
|
1490
1434
|
return_embeddings : bool, default=False
|
|
1491
1435
|
If `True`, return embeddings along with the plot.
|
|
1492
|
-
|
|
1436
|
+
|
|
1493
1437
|
**reduction_kwargs : dict
|
|
1494
1438
|
Additional parameters for reduction methods.
|
|
1495
|
-
|
|
1439
|
+
|
|
1496
1440
|
Returns
|
|
1497
1441
|
-------
|
|
1498
1442
|
fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
|
|
1499
1443
|
The figure object containing the plots.
|
|
1500
|
-
|
|
1444
|
+
|
|
1501
1445
|
embeddings : dict, optional
|
|
1502
1446
|
If `return_embeddings=True`, dictionary with:
|
|
1503
1447
|
- `'original'`: embedding of original data
|
|
1504
1448
|
- `'transformed'`: embedding of ComBat-transformed data
|
|
1505
1449
|
"""
|
|
1506
1450
|
if not hasattr(self._model, "_gamma_star"):
|
|
1507
|
-
raise ValueError(
|
|
1451
|
+
raise ValueError(
|
|
1452
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
|
|
1453
|
+
)
|
|
1508
1454
|
|
|
1509
1455
|
if n_components not in [2, 3]:
|
|
1510
1456
|
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
1511
|
-
if reduction_method not in [
|
|
1512
|
-
raise ValueError(
|
|
1513
|
-
|
|
1457
|
+
if reduction_method not in ["pca", "tsne", "umap"]:
|
|
1458
|
+
raise ValueError(
|
|
1459
|
+
f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
|
|
1460
|
+
)
|
|
1461
|
+
if plot_type not in ["static", "interactive"]:
|
|
1514
1462
|
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
1515
1463
|
|
|
1516
1464
|
if not isinstance(X, pd.DataFrame):
|
|
@@ -1526,16 +1474,16 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1526
1474
|
X_np = X.values
|
|
1527
1475
|
X_trans_np = X_transformed.values
|
|
1528
1476
|
|
|
1529
|
-
if reduction_method ==
|
|
1477
|
+
if reduction_method == "pca":
|
|
1530
1478
|
reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
|
|
1531
1479
|
reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
|
|
1532
|
-
elif reduction_method ==
|
|
1533
|
-
tsne_params = {
|
|
1480
|
+
elif reduction_method == "tsne":
|
|
1481
|
+
tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
|
|
1534
1482
|
tsne_params.update(reduction_kwargs)
|
|
1535
1483
|
reducer_orig = TSNE(n_components=n_components, **tsne_params)
|
|
1536
1484
|
reducer_trans = TSNE(n_components=n_components, **tsne_params)
|
|
1537
1485
|
else:
|
|
1538
|
-
umap_params = {
|
|
1486
|
+
umap_params = {"random_state": 42}
|
|
1539
1487
|
umap_params.update(reduction_kwargs)
|
|
1540
1488
|
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
1541
1489
|
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
@@ -1543,40 +1491,52 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1543
1491
|
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
1544
1492
|
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
1545
1493
|
|
|
1546
|
-
if plot_type ==
|
|
1494
|
+
if plot_type == "static":
|
|
1547
1495
|
fig = self._create_static_plot(
|
|
1548
|
-
X_embedded_orig,
|
|
1549
|
-
|
|
1550
|
-
|
|
1496
|
+
X_embedded_orig,
|
|
1497
|
+
X_embedded_trans,
|
|
1498
|
+
batch_vec,
|
|
1499
|
+
reduction_method,
|
|
1500
|
+
n_components,
|
|
1501
|
+
figsize,
|
|
1502
|
+
alpha,
|
|
1503
|
+
point_size,
|
|
1504
|
+
cmap,
|
|
1505
|
+
title,
|
|
1506
|
+
show_legend,
|
|
1551
1507
|
)
|
|
1552
1508
|
else:
|
|
1553
1509
|
fig = self._create_interactive_plot(
|
|
1554
|
-
X_embedded_orig,
|
|
1555
|
-
|
|
1510
|
+
X_embedded_orig,
|
|
1511
|
+
X_embedded_trans,
|
|
1512
|
+
batch_vec,
|
|
1513
|
+
reduction_method,
|
|
1514
|
+
n_components,
|
|
1515
|
+
cmap,
|
|
1516
|
+
title,
|
|
1517
|
+
show_legend,
|
|
1556
1518
|
)
|
|
1557
1519
|
|
|
1558
1520
|
if return_embeddings:
|
|
1559
|
-
embeddings = {
|
|
1560
|
-
'original': X_embedded_orig,
|
|
1561
|
-
'transformed': X_embedded_trans
|
|
1562
|
-
}
|
|
1521
|
+
embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
|
|
1563
1522
|
return fig, embeddings
|
|
1564
1523
|
else:
|
|
1565
1524
|
return fig
|
|
1566
1525
|
|
|
1567
1526
|
def _create_static_plot(
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1527
|
+
self,
|
|
1528
|
+
X_orig: FloatArray,
|
|
1529
|
+
X_trans: FloatArray,
|
|
1530
|
+
batch_labels: pd.Series,
|
|
1531
|
+
method: str,
|
|
1532
|
+
n_components: int,
|
|
1533
|
+
figsize: tuple[int, int],
|
|
1534
|
+
alpha: float,
|
|
1535
|
+
point_size: int,
|
|
1536
|
+
cmap: str,
|
|
1537
|
+
title: str | None,
|
|
1538
|
+
show_legend: bool,
|
|
1539
|
+
) -> Any:
|
|
1580
1540
|
"""Create static plots using matplotlib."""
|
|
1581
1541
|
|
|
1582
1542
|
fig = plt.figure(figsize=figsize)
|
|
@@ -1587,119 +1547,130 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1587
1547
|
if n_batches <= 10:
|
|
1588
1548
|
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
1589
1549
|
else:
|
|
1590
|
-
colors = matplotlib.colormaps.get_cmap(
|
|
1550
|
+
colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
|
|
1591
1551
|
|
|
1592
1552
|
if n_components == 2:
|
|
1593
1553
|
ax1 = plt.subplot(1, 2, 1)
|
|
1594
1554
|
ax2 = plt.subplot(1, 2, 2)
|
|
1595
1555
|
else:
|
|
1596
|
-
ax1 = fig.add_subplot(121, projection=
|
|
1597
|
-
ax2 = fig.add_subplot(122, projection=
|
|
1556
|
+
ax1 = fig.add_subplot(121, projection="3d")
|
|
1557
|
+
ax2 = fig.add_subplot(122, projection="3d")
|
|
1598
1558
|
|
|
1599
1559
|
for i, batch in enumerate(unique_batches):
|
|
1600
1560
|
mask = batch_labels == batch
|
|
1601
1561
|
if n_components == 2:
|
|
1602
1562
|
ax1.scatter(
|
|
1603
|
-
X_orig[mask, 0],
|
|
1563
|
+
X_orig[mask, 0],
|
|
1564
|
+
X_orig[mask, 1],
|
|
1604
1565
|
c=[colors[i]],
|
|
1605
1566
|
s=point_size,
|
|
1606
|
-
alpha=alpha,
|
|
1607
|
-
label=f
|
|
1608
|
-
edgecolors=
|
|
1609
|
-
linewidth=0.5
|
|
1567
|
+
alpha=alpha,
|
|
1568
|
+
label=f"Batch {batch}",
|
|
1569
|
+
edgecolors="black",
|
|
1570
|
+
linewidth=0.5,
|
|
1610
1571
|
)
|
|
1611
1572
|
else:
|
|
1612
1573
|
ax1.scatter(
|
|
1613
|
-
X_orig[mask, 0],
|
|
1574
|
+
X_orig[mask, 0],
|
|
1575
|
+
X_orig[mask, 1],
|
|
1576
|
+
X_orig[mask, 2],
|
|
1614
1577
|
c=[colors[i]],
|
|
1615
1578
|
s=point_size,
|
|
1616
|
-
alpha=alpha,
|
|
1617
|
-
label=f
|
|
1618
|
-
edgecolors=
|
|
1619
|
-
linewidth=0.5
|
|
1579
|
+
alpha=alpha,
|
|
1580
|
+
label=f"Batch {batch}",
|
|
1581
|
+
edgecolors="black",
|
|
1582
|
+
linewidth=0.5,
|
|
1620
1583
|
)
|
|
1621
1584
|
|
|
1622
|
-
ax1.set_title(f
|
|
1623
|
-
ax1.set_xlabel(f
|
|
1624
|
-
ax1.set_ylabel(f
|
|
1585
|
+
ax1.set_title(f"Before ComBat correction\n({method.upper()})")
|
|
1586
|
+
ax1.set_xlabel(f"{method.upper()}1")
|
|
1587
|
+
ax1.set_ylabel(f"{method.upper()}2")
|
|
1625
1588
|
if n_components == 3:
|
|
1626
|
-
ax1.set_zlabel(f
|
|
1589
|
+
ax1.set_zlabel(f"{method.upper()}3")
|
|
1627
1590
|
|
|
1628
1591
|
for i, batch in enumerate(unique_batches):
|
|
1629
1592
|
mask = batch_labels == batch
|
|
1630
1593
|
if n_components == 2:
|
|
1631
1594
|
ax2.scatter(
|
|
1632
|
-
X_trans[mask, 0],
|
|
1595
|
+
X_trans[mask, 0],
|
|
1596
|
+
X_trans[mask, 1],
|
|
1633
1597
|
c=[colors[i]],
|
|
1634
1598
|
s=point_size,
|
|
1635
|
-
alpha=alpha,
|
|
1636
|
-
label=f
|
|
1637
|
-
edgecolors=
|
|
1638
|
-
linewidth=0.5
|
|
1599
|
+
alpha=alpha,
|
|
1600
|
+
label=f"Batch {batch}",
|
|
1601
|
+
edgecolors="black",
|
|
1602
|
+
linewidth=0.5,
|
|
1639
1603
|
)
|
|
1640
1604
|
else:
|
|
1641
1605
|
ax2.scatter(
|
|
1642
|
-
X_trans[mask, 0],
|
|
1606
|
+
X_trans[mask, 0],
|
|
1607
|
+
X_trans[mask, 1],
|
|
1608
|
+
X_trans[mask, 2],
|
|
1643
1609
|
c=[colors[i]],
|
|
1644
1610
|
s=point_size,
|
|
1645
|
-
alpha=alpha,
|
|
1646
|
-
label=f
|
|
1647
|
-
edgecolors=
|
|
1648
|
-
linewidth=0.5
|
|
1611
|
+
alpha=alpha,
|
|
1612
|
+
label=f"Batch {batch}",
|
|
1613
|
+
edgecolors="black",
|
|
1614
|
+
linewidth=0.5,
|
|
1649
1615
|
)
|
|
1650
1616
|
|
|
1651
|
-
ax2.set_title(f
|
|
1652
|
-
ax2.set_xlabel(f
|
|
1653
|
-
ax2.set_ylabel(f
|
|
1617
|
+
ax2.set_title(f"After ComBat correction\n({method.upper()})")
|
|
1618
|
+
ax2.set_xlabel(f"{method.upper()}1")
|
|
1619
|
+
ax2.set_ylabel(f"{method.upper()}2")
|
|
1654
1620
|
if n_components == 3:
|
|
1655
|
-
ax2.set_zlabel(f
|
|
1621
|
+
ax2.set_zlabel(f"{method.upper()}3")
|
|
1656
1622
|
|
|
1657
1623
|
if show_legend and n_batches <= 20:
|
|
1658
|
-
ax2.legend(bbox_to_anchor=(1.05, 1), loc=
|
|
1624
|
+
ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
1659
1625
|
|
|
1660
1626
|
if title is None:
|
|
1661
|
-
title = f
|
|
1662
|
-
fig.suptitle(title, fontsize=14, fontweight=
|
|
1627
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1628
|
+
fig.suptitle(title, fontsize=14, fontweight="bold")
|
|
1663
1629
|
|
|
1664
1630
|
plt.tight_layout()
|
|
1665
1631
|
return fig
|
|
1666
1632
|
|
|
1667
1633
|
def _create_interactive_plot(
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1634
|
+
self,
|
|
1635
|
+
X_orig: FloatArray,
|
|
1636
|
+
X_trans: FloatArray,
|
|
1637
|
+
batch_labels: pd.Series,
|
|
1638
|
+
method: str,
|
|
1639
|
+
n_components: int,
|
|
1640
|
+
cmap: str,
|
|
1641
|
+
title: str | None,
|
|
1642
|
+
show_legend: bool,
|
|
1643
|
+
) -> Any:
|
|
1677
1644
|
"""Create interactive plots using plotly."""
|
|
1678
1645
|
if n_components == 2:
|
|
1679
1646
|
fig = make_subplots(
|
|
1680
|
-
rows=1,
|
|
1647
|
+
rows=1,
|
|
1648
|
+
cols=2,
|
|
1681
1649
|
subplot_titles=(
|
|
1682
|
-
f
|
|
1683
|
-
f
|
|
1684
|
-
)
|
|
1650
|
+
f"Before ComBat correction ({method.upper()})",
|
|
1651
|
+
f"After ComBat correction ({method.upper()})",
|
|
1652
|
+
),
|
|
1685
1653
|
)
|
|
1686
1654
|
else:
|
|
1687
1655
|
fig = make_subplots(
|
|
1688
|
-
rows=1,
|
|
1689
|
-
|
|
1656
|
+
rows=1,
|
|
1657
|
+
cols=2,
|
|
1658
|
+
specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
|
|
1690
1659
|
subplot_titles=(
|
|
1691
|
-
f
|
|
1692
|
-
f
|
|
1693
|
-
)
|
|
1660
|
+
f"Before ComBat correction ({method.upper()})",
|
|
1661
|
+
f"After ComBat correction ({method.upper()})",
|
|
1662
|
+
),
|
|
1694
1663
|
)
|
|
1695
1664
|
|
|
1696
1665
|
unique_batches = batch_labels.drop_duplicates()
|
|
1697
1666
|
|
|
1698
1667
|
n_batches = len(unique_batches)
|
|
1699
1668
|
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
1700
|
-
color_list = [
|
|
1669
|
+
color_list = [
|
|
1670
|
+
mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
|
|
1671
|
+
]
|
|
1701
1672
|
|
|
1702
|
-
batch_to_color = dict(zip(unique_batches, color_list))
|
|
1673
|
+
batch_to_color = dict(zip(unique_batches, color_list, strict=True))
|
|
1703
1674
|
|
|
1704
1675
|
for batch in unique_batches:
|
|
1705
1676
|
mask = batch_labels == batch
|
|
@@ -1707,72 +1678,86 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1707
1678
|
if n_components == 2:
|
|
1708
1679
|
fig.add_trace(
|
|
1709
1680
|
go.Scatter(
|
|
1710
|
-
x=X_orig[mask, 0],
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1681
|
+
x=X_orig[mask, 0],
|
|
1682
|
+
y=X_orig[mask, 1],
|
|
1683
|
+
mode="markers",
|
|
1684
|
+
name=f"Batch {batch}",
|
|
1685
|
+
marker={
|
|
1686
|
+
"size": 8,
|
|
1687
|
+
"color": batch_to_color[batch],
|
|
1688
|
+
"line": {"width": 1, "color": "black"},
|
|
1689
|
+
},
|
|
1690
|
+
showlegend=False,
|
|
1691
|
+
),
|
|
1692
|
+
row=1,
|
|
1693
|
+
col=1,
|
|
1720
1694
|
)
|
|
1721
1695
|
|
|
1722
1696
|
fig.add_trace(
|
|
1723
1697
|
go.Scatter(
|
|
1724
|
-
x=X_trans[mask, 0],
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1698
|
+
x=X_trans[mask, 0],
|
|
1699
|
+
y=X_trans[mask, 1],
|
|
1700
|
+
mode="markers",
|
|
1701
|
+
name=f"Batch {batch}",
|
|
1702
|
+
marker={
|
|
1703
|
+
"size": 8,
|
|
1704
|
+
"color": batch_to_color[batch],
|
|
1705
|
+
"line": {"width": 1, "color": "black"},
|
|
1706
|
+
},
|
|
1707
|
+
showlegend=show_legend,
|
|
1708
|
+
),
|
|
1709
|
+
row=1,
|
|
1710
|
+
col=2,
|
|
1734
1711
|
)
|
|
1735
1712
|
else:
|
|
1736
1713
|
fig.add_trace(
|
|
1737
1714
|
go.Scatter3d(
|
|
1738
|
-
x=X_orig[mask, 0],
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1715
|
+
x=X_orig[mask, 0],
|
|
1716
|
+
y=X_orig[mask, 1],
|
|
1717
|
+
z=X_orig[mask, 2],
|
|
1718
|
+
mode="markers",
|
|
1719
|
+
name=f"Batch {batch}",
|
|
1720
|
+
marker={
|
|
1721
|
+
"size": 5,
|
|
1722
|
+
"color": batch_to_color[batch],
|
|
1723
|
+
"line": {"width": 0.5, "color": "black"},
|
|
1724
|
+
},
|
|
1725
|
+
showlegend=False,
|
|
1726
|
+
),
|
|
1727
|
+
row=1,
|
|
1728
|
+
col=1,
|
|
1748
1729
|
)
|
|
1749
1730
|
|
|
1750
1731
|
fig.add_trace(
|
|
1751
1732
|
go.Scatter3d(
|
|
1752
|
-
x=X_trans[mask, 0],
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1733
|
+
x=X_trans[mask, 0],
|
|
1734
|
+
y=X_trans[mask, 1],
|
|
1735
|
+
z=X_trans[mask, 2],
|
|
1736
|
+
mode="markers",
|
|
1737
|
+
name=f"Batch {batch}",
|
|
1738
|
+
marker={
|
|
1739
|
+
"size": 5,
|
|
1740
|
+
"color": batch_to_color[batch],
|
|
1741
|
+
"line": {"width": 0.5, "color": "black"},
|
|
1742
|
+
},
|
|
1743
|
+
showlegend=show_legend,
|
|
1744
|
+
),
|
|
1745
|
+
row=1,
|
|
1746
|
+
col=2,
|
|
1762
1747
|
)
|
|
1763
1748
|
|
|
1764
1749
|
if title is None:
|
|
1765
|
-
title = f
|
|
1750
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1766
1751
|
|
|
1767
1752
|
fig.update_layout(
|
|
1768
1753
|
title=title,
|
|
1769
1754
|
title_font_size=16,
|
|
1770
1755
|
height=600,
|
|
1771
1756
|
showlegend=show_legend,
|
|
1772
|
-
hovermode=
|
|
1757
|
+
hovermode="closest",
|
|
1773
1758
|
)
|
|
1774
1759
|
|
|
1775
|
-
axis_labels = [f
|
|
1760
|
+
axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
|
|
1776
1761
|
|
|
1777
1762
|
if n_components == 2:
|
|
1778
1763
|
fig.update_xaxes(title_text=axis_labels[0])
|
|
@@ -1781,7 +1766,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1781
1766
|
fig.update_scenes(
|
|
1782
1767
|
xaxis_title=axis_labels[0],
|
|
1783
1768
|
yaxis_title=axis_labels[1],
|
|
1784
|
-
zaxis_title=axis_labels[2]
|
|
1769
|
+
zaxis_title=axis_labels[2],
|
|
1785
1770
|
)
|
|
1786
1771
|
|
|
1787
1772
|
return fig
|