combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/core.py +578 -0
- combatlearn/metrics.py +788 -0
- combatlearn/sklearn_api.py +143 -0
- combatlearn/visualization.py +533 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/METADATA +24 -14
- combatlearn-1.2.0.dist-info/RECORD +10 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/WHEEL +1 -1
- combatlearn/combat.py +0 -1770
- combatlearn-1.1.2.dist-info/RECORD +0 -7
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/top_level.txt +0 -0
combatlearn/combat.py
DELETED
|
@@ -1,1770 +0,0 @@
|
|
|
1
|
-
"""ComBat algorithm.
|
|
2
|
-
|
|
3
|
-
`ComBatModel` implements both:
|
|
4
|
-
* Johnson et al. (2007) vanilla ComBat (method="johnson")
|
|
5
|
-
* Fortin et al. (2018) extension with covariates (method="fortin")
|
|
6
|
-
* Chen et al. (2022) CovBat (method="chen")
|
|
7
|
-
|
|
8
|
-
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
9
|
-
the batch (and optional covariates) at construction.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
from __future__ import annotations
|
|
13
|
-
|
|
14
|
-
import warnings
|
|
15
|
-
from typing import Any, Literal
|
|
16
|
-
|
|
17
|
-
import matplotlib
|
|
18
|
-
import matplotlib.colors as mcolors
|
|
19
|
-
import matplotlib.pyplot as plt
|
|
20
|
-
import numpy as np
|
|
21
|
-
import numpy.linalg as la
|
|
22
|
-
import numpy.typing as npt
|
|
23
|
-
import pandas as pd
|
|
24
|
-
import plotly.graph_objects as go
|
|
25
|
-
import umap
|
|
26
|
-
from plotly.subplots import make_subplots
|
|
27
|
-
from scipy.spatial.distance import pdist
|
|
28
|
-
from scipy.stats import chi2, levene, spearmanr
|
|
29
|
-
from sklearn.base import BaseEstimator, TransformerMixin
|
|
30
|
-
from sklearn.decomposition import PCA
|
|
31
|
-
from sklearn.manifold import TSNE
|
|
32
|
-
from sklearn.metrics import davies_bouldin_score, silhouette_score
|
|
33
|
-
from sklearn.neighbors import NearestNeighbors
|
|
34
|
-
|
|
35
|
-
ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
|
|
36
|
-
FloatArray = npt.NDArray[np.float64]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _compute_pca_embedding(
|
|
40
|
-
X_before: np.ndarray,
|
|
41
|
-
X_after: np.ndarray,
|
|
42
|
-
n_components: int,
|
|
43
|
-
) -> tuple[np.ndarray, np.ndarray, PCA]:
|
|
44
|
-
"""
|
|
45
|
-
Compute PCA embeddings for both datasets.
|
|
46
|
-
|
|
47
|
-
Fits PCA on X_before and applies to both datasets.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
X_before : np.ndarray
|
|
52
|
-
Original data before correction.
|
|
53
|
-
X_after : np.ndarray
|
|
54
|
-
Corrected data.
|
|
55
|
-
n_components : int
|
|
56
|
-
Number of PCA components.
|
|
57
|
-
|
|
58
|
-
Returns
|
|
59
|
-
-------
|
|
60
|
-
X_before_pca : np.ndarray
|
|
61
|
-
PCA-transformed original data.
|
|
62
|
-
X_after_pca : np.ndarray
|
|
63
|
-
PCA-transformed corrected data.
|
|
64
|
-
pca : PCA
|
|
65
|
-
Fitted PCA model.
|
|
66
|
-
"""
|
|
67
|
-
n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
|
|
68
|
-
pca = PCA(n_components=n_components, random_state=42)
|
|
69
|
-
X_before_pca = pca.fit_transform(X_before)
|
|
70
|
-
X_after_pca = pca.transform(X_after)
|
|
71
|
-
return X_before_pca, X_after_pca, pca
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
75
|
-
"""
|
|
76
|
-
Compute silhouette coefficient using batch as cluster labels.
|
|
77
|
-
|
|
78
|
-
Lower values after correction indicate better batch mixing.
|
|
79
|
-
Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
|
|
80
|
-
|
|
81
|
-
Parameters
|
|
82
|
-
----------
|
|
83
|
-
X : np.ndarray
|
|
84
|
-
Data matrix.
|
|
85
|
-
batch_labels : np.ndarray
|
|
86
|
-
Batch labels for each sample.
|
|
87
|
-
|
|
88
|
-
Returns
|
|
89
|
-
-------
|
|
90
|
-
float
|
|
91
|
-
Silhouette coefficient.
|
|
92
|
-
"""
|
|
93
|
-
unique_batches = np.unique(batch_labels)
|
|
94
|
-
if len(unique_batches) < 2:
|
|
95
|
-
return 0.0
|
|
96
|
-
try:
|
|
97
|
-
return silhouette_score(X, batch_labels, metric="euclidean")
|
|
98
|
-
except Exception:
|
|
99
|
-
return 0.0
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
103
|
-
"""
|
|
104
|
-
Compute Davies-Bouldin index using batch labels.
|
|
105
|
-
|
|
106
|
-
Lower values indicate better batch mixing.
|
|
107
|
-
Range: [0, inf), 0 = perfect batch overlap.
|
|
108
|
-
|
|
109
|
-
Parameters
|
|
110
|
-
----------
|
|
111
|
-
X : np.ndarray
|
|
112
|
-
Data matrix.
|
|
113
|
-
batch_labels : np.ndarray
|
|
114
|
-
Batch labels for each sample.
|
|
115
|
-
|
|
116
|
-
Returns
|
|
117
|
-
-------
|
|
118
|
-
float
|
|
119
|
-
Davies-Bouldin index.
|
|
120
|
-
"""
|
|
121
|
-
unique_batches = np.unique(batch_labels)
|
|
122
|
-
if len(unique_batches) < 2:
|
|
123
|
-
return 0.0
|
|
124
|
-
try:
|
|
125
|
-
return davies_bouldin_score(X, batch_labels)
|
|
126
|
-
except Exception:
|
|
127
|
-
return 0.0
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def _kbet_score(
|
|
131
|
-
X: np.ndarray,
|
|
132
|
-
batch_labels: np.ndarray,
|
|
133
|
-
k0: int,
|
|
134
|
-
alpha: float = 0.05,
|
|
135
|
-
) -> tuple[float, float]:
|
|
136
|
-
"""
|
|
137
|
-
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
|
|
138
|
-
|
|
139
|
-
Tests if local batch proportions match global batch proportions.
|
|
140
|
-
Higher acceptance rate = better batch mixing.
|
|
141
|
-
|
|
142
|
-
Reference: Buttner et al. (2019) Nature Methods
|
|
143
|
-
|
|
144
|
-
Parameters
|
|
145
|
-
----------
|
|
146
|
-
X : np.ndarray
|
|
147
|
-
Data matrix.
|
|
148
|
-
batch_labels : np.ndarray
|
|
149
|
-
Batch labels for each sample.
|
|
150
|
-
k0 : int
|
|
151
|
-
Neighborhood size.
|
|
152
|
-
alpha : float
|
|
153
|
-
Significance level for chi-squared test.
|
|
154
|
-
|
|
155
|
-
Returns
|
|
156
|
-
-------
|
|
157
|
-
acceptance_rate : float
|
|
158
|
-
Fraction of samples where H0 (uniform mixing) is accepted.
|
|
159
|
-
mean_stat : float
|
|
160
|
-
Mean chi-squared statistic across samples.
|
|
161
|
-
"""
|
|
162
|
-
n_samples = X.shape[0]
|
|
163
|
-
unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
|
|
164
|
-
n_batches = len(unique_batches)
|
|
165
|
-
|
|
166
|
-
if n_batches < 2:
|
|
167
|
-
return 1.0, 0.0
|
|
168
|
-
|
|
169
|
-
global_freq = batch_counts / n_samples
|
|
170
|
-
k0 = min(k0, n_samples - 1)
|
|
171
|
-
|
|
172
|
-
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
|
|
173
|
-
nn.fit(X)
|
|
174
|
-
_, indices = nn.kneighbors(X)
|
|
175
|
-
|
|
176
|
-
chi2_stats = []
|
|
177
|
-
p_values = []
|
|
178
|
-
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
179
|
-
|
|
180
|
-
for i in range(n_samples):
|
|
181
|
-
neighbors = indices[i, 1 : k0 + 1]
|
|
182
|
-
neighbor_batches = batch_labels[neighbors]
|
|
183
|
-
|
|
184
|
-
observed = np.zeros(n_batches)
|
|
185
|
-
for nb in neighbor_batches:
|
|
186
|
-
observed[batch_to_idx[nb]] += 1
|
|
187
|
-
|
|
188
|
-
expected = global_freq * k0
|
|
189
|
-
|
|
190
|
-
mask = expected > 0
|
|
191
|
-
if mask.sum() < 2:
|
|
192
|
-
continue
|
|
193
|
-
|
|
194
|
-
stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
|
|
195
|
-
df = max(1, mask.sum() - 1)
|
|
196
|
-
p_val = 1 - chi2.cdf(stat, df)
|
|
197
|
-
|
|
198
|
-
chi2_stats.append(stat)
|
|
199
|
-
p_values.append(p_val)
|
|
200
|
-
|
|
201
|
-
if len(p_values) == 0:
|
|
202
|
-
return 1.0, 0.0
|
|
203
|
-
|
|
204
|
-
acceptance_rate = np.mean(np.array(p_values) > alpha)
|
|
205
|
-
mean_stat = np.mean(chi2_stats)
|
|
206
|
-
|
|
207
|
-
return acceptance_rate, mean_stat
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
|
|
211
|
-
"""
|
|
212
|
-
Binary search for sigma to achieve target perplexity.
|
|
213
|
-
|
|
214
|
-
Used in LISI computation.
|
|
215
|
-
|
|
216
|
-
Parameters
|
|
217
|
-
----------
|
|
218
|
-
distances : np.ndarray
|
|
219
|
-
Distances to neighbors.
|
|
220
|
-
target_perplexity : float
|
|
221
|
-
Target perplexity value.
|
|
222
|
-
tol : float
|
|
223
|
-
Tolerance for convergence.
|
|
224
|
-
|
|
225
|
-
Returns
|
|
226
|
-
-------
|
|
227
|
-
float
|
|
228
|
-
Sigma value.
|
|
229
|
-
"""
|
|
230
|
-
target_H = np.log2(target_perplexity + 1e-10)
|
|
231
|
-
|
|
232
|
-
sigma_min, sigma_max = 1e-10, 1e10
|
|
233
|
-
sigma = 1.0
|
|
234
|
-
|
|
235
|
-
for _ in range(50):
|
|
236
|
-
P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
|
|
237
|
-
P_sum = P.sum()
|
|
238
|
-
if P_sum < 1e-10:
|
|
239
|
-
sigma = (sigma + sigma_max) / 2
|
|
240
|
-
continue
|
|
241
|
-
P = P / P_sum
|
|
242
|
-
P = np.clip(P, 1e-10, 1.0)
|
|
243
|
-
H = -np.sum(P * np.log2(P))
|
|
244
|
-
|
|
245
|
-
if abs(H - target_H) < tol:
|
|
246
|
-
break
|
|
247
|
-
elif target_H > H:
|
|
248
|
-
sigma_min = sigma
|
|
249
|
-
else:
|
|
250
|
-
sigma_max = sigma
|
|
251
|
-
sigma = (sigma_min + sigma_max) / 2
|
|
252
|
-
|
|
253
|
-
return sigma
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
def _lisi_score(
|
|
257
|
-
X: np.ndarray,
|
|
258
|
-
batch_labels: np.ndarray,
|
|
259
|
-
perplexity: int = 30,
|
|
260
|
-
) -> float:
|
|
261
|
-
"""
|
|
262
|
-
Compute mean Local Inverse Simpson's Index (LISI).
|
|
263
|
-
|
|
264
|
-
Range: [1, n_batches], where n_batches = perfect mixing.
|
|
265
|
-
Higher = better batch mixing.
|
|
266
|
-
|
|
267
|
-
Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
|
|
268
|
-
|
|
269
|
-
Parameters
|
|
270
|
-
----------
|
|
271
|
-
X : np.ndarray
|
|
272
|
-
Data matrix.
|
|
273
|
-
batch_labels : np.ndarray
|
|
274
|
-
Batch labels for each sample.
|
|
275
|
-
perplexity : int
|
|
276
|
-
Perplexity for Gaussian kernel.
|
|
277
|
-
|
|
278
|
-
Returns
|
|
279
|
-
-------
|
|
280
|
-
float
|
|
281
|
-
Mean LISI score.
|
|
282
|
-
"""
|
|
283
|
-
n_samples = X.shape[0]
|
|
284
|
-
unique_batches = np.unique(batch_labels)
|
|
285
|
-
n_batches = len(unique_batches)
|
|
286
|
-
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
287
|
-
|
|
288
|
-
if n_batches < 2:
|
|
289
|
-
return 1.0
|
|
290
|
-
|
|
291
|
-
k = min(3 * perplexity, n_samples - 1)
|
|
292
|
-
|
|
293
|
-
nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
|
|
294
|
-
nn.fit(X)
|
|
295
|
-
distances, indices = nn.kneighbors(X)
|
|
296
|
-
|
|
297
|
-
distances = distances[:, 1:]
|
|
298
|
-
indices = indices[:, 1:]
|
|
299
|
-
|
|
300
|
-
lisi_values = []
|
|
301
|
-
|
|
302
|
-
for i in range(n_samples):
|
|
303
|
-
sigma = _find_sigma(distances[i], perplexity)
|
|
304
|
-
|
|
305
|
-
P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
|
|
306
|
-
P_sum = P.sum()
|
|
307
|
-
if P_sum < 1e-10:
|
|
308
|
-
lisi_values.append(1.0)
|
|
309
|
-
continue
|
|
310
|
-
P = P / P_sum
|
|
311
|
-
|
|
312
|
-
neighbor_batches = batch_labels[indices[i]]
|
|
313
|
-
batch_probs = np.zeros(n_batches)
|
|
314
|
-
for j, nb in enumerate(neighbor_batches):
|
|
315
|
-
batch_probs[batch_to_idx[nb]] += P[j]
|
|
316
|
-
|
|
317
|
-
simpson = np.sum(batch_probs**2)
|
|
318
|
-
lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
|
|
319
|
-
lisi_values.append(lisi)
|
|
320
|
-
|
|
321
|
-
return np.mean(lisi_values)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
325
|
-
"""
|
|
326
|
-
Compute between-batch to within-batch variance ratio.
|
|
327
|
-
|
|
328
|
-
Similar to F-statistic in one-way ANOVA.
|
|
329
|
-
Lower ratio after correction = better batch effect removal.
|
|
330
|
-
|
|
331
|
-
Parameters
|
|
332
|
-
----------
|
|
333
|
-
X : np.ndarray
|
|
334
|
-
Data matrix.
|
|
335
|
-
batch_labels : np.ndarray
|
|
336
|
-
Batch labels for each sample.
|
|
337
|
-
|
|
338
|
-
Returns
|
|
339
|
-
-------
|
|
340
|
-
float
|
|
341
|
-
Variance ratio (between/within).
|
|
342
|
-
"""
|
|
343
|
-
unique_batches = np.unique(batch_labels)
|
|
344
|
-
n_batches = len(unique_batches)
|
|
345
|
-
n_samples = X.shape[0]
|
|
346
|
-
|
|
347
|
-
if n_batches < 2:
|
|
348
|
-
return 0.0
|
|
349
|
-
|
|
350
|
-
grand_mean = np.mean(X, axis=0)
|
|
351
|
-
|
|
352
|
-
between_var = 0.0
|
|
353
|
-
within_var = 0.0
|
|
354
|
-
|
|
355
|
-
for batch in unique_batches:
|
|
356
|
-
mask = batch_labels == batch
|
|
357
|
-
n_b = np.sum(mask)
|
|
358
|
-
X_batch = X[mask]
|
|
359
|
-
batch_mean = np.mean(X_batch, axis=0)
|
|
360
|
-
|
|
361
|
-
between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
|
|
362
|
-
within_var += np.sum((X_batch - batch_mean) ** 2)
|
|
363
|
-
|
|
364
|
-
between_var /= n_batches - 1
|
|
365
|
-
within_var /= n_samples - n_batches
|
|
366
|
-
|
|
367
|
-
if within_var < 1e-10:
|
|
368
|
-
return 0.0
|
|
369
|
-
|
|
370
|
-
return between_var / within_var
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
def _knn_preservation(
|
|
374
|
-
X_before: np.ndarray,
|
|
375
|
-
X_after: np.ndarray,
|
|
376
|
-
k_values: list[int],
|
|
377
|
-
n_jobs: int = 1,
|
|
378
|
-
) -> dict[int, float]:
|
|
379
|
-
"""
|
|
380
|
-
Compute fraction of k-nearest neighbors preserved after correction.
|
|
381
|
-
|
|
382
|
-
Range: [0, 1], where 1 = perfect preservation.
|
|
383
|
-
Higher = better biological structure preservation.
|
|
384
|
-
|
|
385
|
-
Parameters
|
|
386
|
-
----------
|
|
387
|
-
X_before : np.ndarray
|
|
388
|
-
Original data.
|
|
389
|
-
X_after : np.ndarray
|
|
390
|
-
Corrected data.
|
|
391
|
-
k_values : list of int
|
|
392
|
-
Values of k for k-NN.
|
|
393
|
-
n_jobs : int
|
|
394
|
-
Number of parallel jobs.
|
|
395
|
-
|
|
396
|
-
Returns
|
|
397
|
-
-------
|
|
398
|
-
dict
|
|
399
|
-
Mapping from k to preservation fraction.
|
|
400
|
-
"""
|
|
401
|
-
results = {}
|
|
402
|
-
max_k = max(k_values)
|
|
403
|
-
max_k = min(max_k, X_before.shape[0] - 1)
|
|
404
|
-
|
|
405
|
-
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
406
|
-
nn_before.fit(X_before)
|
|
407
|
-
_, indices_before = nn_before.kneighbors(X_before)
|
|
408
|
-
|
|
409
|
-
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
410
|
-
nn_after.fit(X_after)
|
|
411
|
-
_, indices_after = nn_after.kneighbors(X_after)
|
|
412
|
-
|
|
413
|
-
for k in k_values:
|
|
414
|
-
if k > max_k:
|
|
415
|
-
results[k] = 0.0
|
|
416
|
-
continue
|
|
417
|
-
|
|
418
|
-
overlaps = []
|
|
419
|
-
for i in range(X_before.shape[0]):
|
|
420
|
-
neighbors_before = set(indices_before[i, 1 : k + 1])
|
|
421
|
-
neighbors_after = set(indices_after[i, 1 : k + 1])
|
|
422
|
-
overlap = len(neighbors_before & neighbors_after) / k
|
|
423
|
-
overlaps.append(overlap)
|
|
424
|
-
|
|
425
|
-
results[k] = np.mean(overlaps)
|
|
426
|
-
|
|
427
|
-
return results
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
def _pairwise_distance_correlation(
|
|
431
|
-
X_before: np.ndarray,
|
|
432
|
-
X_after: np.ndarray,
|
|
433
|
-
subsample: int = 1000,
|
|
434
|
-
random_state: int = 42,
|
|
435
|
-
) -> float:
|
|
436
|
-
"""
|
|
437
|
-
Compute Spearman correlation of pairwise distances.
|
|
438
|
-
|
|
439
|
-
Range: [-1, 1], where 1 = perfect rank preservation.
|
|
440
|
-
Higher = better relative relationship preservation.
|
|
441
|
-
|
|
442
|
-
Parameters
|
|
443
|
-
----------
|
|
444
|
-
X_before : np.ndarray
|
|
445
|
-
Original data.
|
|
446
|
-
X_after : np.ndarray
|
|
447
|
-
Corrected data.
|
|
448
|
-
subsample : int
|
|
449
|
-
Maximum samples to use (for efficiency).
|
|
450
|
-
random_state : int
|
|
451
|
-
Random seed for subsampling.
|
|
452
|
-
|
|
453
|
-
Returns
|
|
454
|
-
-------
|
|
455
|
-
float
|
|
456
|
-
Spearman correlation coefficient.
|
|
457
|
-
"""
|
|
458
|
-
n_samples = X_before.shape[0]
|
|
459
|
-
|
|
460
|
-
if n_samples > subsample:
|
|
461
|
-
rng = np.random.default_rng(random_state)
|
|
462
|
-
idx = rng.choice(n_samples, subsample, replace=False)
|
|
463
|
-
X_before = X_before[idx]
|
|
464
|
-
X_after = X_after[idx]
|
|
465
|
-
|
|
466
|
-
dist_before = pdist(X_before, metric="euclidean")
|
|
467
|
-
dist_after = pdist(X_after, metric="euclidean")
|
|
468
|
-
|
|
469
|
-
if len(dist_before) == 0:
|
|
470
|
-
return 1.0
|
|
471
|
-
|
|
472
|
-
corr, _ = spearmanr(dist_before, dist_after)
|
|
473
|
-
|
|
474
|
-
if np.isnan(corr):
|
|
475
|
-
return 1.0
|
|
476
|
-
|
|
477
|
-
return corr
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
481
|
-
"""
|
|
482
|
-
Compute mean pairwise Euclidean distance between batch centroids.
|
|
483
|
-
|
|
484
|
-
Lower after correction = better batch alignment.
|
|
485
|
-
|
|
486
|
-
Parameters
|
|
487
|
-
----------
|
|
488
|
-
X : np.ndarray
|
|
489
|
-
Data matrix.
|
|
490
|
-
batch_labels : np.ndarray
|
|
491
|
-
Batch labels for each sample.
|
|
492
|
-
|
|
493
|
-
Returns
|
|
494
|
-
-------
|
|
495
|
-
float
|
|
496
|
-
Mean pairwise distance between centroids.
|
|
497
|
-
"""
|
|
498
|
-
unique_batches = np.unique(batch_labels)
|
|
499
|
-
n_batches = len(unique_batches)
|
|
500
|
-
|
|
501
|
-
if n_batches < 2:
|
|
502
|
-
return 0.0
|
|
503
|
-
|
|
504
|
-
centroids = []
|
|
505
|
-
for batch in unique_batches:
|
|
506
|
-
mask = batch_labels == batch
|
|
507
|
-
centroid = np.mean(X[mask], axis=0)
|
|
508
|
-
centroids.append(centroid)
|
|
509
|
-
|
|
510
|
-
centroids = np.array(centroids)
|
|
511
|
-
distances = pdist(centroids, metric="euclidean")
|
|
512
|
-
|
|
513
|
-
return np.mean(distances)
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
517
|
-
"""
|
|
518
|
-
Compute median Levene test statistic across features.
|
|
519
|
-
|
|
520
|
-
Lower statistic = more homogeneous variances across batches.
|
|
521
|
-
|
|
522
|
-
Parameters
|
|
523
|
-
----------
|
|
524
|
-
X : np.ndarray
|
|
525
|
-
Data matrix.
|
|
526
|
-
batch_labels : np.ndarray
|
|
527
|
-
Batch labels for each sample.
|
|
528
|
-
|
|
529
|
-
Returns
|
|
530
|
-
-------
|
|
531
|
-
float
|
|
532
|
-
Median Levene test statistic.
|
|
533
|
-
"""
|
|
534
|
-
unique_batches = np.unique(batch_labels)
|
|
535
|
-
if len(unique_batches) < 2:
|
|
536
|
-
return 0.0
|
|
537
|
-
|
|
538
|
-
levene_stats = []
|
|
539
|
-
for j in range(X.shape[1]):
|
|
540
|
-
groups = [X[batch_labels == b, j] for b in unique_batches]
|
|
541
|
-
groups = [g for g in groups if len(g) > 0]
|
|
542
|
-
if len(groups) < 2:
|
|
543
|
-
continue
|
|
544
|
-
try:
|
|
545
|
-
stat, _ = levene(*groups, center="median")
|
|
546
|
-
if not np.isnan(stat):
|
|
547
|
-
levene_stats.append(stat)
|
|
548
|
-
except Exception:
|
|
549
|
-
continue
|
|
550
|
-
|
|
551
|
-
if len(levene_stats) == 0:
|
|
552
|
-
return 0.0
|
|
553
|
-
|
|
554
|
-
return np.median(levene_stats)
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
class ComBatModel:
|
|
558
|
-
"""ComBat algorithm.
|
|
559
|
-
|
|
560
|
-
Parameters
|
|
561
|
-
----------
|
|
562
|
-
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
563
|
-
* 'johnson' - classic ComBat.
|
|
564
|
-
* 'fortin' - covariate-aware ComBat.
|
|
565
|
-
* 'chen' - CovBat, PCA-based ComBat.
|
|
566
|
-
parametric : bool, default=True
|
|
567
|
-
Use the parametric empirical Bayes variant.
|
|
568
|
-
mean_only : bool, default=False
|
|
569
|
-
If True, only the mean is adjusted (`gamma_star`),
|
|
570
|
-
ignoring the variance (`delta_star`).
|
|
571
|
-
reference_batch : str, optional
|
|
572
|
-
If specified, the batch level to use as reference.
|
|
573
|
-
covbat_cov_thresh : float or int, default=0.9
|
|
574
|
-
CovBat: cumulative variance threshold (0, 1] to retain PCs, or
|
|
575
|
-
integer >= 1 specifying the number of components directly.
|
|
576
|
-
eps : float, default=1e-8
|
|
577
|
-
Numerical jitter to avoid division-by-zero.
|
|
578
|
-
"""
|
|
579
|
-
|
|
580
|
-
def __init__(
|
|
581
|
-
self,
|
|
582
|
-
*,
|
|
583
|
-
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
584
|
-
parametric: bool = True,
|
|
585
|
-
mean_only: bool = False,
|
|
586
|
-
reference_batch: str | None = None,
|
|
587
|
-
eps: float = 1e-8,
|
|
588
|
-
covbat_cov_thresh: float | int = 0.9,
|
|
589
|
-
) -> None:
|
|
590
|
-
self.method: str = method
|
|
591
|
-
self.parametric: bool = parametric
|
|
592
|
-
self.mean_only: bool = bool(mean_only)
|
|
593
|
-
self.reference_batch: str | None = reference_batch
|
|
594
|
-
self.eps: float = float(eps)
|
|
595
|
-
self.covbat_cov_thresh: float | int = covbat_cov_thresh
|
|
596
|
-
|
|
597
|
-
self._batch_levels: pd.Index
|
|
598
|
-
self._grand_mean: pd.Series
|
|
599
|
-
self._pooled_var: pd.Series
|
|
600
|
-
self._gamma_star: FloatArray
|
|
601
|
-
self._delta_star: FloatArray
|
|
602
|
-
self._n_per_batch: dict[str, int]
|
|
603
|
-
self._reference_batch_idx: int | None
|
|
604
|
-
self._beta_hat_nonbatch: FloatArray
|
|
605
|
-
self._n_batch: int
|
|
606
|
-
self._p_design: int
|
|
607
|
-
self._covbat_pca: PCA
|
|
608
|
-
self._covbat_n_pc: int
|
|
609
|
-
self._batch_levels_pc: pd.Index
|
|
610
|
-
self._pc_gamma_star: FloatArray
|
|
611
|
-
self._pc_delta_star: FloatArray
|
|
612
|
-
|
|
613
|
-
# Validate covbat_cov_thresh
|
|
614
|
-
if isinstance(self.covbat_cov_thresh, float):
|
|
615
|
-
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
616
|
-
raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
|
|
617
|
-
elif isinstance(self.covbat_cov_thresh, int):
|
|
618
|
-
if self.covbat_cov_thresh < 1:
|
|
619
|
-
raise ValueError("covbat_cov_thresh must be >= 1 when int.")
|
|
620
|
-
else:
|
|
621
|
-
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
622
|
-
|
|
623
|
-
@staticmethod
|
|
624
|
-
def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
|
|
625
|
-
"""Convert array-like to categorical Series with validation."""
|
|
626
|
-
ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
|
|
627
|
-
if not ser.index.equals(index):
|
|
628
|
-
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
629
|
-
return ser.astype("category")
|
|
630
|
-
|
|
631
|
-
@staticmethod
|
|
632
|
-
def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
|
|
633
|
-
"""Convert array-like to DataFrame."""
|
|
634
|
-
if arr is None:
|
|
635
|
-
return None
|
|
636
|
-
if isinstance(arr, pd.Series):
|
|
637
|
-
arr = arr.to_frame()
|
|
638
|
-
if not isinstance(arr, pd.DataFrame):
|
|
639
|
-
arr = pd.DataFrame(arr, index=index)
|
|
640
|
-
if not arr.index.equals(index):
|
|
641
|
-
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
642
|
-
return arr
|
|
643
|
-
|
|
644
|
-
def fit(
|
|
645
|
-
self,
|
|
646
|
-
X: ArrayLike,
|
|
647
|
-
y: ArrayLike | None = None,
|
|
648
|
-
*,
|
|
649
|
-
batch: ArrayLike,
|
|
650
|
-
discrete_covariates: ArrayLike | None = None,
|
|
651
|
-
continuous_covariates: ArrayLike | None = None,
|
|
652
|
-
) -> ComBatModel:
|
|
653
|
-
"""Fit the ComBat model."""
|
|
654
|
-
method = self.method.lower()
|
|
655
|
-
if method not in {"johnson", "fortin", "chen"}:
|
|
656
|
-
raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
|
|
657
|
-
if not isinstance(X, pd.DataFrame):
|
|
658
|
-
X = pd.DataFrame(X)
|
|
659
|
-
idx = X.index
|
|
660
|
-
batch = self._as_series(batch, idx, "batch")
|
|
661
|
-
|
|
662
|
-
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
663
|
-
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
664
|
-
|
|
665
|
-
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
666
|
-
raise ValueError(
|
|
667
|
-
f"reference_batch={self.reference_batch!r} not present in the data batches."
|
|
668
|
-
f"{list(batch.cat.categories)}"
|
|
669
|
-
)
|
|
670
|
-
|
|
671
|
-
if method == "johnson":
|
|
672
|
-
if disc is not None or cont is not None:
|
|
673
|
-
warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
|
|
674
|
-
self._fit_johnson(X, batch)
|
|
675
|
-
elif method == "fortin":
|
|
676
|
-
self._fit_fortin(X, batch, disc, cont)
|
|
677
|
-
elif method == "chen":
|
|
678
|
-
self._fit_chen(X, batch, disc, cont)
|
|
679
|
-
return self
|
|
680
|
-
|
|
681
|
-
def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
|
|
682
|
-
"""Johnson et al. (2007) ComBat."""
|
|
683
|
-
self._batch_levels = batch.cat.categories
|
|
684
|
-
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
685
|
-
grand_mean = X.mean(axis=0)
|
|
686
|
-
|
|
687
|
-
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
688
|
-
|
|
689
|
-
n_per_batch: dict[str, int] = {}
|
|
690
|
-
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
691
|
-
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
692
|
-
|
|
693
|
-
for lvl in self._batch_levels:
|
|
694
|
-
idx = batch == lvl
|
|
695
|
-
n_b = int(idx.sum())
|
|
696
|
-
if n_b < 2:
|
|
697
|
-
raise ValueError(f"Batch '{lvl}' has <2 samples.")
|
|
698
|
-
n_per_batch[str(lvl)] = n_b
|
|
699
|
-
xb = Xs.loc[idx]
|
|
700
|
-
gamma_hat.append(xb.mean(axis=0).values)
|
|
701
|
-
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
702
|
-
|
|
703
|
-
gamma_hat_arr = np.vstack(gamma_hat)
|
|
704
|
-
delta_hat_arr = np.vstack(delta_hat)
|
|
705
|
-
|
|
706
|
-
if self.mean_only:
|
|
707
|
-
gamma_star = self._shrink_gamma(
|
|
708
|
-
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
709
|
-
)
|
|
710
|
-
delta_star = np.ones_like(delta_hat_arr)
|
|
711
|
-
else:
|
|
712
|
-
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
713
|
-
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
714
|
-
)
|
|
715
|
-
|
|
716
|
-
if self.reference_batch is not None:
|
|
717
|
-
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
718
|
-
gamma_ref = gamma_star[ref_idx]
|
|
719
|
-
delta_ref = delta_star[ref_idx]
|
|
720
|
-
gamma_star = gamma_star - gamma_ref
|
|
721
|
-
if not self.mean_only:
|
|
722
|
-
delta_star = delta_star / delta_ref
|
|
723
|
-
self._reference_batch_idx = ref_idx
|
|
724
|
-
else:
|
|
725
|
-
self._reference_batch_idx = None
|
|
726
|
-
|
|
727
|
-
self._grand_mean = grand_mean
|
|
728
|
-
self._pooled_var = pooled_var
|
|
729
|
-
self._gamma_star = gamma_star
|
|
730
|
-
self._delta_star = delta_star
|
|
731
|
-
self._n_per_batch = n_per_batch
|
|
732
|
-
|
|
733
|
-
def _fit_fortin(
|
|
734
|
-
self,
|
|
735
|
-
X: pd.DataFrame,
|
|
736
|
-
batch: pd.Series,
|
|
737
|
-
disc: pd.DataFrame | None,
|
|
738
|
-
cont: pd.DataFrame | None,
|
|
739
|
-
) -> None:
|
|
740
|
-
"""Fortin et al. (2018) neuroComBat."""
|
|
741
|
-
self._batch_levels = batch.cat.categories
|
|
742
|
-
n_batch = len(self._batch_levels)
|
|
743
|
-
n_samples = len(X)
|
|
744
|
-
|
|
745
|
-
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
|
|
746
|
-
if self.reference_batch is not None:
|
|
747
|
-
if self.reference_batch not in self._batch_levels:
|
|
748
|
-
raise ValueError(
|
|
749
|
-
f"reference_batch={self.reference_batch!r} not present in batches."
|
|
750
|
-
f"{list(self._batch_levels)}"
|
|
751
|
-
)
|
|
752
|
-
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
753
|
-
|
|
754
|
-
parts: list[pd.DataFrame] = [batch_dummies]
|
|
755
|
-
if disc is not None:
|
|
756
|
-
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
757
|
-
|
|
758
|
-
if cont is not None:
|
|
759
|
-
parts.append(cont.astype(float))
|
|
760
|
-
|
|
761
|
-
design = pd.concat(parts, axis=1).values
|
|
762
|
-
p_design = design.shape[1]
|
|
763
|
-
|
|
764
|
-
X_np = X.values
|
|
765
|
-
beta_hat = la.lstsq(design, X_np, rcond=None)[0]
|
|
766
|
-
|
|
767
|
-
beta_hat_batch = beta_hat[:n_batch]
|
|
768
|
-
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
769
|
-
|
|
770
|
-
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
771
|
-
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
|
|
772
|
-
|
|
773
|
-
if self.reference_batch is not None:
|
|
774
|
-
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
775
|
-
grand_mean = beta_hat_batch[ref_idx]
|
|
776
|
-
else:
|
|
777
|
-
grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
|
|
778
|
-
ref_idx = None
|
|
779
|
-
|
|
780
|
-
self._grand_mean = pd.Series(grand_mean, index=X.columns)
|
|
781
|
-
|
|
782
|
-
if self.reference_batch is not None:
|
|
783
|
-
ref_mask = (batch == self.reference_batch).values
|
|
784
|
-
resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
|
|
785
|
-
denom = int(ref_mask.sum())
|
|
786
|
-
else:
|
|
787
|
-
resid = X_np - design @ beta_hat
|
|
788
|
-
denom = n_samples
|
|
789
|
-
var_pooled = (resid**2).sum(axis=0) / denom + self.eps
|
|
790
|
-
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
791
|
-
|
|
792
|
-
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
793
|
-
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
794
|
-
|
|
795
|
-
gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
|
|
796
|
-
delta_hat = np.vstack(
|
|
797
|
-
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
|
|
798
|
-
)
|
|
799
|
-
|
|
800
|
-
if self.mean_only:
|
|
801
|
-
gamma_star = self._shrink_gamma(
|
|
802
|
-
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
803
|
-
)
|
|
804
|
-
delta_star = np.ones_like(delta_hat)
|
|
805
|
-
else:
|
|
806
|
-
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
807
|
-
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
808
|
-
)
|
|
809
|
-
|
|
810
|
-
if ref_idx is not None:
|
|
811
|
-
gamma_star[ref_idx] = 0.0
|
|
812
|
-
if not self.mean_only:
|
|
813
|
-
delta_star[ref_idx] = 1.0
|
|
814
|
-
self._reference_batch_idx = ref_idx
|
|
815
|
-
|
|
816
|
-
self._gamma_star = gamma_star
|
|
817
|
-
self._delta_star = delta_star
|
|
818
|
-
self._n_batch = n_batch
|
|
819
|
-
self._p_design = p_design
|
|
820
|
-
|
|
821
|
-
def _fit_chen(
|
|
822
|
-
self,
|
|
823
|
-
X: pd.DataFrame,
|
|
824
|
-
batch: pd.Series,
|
|
825
|
-
disc: pd.DataFrame | None,
|
|
826
|
-
cont: pd.DataFrame | None,
|
|
827
|
-
) -> None:
|
|
828
|
-
"""Chen et al. (2022) CovBat."""
|
|
829
|
-
self._fit_fortin(X, batch, disc, cont)
|
|
830
|
-
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
831
|
-
pca = PCA(svd_solver="full", whiten=False).fit(X_meanvar_adj)
|
|
832
|
-
|
|
833
|
-
# Determine number of components based on threshold type
|
|
834
|
-
if isinstance(self.covbat_cov_thresh, int):
|
|
835
|
-
n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
|
|
836
|
-
else:
|
|
837
|
-
cumulative = np.cumsum(pca.explained_variance_ratio_)
|
|
838
|
-
n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
|
|
839
|
-
|
|
840
|
-
self._covbat_pca = pca
|
|
841
|
-
self._covbat_n_pc = n_pc
|
|
842
|
-
|
|
843
|
-
scores = pca.transform(X_meanvar_adj)[:, :n_pc]
|
|
844
|
-
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
|
|
845
|
-
self._batch_levels_pc = self._batch_levels
|
|
846
|
-
n_per_batch = self._n_per_batch
|
|
847
|
-
|
|
848
|
-
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
849
|
-
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
850
|
-
for lvl in self._batch_levels_pc:
|
|
851
|
-
idx = batch == lvl
|
|
852
|
-
xb = scores_df.loc[idx]
|
|
853
|
-
gamma_hat.append(xb.mean(axis=0).values)
|
|
854
|
-
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
855
|
-
gamma_hat_arr = np.vstack(gamma_hat)
|
|
856
|
-
delta_hat_arr = np.vstack(delta_hat)
|
|
857
|
-
|
|
858
|
-
if self.mean_only:
|
|
859
|
-
gamma_star = self._shrink_gamma(
|
|
860
|
-
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
861
|
-
)
|
|
862
|
-
delta_star = np.ones_like(delta_hat_arr)
|
|
863
|
-
else:
|
|
864
|
-
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
865
|
-
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
866
|
-
)
|
|
867
|
-
|
|
868
|
-
if self.reference_batch is not None:
|
|
869
|
-
ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
|
|
870
|
-
gamma_ref = gamma_star[ref_idx]
|
|
871
|
-
delta_ref = delta_star[ref_idx]
|
|
872
|
-
gamma_star = gamma_star - gamma_ref
|
|
873
|
-
if not self.mean_only:
|
|
874
|
-
delta_star = delta_star / delta_ref
|
|
875
|
-
|
|
876
|
-
self._pc_gamma_star = gamma_star
|
|
877
|
-
self._pc_delta_star = delta_star
|
|
878
|
-
|
|
879
|
-
def _shrink_gamma_delta(
|
|
880
|
-
self,
|
|
881
|
-
gamma_hat: FloatArray,
|
|
882
|
-
delta_hat: FloatArray,
|
|
883
|
-
n_per_batch: dict[str, int] | FloatArray,
|
|
884
|
-
*,
|
|
885
|
-
parametric: bool,
|
|
886
|
-
max_iter: int = 100,
|
|
887
|
-
tol: float = 1e-4,
|
|
888
|
-
) -> tuple[FloatArray, FloatArray]:
|
|
889
|
-
"""Empirical Bayes shrinkage estimation."""
|
|
890
|
-
if parametric:
|
|
891
|
-
gamma_bar = gamma_hat.mean(axis=0)
|
|
892
|
-
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
893
|
-
a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
|
|
894
|
-
b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
|
|
895
|
-
|
|
896
|
-
B, _p = gamma_hat.shape
|
|
897
|
-
gamma_star = np.empty_like(gamma_hat)
|
|
898
|
-
delta_star = np.empty_like(delta_hat)
|
|
899
|
-
n_vec = (
|
|
900
|
-
np.array(list(n_per_batch.values()))
|
|
901
|
-
if isinstance(n_per_batch, dict)
|
|
902
|
-
else n_per_batch
|
|
903
|
-
)
|
|
904
|
-
|
|
905
|
-
for i in range(B):
|
|
906
|
-
n_i = n_vec[i]
|
|
907
|
-
g, d = gamma_hat[i], delta_hat[i]
|
|
908
|
-
gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
|
|
909
|
-
gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
|
|
910
|
-
|
|
911
|
-
a_post = a_prior + n_i / 2.0
|
|
912
|
-
b_post = b_prior + 0.5 * n_i * d
|
|
913
|
-
delta_star[i] = b_post / (a_post - 1)
|
|
914
|
-
return gamma_star, delta_star
|
|
915
|
-
|
|
916
|
-
else:
|
|
917
|
-
B, _p = gamma_hat.shape
|
|
918
|
-
n_vec = (
|
|
919
|
-
np.array(list(n_per_batch.values()))
|
|
920
|
-
if isinstance(n_per_batch, dict)
|
|
921
|
-
else n_per_batch
|
|
922
|
-
)
|
|
923
|
-
gamma_bar = gamma_hat.mean(axis=0)
|
|
924
|
-
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
925
|
-
|
|
926
|
-
def postmean(
|
|
927
|
-
g_hat: FloatArray,
|
|
928
|
-
g_bar: FloatArray,
|
|
929
|
-
n: float,
|
|
930
|
-
d_star: FloatArray,
|
|
931
|
-
t2_: FloatArray,
|
|
932
|
-
) -> FloatArray:
|
|
933
|
-
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
934
|
-
|
|
935
|
-
def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
|
|
936
|
-
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
937
|
-
|
|
938
|
-
def aprior(delta: FloatArray) -> FloatArray:
|
|
939
|
-
m, s2 = delta.mean(), delta.var()
|
|
940
|
-
s2 = max(s2, self.eps)
|
|
941
|
-
return (2 * s2 + m**2) / s2
|
|
942
|
-
|
|
943
|
-
def bprior(delta: FloatArray) -> FloatArray:
|
|
944
|
-
m, s2 = delta.mean(), delta.var()
|
|
945
|
-
s2 = max(s2, self.eps)
|
|
946
|
-
return (m * s2 + m**3) / s2
|
|
947
|
-
|
|
948
|
-
gamma_star = np.empty_like(gamma_hat)
|
|
949
|
-
delta_star = np.empty_like(delta_hat)
|
|
950
|
-
|
|
951
|
-
for i in range(B):
|
|
952
|
-
n_i = n_vec[i]
|
|
953
|
-
g_hat_i = gamma_hat[i]
|
|
954
|
-
d_hat_i = delta_hat[i]
|
|
955
|
-
a_i = aprior(d_hat_i)
|
|
956
|
-
b_i = bprior(d_hat_i)
|
|
957
|
-
|
|
958
|
-
g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
|
|
959
|
-
for _ in range(max_iter):
|
|
960
|
-
g_prev, d_prev = g_new, d_new
|
|
961
|
-
g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
|
|
962
|
-
sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
|
|
963
|
-
d_new = postvar(sum2, n_i, a_i, b_i)
|
|
964
|
-
if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
|
|
965
|
-
self.mean_only
|
|
966
|
-
or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
|
|
967
|
-
):
|
|
968
|
-
break
|
|
969
|
-
gamma_star[i] = g_new
|
|
970
|
-
delta_star[i] = 1.0 if self.mean_only else d_new
|
|
971
|
-
return gamma_star, delta_star
|
|
972
|
-
|
|
973
|
-
def _shrink_gamma(
|
|
974
|
-
self,
|
|
975
|
-
gamma_hat: FloatArray,
|
|
976
|
-
delta_hat: FloatArray,
|
|
977
|
-
n_per_batch: dict[str, int] | FloatArray,
|
|
978
|
-
*,
|
|
979
|
-
parametric: bool,
|
|
980
|
-
) -> FloatArray:
|
|
981
|
-
"""Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
|
|
982
|
-
gamma, _ = self._shrink_gamma_delta(
|
|
983
|
-
gamma_hat, delta_hat, n_per_batch, parametric=parametric
|
|
984
|
-
)
|
|
985
|
-
return gamma
|
|
986
|
-
|
|
987
|
-
def transform(
|
|
988
|
-
self,
|
|
989
|
-
X: ArrayLike,
|
|
990
|
-
*,
|
|
991
|
-
batch: ArrayLike,
|
|
992
|
-
discrete_covariates: ArrayLike | None = None,
|
|
993
|
-
continuous_covariates: ArrayLike | None = None,
|
|
994
|
-
) -> pd.DataFrame:
|
|
995
|
-
"""Transform the data using fitted ComBat parameters."""
|
|
996
|
-
if not hasattr(self, "_gamma_star"):
|
|
997
|
-
raise ValueError(
|
|
998
|
-
"This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
|
|
999
|
-
)
|
|
1000
|
-
if not isinstance(X, pd.DataFrame):
|
|
1001
|
-
X = pd.DataFrame(X)
|
|
1002
|
-
idx = X.index
|
|
1003
|
-
batch = self._as_series(batch, idx, "batch")
|
|
1004
|
-
unseen = set(batch.cat.categories) - set(self._batch_levels)
|
|
1005
|
-
if unseen:
|
|
1006
|
-
raise ValueError(f"Unseen batch levels during transform: {unseen}.")
|
|
1007
|
-
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
1008
|
-
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
1009
|
-
|
|
1010
|
-
method = self.method.lower()
|
|
1011
|
-
if method == "johnson":
|
|
1012
|
-
return self._transform_johnson(X, batch)
|
|
1013
|
-
elif method == "fortin":
|
|
1014
|
-
return self._transform_fortin(X, batch, disc, cont)
|
|
1015
|
-
elif method == "chen":
|
|
1016
|
-
return self._transform_chen(X, batch, disc, cont)
|
|
1017
|
-
else:
|
|
1018
|
-
raise ValueError(f"Unknown method: {method}.")
|
|
1019
|
-
|
|
1020
|
-
def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
|
|
1021
|
-
"""Johnson transform implementation."""
|
|
1022
|
-
pooled = self._pooled_var
|
|
1023
|
-
grand = self._grand_mean
|
|
1024
|
-
|
|
1025
|
-
Xs = (X - grand) / np.sqrt(pooled)
|
|
1026
|
-
X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
|
|
1027
|
-
|
|
1028
|
-
for i, lvl in enumerate(self._batch_levels):
|
|
1029
|
-
idx = batch == lvl
|
|
1030
|
-
if not idx.any():
|
|
1031
|
-
continue
|
|
1032
|
-
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
1033
|
-
X_adj.loc[idx] = X.loc[idx].values
|
|
1034
|
-
continue
|
|
1035
|
-
|
|
1036
|
-
g = self._gamma_star[i]
|
|
1037
|
-
d = self._delta_star[i]
|
|
1038
|
-
Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
|
|
1039
|
-
X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
|
|
1040
|
-
return X_adj
|
|
1041
|
-
|
|
1042
|
-
def _transform_fortin(
|
|
1043
|
-
self,
|
|
1044
|
-
X: pd.DataFrame,
|
|
1045
|
-
batch: pd.Series,
|
|
1046
|
-
disc: pd.DataFrame | None,
|
|
1047
|
-
cont: pd.DataFrame | None,
|
|
1048
|
-
) -> pd.DataFrame:
|
|
1049
|
-
"""Fortin transform implementation."""
|
|
1050
|
-
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
1051
|
-
if self.reference_batch is not None:
|
|
1052
|
-
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
1053
|
-
|
|
1054
|
-
parts = [batch_dummies]
|
|
1055
|
-
if disc is not None:
|
|
1056
|
-
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
1057
|
-
if cont is not None:
|
|
1058
|
-
parts.append(cont.astype(float))
|
|
1059
|
-
|
|
1060
|
-
design = pd.concat(parts, axis=1).values
|
|
1061
|
-
|
|
1062
|
-
X_np = X.values
|
|
1063
|
-
stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
|
|
1064
|
-
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
1065
|
-
|
|
1066
|
-
for i, lvl in enumerate(self._batch_levels):
|
|
1067
|
-
idx = batch == lvl
|
|
1068
|
-
if not idx.any():
|
|
1069
|
-
continue
|
|
1070
|
-
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
1071
|
-
# leave reference samples unchanged
|
|
1072
|
-
continue
|
|
1073
|
-
|
|
1074
|
-
g = self._gamma_star[i]
|
|
1075
|
-
d = self._delta_star[i]
|
|
1076
|
-
if self.mean_only:
|
|
1077
|
-
Xs[idx] = Xs[idx] - g
|
|
1078
|
-
else:
|
|
1079
|
-
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
1080
|
-
|
|
1081
|
-
X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
|
|
1082
|
-
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
1083
|
-
|
|
1084
|
-
def _transform_chen(
|
|
1085
|
-
self,
|
|
1086
|
-
X: pd.DataFrame,
|
|
1087
|
-
batch: pd.Series,
|
|
1088
|
-
disc: pd.DataFrame | None,
|
|
1089
|
-
cont: pd.DataFrame | None,
|
|
1090
|
-
) -> pd.DataFrame:
|
|
1091
|
-
"""Chen transform implementation."""
|
|
1092
|
-
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
1093
|
-
scores = self._covbat_pca.transform(X_meanvar_adj)
|
|
1094
|
-
n_pc = self._covbat_n_pc
|
|
1095
|
-
scores_adj = scores.copy()
|
|
1096
|
-
|
|
1097
|
-
for i, lvl in enumerate(self._batch_levels_pc):
|
|
1098
|
-
idx = batch == lvl
|
|
1099
|
-
if not idx.any():
|
|
1100
|
-
continue
|
|
1101
|
-
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
1102
|
-
continue
|
|
1103
|
-
g = self._pc_gamma_star[i]
|
|
1104
|
-
d = self._pc_delta_star[i]
|
|
1105
|
-
if self.mean_only:
|
|
1106
|
-
scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
|
|
1107
|
-
else:
|
|
1108
|
-
scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
|
|
1109
|
-
|
|
1110
|
-
X_recon = self._covbat_pca.inverse_transform(scores_adj)
|
|
1111
|
-
return pd.DataFrame(X_recon, index=X.index, columns=X.columns)
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
class ComBat(BaseEstimator, TransformerMixin):
|
|
1115
|
-
"""Pipeline-friendly wrapper around `ComBatModel`.
|
|
1116
|
-
|
|
1117
|
-
Stores batch (and optional covariates) passed at construction and
|
|
1118
|
-
appropriately uses them for separate `fit` and `transform`.
|
|
1119
|
-
"""
|
|
1120
|
-
|
|
1121
|
-
def __init__(
|
|
1122
|
-
self,
|
|
1123
|
-
batch: ArrayLike,
|
|
1124
|
-
*,
|
|
1125
|
-
discrete_covariates: ArrayLike | None = None,
|
|
1126
|
-
continuous_covariates: ArrayLike | None = None,
|
|
1127
|
-
method: str = "johnson",
|
|
1128
|
-
parametric: bool = True,
|
|
1129
|
-
mean_only: bool = False,
|
|
1130
|
-
reference_batch: str | None = None,
|
|
1131
|
-
eps: float = 1e-8,
|
|
1132
|
-
covbat_cov_thresh: float | int = 0.9,
|
|
1133
|
-
compute_metrics: bool = False,
|
|
1134
|
-
) -> None:
|
|
1135
|
-
self.batch = batch
|
|
1136
|
-
self.discrete_covariates = discrete_covariates
|
|
1137
|
-
self.continuous_covariates = continuous_covariates
|
|
1138
|
-
self.method = method
|
|
1139
|
-
self.parametric = parametric
|
|
1140
|
-
self.mean_only = mean_only
|
|
1141
|
-
self.reference_batch = reference_batch
|
|
1142
|
-
self.eps = eps
|
|
1143
|
-
self.covbat_cov_thresh = covbat_cov_thresh
|
|
1144
|
-
self.compute_metrics = compute_metrics
|
|
1145
|
-
self._model = ComBatModel(
|
|
1146
|
-
method=method,
|
|
1147
|
-
parametric=parametric,
|
|
1148
|
-
mean_only=mean_only,
|
|
1149
|
-
reference_batch=reference_batch,
|
|
1150
|
-
eps=eps,
|
|
1151
|
-
covbat_cov_thresh=covbat_cov_thresh,
|
|
1152
|
-
)
|
|
1153
|
-
|
|
1154
|
-
def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
|
|
1155
|
-
"""Fit the ComBat model."""
|
|
1156
|
-
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
1157
|
-
batch_vec = self._subset(self.batch, idx)
|
|
1158
|
-
disc = self._subset(self.discrete_covariates, idx)
|
|
1159
|
-
cont = self._subset(self.continuous_covariates, idx)
|
|
1160
|
-
self._model.fit(
|
|
1161
|
-
X,
|
|
1162
|
-
batch=batch_vec,
|
|
1163
|
-
discrete_covariates=disc,
|
|
1164
|
-
continuous_covariates=cont,
|
|
1165
|
-
)
|
|
1166
|
-
self._fitted_batch = batch_vec
|
|
1167
|
-
return self
|
|
1168
|
-
|
|
1169
|
-
def transform(self, X: ArrayLike) -> pd.DataFrame:
|
|
1170
|
-
"""Transform the data using fitted ComBat parameters."""
|
|
1171
|
-
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
1172
|
-
batch_vec = self._subset(self.batch, idx)
|
|
1173
|
-
disc = self._subset(self.discrete_covariates, idx)
|
|
1174
|
-
cont = self._subset(self.continuous_covariates, idx)
|
|
1175
|
-
return self._model.transform(
|
|
1176
|
-
X,
|
|
1177
|
-
batch=batch_vec,
|
|
1178
|
-
discrete_covariates=disc,
|
|
1179
|
-
continuous_covariates=cont,
|
|
1180
|
-
)
|
|
1181
|
-
|
|
1182
|
-
@staticmethod
|
|
1183
|
-
def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
|
|
1184
|
-
"""Subset array-like object by index."""
|
|
1185
|
-
if obj is None:
|
|
1186
|
-
return None
|
|
1187
|
-
if isinstance(obj, (pd.Series, pd.DataFrame)):
|
|
1188
|
-
return obj.loc[idx]
|
|
1189
|
-
else:
|
|
1190
|
-
if isinstance(obj, np.ndarray) and obj.ndim == 1:
|
|
1191
|
-
return pd.Series(obj, index=idx)
|
|
1192
|
-
else:
|
|
1193
|
-
return pd.DataFrame(obj, index=idx)
|
|
1194
|
-
|
|
1195
|
-
@property
|
|
1196
|
-
def metrics_(self) -> dict[str, Any] | None:
|
|
1197
|
-
"""Return cached metrics from last fit_transform with compute_metrics=True.
|
|
1198
|
-
|
|
1199
|
-
Returns
|
|
1200
|
-
-------
|
|
1201
|
-
dict or None
|
|
1202
|
-
Cached metrics dictionary, or None if no metrics have been computed.
|
|
1203
|
-
"""
|
|
1204
|
-
return getattr(self, "_metrics_cache", None)
|
|
1205
|
-
|
|
1206
|
-
def compute_batch_metrics(
|
|
1207
|
-
self,
|
|
1208
|
-
X: ArrayLike,
|
|
1209
|
-
batch: ArrayLike | None = None,
|
|
1210
|
-
*,
|
|
1211
|
-
pca_components: int | None = None,
|
|
1212
|
-
k_neighbors: list[int] | None = None,
|
|
1213
|
-
kbet_k0: int | None = None,
|
|
1214
|
-
lisi_perplexity: int = 30,
|
|
1215
|
-
n_jobs: int = 1,
|
|
1216
|
-
) -> dict[str, Any]:
|
|
1217
|
-
"""
|
|
1218
|
-
Compute batch effect metrics before and after ComBat correction.
|
|
1219
|
-
|
|
1220
|
-
Parameters
|
|
1221
|
-
----------
|
|
1222
|
-
X : array-like of shape (n_samples, n_features)
|
|
1223
|
-
Input data to evaluate.
|
|
1224
|
-
batch : array-like of shape (n_samples,), optional
|
|
1225
|
-
Batch labels. If None, uses the batch stored at construction.
|
|
1226
|
-
pca_components : int, optional
|
|
1227
|
-
Number of PCA components for dimensionality reduction before
|
|
1228
|
-
computing metrics. If None (default), metrics are computed in
|
|
1229
|
-
the original feature space. Must be less than min(n_samples, n_features).
|
|
1230
|
-
k_neighbors : list of int, default=[5, 10, 50]
|
|
1231
|
-
Values of k for k-NN preservation metric.
|
|
1232
|
-
kbet_k0 : int, optional
|
|
1233
|
-
Neighborhood size for kBET. Default is 10% of samples.
|
|
1234
|
-
lisi_perplexity : int, default=30
|
|
1235
|
-
Perplexity for LISI computation.
|
|
1236
|
-
n_jobs : int, default=1
|
|
1237
|
-
Number of parallel jobs for neighbor computations.
|
|
1238
|
-
|
|
1239
|
-
Returns
|
|
1240
|
-
-------
|
|
1241
|
-
dict
|
|
1242
|
-
Dictionary with three main keys:
|
|
1243
|
-
|
|
1244
|
-
- ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
|
|
1245
|
-
(each with 'before' and 'after' values)
|
|
1246
|
-
- ``preservation``: k-NN preservation fractions, distance correlation
|
|
1247
|
-
- ``alignment``: Centroid distance, Levene statistic (each with
|
|
1248
|
-
'before' and 'after' values)
|
|
1249
|
-
|
|
1250
|
-
Raises
|
|
1251
|
-
------
|
|
1252
|
-
ValueError
|
|
1253
|
-
If the model is not fitted or if pca_components is invalid.
|
|
1254
|
-
"""
|
|
1255
|
-
if not hasattr(self._model, "_gamma_star"):
|
|
1256
|
-
raise ValueError(
|
|
1257
|
-
"This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
|
|
1258
|
-
)
|
|
1259
|
-
|
|
1260
|
-
if not isinstance(X, pd.DataFrame):
|
|
1261
|
-
X = pd.DataFrame(X)
|
|
1262
|
-
|
|
1263
|
-
idx = X.index
|
|
1264
|
-
|
|
1265
|
-
if batch is None:
|
|
1266
|
-
batch_vec = self._subset(self.batch, idx)
|
|
1267
|
-
else:
|
|
1268
|
-
if isinstance(batch, (pd.Series, pd.DataFrame)):
|
|
1269
|
-
batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
|
|
1270
|
-
elif isinstance(batch, np.ndarray):
|
|
1271
|
-
batch_vec = pd.Series(batch, index=idx)
|
|
1272
|
-
else:
|
|
1273
|
-
batch_vec = pd.Series(batch, index=idx)
|
|
1274
|
-
|
|
1275
|
-
batch_labels = np.array(batch_vec)
|
|
1276
|
-
|
|
1277
|
-
X_before = X.values
|
|
1278
|
-
X_after = self.transform(X).values
|
|
1279
|
-
|
|
1280
|
-
n_samples, n_features = X_before.shape
|
|
1281
|
-
if kbet_k0 is None:
|
|
1282
|
-
kbet_k0 = max(10, int(0.10 * n_samples))
|
|
1283
|
-
if k_neighbors is None:
|
|
1284
|
-
k_neighbors = [5, 10, 50]
|
|
1285
|
-
|
|
1286
|
-
# Validate and apply PCA if requested
|
|
1287
|
-
if pca_components is not None:
|
|
1288
|
-
max_components = min(n_samples, n_features)
|
|
1289
|
-
if pca_components >= max_components:
|
|
1290
|
-
raise ValueError(
|
|
1291
|
-
f"pca_components={pca_components} must be less than "
|
|
1292
|
-
f"min(n_samples, n_features)={max_components}."
|
|
1293
|
-
)
|
|
1294
|
-
X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
|
|
1295
|
-
else:
|
|
1296
|
-
X_before_pca = X_before
|
|
1297
|
-
X_after_pca = X_after
|
|
1298
|
-
|
|
1299
|
-
silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
|
|
1300
|
-
silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
|
|
1301
|
-
|
|
1302
|
-
db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
|
|
1303
|
-
db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
|
|
1304
|
-
|
|
1305
|
-
kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
|
|
1306
|
-
kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
|
|
1307
|
-
|
|
1308
|
-
lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
|
|
1309
|
-
lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
|
|
1310
|
-
|
|
1311
|
-
var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
|
|
1312
|
-
var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
|
|
1313
|
-
|
|
1314
|
-
knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
|
|
1315
|
-
dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
|
|
1316
|
-
|
|
1317
|
-
centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
|
|
1318
|
-
centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
|
|
1319
|
-
|
|
1320
|
-
levene_before = _levene_median_statistic(X_before, batch_labels)
|
|
1321
|
-
levene_after = _levene_median_statistic(X_after, batch_labels)
|
|
1322
|
-
|
|
1323
|
-
n_batches = len(np.unique(batch_labels))
|
|
1324
|
-
|
|
1325
|
-
metrics = {
|
|
1326
|
-
"batch_effect": {
|
|
1327
|
-
"silhouette": {
|
|
1328
|
-
"before": silhouette_before,
|
|
1329
|
-
"after": silhouette_after,
|
|
1330
|
-
},
|
|
1331
|
-
"davies_bouldin": {
|
|
1332
|
-
"before": db_before,
|
|
1333
|
-
"after": db_after,
|
|
1334
|
-
},
|
|
1335
|
-
"kbet": {
|
|
1336
|
-
"before": kbet_before,
|
|
1337
|
-
"after": kbet_after,
|
|
1338
|
-
},
|
|
1339
|
-
"lisi": {
|
|
1340
|
-
"before": lisi_before,
|
|
1341
|
-
"after": lisi_after,
|
|
1342
|
-
"max_value": n_batches,
|
|
1343
|
-
},
|
|
1344
|
-
"variance_ratio": {
|
|
1345
|
-
"before": var_ratio_before,
|
|
1346
|
-
"after": var_ratio_after,
|
|
1347
|
-
},
|
|
1348
|
-
},
|
|
1349
|
-
"preservation": {
|
|
1350
|
-
"knn": knn_results,
|
|
1351
|
-
"distance_correlation": dist_corr,
|
|
1352
|
-
},
|
|
1353
|
-
"alignment": {
|
|
1354
|
-
"centroid_distance": {
|
|
1355
|
-
"before": centroid_before,
|
|
1356
|
-
"after": centroid_after,
|
|
1357
|
-
},
|
|
1358
|
-
"levene_statistic": {
|
|
1359
|
-
"before": levene_before,
|
|
1360
|
-
"after": levene_after,
|
|
1361
|
-
},
|
|
1362
|
-
},
|
|
1363
|
-
}
|
|
1364
|
-
|
|
1365
|
-
return metrics
|
|
1366
|
-
|
|
1367
|
-
def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
|
|
1368
|
-
"""
|
|
1369
|
-
Fit and transform the data, optionally computing metrics.
|
|
1370
|
-
|
|
1371
|
-
If ``compute_metrics=True`` was set at construction, batch effect
|
|
1372
|
-
metrics are computed and cached in the ``metrics_`` property.
|
|
1373
|
-
|
|
1374
|
-
Parameters
|
|
1375
|
-
----------
|
|
1376
|
-
X : array-like of shape (n_samples, n_features)
|
|
1377
|
-
Input data to fit and transform.
|
|
1378
|
-
y : None
|
|
1379
|
-
Ignored. Present for API compatibility.
|
|
1380
|
-
|
|
1381
|
-
Returns
|
|
1382
|
-
-------
|
|
1383
|
-
X_transformed : pd.DataFrame
|
|
1384
|
-
Batch-corrected data.
|
|
1385
|
-
"""
|
|
1386
|
-
self.fit(X, y)
|
|
1387
|
-
X_transformed = self.transform(X)
|
|
1388
|
-
|
|
1389
|
-
if self.compute_metrics:
|
|
1390
|
-
self._metrics_cache = self.compute_batch_metrics(X)
|
|
1391
|
-
|
|
1392
|
-
return X_transformed
|
|
1393
|
-
|
|
1394
|
-
def plot_transformation(
|
|
1395
|
-
self,
|
|
1396
|
-
X: ArrayLike,
|
|
1397
|
-
*,
|
|
1398
|
-
reduction_method: Literal["pca", "tsne", "umap"] = "pca",
|
|
1399
|
-
n_components: Literal[2, 3] = 2,
|
|
1400
|
-
plot_type: Literal["static", "interactive"] = "static",
|
|
1401
|
-
figsize: tuple[int, int] = (12, 5),
|
|
1402
|
-
alpha: float = 0.7,
|
|
1403
|
-
point_size: int = 50,
|
|
1404
|
-
cmap: str = "Set1",
|
|
1405
|
-
title: str | None = None,
|
|
1406
|
-
show_legend: bool = True,
|
|
1407
|
-
return_embeddings: bool = False,
|
|
1408
|
-
**reduction_kwargs,
|
|
1409
|
-
) -> Any | tuple[Any, dict[str, FloatArray]]:
|
|
1410
|
-
"""
|
|
1411
|
-
Visualize the ComBat transformation effect using dimensionality reduction.
|
|
1412
|
-
|
|
1413
|
-
It shows a before/after comparison of data transformed by `ComBat` using
|
|
1414
|
-
PCA, t-SNE, or UMAP to reduce dimensions for visualization.
|
|
1415
|
-
|
|
1416
|
-
Parameters
|
|
1417
|
-
----------
|
|
1418
|
-
X : array-like of shape (n_samples, n_features)
|
|
1419
|
-
Input data to transform and visualize.
|
|
1420
|
-
|
|
1421
|
-
reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
|
|
1422
|
-
Dimensionality reduction method.
|
|
1423
|
-
|
|
1424
|
-
n_components : {2, 3}, default=2
|
|
1425
|
-
Number of components for dimensionality reduction.
|
|
1426
|
-
|
|
1427
|
-
plot_type : {`'static'`, `'interactive'`}, default=`'static'`
|
|
1428
|
-
Visualization type:
|
|
1429
|
-
- `'static'`: matplotlib plots (can be saved as images)
|
|
1430
|
-
- `'interactive'`: plotly plots (explorable, requires plotly)
|
|
1431
|
-
|
|
1432
|
-
return_embeddings : bool, default=False
|
|
1433
|
-
If `True`, return embeddings along with the plot.
|
|
1434
|
-
|
|
1435
|
-
**reduction_kwargs : dict
|
|
1436
|
-
Additional parameters for reduction methods.
|
|
1437
|
-
|
|
1438
|
-
Returns
|
|
1439
|
-
-------
|
|
1440
|
-
fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
|
|
1441
|
-
The figure object containing the plots.
|
|
1442
|
-
|
|
1443
|
-
embeddings : dict, optional
|
|
1444
|
-
If `return_embeddings=True`, dictionary with:
|
|
1445
|
-
- `'original'`: embedding of original data
|
|
1446
|
-
- `'transformed'`: embedding of ComBat-transformed data
|
|
1447
|
-
"""
|
|
1448
|
-
if not hasattr(self._model, "_gamma_star"):
|
|
1449
|
-
raise ValueError(
|
|
1450
|
-
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
|
|
1451
|
-
)
|
|
1452
|
-
|
|
1453
|
-
if n_components not in [2, 3]:
|
|
1454
|
-
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
1455
|
-
if reduction_method not in ["pca", "tsne", "umap"]:
|
|
1456
|
-
raise ValueError(
|
|
1457
|
-
f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
|
|
1458
|
-
)
|
|
1459
|
-
if plot_type not in ["static", "interactive"]:
|
|
1460
|
-
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
1461
|
-
|
|
1462
|
-
if not isinstance(X, pd.DataFrame):
|
|
1463
|
-
X = pd.DataFrame(X)
|
|
1464
|
-
|
|
1465
|
-
idx = X.index
|
|
1466
|
-
batch_vec = self._subset(self.batch, idx)
|
|
1467
|
-
if batch_vec is None:
|
|
1468
|
-
raise ValueError("Batch information is required for visualization")
|
|
1469
|
-
|
|
1470
|
-
X_transformed = self.transform(X)
|
|
1471
|
-
|
|
1472
|
-
X_np = X.values
|
|
1473
|
-
X_trans_np = X_transformed.values
|
|
1474
|
-
|
|
1475
|
-
if reduction_method == "pca":
|
|
1476
|
-
reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
|
|
1477
|
-
reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
|
|
1478
|
-
elif reduction_method == "tsne":
|
|
1479
|
-
tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
|
|
1480
|
-
tsne_params.update(reduction_kwargs)
|
|
1481
|
-
reducer_orig = TSNE(n_components=n_components, **tsne_params)
|
|
1482
|
-
reducer_trans = TSNE(n_components=n_components, **tsne_params)
|
|
1483
|
-
else:
|
|
1484
|
-
umap_params = {"random_state": 42}
|
|
1485
|
-
umap_params.update(reduction_kwargs)
|
|
1486
|
-
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
1487
|
-
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
1488
|
-
|
|
1489
|
-
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
1490
|
-
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
1491
|
-
|
|
1492
|
-
if plot_type == "static":
|
|
1493
|
-
fig = self._create_static_plot(
|
|
1494
|
-
X_embedded_orig,
|
|
1495
|
-
X_embedded_trans,
|
|
1496
|
-
batch_vec,
|
|
1497
|
-
reduction_method,
|
|
1498
|
-
n_components,
|
|
1499
|
-
figsize,
|
|
1500
|
-
alpha,
|
|
1501
|
-
point_size,
|
|
1502
|
-
cmap,
|
|
1503
|
-
title,
|
|
1504
|
-
show_legend,
|
|
1505
|
-
)
|
|
1506
|
-
else:
|
|
1507
|
-
fig = self._create_interactive_plot(
|
|
1508
|
-
X_embedded_orig,
|
|
1509
|
-
X_embedded_trans,
|
|
1510
|
-
batch_vec,
|
|
1511
|
-
reduction_method,
|
|
1512
|
-
n_components,
|
|
1513
|
-
cmap,
|
|
1514
|
-
title,
|
|
1515
|
-
show_legend,
|
|
1516
|
-
)
|
|
1517
|
-
|
|
1518
|
-
if return_embeddings:
|
|
1519
|
-
embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
|
|
1520
|
-
return fig, embeddings
|
|
1521
|
-
else:
|
|
1522
|
-
return fig
|
|
1523
|
-
|
|
1524
|
-
def _create_static_plot(
|
|
1525
|
-
self,
|
|
1526
|
-
X_orig: FloatArray,
|
|
1527
|
-
X_trans: FloatArray,
|
|
1528
|
-
batch_labels: pd.Series,
|
|
1529
|
-
method: str,
|
|
1530
|
-
n_components: int,
|
|
1531
|
-
figsize: tuple[int, int],
|
|
1532
|
-
alpha: float,
|
|
1533
|
-
point_size: int,
|
|
1534
|
-
cmap: str,
|
|
1535
|
-
title: str | None,
|
|
1536
|
-
show_legend: bool,
|
|
1537
|
-
) -> Any:
|
|
1538
|
-
"""Create static plots using matplotlib."""
|
|
1539
|
-
|
|
1540
|
-
fig = plt.figure(figsize=figsize)
|
|
1541
|
-
|
|
1542
|
-
unique_batches = batch_labels.drop_duplicates()
|
|
1543
|
-
n_batches = len(unique_batches)
|
|
1544
|
-
|
|
1545
|
-
if n_batches <= 10:
|
|
1546
|
-
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
1547
|
-
else:
|
|
1548
|
-
colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
|
|
1549
|
-
|
|
1550
|
-
if n_components == 2:
|
|
1551
|
-
ax1 = plt.subplot(1, 2, 1)
|
|
1552
|
-
ax2 = plt.subplot(1, 2, 2)
|
|
1553
|
-
else:
|
|
1554
|
-
ax1 = fig.add_subplot(121, projection="3d")
|
|
1555
|
-
ax2 = fig.add_subplot(122, projection="3d")
|
|
1556
|
-
|
|
1557
|
-
for i, batch in enumerate(unique_batches):
|
|
1558
|
-
mask = batch_labels == batch
|
|
1559
|
-
if n_components == 2:
|
|
1560
|
-
ax1.scatter(
|
|
1561
|
-
X_orig[mask, 0],
|
|
1562
|
-
X_orig[mask, 1],
|
|
1563
|
-
c=[colors[i]],
|
|
1564
|
-
s=point_size,
|
|
1565
|
-
alpha=alpha,
|
|
1566
|
-
label=f"Batch {batch}",
|
|
1567
|
-
edgecolors="black",
|
|
1568
|
-
linewidth=0.5,
|
|
1569
|
-
)
|
|
1570
|
-
else:
|
|
1571
|
-
ax1.scatter(
|
|
1572
|
-
X_orig[mask, 0],
|
|
1573
|
-
X_orig[mask, 1],
|
|
1574
|
-
X_orig[mask, 2],
|
|
1575
|
-
c=[colors[i]],
|
|
1576
|
-
s=point_size,
|
|
1577
|
-
alpha=alpha,
|
|
1578
|
-
label=f"Batch {batch}",
|
|
1579
|
-
edgecolors="black",
|
|
1580
|
-
linewidth=0.5,
|
|
1581
|
-
)
|
|
1582
|
-
|
|
1583
|
-
ax1.set_title(f"Before ComBat correction\n({method.upper()})")
|
|
1584
|
-
ax1.set_xlabel(f"{method.upper()}1")
|
|
1585
|
-
ax1.set_ylabel(f"{method.upper()}2")
|
|
1586
|
-
if n_components == 3:
|
|
1587
|
-
ax1.set_zlabel(f"{method.upper()}3")
|
|
1588
|
-
|
|
1589
|
-
for i, batch in enumerate(unique_batches):
|
|
1590
|
-
mask = batch_labels == batch
|
|
1591
|
-
if n_components == 2:
|
|
1592
|
-
ax2.scatter(
|
|
1593
|
-
X_trans[mask, 0],
|
|
1594
|
-
X_trans[mask, 1],
|
|
1595
|
-
c=[colors[i]],
|
|
1596
|
-
s=point_size,
|
|
1597
|
-
alpha=alpha,
|
|
1598
|
-
label=f"Batch {batch}",
|
|
1599
|
-
edgecolors="black",
|
|
1600
|
-
linewidth=0.5,
|
|
1601
|
-
)
|
|
1602
|
-
else:
|
|
1603
|
-
ax2.scatter(
|
|
1604
|
-
X_trans[mask, 0],
|
|
1605
|
-
X_trans[mask, 1],
|
|
1606
|
-
X_trans[mask, 2],
|
|
1607
|
-
c=[colors[i]],
|
|
1608
|
-
s=point_size,
|
|
1609
|
-
alpha=alpha,
|
|
1610
|
-
label=f"Batch {batch}",
|
|
1611
|
-
edgecolors="black",
|
|
1612
|
-
linewidth=0.5,
|
|
1613
|
-
)
|
|
1614
|
-
|
|
1615
|
-
ax2.set_title(f"After ComBat correction\n({method.upper()})")
|
|
1616
|
-
ax2.set_xlabel(f"{method.upper()}1")
|
|
1617
|
-
ax2.set_ylabel(f"{method.upper()}2")
|
|
1618
|
-
if n_components == 3:
|
|
1619
|
-
ax2.set_zlabel(f"{method.upper()}3")
|
|
1620
|
-
|
|
1621
|
-
if show_legend and n_batches <= 20:
|
|
1622
|
-
ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
1623
|
-
|
|
1624
|
-
if title is None:
|
|
1625
|
-
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1626
|
-
fig.suptitle(title, fontsize=14, fontweight="bold")
|
|
1627
|
-
|
|
1628
|
-
plt.tight_layout()
|
|
1629
|
-
return fig
|
|
1630
|
-
|
|
1631
|
-
def _create_interactive_plot(
|
|
1632
|
-
self,
|
|
1633
|
-
X_orig: FloatArray,
|
|
1634
|
-
X_trans: FloatArray,
|
|
1635
|
-
batch_labels: pd.Series,
|
|
1636
|
-
method: str,
|
|
1637
|
-
n_components: int,
|
|
1638
|
-
cmap: str,
|
|
1639
|
-
title: str | None,
|
|
1640
|
-
show_legend: bool,
|
|
1641
|
-
) -> Any:
|
|
1642
|
-
"""Create interactive plots using plotly."""
|
|
1643
|
-
if n_components == 2:
|
|
1644
|
-
fig = make_subplots(
|
|
1645
|
-
rows=1,
|
|
1646
|
-
cols=2,
|
|
1647
|
-
subplot_titles=(
|
|
1648
|
-
f"Before ComBat correction ({method.upper()})",
|
|
1649
|
-
f"After ComBat correction ({method.upper()})",
|
|
1650
|
-
),
|
|
1651
|
-
)
|
|
1652
|
-
else:
|
|
1653
|
-
fig = make_subplots(
|
|
1654
|
-
rows=1,
|
|
1655
|
-
cols=2,
|
|
1656
|
-
specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
|
|
1657
|
-
subplot_titles=(
|
|
1658
|
-
f"Before ComBat correction ({method.upper()})",
|
|
1659
|
-
f"After ComBat correction ({method.upper()})",
|
|
1660
|
-
),
|
|
1661
|
-
)
|
|
1662
|
-
|
|
1663
|
-
unique_batches = batch_labels.drop_duplicates()
|
|
1664
|
-
|
|
1665
|
-
n_batches = len(unique_batches)
|
|
1666
|
-
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
1667
|
-
color_list = [
|
|
1668
|
-
mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
|
|
1669
|
-
]
|
|
1670
|
-
|
|
1671
|
-
batch_to_color = dict(zip(unique_batches, color_list, strict=True))
|
|
1672
|
-
|
|
1673
|
-
for batch in unique_batches:
|
|
1674
|
-
mask = batch_labels == batch
|
|
1675
|
-
|
|
1676
|
-
if n_components == 2:
|
|
1677
|
-
fig.add_trace(
|
|
1678
|
-
go.Scatter(
|
|
1679
|
-
x=X_orig[mask, 0],
|
|
1680
|
-
y=X_orig[mask, 1],
|
|
1681
|
-
mode="markers",
|
|
1682
|
-
name=f"Batch {batch}",
|
|
1683
|
-
marker={
|
|
1684
|
-
"size": 8,
|
|
1685
|
-
"color": batch_to_color[batch],
|
|
1686
|
-
"line": {"width": 1, "color": "black"},
|
|
1687
|
-
},
|
|
1688
|
-
showlegend=False,
|
|
1689
|
-
),
|
|
1690
|
-
row=1,
|
|
1691
|
-
col=1,
|
|
1692
|
-
)
|
|
1693
|
-
|
|
1694
|
-
fig.add_trace(
|
|
1695
|
-
go.Scatter(
|
|
1696
|
-
x=X_trans[mask, 0],
|
|
1697
|
-
y=X_trans[mask, 1],
|
|
1698
|
-
mode="markers",
|
|
1699
|
-
name=f"Batch {batch}",
|
|
1700
|
-
marker={
|
|
1701
|
-
"size": 8,
|
|
1702
|
-
"color": batch_to_color[batch],
|
|
1703
|
-
"line": {"width": 1, "color": "black"},
|
|
1704
|
-
},
|
|
1705
|
-
showlegend=show_legend,
|
|
1706
|
-
),
|
|
1707
|
-
row=1,
|
|
1708
|
-
col=2,
|
|
1709
|
-
)
|
|
1710
|
-
else:
|
|
1711
|
-
fig.add_trace(
|
|
1712
|
-
go.Scatter3d(
|
|
1713
|
-
x=X_orig[mask, 0],
|
|
1714
|
-
y=X_orig[mask, 1],
|
|
1715
|
-
z=X_orig[mask, 2],
|
|
1716
|
-
mode="markers",
|
|
1717
|
-
name=f"Batch {batch}",
|
|
1718
|
-
marker={
|
|
1719
|
-
"size": 5,
|
|
1720
|
-
"color": batch_to_color[batch],
|
|
1721
|
-
"line": {"width": 0.5, "color": "black"},
|
|
1722
|
-
},
|
|
1723
|
-
showlegend=False,
|
|
1724
|
-
),
|
|
1725
|
-
row=1,
|
|
1726
|
-
col=1,
|
|
1727
|
-
)
|
|
1728
|
-
|
|
1729
|
-
fig.add_trace(
|
|
1730
|
-
go.Scatter3d(
|
|
1731
|
-
x=X_trans[mask, 0],
|
|
1732
|
-
y=X_trans[mask, 1],
|
|
1733
|
-
z=X_trans[mask, 2],
|
|
1734
|
-
mode="markers",
|
|
1735
|
-
name=f"Batch {batch}",
|
|
1736
|
-
marker={
|
|
1737
|
-
"size": 5,
|
|
1738
|
-
"color": batch_to_color[batch],
|
|
1739
|
-
"line": {"width": 0.5, "color": "black"},
|
|
1740
|
-
},
|
|
1741
|
-
showlegend=show_legend,
|
|
1742
|
-
),
|
|
1743
|
-
row=1,
|
|
1744
|
-
col=2,
|
|
1745
|
-
)
|
|
1746
|
-
|
|
1747
|
-
if title is None:
|
|
1748
|
-
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
1749
|
-
|
|
1750
|
-
fig.update_layout(
|
|
1751
|
-
title=title,
|
|
1752
|
-
title_font_size=16,
|
|
1753
|
-
height=600,
|
|
1754
|
-
showlegend=show_legend,
|
|
1755
|
-
hovermode="closest",
|
|
1756
|
-
)
|
|
1757
|
-
|
|
1758
|
-
axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
|
|
1759
|
-
|
|
1760
|
-
if n_components == 2:
|
|
1761
|
-
fig.update_xaxes(title_text=axis_labels[0])
|
|
1762
|
-
fig.update_yaxes(title_text=axis_labels[1])
|
|
1763
|
-
else:
|
|
1764
|
-
fig.update_scenes(
|
|
1765
|
-
xaxis_title=axis_labels[0],
|
|
1766
|
-
yaxis_title=axis_labels[1],
|
|
1767
|
-
zaxis_title=axis_labels[2],
|
|
1768
|
-
)
|
|
1769
|
-
|
|
1770
|
-
return fig
|