combatlearn 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/combat.py +998 -272
- {combatlearn-1.0.0.dist-info → combatlearn-1.1.1.dist-info}/METADATA +28 -26
- combatlearn-1.1.1.dist-info/RECORD +7 -0
- combatlearn-1.0.0.dist-info/RECORD +0 -7
- {combatlearn-1.0.0.dist-info → combatlearn-1.1.1.dist-info}/WHEEL +0 -0
- {combatlearn-1.0.0.dist-info → combatlearn-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.0.0.dist-info → combatlearn-1.1.1.dist-info}/top_level.txt +0 -0
combatlearn/combat.py
CHANGED
|
@@ -8,28 +8,552 @@
|
|
|
8
8
|
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
9
9
|
the batch (and optional covariates) at construction.
|
|
10
10
|
"""
|
|
11
|
+
|
|
11
12
|
from __future__ import annotations
|
|
12
13
|
|
|
14
|
+
import warnings
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
import matplotlib
|
|
18
|
+
import matplotlib.colors as mcolors
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
13
20
|
import numpy as np
|
|
14
21
|
import numpy.linalg as la
|
|
22
|
+
import numpy.typing as npt
|
|
15
23
|
import pandas as pd
|
|
24
|
+
import plotly.graph_objects as go
|
|
25
|
+
import umap
|
|
26
|
+
from plotly.subplots import make_subplots
|
|
27
|
+
from scipy.spatial.distance import pdist
|
|
28
|
+
from scipy.stats import chi2, levene, spearmanr
|
|
16
29
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
17
30
|
from sklearn.decomposition import PCA
|
|
18
31
|
from sklearn.manifold import TSNE
|
|
19
|
-
import
|
|
20
|
-
|
|
21
|
-
import matplotlib.colors as mcolors
|
|
22
|
-
from typing import Literal, Optional, Union, Dict, Tuple, Any
|
|
23
|
-
import numpy.typing as npt
|
|
24
|
-
import warnings
|
|
25
|
-
import umap
|
|
26
|
-
import plotly.graph_objects as go
|
|
27
|
-
from plotly.subplots import make_subplots
|
|
32
|
+
from sklearn.metrics import davies_bouldin_score, silhouette_score
|
|
33
|
+
from sklearn.neighbors import NearestNeighbors
|
|
28
34
|
|
|
29
|
-
ArrayLike =
|
|
35
|
+
ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
|
|
30
36
|
FloatArray = npt.NDArray[np.float64]
|
|
31
37
|
|
|
32
38
|
|
|
39
|
+
def _compute_pca_embedding(
|
|
40
|
+
X_before: np.ndarray,
|
|
41
|
+
X_after: np.ndarray,
|
|
42
|
+
n_components: int,
|
|
43
|
+
) -> tuple[np.ndarray, np.ndarray, PCA]:
|
|
44
|
+
"""
|
|
45
|
+
Compute PCA embeddings for both datasets.
|
|
46
|
+
|
|
47
|
+
Fits PCA on X_before and applies to both datasets.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
X_before : np.ndarray
|
|
52
|
+
Original data before correction.
|
|
53
|
+
X_after : np.ndarray
|
|
54
|
+
Corrected data.
|
|
55
|
+
n_components : int
|
|
56
|
+
Number of PCA components.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
X_before_pca : np.ndarray
|
|
61
|
+
PCA-transformed original data.
|
|
62
|
+
X_after_pca : np.ndarray
|
|
63
|
+
PCA-transformed corrected data.
|
|
64
|
+
pca : PCA
|
|
65
|
+
Fitted PCA model.
|
|
66
|
+
"""
|
|
67
|
+
n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
|
|
68
|
+
pca = PCA(n_components=n_components, random_state=42)
|
|
69
|
+
X_before_pca = pca.fit_transform(X_before)
|
|
70
|
+
X_after_pca = pca.transform(X_after)
|
|
71
|
+
return X_before_pca, X_after_pca, pca
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
75
|
+
"""
|
|
76
|
+
Compute silhouette coefficient using batch as cluster labels.
|
|
77
|
+
|
|
78
|
+
Lower values after correction indicate better batch mixing.
|
|
79
|
+
Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
X : np.ndarray
|
|
84
|
+
Data matrix.
|
|
85
|
+
batch_labels : np.ndarray
|
|
86
|
+
Batch labels for each sample.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
float
|
|
91
|
+
Silhouette coefficient.
|
|
92
|
+
"""
|
|
93
|
+
unique_batches = np.unique(batch_labels)
|
|
94
|
+
if len(unique_batches) < 2:
|
|
95
|
+
return 0.0
|
|
96
|
+
try:
|
|
97
|
+
return silhouette_score(X, batch_labels, metric="euclidean")
|
|
98
|
+
except Exception:
|
|
99
|
+
return 0.0
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
103
|
+
"""
|
|
104
|
+
Compute Davies-Bouldin index using batch labels.
|
|
105
|
+
|
|
106
|
+
Lower values indicate better batch mixing.
|
|
107
|
+
Range: [0, inf), 0 = perfect batch overlap.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
X : np.ndarray
|
|
112
|
+
Data matrix.
|
|
113
|
+
batch_labels : np.ndarray
|
|
114
|
+
Batch labels for each sample.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
float
|
|
119
|
+
Davies-Bouldin index.
|
|
120
|
+
"""
|
|
121
|
+
unique_batches = np.unique(batch_labels)
|
|
122
|
+
if len(unique_batches) < 2:
|
|
123
|
+
return 0.0
|
|
124
|
+
try:
|
|
125
|
+
return davies_bouldin_score(X, batch_labels)
|
|
126
|
+
except Exception:
|
|
127
|
+
return 0.0
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _kbet_score(
|
|
131
|
+
X: np.ndarray,
|
|
132
|
+
batch_labels: np.ndarray,
|
|
133
|
+
k0: int,
|
|
134
|
+
alpha: float = 0.05,
|
|
135
|
+
) -> tuple[float, float]:
|
|
136
|
+
"""
|
|
137
|
+
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
|
|
138
|
+
|
|
139
|
+
Tests if local batch proportions match global batch proportions.
|
|
140
|
+
Higher acceptance rate = better batch mixing.
|
|
141
|
+
|
|
142
|
+
Reference: Buttner et al. (2019) Nature Methods
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
X : np.ndarray
|
|
147
|
+
Data matrix.
|
|
148
|
+
batch_labels : np.ndarray
|
|
149
|
+
Batch labels for each sample.
|
|
150
|
+
k0 : int
|
|
151
|
+
Neighborhood size.
|
|
152
|
+
alpha : float
|
|
153
|
+
Significance level for chi-squared test.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
acceptance_rate : float
|
|
158
|
+
Fraction of samples where H0 (uniform mixing) is accepted.
|
|
159
|
+
mean_stat : float
|
|
160
|
+
Mean chi-squared statistic across samples.
|
|
161
|
+
"""
|
|
162
|
+
n_samples = X.shape[0]
|
|
163
|
+
unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
|
|
164
|
+
n_batches = len(unique_batches)
|
|
165
|
+
|
|
166
|
+
if n_batches < 2:
|
|
167
|
+
return 1.0, 0.0
|
|
168
|
+
|
|
169
|
+
global_freq = batch_counts / n_samples
|
|
170
|
+
k0 = min(k0, n_samples - 1)
|
|
171
|
+
|
|
172
|
+
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
|
|
173
|
+
nn.fit(X)
|
|
174
|
+
_, indices = nn.kneighbors(X)
|
|
175
|
+
|
|
176
|
+
chi2_stats = []
|
|
177
|
+
p_values = []
|
|
178
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
179
|
+
|
|
180
|
+
for i in range(n_samples):
|
|
181
|
+
neighbors = indices[i, 1 : k0 + 1]
|
|
182
|
+
neighbor_batches = batch_labels[neighbors]
|
|
183
|
+
|
|
184
|
+
observed = np.zeros(n_batches)
|
|
185
|
+
for nb in neighbor_batches:
|
|
186
|
+
observed[batch_to_idx[nb]] += 1
|
|
187
|
+
|
|
188
|
+
expected = global_freq * k0
|
|
189
|
+
|
|
190
|
+
mask = expected > 0
|
|
191
|
+
if mask.sum() < 2:
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
|
|
195
|
+
df = max(1, mask.sum() - 1)
|
|
196
|
+
p_val = 1 - chi2.cdf(stat, df)
|
|
197
|
+
|
|
198
|
+
chi2_stats.append(stat)
|
|
199
|
+
p_values.append(p_val)
|
|
200
|
+
|
|
201
|
+
if len(p_values) == 0:
|
|
202
|
+
return 1.0, 0.0
|
|
203
|
+
|
|
204
|
+
acceptance_rate = np.mean(np.array(p_values) > alpha)
|
|
205
|
+
mean_stat = np.mean(chi2_stats)
|
|
206
|
+
|
|
207
|
+
return acceptance_rate, mean_stat
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
|
|
211
|
+
"""
|
|
212
|
+
Binary search for sigma to achieve target perplexity.
|
|
213
|
+
|
|
214
|
+
Used in LISI computation.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
distances : np.ndarray
|
|
219
|
+
Distances to neighbors.
|
|
220
|
+
target_perplexity : float
|
|
221
|
+
Target perplexity value.
|
|
222
|
+
tol : float
|
|
223
|
+
Tolerance for convergence.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
float
|
|
228
|
+
Sigma value.
|
|
229
|
+
"""
|
|
230
|
+
target_H = np.log2(target_perplexity + 1e-10)
|
|
231
|
+
|
|
232
|
+
sigma_min, sigma_max = 1e-10, 1e10
|
|
233
|
+
sigma = 1.0
|
|
234
|
+
|
|
235
|
+
for _ in range(50):
|
|
236
|
+
P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
|
|
237
|
+
P_sum = P.sum()
|
|
238
|
+
if P_sum < 1e-10:
|
|
239
|
+
sigma = (sigma + sigma_max) / 2
|
|
240
|
+
continue
|
|
241
|
+
P = P / P_sum
|
|
242
|
+
P = np.clip(P, 1e-10, 1.0)
|
|
243
|
+
H = -np.sum(P * np.log2(P))
|
|
244
|
+
|
|
245
|
+
if abs(H - target_H) < tol:
|
|
246
|
+
break
|
|
247
|
+
elif target_H > H:
|
|
248
|
+
sigma_min = sigma
|
|
249
|
+
else:
|
|
250
|
+
sigma_max = sigma
|
|
251
|
+
sigma = (sigma_min + sigma_max) / 2
|
|
252
|
+
|
|
253
|
+
return sigma
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _lisi_score(
|
|
257
|
+
X: np.ndarray,
|
|
258
|
+
batch_labels: np.ndarray,
|
|
259
|
+
perplexity: int = 30,
|
|
260
|
+
) -> float:
|
|
261
|
+
"""
|
|
262
|
+
Compute mean Local Inverse Simpson's Index (LISI).
|
|
263
|
+
|
|
264
|
+
Range: [1, n_batches], where n_batches = perfect mixing.
|
|
265
|
+
Higher = better batch mixing.
|
|
266
|
+
|
|
267
|
+
Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
X : np.ndarray
|
|
272
|
+
Data matrix.
|
|
273
|
+
batch_labels : np.ndarray
|
|
274
|
+
Batch labels for each sample.
|
|
275
|
+
perplexity : int
|
|
276
|
+
Perplexity for Gaussian kernel.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
float
|
|
281
|
+
Mean LISI score.
|
|
282
|
+
"""
|
|
283
|
+
n_samples = X.shape[0]
|
|
284
|
+
unique_batches = np.unique(batch_labels)
|
|
285
|
+
n_batches = len(unique_batches)
|
|
286
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
287
|
+
|
|
288
|
+
if n_batches < 2:
|
|
289
|
+
return 1.0
|
|
290
|
+
|
|
291
|
+
k = min(3 * perplexity, n_samples - 1)
|
|
292
|
+
|
|
293
|
+
nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
|
|
294
|
+
nn.fit(X)
|
|
295
|
+
distances, indices = nn.kneighbors(X)
|
|
296
|
+
|
|
297
|
+
distances = distances[:, 1:]
|
|
298
|
+
indices = indices[:, 1:]
|
|
299
|
+
|
|
300
|
+
lisi_values = []
|
|
301
|
+
|
|
302
|
+
for i in range(n_samples):
|
|
303
|
+
sigma = _find_sigma(distances[i], perplexity)
|
|
304
|
+
|
|
305
|
+
P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
|
|
306
|
+
P_sum = P.sum()
|
|
307
|
+
if P_sum < 1e-10:
|
|
308
|
+
lisi_values.append(1.0)
|
|
309
|
+
continue
|
|
310
|
+
P = P / P_sum
|
|
311
|
+
|
|
312
|
+
neighbor_batches = batch_labels[indices[i]]
|
|
313
|
+
batch_probs = np.zeros(n_batches)
|
|
314
|
+
for j, nb in enumerate(neighbor_batches):
|
|
315
|
+
batch_probs[batch_to_idx[nb]] += P[j]
|
|
316
|
+
|
|
317
|
+
simpson = np.sum(batch_probs**2)
|
|
318
|
+
lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
|
|
319
|
+
lisi_values.append(lisi)
|
|
320
|
+
|
|
321
|
+
return np.mean(lisi_values)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
325
|
+
"""
|
|
326
|
+
Compute between-batch to within-batch variance ratio.
|
|
327
|
+
|
|
328
|
+
Similar to F-statistic in one-way ANOVA.
|
|
329
|
+
Lower ratio after correction = better batch effect removal.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
X : np.ndarray
|
|
334
|
+
Data matrix.
|
|
335
|
+
batch_labels : np.ndarray
|
|
336
|
+
Batch labels for each sample.
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
float
|
|
341
|
+
Variance ratio (between/within).
|
|
342
|
+
"""
|
|
343
|
+
unique_batches = np.unique(batch_labels)
|
|
344
|
+
n_batches = len(unique_batches)
|
|
345
|
+
n_samples = X.shape[0]
|
|
346
|
+
|
|
347
|
+
if n_batches < 2:
|
|
348
|
+
return 0.0
|
|
349
|
+
|
|
350
|
+
grand_mean = np.mean(X, axis=0)
|
|
351
|
+
|
|
352
|
+
between_var = 0.0
|
|
353
|
+
within_var = 0.0
|
|
354
|
+
|
|
355
|
+
for batch in unique_batches:
|
|
356
|
+
mask = batch_labels == batch
|
|
357
|
+
n_b = np.sum(mask)
|
|
358
|
+
X_batch = X[mask]
|
|
359
|
+
batch_mean = np.mean(X_batch, axis=0)
|
|
360
|
+
|
|
361
|
+
between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
|
|
362
|
+
within_var += np.sum((X_batch - batch_mean) ** 2)
|
|
363
|
+
|
|
364
|
+
between_var /= n_batches - 1
|
|
365
|
+
within_var /= n_samples - n_batches
|
|
366
|
+
|
|
367
|
+
if within_var < 1e-10:
|
|
368
|
+
return 0.0
|
|
369
|
+
|
|
370
|
+
return between_var / within_var
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _knn_preservation(
|
|
374
|
+
X_before: np.ndarray,
|
|
375
|
+
X_after: np.ndarray,
|
|
376
|
+
k_values: list[int],
|
|
377
|
+
n_jobs: int = 1,
|
|
378
|
+
) -> dict[int, float]:
|
|
379
|
+
"""
|
|
380
|
+
Compute fraction of k-nearest neighbors preserved after correction.
|
|
381
|
+
|
|
382
|
+
Range: [0, 1], where 1 = perfect preservation.
|
|
383
|
+
Higher = better biological structure preservation.
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
X_before : np.ndarray
|
|
388
|
+
Original data.
|
|
389
|
+
X_after : np.ndarray
|
|
390
|
+
Corrected data.
|
|
391
|
+
k_values : list of int
|
|
392
|
+
Values of k for k-NN.
|
|
393
|
+
n_jobs : int
|
|
394
|
+
Number of parallel jobs.
|
|
395
|
+
|
|
396
|
+
Returns
|
|
397
|
+
-------
|
|
398
|
+
dict
|
|
399
|
+
Mapping from k to preservation fraction.
|
|
400
|
+
"""
|
|
401
|
+
results = {}
|
|
402
|
+
max_k = max(k_values)
|
|
403
|
+
max_k = min(max_k, X_before.shape[0] - 1)
|
|
404
|
+
|
|
405
|
+
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
406
|
+
nn_before.fit(X_before)
|
|
407
|
+
_, indices_before = nn_before.kneighbors(X_before)
|
|
408
|
+
|
|
409
|
+
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
410
|
+
nn_after.fit(X_after)
|
|
411
|
+
_, indices_after = nn_after.kneighbors(X_after)
|
|
412
|
+
|
|
413
|
+
for k in k_values:
|
|
414
|
+
if k > max_k:
|
|
415
|
+
results[k] = 0.0
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
overlaps = []
|
|
419
|
+
for i in range(X_before.shape[0]):
|
|
420
|
+
neighbors_before = set(indices_before[i, 1 : k + 1])
|
|
421
|
+
neighbors_after = set(indices_after[i, 1 : k + 1])
|
|
422
|
+
overlap = len(neighbors_before & neighbors_after) / k
|
|
423
|
+
overlaps.append(overlap)
|
|
424
|
+
|
|
425
|
+
results[k] = np.mean(overlaps)
|
|
426
|
+
|
|
427
|
+
return results
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _pairwise_distance_correlation(
|
|
431
|
+
X_before: np.ndarray,
|
|
432
|
+
X_after: np.ndarray,
|
|
433
|
+
subsample: int = 1000,
|
|
434
|
+
random_state: int = 42,
|
|
435
|
+
) -> float:
|
|
436
|
+
"""
|
|
437
|
+
Compute Spearman correlation of pairwise distances.
|
|
438
|
+
|
|
439
|
+
Range: [-1, 1], where 1 = perfect rank preservation.
|
|
440
|
+
Higher = better relative relationship preservation.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
X_before : np.ndarray
|
|
445
|
+
Original data.
|
|
446
|
+
X_after : np.ndarray
|
|
447
|
+
Corrected data.
|
|
448
|
+
subsample : int
|
|
449
|
+
Maximum samples to use (for efficiency).
|
|
450
|
+
random_state : int
|
|
451
|
+
Random seed for subsampling.
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
float
|
|
456
|
+
Spearman correlation coefficient.
|
|
457
|
+
"""
|
|
458
|
+
n_samples = X_before.shape[0]
|
|
459
|
+
|
|
460
|
+
if n_samples > subsample:
|
|
461
|
+
rng = np.random.default_rng(random_state)
|
|
462
|
+
idx = rng.choice(n_samples, subsample, replace=False)
|
|
463
|
+
X_before = X_before[idx]
|
|
464
|
+
X_after = X_after[idx]
|
|
465
|
+
|
|
466
|
+
dist_before = pdist(X_before, metric="euclidean")
|
|
467
|
+
dist_after = pdist(X_after, metric="euclidean")
|
|
468
|
+
|
|
469
|
+
if len(dist_before) == 0:
|
|
470
|
+
return 1.0
|
|
471
|
+
|
|
472
|
+
corr, _ = spearmanr(dist_before, dist_after)
|
|
473
|
+
|
|
474
|
+
if np.isnan(corr):
|
|
475
|
+
return 1.0
|
|
476
|
+
|
|
477
|
+
return corr
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
481
|
+
"""
|
|
482
|
+
Compute mean pairwise Euclidean distance between batch centroids.
|
|
483
|
+
|
|
484
|
+
Lower after correction = better batch alignment.
|
|
485
|
+
|
|
486
|
+
Parameters
|
|
487
|
+
----------
|
|
488
|
+
X : np.ndarray
|
|
489
|
+
Data matrix.
|
|
490
|
+
batch_labels : np.ndarray
|
|
491
|
+
Batch labels for each sample.
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
float
|
|
496
|
+
Mean pairwise distance between centroids.
|
|
497
|
+
"""
|
|
498
|
+
unique_batches = np.unique(batch_labels)
|
|
499
|
+
n_batches = len(unique_batches)
|
|
500
|
+
|
|
501
|
+
if n_batches < 2:
|
|
502
|
+
return 0.0
|
|
503
|
+
|
|
504
|
+
centroids = []
|
|
505
|
+
for batch in unique_batches:
|
|
506
|
+
mask = batch_labels == batch
|
|
507
|
+
centroid = np.mean(X[mask], axis=0)
|
|
508
|
+
centroids.append(centroid)
|
|
509
|
+
|
|
510
|
+
centroids = np.array(centroids)
|
|
511
|
+
distances = pdist(centroids, metric="euclidean")
|
|
512
|
+
|
|
513
|
+
return np.mean(distances)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
517
|
+
"""
|
|
518
|
+
Compute median Levene test statistic across features.
|
|
519
|
+
|
|
520
|
+
Lower statistic = more homogeneous variances across batches.
|
|
521
|
+
|
|
522
|
+
Parameters
|
|
523
|
+
----------
|
|
524
|
+
X : np.ndarray
|
|
525
|
+
Data matrix.
|
|
526
|
+
batch_labels : np.ndarray
|
|
527
|
+
Batch labels for each sample.
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
float
|
|
532
|
+
Median Levene test statistic.
|
|
533
|
+
"""
|
|
534
|
+
unique_batches = np.unique(batch_labels)
|
|
535
|
+
if len(unique_batches) < 2:
|
|
536
|
+
return 0.0
|
|
537
|
+
|
|
538
|
+
levene_stats = []
|
|
539
|
+
for j in range(X.shape[1]):
|
|
540
|
+
groups = [X[batch_labels == b, j] for b in unique_batches]
|
|
541
|
+
groups = [g for g in groups if len(g) > 0]
|
|
542
|
+
if len(groups) < 2:
|
|
543
|
+
continue
|
|
544
|
+
try:
|
|
545
|
+
stat, _ = levene(*groups, center="median")
|
|
546
|
+
if not np.isnan(stat):
|
|
547
|
+
levene_stats.append(stat)
|
|
548
|
+
except Exception:
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
if len(levene_stats) == 0:
|
|
552
|
+
return 0.0
|
|
553
|
+
|
|
554
|
+
return np.median(levene_stats)
|
|
555
|
+
|
|
556
|
+
|
|
33
557
|
class ComBatModel:
|
|
34
558
|
"""ComBat algorithm.
|
|
35
559
|
|
|
@@ -59,24 +583,24 @@ class ComBatModel:
|
|
|
59
583
|
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
60
584
|
parametric: bool = True,
|
|
61
585
|
mean_only: bool = False,
|
|
62
|
-
reference_batch:
|
|
586
|
+
reference_batch: str | None = None,
|
|
63
587
|
eps: float = 1e-8,
|
|
64
|
-
covbat_cov_thresh:
|
|
588
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
65
589
|
) -> None:
|
|
66
590
|
self.method: str = method
|
|
67
591
|
self.parametric: bool = parametric
|
|
68
592
|
self.mean_only: bool = bool(mean_only)
|
|
69
|
-
self.reference_batch:
|
|
593
|
+
self.reference_batch: str | None = reference_batch
|
|
70
594
|
self.eps: float = float(eps)
|
|
71
|
-
self.covbat_cov_thresh:
|
|
595
|
+
self.covbat_cov_thresh: float | int = covbat_cov_thresh
|
|
72
596
|
|
|
73
597
|
self._batch_levels: pd.Index
|
|
74
598
|
self._grand_mean: pd.Series
|
|
75
599
|
self._pooled_var: pd.Series
|
|
76
600
|
self._gamma_star: FloatArray
|
|
77
601
|
self._delta_star: FloatArray
|
|
78
|
-
self._n_per_batch:
|
|
79
|
-
self._reference_batch_idx:
|
|
602
|
+
self._n_per_batch: dict[str, int]
|
|
603
|
+
self._reference_batch_idx: int | None
|
|
80
604
|
self._beta_hat_nonbatch: FloatArray
|
|
81
605
|
self._n_batch: int
|
|
82
606
|
self._p_design: int
|
|
@@ -97,26 +621,15 @@ class ComBatModel:
|
|
|
97
621
|
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
98
622
|
|
|
99
623
|
@staticmethod
|
|
100
|
-
def _as_series(
|
|
101
|
-
arr: ArrayLike,
|
|
102
|
-
index: pd.Index,
|
|
103
|
-
name: str
|
|
104
|
-
) -> pd.Series:
|
|
624
|
+
def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
|
|
105
625
|
"""Convert array-like to categorical Series with validation."""
|
|
106
|
-
if isinstance(arr, pd.Series)
|
|
107
|
-
ser = arr.copy()
|
|
108
|
-
else:
|
|
109
|
-
ser = pd.Series(arr, index=index, name=name)
|
|
626
|
+
ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
|
|
110
627
|
if not ser.index.equals(index):
|
|
111
628
|
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
112
629
|
return ser.astype("category")
|
|
113
630
|
|
|
114
631
|
@staticmethod
|
|
115
|
-
def _to_df(
|
|
116
|
-
arr: Optional[ArrayLike],
|
|
117
|
-
index: pd.Index,
|
|
118
|
-
name: str
|
|
119
|
-
) -> Optional[pd.DataFrame]:
|
|
632
|
+
def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
|
|
120
633
|
"""Convert array-like to DataFrame."""
|
|
121
634
|
if arr is None:
|
|
122
635
|
return None
|
|
@@ -131,11 +644,11 @@ class ComBatModel:
|
|
|
131
644
|
def fit(
|
|
132
645
|
self,
|
|
133
646
|
X: ArrayLike,
|
|
134
|
-
y:
|
|
647
|
+
y: ArrayLike | None = None,
|
|
135
648
|
*,
|
|
136
649
|
batch: ArrayLike,
|
|
137
|
-
discrete_covariates:
|
|
138
|
-
continuous_covariates:
|
|
650
|
+
discrete_covariates: ArrayLike | None = None,
|
|
651
|
+
continuous_covariates: ArrayLike | None = None,
|
|
139
652
|
) -> ComBatModel:
|
|
140
653
|
"""Fit the ComBat model."""
|
|
141
654
|
method = self.method.lower()
|
|
@@ -157,9 +670,7 @@ class ComBatModel:
|
|
|
157
670
|
|
|
158
671
|
if method == "johnson":
|
|
159
672
|
if disc is not None or cont is not None:
|
|
160
|
-
warnings.warn(
|
|
161
|
-
"Covariates are ignored when using method='johnson'."
|
|
162
|
-
)
|
|
673
|
+
warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
|
|
163
674
|
self._fit_johnson(X, batch)
|
|
164
675
|
elif method == "fortin":
|
|
165
676
|
self._fit_fortin(X, batch, disc, cont)
|
|
@@ -167,11 +678,7 @@ class ComBatModel:
|
|
|
167
678
|
self._fit_chen(X, batch, disc, cont)
|
|
168
679
|
return self
|
|
169
680
|
|
|
170
|
-
def _fit_johnson(
|
|
171
|
-
self,
|
|
172
|
-
X: pd.DataFrame,
|
|
173
|
-
batch: pd.Series
|
|
174
|
-
) -> None:
|
|
681
|
+
def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
|
|
175
682
|
"""Johnson et al. (2007) ComBat."""
|
|
176
683
|
self._batch_levels = batch.cat.categories
|
|
177
684
|
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
@@ -179,10 +686,10 @@ class ComBatModel:
|
|
|
179
686
|
|
|
180
687
|
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
181
688
|
|
|
182
|
-
n_per_batch:
|
|
689
|
+
n_per_batch: dict[str, int] = {}
|
|
183
690
|
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
184
691
|
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
185
|
-
|
|
692
|
+
|
|
186
693
|
for lvl in self._batch_levels:
|
|
187
694
|
idx = batch == lvl
|
|
188
695
|
n_b = int(idx.sum())
|
|
@@ -227,8 +734,8 @@ class ComBatModel:
|
|
|
227
734
|
self,
|
|
228
735
|
X: pd.DataFrame,
|
|
229
736
|
batch: pd.Series,
|
|
230
|
-
disc:
|
|
231
|
-
cont:
|
|
737
|
+
disc: pd.DataFrame | None,
|
|
738
|
+
cont: pd.DataFrame | None,
|
|
232
739
|
) -> None:
|
|
233
740
|
"""Fortin et al. (2018) neuroComBat."""
|
|
234
741
|
self._batch_levels = batch.cat.categories
|
|
@@ -246,11 +753,7 @@ class ComBatModel:
|
|
|
246
753
|
|
|
247
754
|
parts: list[pd.DataFrame] = [batch_dummies]
|
|
248
755
|
if disc is not None:
|
|
249
|
-
parts.append(
|
|
250
|
-
pd.get_dummies(
|
|
251
|
-
disc.astype("category"), drop_first=True
|
|
252
|
-
).astype(float)
|
|
253
|
-
)
|
|
756
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
254
757
|
|
|
255
758
|
if cont is not None:
|
|
256
759
|
parts.append(cont.astype(float))
|
|
@@ -265,7 +768,7 @@ class ComBatModel:
|
|
|
265
768
|
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
266
769
|
|
|
267
770
|
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
268
|
-
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
|
|
771
|
+
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
|
|
269
772
|
|
|
270
773
|
if self.reference_batch is not None:
|
|
271
774
|
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
@@ -283,30 +786,25 @@ class ComBatModel:
|
|
|
283
786
|
else:
|
|
284
787
|
resid = X_np - design @ beta_hat
|
|
285
788
|
denom = n_samples
|
|
286
|
-
var_pooled = (resid
|
|
789
|
+
var_pooled = (resid**2).sum(axis=0) / denom + self.eps
|
|
287
790
|
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
288
791
|
|
|
289
792
|
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
290
793
|
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
291
794
|
|
|
292
|
-
gamma_hat = np.vstack(
|
|
293
|
-
[Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
|
|
294
|
-
)
|
|
795
|
+
gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
|
|
295
796
|
delta_hat = np.vstack(
|
|
296
|
-
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
|
|
297
|
-
for lvl in self._batch_levels]
|
|
797
|
+
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
|
|
298
798
|
)
|
|
299
799
|
|
|
300
800
|
if self.mean_only:
|
|
301
801
|
gamma_star = self._shrink_gamma(
|
|
302
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
303
|
-
parametric = self.parametric
|
|
802
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
304
803
|
)
|
|
305
804
|
delta_star = np.ones_like(delta_hat)
|
|
306
805
|
else:
|
|
307
806
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
308
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
309
|
-
parametric = self.parametric
|
|
807
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
310
808
|
)
|
|
311
809
|
|
|
312
810
|
if ref_idx is not None:
|
|
@@ -317,15 +815,15 @@ class ComBatModel:
|
|
|
317
815
|
|
|
318
816
|
self._gamma_star = gamma_star
|
|
319
817
|
self._delta_star = delta_star
|
|
320
|
-
self._n_batch
|
|
818
|
+
self._n_batch = n_batch
|
|
321
819
|
self._p_design = p_design
|
|
322
|
-
|
|
820
|
+
|
|
323
821
|
def _fit_chen(
|
|
324
822
|
self,
|
|
325
823
|
X: pd.DataFrame,
|
|
326
824
|
batch: pd.Series,
|
|
327
|
-
disc:
|
|
328
|
-
cont:
|
|
825
|
+
disc: pd.DataFrame | None,
|
|
826
|
+
cont: pd.DataFrame | None,
|
|
329
827
|
) -> None:
|
|
330
828
|
"""Chen et al. (2022) CovBat."""
|
|
331
829
|
self._fit_fortin(X, batch, disc, cont)
|
|
@@ -344,7 +842,7 @@ class ComBatModel:
|
|
|
344
842
|
self._covbat_n_pc = n_pc
|
|
345
843
|
|
|
346
844
|
scores = pca.transform(X_centered)[:, :n_pc]
|
|
347
|
-
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
|
|
845
|
+
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
|
|
348
846
|
self._batch_levels_pc = self._batch_levels
|
|
349
847
|
n_per_batch = self._n_per_batch
|
|
350
848
|
|
|
@@ -383,12 +881,12 @@ class ComBatModel:
|
|
|
383
881
|
self,
|
|
384
882
|
gamma_hat: FloatArray,
|
|
385
883
|
delta_hat: FloatArray,
|
|
386
|
-
n_per_batch:
|
|
884
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
387
885
|
*,
|
|
388
886
|
parametric: bool,
|
|
389
887
|
max_iter: int = 100,
|
|
390
888
|
tol: float = 1e-4,
|
|
391
|
-
) ->
|
|
889
|
+
) -> tuple[FloatArray, FloatArray]:
|
|
392
890
|
"""Empirical Bayes shrinkage estimation."""
|
|
393
891
|
if parametric:
|
|
394
892
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
@@ -396,10 +894,14 @@ class ComBatModel:
|
|
|
396
894
|
a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
|
|
397
895
|
b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
|
|
398
896
|
|
|
399
|
-
B,
|
|
897
|
+
B, _p = gamma_hat.shape
|
|
400
898
|
gamma_star = np.empty_like(gamma_hat)
|
|
401
899
|
delta_star = np.empty_like(delta_hat)
|
|
402
|
-
n_vec =
|
|
900
|
+
n_vec = (
|
|
901
|
+
np.array(list(n_per_batch.values()))
|
|
902
|
+
if isinstance(n_per_batch, dict)
|
|
903
|
+
else n_per_batch
|
|
904
|
+
)
|
|
403
905
|
|
|
404
906
|
for i in range(B):
|
|
405
907
|
n_i = n_vec[i]
|
|
@@ -413,8 +915,12 @@ class ComBatModel:
|
|
|
413
915
|
return gamma_star, delta_star
|
|
414
916
|
|
|
415
917
|
else:
|
|
416
|
-
B,
|
|
417
|
-
n_vec =
|
|
918
|
+
B, _p = gamma_hat.shape
|
|
919
|
+
n_vec = (
|
|
920
|
+
np.array(list(n_per_batch.values()))
|
|
921
|
+
if isinstance(n_per_batch, dict)
|
|
922
|
+
else n_per_batch
|
|
923
|
+
)
|
|
418
924
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
419
925
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
420
926
|
|
|
@@ -423,27 +929,22 @@ class ComBatModel:
|
|
|
423
929
|
g_bar: FloatArray,
|
|
424
930
|
n: float,
|
|
425
931
|
d_star: FloatArray,
|
|
426
|
-
t2_: FloatArray
|
|
932
|
+
t2_: FloatArray,
|
|
427
933
|
) -> FloatArray:
|
|
428
934
|
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
429
935
|
|
|
430
|
-
def postvar(
|
|
431
|
-
sum2: FloatArray,
|
|
432
|
-
n: float,
|
|
433
|
-
a: FloatArray,
|
|
434
|
-
b: FloatArray
|
|
435
|
-
) -> FloatArray:
|
|
936
|
+
def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
|
|
436
937
|
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
437
938
|
|
|
438
939
|
def aprior(delta: FloatArray) -> FloatArray:
|
|
439
940
|
m, s2 = delta.mean(), delta.var()
|
|
440
941
|
s2 = max(s2, self.eps)
|
|
441
|
-
return (2 * s2 + m
|
|
942
|
+
return (2 * s2 + m**2) / s2
|
|
442
943
|
|
|
443
944
|
def bprior(delta: FloatArray) -> FloatArray:
|
|
444
945
|
m, s2 = delta.mean(), delta.var()
|
|
445
946
|
s2 = max(s2, self.eps)
|
|
446
|
-
return (m * s2 + m
|
|
947
|
+
return (m * s2 + m**3) / s2
|
|
447
948
|
|
|
448
949
|
gamma_star = np.empty_like(gamma_hat)
|
|
449
950
|
delta_star = np.empty_like(delta_hat)
|
|
@@ -462,7 +963,8 @@ class ComBatModel:
|
|
|
462
963
|
sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
|
|
463
964
|
d_new = postvar(sum2, n_i, a_i, b_i)
|
|
464
965
|
if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
|
|
465
|
-
self.mean_only
|
|
966
|
+
self.mean_only
|
|
967
|
+
or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
|
|
466
968
|
):
|
|
467
969
|
break
|
|
468
970
|
gamma_star[i] = g_new
|
|
@@ -473,12 +975,14 @@ class ComBatModel:
|
|
|
473
975
|
self,
|
|
474
976
|
gamma_hat: FloatArray,
|
|
475
977
|
delta_hat: FloatArray,
|
|
476
|
-
n_per_batch:
|
|
978
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
477
979
|
*,
|
|
478
980
|
parametric: bool,
|
|
479
981
|
) -> FloatArray:
|
|
480
|
-
"""Convenience wrapper that returns only
|
|
481
|
-
gamma, _ = self._shrink_gamma_delta(
|
|
982
|
+
"""Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
|
|
983
|
+
gamma, _ = self._shrink_gamma_delta(
|
|
984
|
+
gamma_hat, delta_hat, n_per_batch, parametric=parametric
|
|
985
|
+
)
|
|
482
986
|
return gamma
|
|
483
987
|
|
|
484
988
|
def transform(
|
|
@@ -486,12 +990,14 @@ class ComBatModel:
|
|
|
486
990
|
X: ArrayLike,
|
|
487
991
|
*,
|
|
488
992
|
batch: ArrayLike,
|
|
489
|
-
discrete_covariates:
|
|
490
|
-
continuous_covariates:
|
|
993
|
+
discrete_covariates: ArrayLike | None = None,
|
|
994
|
+
continuous_covariates: ArrayLike | None = None,
|
|
491
995
|
) -> pd.DataFrame:
|
|
492
996
|
"""Transform the data using fitted ComBat parameters."""
|
|
493
997
|
if not hasattr(self, "_gamma_star"):
|
|
494
|
-
raise ValueError(
|
|
998
|
+
raise ValueError(
|
|
999
|
+
"This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
|
|
1000
|
+
)
|
|
495
1001
|
if not isinstance(X, pd.DataFrame):
|
|
496
1002
|
X = pd.DataFrame(X)
|
|
497
1003
|
idx = X.index
|
|
@@ -512,11 +1018,7 @@ class ComBatModel:
|
|
|
512
1018
|
else:
|
|
513
1019
|
raise ValueError(f"Unknown method: {method}.")
|
|
514
1020
|
|
|
515
|
-
def _transform_johnson(
|
|
516
|
-
self,
|
|
517
|
-
X: pd.DataFrame,
|
|
518
|
-
batch: pd.Series
|
|
519
|
-
) -> pd.DataFrame:
|
|
1021
|
+
def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
|
|
520
1022
|
"""Johnson transform implementation."""
|
|
521
1023
|
pooled = self._pooled_var
|
|
522
1024
|
grand = self._grand_mean
|
|
@@ -534,10 +1036,7 @@ class ComBatModel:
|
|
|
534
1036
|
|
|
535
1037
|
g = self._gamma_star[i]
|
|
536
1038
|
d = self._delta_star[i]
|
|
537
|
-
if self.mean_only
|
|
538
|
-
Xb = Xs.loc[idx] - g
|
|
539
|
-
else:
|
|
540
|
-
Xb = (Xs.loc[idx] - g) / np.sqrt(d)
|
|
1039
|
+
Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
|
|
541
1040
|
X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
|
|
542
1041
|
return X_adj
|
|
543
1042
|
|
|
@@ -545,8 +1044,8 @@ class ComBatModel:
|
|
|
545
1044
|
self,
|
|
546
1045
|
X: pd.DataFrame,
|
|
547
1046
|
batch: pd.Series,
|
|
548
|
-
disc:
|
|
549
|
-
cont:
|
|
1047
|
+
disc: pd.DataFrame | None,
|
|
1048
|
+
cont: pd.DataFrame | None,
|
|
550
1049
|
) -> pd.DataFrame:
|
|
551
1050
|
"""Fortin transform implementation."""
|
|
552
1051
|
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
@@ -555,21 +1054,14 @@ class ComBatModel:
|
|
|
555
1054
|
|
|
556
1055
|
parts = [batch_dummies]
|
|
557
1056
|
if disc is not None:
|
|
558
|
-
parts.append(
|
|
559
|
-
pd.get_dummies(
|
|
560
|
-
disc.astype("category"), drop_first=True
|
|
561
|
-
).astype(float)
|
|
562
|
-
)
|
|
1057
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
563
1058
|
if cont is not None:
|
|
564
1059
|
parts.append(cont.astype(float))
|
|
565
1060
|
|
|
566
1061
|
design = pd.concat(parts, axis=1).values
|
|
567
1062
|
|
|
568
1063
|
X_np = X.values
|
|
569
|
-
stand_mu =
|
|
570
|
-
self._grand_mean.values +
|
|
571
|
-
design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
572
|
-
)
|
|
1064
|
+
stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
|
|
573
1065
|
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
574
1066
|
|
|
575
1067
|
for i, lvl in enumerate(self._batch_levels):
|
|
@@ -587,18 +1079,15 @@ class ComBatModel:
|
|
|
587
1079
|
else:
|
|
588
1080
|
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
589
1081
|
|
|
590
|
-
X_adj = (
|
|
591
|
-
Xs * np.sqrt(self._pooled_var.values) +
|
|
592
|
-
stand_mu
|
|
593
|
-
)
|
|
1082
|
+
X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
|
|
594
1083
|
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
595
|
-
|
|
1084
|
+
|
|
596
1085
|
def _transform_chen(
|
|
597
1086
|
self,
|
|
598
1087
|
X: pd.DataFrame,
|
|
599
1088
|
batch: pd.Series,
|
|
600
|
-
disc:
|
|
601
|
-
cont:
|
|
1089
|
+
disc: pd.DataFrame | None,
|
|
1090
|
+
cont: pd.DataFrame | None,
|
|
602
1091
|
) -> pd.DataFrame:
|
|
603
1092
|
"""Chen transform implementation."""
|
|
604
1093
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
@@ -635,14 +1124,15 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
635
1124
|
self,
|
|
636
1125
|
batch: ArrayLike,
|
|
637
1126
|
*,
|
|
638
|
-
discrete_covariates:
|
|
639
|
-
continuous_covariates:
|
|
1127
|
+
discrete_covariates: ArrayLike | None = None,
|
|
1128
|
+
continuous_covariates: ArrayLike | None = None,
|
|
640
1129
|
method: str = "johnson",
|
|
641
1130
|
parametric: bool = True,
|
|
642
1131
|
mean_only: bool = False,
|
|
643
|
-
reference_batch:
|
|
1132
|
+
reference_batch: str | None = None,
|
|
644
1133
|
eps: float = 1e-8,
|
|
645
|
-
covbat_cov_thresh:
|
|
1134
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
1135
|
+
compute_metrics: bool = False,
|
|
646
1136
|
) -> None:
|
|
647
1137
|
self.batch = batch
|
|
648
1138
|
self.discrete_covariates = discrete_covariates
|
|
@@ -653,6 +1143,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
653
1143
|
self.reference_batch = reference_batch
|
|
654
1144
|
self.eps = eps
|
|
655
1145
|
self.covbat_cov_thresh = covbat_cov_thresh
|
|
1146
|
+
self.compute_metrics = compute_metrics
|
|
656
1147
|
self._model = ComBatModel(
|
|
657
1148
|
method=method,
|
|
658
1149
|
parametric=parametric,
|
|
@@ -662,11 +1153,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
662
1153
|
covbat_cov_thresh=covbat_cov_thresh,
|
|
663
1154
|
)
|
|
664
1155
|
|
|
665
|
-
def fit(
|
|
666
|
-
self,
|
|
667
|
-
X: ArrayLike,
|
|
668
|
-
y: Optional[ArrayLike] = None
|
|
669
|
-
) -> "ComBat":
|
|
1156
|
+
def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
|
|
670
1157
|
"""Fit the ComBat model."""
|
|
671
1158
|
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
672
1159
|
batch_vec = self._subset(self.batch, idx)
|
|
@@ -695,10 +1182,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
695
1182
|
)
|
|
696
1183
|
|
|
697
1184
|
@staticmethod
|
|
698
|
-
def _subset(
|
|
699
|
-
obj: Optional[ArrayLike],
|
|
700
|
-
idx: pd.Index
|
|
701
|
-
) -> Optional[Union[pd.DataFrame, pd.Series]]:
|
|
1185
|
+
def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
|
|
702
1186
|
"""Subset array-like object by index."""
|
|
703
1187
|
if obj is None:
|
|
704
1188
|
return None
|
|
@@ -710,20 +1194,221 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
710
1194
|
else:
|
|
711
1195
|
return pd.DataFrame(obj, index=idx)
|
|
712
1196
|
|
|
1197
|
+
@property
|
|
1198
|
+
def metrics_(self) -> dict[str, Any] | None:
|
|
1199
|
+
"""Return cached metrics from last fit_transform with compute_metrics=True.
|
|
1200
|
+
|
|
1201
|
+
Returns
|
|
1202
|
+
-------
|
|
1203
|
+
dict or None
|
|
1204
|
+
Cached metrics dictionary, or None if no metrics have been computed.
|
|
1205
|
+
"""
|
|
1206
|
+
return getattr(self, "_metrics_cache", None)
|
|
1207
|
+
|
|
1208
|
+
def compute_batch_metrics(
|
|
1209
|
+
self,
|
|
1210
|
+
X: ArrayLike,
|
|
1211
|
+
batch: ArrayLike | None = None,
|
|
1212
|
+
*,
|
|
1213
|
+
pca_components: int | None = None,
|
|
1214
|
+
k_neighbors: list[int] | None = None,
|
|
1215
|
+
kbet_k0: int | None = None,
|
|
1216
|
+
lisi_perplexity: int = 30,
|
|
1217
|
+
n_jobs: int = 1,
|
|
1218
|
+
) -> dict[str, Any]:
|
|
1219
|
+
"""
|
|
1220
|
+
Compute batch effect metrics before and after ComBat correction.
|
|
1221
|
+
|
|
1222
|
+
Parameters
|
|
1223
|
+
----------
|
|
1224
|
+
X : array-like of shape (n_samples, n_features)
|
|
1225
|
+
Input data to evaluate.
|
|
1226
|
+
batch : array-like of shape (n_samples,), optional
|
|
1227
|
+
Batch labels. If None, uses the batch stored at construction.
|
|
1228
|
+
pca_components : int, optional
|
|
1229
|
+
Number of PCA components for dimensionality reduction before
|
|
1230
|
+
computing metrics. If None (default), metrics are computed in
|
|
1231
|
+
the original feature space. Must be less than min(n_samples, n_features).
|
|
1232
|
+
k_neighbors : list of int, default=[5, 10, 50]
|
|
1233
|
+
Values of k for k-NN preservation metric.
|
|
1234
|
+
kbet_k0 : int, optional
|
|
1235
|
+
Neighborhood size for kBET. Default is 10% of samples.
|
|
1236
|
+
lisi_perplexity : int, default=30
|
|
1237
|
+
Perplexity for LISI computation.
|
|
1238
|
+
n_jobs : int, default=1
|
|
1239
|
+
Number of parallel jobs for neighbor computations.
|
|
1240
|
+
|
|
1241
|
+
Returns
|
|
1242
|
+
-------
|
|
1243
|
+
dict
|
|
1244
|
+
Dictionary with three main keys:
|
|
1245
|
+
|
|
1246
|
+
- ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
|
|
1247
|
+
(each with 'before' and 'after' values)
|
|
1248
|
+
- ``preservation``: k-NN preservation fractions, distance correlation
|
|
1249
|
+
- ``alignment``: Centroid distance, Levene statistic (each with
|
|
1250
|
+
'before' and 'after' values)
|
|
1251
|
+
|
|
1252
|
+
Raises
|
|
1253
|
+
------
|
|
1254
|
+
ValueError
|
|
1255
|
+
If the model is not fitted or if pca_components is invalid.
|
|
1256
|
+
"""
|
|
1257
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
1258
|
+
raise ValueError(
|
|
1259
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
|
|
1260
|
+
)
|
|
1261
|
+
|
|
1262
|
+
if not isinstance(X, pd.DataFrame):
|
|
1263
|
+
X = pd.DataFrame(X)
|
|
1264
|
+
|
|
1265
|
+
idx = X.index
|
|
1266
|
+
|
|
1267
|
+
if batch is None:
|
|
1268
|
+
batch_vec = self._subset(self.batch, idx)
|
|
1269
|
+
else:
|
|
1270
|
+
if isinstance(batch, (pd.Series, pd.DataFrame)):
|
|
1271
|
+
batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
|
|
1272
|
+
elif isinstance(batch, np.ndarray):
|
|
1273
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
1274
|
+
else:
|
|
1275
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
1276
|
+
|
|
1277
|
+
batch_labels = np.array(batch_vec)
|
|
1278
|
+
|
|
1279
|
+
X_before = X.values
|
|
1280
|
+
X_after = self.transform(X).values
|
|
1281
|
+
|
|
1282
|
+
n_samples, n_features = X_before.shape
|
|
1283
|
+
if kbet_k0 is None:
|
|
1284
|
+
kbet_k0 = max(10, int(0.10 * n_samples))
|
|
1285
|
+
if k_neighbors is None:
|
|
1286
|
+
k_neighbors = [5, 10, 50]
|
|
1287
|
+
|
|
1288
|
+
# Validate and apply PCA if requested
|
|
1289
|
+
if pca_components is not None:
|
|
1290
|
+
max_components = min(n_samples, n_features)
|
|
1291
|
+
if pca_components >= max_components:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"pca_components={pca_components} must be less than "
|
|
1294
|
+
f"min(n_samples, n_features)={max_components}."
|
|
1295
|
+
)
|
|
1296
|
+
X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
|
|
1297
|
+
else:
|
|
1298
|
+
X_before_pca = X_before
|
|
1299
|
+
X_after_pca = X_after
|
|
1300
|
+
|
|
1301
|
+
silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
|
|
1302
|
+
silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
|
|
1303
|
+
|
|
1304
|
+
db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
|
|
1305
|
+
db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
|
|
1306
|
+
|
|
1307
|
+
kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
|
|
1308
|
+
kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
|
|
1309
|
+
|
|
1310
|
+
lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
|
|
1311
|
+
lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
|
|
1312
|
+
|
|
1313
|
+
var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
|
|
1314
|
+
var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
|
|
1315
|
+
|
|
1316
|
+
knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
|
|
1317
|
+
dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
|
|
1318
|
+
|
|
1319
|
+
centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
|
|
1320
|
+
centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
|
|
1321
|
+
|
|
1322
|
+
levene_before = _levene_median_statistic(X_before, batch_labels)
|
|
1323
|
+
levene_after = _levene_median_statistic(X_after, batch_labels)
|
|
1324
|
+
|
|
1325
|
+
n_batches = len(np.unique(batch_labels))
|
|
1326
|
+
|
|
1327
|
+
metrics = {
|
|
1328
|
+
"batch_effect": {
|
|
1329
|
+
"silhouette": {
|
|
1330
|
+
"before": silhouette_before,
|
|
1331
|
+
"after": silhouette_after,
|
|
1332
|
+
},
|
|
1333
|
+
"davies_bouldin": {
|
|
1334
|
+
"before": db_before,
|
|
1335
|
+
"after": db_after,
|
|
1336
|
+
},
|
|
1337
|
+
"kbet": {
|
|
1338
|
+
"before": kbet_before,
|
|
1339
|
+
"after": kbet_after,
|
|
1340
|
+
},
|
|
1341
|
+
"lisi": {
|
|
1342
|
+
"before": lisi_before,
|
|
1343
|
+
"after": lisi_after,
|
|
1344
|
+
"max_value": n_batches,
|
|
1345
|
+
},
|
|
1346
|
+
"variance_ratio": {
|
|
1347
|
+
"before": var_ratio_before,
|
|
1348
|
+
"after": var_ratio_after,
|
|
1349
|
+
},
|
|
1350
|
+
},
|
|
1351
|
+
"preservation": {
|
|
1352
|
+
"knn": knn_results,
|
|
1353
|
+
"distance_correlation": dist_corr,
|
|
1354
|
+
},
|
|
1355
|
+
"alignment": {
|
|
1356
|
+
"centroid_distance": {
|
|
1357
|
+
"before": centroid_before,
|
|
1358
|
+
"after": centroid_after,
|
|
1359
|
+
},
|
|
1360
|
+
"levene_statistic": {
|
|
1361
|
+
"before": levene_before,
|
|
1362
|
+
"after": levene_after,
|
|
1363
|
+
},
|
|
1364
|
+
},
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
return metrics
|
|
1368
|
+
|
|
1369
|
+
def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
|
|
1370
|
+
"""
|
|
1371
|
+
Fit and transform the data, optionally computing metrics.
|
|
1372
|
+
|
|
1373
|
+
If ``compute_metrics=True`` was set at construction, batch effect
|
|
1374
|
+
metrics are computed and cached in the ``metrics_`` property.
|
|
1375
|
+
|
|
1376
|
+
Parameters
|
|
1377
|
+
----------
|
|
1378
|
+
X : array-like of shape (n_samples, n_features)
|
|
1379
|
+
Input data to fit and transform.
|
|
1380
|
+
y : None
|
|
1381
|
+
Ignored. Present for API compatibility.
|
|
1382
|
+
|
|
1383
|
+
Returns
|
|
1384
|
+
-------
|
|
1385
|
+
X_transformed : pd.DataFrame
|
|
1386
|
+
Batch-corrected data.
|
|
1387
|
+
"""
|
|
1388
|
+
self.fit(X, y)
|
|
1389
|
+
X_transformed = self.transform(X)
|
|
1390
|
+
|
|
1391
|
+
if self.compute_metrics:
|
|
1392
|
+
self._metrics_cache = self.compute_batch_metrics(X)
|
|
1393
|
+
|
|
1394
|
+
return X_transformed
|
|
1395
|
+
|
|
713
1396
|
def plot_transformation(
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
1397
|
+
self,
|
|
1398
|
+
X: ArrayLike,
|
|
1399
|
+
*,
|
|
1400
|
+
reduction_method: Literal["pca", "tsne", "umap"] = "pca",
|
|
1401
|
+
n_components: Literal[2, 3] = 2,
|
|
1402
|
+
plot_type: Literal["static", "interactive"] = "static",
|
|
1403
|
+
figsize: tuple[int, int] = (12, 5),
|
|
1404
|
+
alpha: float = 0.7,
|
|
1405
|
+
point_size: int = 50,
|
|
1406
|
+
cmap: str = "Set1",
|
|
1407
|
+
title: str | None = None,
|
|
1408
|
+
show_legend: bool = True,
|
|
1409
|
+
return_embeddings: bool = False,
|
|
1410
|
+
**reduction_kwargs,
|
|
1411
|
+
) -> Any | tuple[Any, dict[str, FloatArray]]:
|
|
727
1412
|
"""
|
|
728
1413
|
Visualize the ComBat transformation effect using dimensionality reduction.
|
|
729
1414
|
|
|
@@ -748,28 +1433,32 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
748
1433
|
|
|
749
1434
|
return_embeddings : bool, default=False
|
|
750
1435
|
If `True`, return embeddings along with the plot.
|
|
751
|
-
|
|
1436
|
+
|
|
752
1437
|
**reduction_kwargs : dict
|
|
753
1438
|
Additional parameters for reduction methods.
|
|
754
|
-
|
|
1439
|
+
|
|
755
1440
|
Returns
|
|
756
1441
|
-------
|
|
757
1442
|
fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
|
|
758
1443
|
The figure object containing the plots.
|
|
759
|
-
|
|
1444
|
+
|
|
760
1445
|
embeddings : dict, optional
|
|
761
1446
|
If `return_embeddings=True`, dictionary with:
|
|
762
1447
|
- `'original'`: embedding of original data
|
|
763
1448
|
- `'transformed'`: embedding of ComBat-transformed data
|
|
764
1449
|
"""
|
|
765
1450
|
if not hasattr(self._model, "_gamma_star"):
|
|
766
|
-
raise ValueError(
|
|
1451
|
+
raise ValueError(
|
|
1452
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
|
|
1453
|
+
)
|
|
767
1454
|
|
|
768
1455
|
if n_components not in [2, 3]:
|
|
769
1456
|
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
770
|
-
if reduction_method not in [
|
|
771
|
-
raise ValueError(
|
|
772
|
-
|
|
1457
|
+
if reduction_method not in ["pca", "tsne", "umap"]:
|
|
1458
|
+
raise ValueError(
|
|
1459
|
+
f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
|
|
1460
|
+
)
|
|
1461
|
+
if plot_type not in ["static", "interactive"]:
|
|
773
1462
|
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
774
1463
|
|
|
775
1464
|
if not isinstance(X, pd.DataFrame):
|
|
@@ -785,16 +1474,16 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
785
1474
|
X_np = X.values
|
|
786
1475
|
X_trans_np = X_transformed.values
|
|
787
1476
|
|
|
788
|
-
if reduction_method ==
|
|
1477
|
+
if reduction_method == "pca":
|
|
789
1478
|
reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
|
|
790
1479
|
reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
|
|
791
|
-
elif reduction_method ==
|
|
792
|
-
tsne_params = {
|
|
1480
|
+
elif reduction_method == "tsne":
|
|
1481
|
+
tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
|
|
793
1482
|
tsne_params.update(reduction_kwargs)
|
|
794
1483
|
reducer_orig = TSNE(n_components=n_components, **tsne_params)
|
|
795
1484
|
reducer_trans = TSNE(n_components=n_components, **tsne_params)
|
|
796
1485
|
else:
|
|
797
|
-
umap_params = {
|
|
1486
|
+
umap_params = {"random_state": 42}
|
|
798
1487
|
umap_params.update(reduction_kwargs)
|
|
799
1488
|
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
800
1489
|
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
@@ -802,40 +1491,52 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
802
1491
|
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
803
1492
|
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
804
1493
|
|
|
805
|
-
if plot_type ==
|
|
1494
|
+
if plot_type == "static":
|
|
806
1495
|
fig = self._create_static_plot(
|
|
807
|
-
X_embedded_orig,
|
|
808
|
-
|
|
809
|
-
|
|
1496
|
+
X_embedded_orig,
|
|
1497
|
+
X_embedded_trans,
|
|
1498
|
+
batch_vec,
|
|
1499
|
+
reduction_method,
|
|
1500
|
+
n_components,
|
|
1501
|
+
figsize,
|
|
1502
|
+
alpha,
|
|
1503
|
+
point_size,
|
|
1504
|
+
cmap,
|
|
1505
|
+
title,
|
|
1506
|
+
show_legend,
|
|
810
1507
|
)
|
|
811
1508
|
else:
|
|
812
1509
|
fig = self._create_interactive_plot(
|
|
813
|
-
X_embedded_orig,
|
|
814
|
-
|
|
1510
|
+
X_embedded_orig,
|
|
1511
|
+
X_embedded_trans,
|
|
1512
|
+
batch_vec,
|
|
1513
|
+
reduction_method,
|
|
1514
|
+
n_components,
|
|
1515
|
+
cmap,
|
|
1516
|
+
title,
|
|
1517
|
+
show_legend,
|
|
815
1518
|
)
|
|
816
1519
|
|
|
817
1520
|
if return_embeddings:
|
|
818
|
-
embeddings = {
|
|
819
|
-
'original': X_embedded_orig,
|
|
820
|
-
'transformed': X_embedded_trans
|
|
821
|
-
}
|
|
1521
|
+
embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
|
|
822
1522
|
return fig, embeddings
|
|
823
1523
|
else:
|
|
824
1524
|
return fig
|
|
825
1525
|
|
|
826
1526
|
def _create_static_plot(
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
1527
|
+
self,
|
|
1528
|
+
X_orig: FloatArray,
|
|
1529
|
+
X_trans: FloatArray,
|
|
1530
|
+
batch_labels: pd.Series,
|
|
1531
|
+
method: str,
|
|
1532
|
+
n_components: int,
|
|
1533
|
+
figsize: tuple[int, int],
|
|
1534
|
+
alpha: float,
|
|
1535
|
+
point_size: int,
|
|
1536
|
+
cmap: str,
|
|
1537
|
+
title: str | None,
|
|
1538
|
+
show_legend: bool,
|
|
1539
|
+
) -> Any:
|
|
839
1540
|
"""Create static plots using matplotlib."""
|
|
840
1541
|
|
|
841
1542
|
fig = plt.figure(figsize=figsize)
|
|
@@ -846,119 +1547,130 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
846
1547
|
if n_batches <= 10:
|
|
847
1548
|
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
848
1549
|
else:
|
|
849
|
-
colors = matplotlib.colormaps.get_cmap(
|
|
1550
|
+
colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
|
|
850
1551
|
|
|
851
1552
|
if n_components == 2:
|
|
852
1553
|
ax1 = plt.subplot(1, 2, 1)
|
|
853
1554
|
ax2 = plt.subplot(1, 2, 2)
|
|
854
1555
|
else:
|
|
855
|
-
ax1 = fig.add_subplot(121, projection=
|
|
856
|
-
ax2 = fig.add_subplot(122, projection=
|
|
1556
|
+
ax1 = fig.add_subplot(121, projection="3d")
|
|
1557
|
+
ax2 = fig.add_subplot(122, projection="3d")
|
|
857
1558
|
|
|
858
1559
|
for i, batch in enumerate(unique_batches):
|
|
859
1560
|
mask = batch_labels == batch
|
|
860
1561
|
if n_components == 2:
|
|
861
1562
|
ax1.scatter(
|
|
862
|
-
X_orig[mask, 0],
|
|
1563
|
+
X_orig[mask, 0],
|
|
1564
|
+
X_orig[mask, 1],
|
|
863
1565
|
c=[colors[i]],
|
|
864
1566
|
s=point_size,
|
|
865
|
-
alpha=alpha,
|
|
866
|
-
label=f
|
|
867
|
-
edgecolors=
|
|
868
|
-
linewidth=0.5
|
|
1567
|
+
alpha=alpha,
|
|
1568
|
+
label=f"Batch {batch}",
|
|
1569
|
+
edgecolors="black",
|
|
1570
|
+
linewidth=0.5,
|
|
869
1571
|
)
|
|
870
1572
|
else:
|
|
871
1573
|
ax1.scatter(
|
|
872
|
-
X_orig[mask, 0],
|
|
1574
|
+
X_orig[mask, 0],
|
|
1575
|
+
X_orig[mask, 1],
|
|
1576
|
+
X_orig[mask, 2],
|
|
873
1577
|
c=[colors[i]],
|
|
874
1578
|
s=point_size,
|
|
875
|
-
alpha=alpha,
|
|
876
|
-
label=f
|
|
877
|
-
edgecolors=
|
|
878
|
-
linewidth=0.5
|
|
1579
|
+
alpha=alpha,
|
|
1580
|
+
label=f"Batch {batch}",
|
|
1581
|
+
edgecolors="black",
|
|
1582
|
+
linewidth=0.5,
|
|
879
1583
|
)
|
|
880
1584
|
|
|
881
|
-
ax1.set_title(f
|
|
882
|
-
ax1.set_xlabel(f
|
|
883
|
-
ax1.set_ylabel(f
|
|
1585
|
+
ax1.set_title(f"Before ComBat correction\n({method.upper()})")
|
|
1586
|
+
ax1.set_xlabel(f"{method.upper()}1")
|
|
1587
|
+
ax1.set_ylabel(f"{method.upper()}2")
|
|
884
1588
|
if n_components == 3:
|
|
885
|
-
ax1.set_zlabel(f
|
|
1589
|
+
ax1.set_zlabel(f"{method.upper()}3")
|
|
886
1590
|
|
|
887
1591
|
for i, batch in enumerate(unique_batches):
|
|
888
1592
|
mask = batch_labels == batch
|
|
889
1593
|
if n_components == 2:
|
|
890
1594
|
ax2.scatter(
|
|
891
|
-
X_trans[mask, 0],
|
|
1595
|
+
X_trans[mask, 0],
|
|
1596
|
+
X_trans[mask, 1],
|
|
892
1597
|
c=[colors[i]],
|
|
893
1598
|
s=point_size,
|
|
894
|
-
alpha=alpha,
|
|
895
|
-
label=f
|
|
896
|
-
edgecolors=
|
|
897
|
-
linewidth=0.5
|
|
1599
|
+
alpha=alpha,
|
|
1600
|
+
label=f"Batch {batch}",
|
|
1601
|
+
edgecolors="black",
|
|
1602
|
+
linewidth=0.5,
|
|
898
1603
|
)
|
|
899
1604
|
else:
|
|
900
1605
|
ax2.scatter(
|
|
901
|
-
X_trans[mask, 0],
|
|
1606
|
+
X_trans[mask, 0],
|
|
1607
|
+
X_trans[mask, 1],
|
|
1608
|
+
X_trans[mask, 2],
|
|
902
1609
|
c=[colors[i]],
|
|
903
1610
|
s=point_size,
|
|
904
|
-
alpha=alpha,
|
|
905
|
-
label=f
|
|
906
|
-
edgecolors=
|
|
907
|
-
linewidth=0.5
|
|
1611
|
+
alpha=alpha,
|
|
1612
|
+
label=f"Batch {batch}",
|
|
1613
|
+
edgecolors="black",
|
|
1614
|
+
linewidth=0.5,
|
|
908
1615
|
)
|
|
909
1616
|
|
|
910
|
-
ax2.set_title(f
|
|
911
|
-
ax2.set_xlabel(f
|
|
912
|
-
ax2.set_ylabel(f
|
|
1617
|
+
ax2.set_title(f"After ComBat correction\n({method.upper()})")
|
|
1618
|
+
ax2.set_xlabel(f"{method.upper()}1")
|
|
1619
|
+
ax2.set_ylabel(f"{method.upper()}2")
|
|
913
1620
|
if n_components == 3:
|
|
914
|
-
ax2.set_zlabel(f
|
|
1621
|
+
ax2.set_zlabel(f"{method.upper()}3")
|
|
915
1622
|
|
|
916
1623
|
if show_legend and n_batches <= 20:
|
|
917
|
-
ax2.legend(bbox_to_anchor=(1.05, 1), loc=
|
|
1624
|
+
ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
918
1625
|
|
|
919
1626
|
if title is None:
|
|
920
|
-
title = f
|
|
921
|
-
fig.suptitle(title, fontsize=14, fontweight=
|
|
1627
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1628
|
+
fig.suptitle(title, fontsize=14, fontweight="bold")
|
|
922
1629
|
|
|
923
1630
|
plt.tight_layout()
|
|
924
1631
|
return fig
|
|
925
1632
|
|
|
926
1633
|
def _create_interactive_plot(
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1634
|
+
self,
|
|
1635
|
+
X_orig: FloatArray,
|
|
1636
|
+
X_trans: FloatArray,
|
|
1637
|
+
batch_labels: pd.Series,
|
|
1638
|
+
method: str,
|
|
1639
|
+
n_components: int,
|
|
1640
|
+
cmap: str,
|
|
1641
|
+
title: str | None,
|
|
1642
|
+
show_legend: bool,
|
|
1643
|
+
) -> Any:
|
|
936
1644
|
"""Create interactive plots using plotly."""
|
|
937
1645
|
if n_components == 2:
|
|
938
1646
|
fig = make_subplots(
|
|
939
|
-
rows=1,
|
|
1647
|
+
rows=1,
|
|
1648
|
+
cols=2,
|
|
940
1649
|
subplot_titles=(
|
|
941
|
-
f
|
|
942
|
-
f
|
|
943
|
-
)
|
|
1650
|
+
f"Before ComBat correction ({method.upper()})",
|
|
1651
|
+
f"After ComBat correction ({method.upper()})",
|
|
1652
|
+
),
|
|
944
1653
|
)
|
|
945
1654
|
else:
|
|
946
1655
|
fig = make_subplots(
|
|
947
|
-
rows=1,
|
|
948
|
-
|
|
1656
|
+
rows=1,
|
|
1657
|
+
cols=2,
|
|
1658
|
+
specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
|
|
949
1659
|
subplot_titles=(
|
|
950
|
-
f
|
|
951
|
-
f
|
|
952
|
-
)
|
|
1660
|
+
f"Before ComBat correction ({method.upper()})",
|
|
1661
|
+
f"After ComBat correction ({method.upper()})",
|
|
1662
|
+
),
|
|
953
1663
|
)
|
|
954
1664
|
|
|
955
1665
|
unique_batches = batch_labels.drop_duplicates()
|
|
956
1666
|
|
|
957
1667
|
n_batches = len(unique_batches)
|
|
958
1668
|
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
959
|
-
color_list = [
|
|
1669
|
+
color_list = [
|
|
1670
|
+
mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
|
|
1671
|
+
]
|
|
960
1672
|
|
|
961
|
-
batch_to_color = dict(zip(unique_batches, color_list))
|
|
1673
|
+
batch_to_color = dict(zip(unique_batches, color_list, strict=True))
|
|
962
1674
|
|
|
963
1675
|
for batch in unique_batches:
|
|
964
1676
|
mask = batch_labels == batch
|
|
@@ -966,72 +1678,86 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
966
1678
|
if n_components == 2:
|
|
967
1679
|
fig.add_trace(
|
|
968
1680
|
go.Scatter(
|
|
969
|
-
x=X_orig[mask, 0],
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
1681
|
+
x=X_orig[mask, 0],
|
|
1682
|
+
y=X_orig[mask, 1],
|
|
1683
|
+
mode="markers",
|
|
1684
|
+
name=f"Batch {batch}",
|
|
1685
|
+
marker={
|
|
1686
|
+
"size": 8,
|
|
1687
|
+
"color": batch_to_color[batch],
|
|
1688
|
+
"line": {"width": 1, "color": "black"},
|
|
1689
|
+
},
|
|
1690
|
+
showlegend=False,
|
|
1691
|
+
),
|
|
1692
|
+
row=1,
|
|
1693
|
+
col=1,
|
|
979
1694
|
)
|
|
980
1695
|
|
|
981
1696
|
fig.add_trace(
|
|
982
1697
|
go.Scatter(
|
|
983
|
-
x=X_trans[mask, 0],
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
1698
|
+
x=X_trans[mask, 0],
|
|
1699
|
+
y=X_trans[mask, 1],
|
|
1700
|
+
mode="markers",
|
|
1701
|
+
name=f"Batch {batch}",
|
|
1702
|
+
marker={
|
|
1703
|
+
"size": 8,
|
|
1704
|
+
"color": batch_to_color[batch],
|
|
1705
|
+
"line": {"width": 1, "color": "black"},
|
|
1706
|
+
},
|
|
1707
|
+
showlegend=show_legend,
|
|
1708
|
+
),
|
|
1709
|
+
row=1,
|
|
1710
|
+
col=2,
|
|
993
1711
|
)
|
|
994
1712
|
else:
|
|
995
1713
|
fig.add_trace(
|
|
996
1714
|
go.Scatter3d(
|
|
997
|
-
x=X_orig[mask, 0],
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1715
|
+
x=X_orig[mask, 0],
|
|
1716
|
+
y=X_orig[mask, 1],
|
|
1717
|
+
z=X_orig[mask, 2],
|
|
1718
|
+
mode="markers",
|
|
1719
|
+
name=f"Batch {batch}",
|
|
1720
|
+
marker={
|
|
1721
|
+
"size": 5,
|
|
1722
|
+
"color": batch_to_color[batch],
|
|
1723
|
+
"line": {"width": 0.5, "color": "black"},
|
|
1724
|
+
},
|
|
1725
|
+
showlegend=False,
|
|
1726
|
+
),
|
|
1727
|
+
row=1,
|
|
1728
|
+
col=1,
|
|
1007
1729
|
)
|
|
1008
1730
|
|
|
1009
1731
|
fig.add_trace(
|
|
1010
1732
|
go.Scatter3d(
|
|
1011
|
-
x=X_trans[mask, 0],
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1733
|
+
x=X_trans[mask, 0],
|
|
1734
|
+
y=X_trans[mask, 1],
|
|
1735
|
+
z=X_trans[mask, 2],
|
|
1736
|
+
mode="markers",
|
|
1737
|
+
name=f"Batch {batch}",
|
|
1738
|
+
marker={
|
|
1739
|
+
"size": 5,
|
|
1740
|
+
"color": batch_to_color[batch],
|
|
1741
|
+
"line": {"width": 0.5, "color": "black"},
|
|
1742
|
+
},
|
|
1743
|
+
showlegend=show_legend,
|
|
1744
|
+
),
|
|
1745
|
+
row=1,
|
|
1746
|
+
col=2,
|
|
1021
1747
|
)
|
|
1022
1748
|
|
|
1023
1749
|
if title is None:
|
|
1024
|
-
title = f
|
|
1750
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1025
1751
|
|
|
1026
1752
|
fig.update_layout(
|
|
1027
1753
|
title=title,
|
|
1028
1754
|
title_font_size=16,
|
|
1029
1755
|
height=600,
|
|
1030
1756
|
showlegend=show_legend,
|
|
1031
|
-
hovermode=
|
|
1757
|
+
hovermode="closest",
|
|
1032
1758
|
)
|
|
1033
1759
|
|
|
1034
|
-
axis_labels = [f
|
|
1760
|
+
axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
|
|
1035
1761
|
|
|
1036
1762
|
if n_components == 2:
|
|
1037
1763
|
fig.update_xaxes(title_text=axis_labels[0])
|
|
@@ -1040,7 +1766,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
1040
1766
|
fig.update_scenes(
|
|
1041
1767
|
xaxis_title=axis_labels[0],
|
|
1042
1768
|
yaxis_title=axis_labels[1],
|
|
1043
|
-
zaxis_title=axis_labels[2]
|
|
1769
|
+
zaxis_title=axis_labels[2],
|
|
1044
1770
|
)
|
|
1045
1771
|
|
|
1046
1772
|
return fig
|