combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/core.py +578 -0
- combatlearn/metrics.py +788 -0
- combatlearn/sklearn_api.py +143 -0
- combatlearn/visualization.py +533 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/METADATA +24 -14
- combatlearn-1.2.0.dist-info/RECORD +10 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/WHEEL +1 -1
- combatlearn/combat.py +0 -1770
- combatlearn-1.1.2.dist-info/RECORD +0 -7
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/top_level.txt +0 -0
combatlearn/metrics.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
"""Batch effect metrics and diagnostics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from scipy.spatial.distance import pdist
|
|
10
|
+
from scipy.stats import chi2, levene, spearmanr
|
|
11
|
+
from sklearn.decomposition import PCA
|
|
12
|
+
from sklearn.metrics import davies_bouldin_score, silhouette_score
|
|
13
|
+
from sklearn.neighbors import NearestNeighbors
|
|
14
|
+
|
|
15
|
+
from .core import ArrayLike
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _compute_pca_embedding(
|
|
19
|
+
X_before: np.ndarray,
|
|
20
|
+
X_after: np.ndarray,
|
|
21
|
+
n_components: int,
|
|
22
|
+
) -> tuple[np.ndarray, np.ndarray, PCA]:
|
|
23
|
+
"""
|
|
24
|
+
Compute PCA embeddings for both datasets.
|
|
25
|
+
|
|
26
|
+
Fits PCA on X_before and applies to both datasets.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
X_before : np.ndarray
|
|
31
|
+
Original data before correction.
|
|
32
|
+
X_after : np.ndarray
|
|
33
|
+
Corrected data.
|
|
34
|
+
n_components : int
|
|
35
|
+
Number of PCA components.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
X_before_pca : np.ndarray
|
|
40
|
+
PCA-transformed original data.
|
|
41
|
+
X_after_pca : np.ndarray
|
|
42
|
+
PCA-transformed corrected data.
|
|
43
|
+
pca : PCA
|
|
44
|
+
Fitted PCA model.
|
|
45
|
+
"""
|
|
46
|
+
n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
|
|
47
|
+
pca = PCA(n_components=n_components, random_state=42)
|
|
48
|
+
X_before_pca = pca.fit_transform(X_before)
|
|
49
|
+
X_after_pca = pca.transform(X_after)
|
|
50
|
+
return X_before_pca, X_after_pca, pca
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
54
|
+
"""
|
|
55
|
+
Compute silhouette coefficient using batch as cluster labels.
|
|
56
|
+
|
|
57
|
+
Lower values after correction indicate better batch mixing.
|
|
58
|
+
Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
X : np.ndarray
|
|
63
|
+
Data matrix.
|
|
64
|
+
batch_labels : np.ndarray
|
|
65
|
+
Batch labels for each sample.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
float
|
|
70
|
+
Silhouette coefficient.
|
|
71
|
+
"""
|
|
72
|
+
unique_batches = np.unique(batch_labels)
|
|
73
|
+
if len(unique_batches) < 2:
|
|
74
|
+
return 0.0
|
|
75
|
+
try:
|
|
76
|
+
return silhouette_score(X, batch_labels, metric="euclidean")
|
|
77
|
+
except Exception:
|
|
78
|
+
return 0.0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
82
|
+
"""
|
|
83
|
+
Compute Davies-Bouldin index using batch labels.
|
|
84
|
+
|
|
85
|
+
Lower values indicate better batch mixing.
|
|
86
|
+
Range: [0, inf), 0 = perfect batch overlap.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
X : np.ndarray
|
|
91
|
+
Data matrix.
|
|
92
|
+
batch_labels : np.ndarray
|
|
93
|
+
Batch labels for each sample.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
float
|
|
98
|
+
Davies-Bouldin index.
|
|
99
|
+
"""
|
|
100
|
+
unique_batches = np.unique(batch_labels)
|
|
101
|
+
if len(unique_batches) < 2:
|
|
102
|
+
return 0.0
|
|
103
|
+
try:
|
|
104
|
+
return davies_bouldin_score(X, batch_labels)
|
|
105
|
+
except Exception:
|
|
106
|
+
return 0.0
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _kbet_score(
|
|
110
|
+
X: np.ndarray,
|
|
111
|
+
batch_labels: np.ndarray,
|
|
112
|
+
k0: int,
|
|
113
|
+
alpha: float = 0.05,
|
|
114
|
+
) -> tuple[float, float]:
|
|
115
|
+
"""
|
|
116
|
+
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
|
|
117
|
+
|
|
118
|
+
Tests if local batch proportions match global batch proportions.
|
|
119
|
+
Higher acceptance rate = better batch mixing.
|
|
120
|
+
|
|
121
|
+
Reference: Buttner et al. (2019) Nature Methods
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
X : np.ndarray
|
|
126
|
+
Data matrix.
|
|
127
|
+
batch_labels : np.ndarray
|
|
128
|
+
Batch labels for each sample.
|
|
129
|
+
k0 : int
|
|
130
|
+
Neighborhood size.
|
|
131
|
+
alpha : float
|
|
132
|
+
Significance level for chi-squared test.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
acceptance_rate : float
|
|
137
|
+
Fraction of samples where H0 (uniform mixing) is accepted.
|
|
138
|
+
mean_stat : float
|
|
139
|
+
Mean chi-squared statistic across samples.
|
|
140
|
+
"""
|
|
141
|
+
n_samples = X.shape[0]
|
|
142
|
+
unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
|
|
143
|
+
n_batches = len(unique_batches)
|
|
144
|
+
|
|
145
|
+
if n_batches < 2:
|
|
146
|
+
return 1.0, 0.0
|
|
147
|
+
|
|
148
|
+
global_freq = batch_counts / n_samples
|
|
149
|
+
k0 = min(k0, n_samples - 1)
|
|
150
|
+
|
|
151
|
+
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm="auto")
|
|
152
|
+
nn.fit(X)
|
|
153
|
+
_, indices = nn.kneighbors(X)
|
|
154
|
+
|
|
155
|
+
chi2_stats = []
|
|
156
|
+
p_values = []
|
|
157
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
158
|
+
|
|
159
|
+
for i in range(n_samples):
|
|
160
|
+
neighbors = indices[i, 1 : k0 + 1]
|
|
161
|
+
neighbor_batches = batch_labels[neighbors]
|
|
162
|
+
|
|
163
|
+
observed = np.zeros(n_batches)
|
|
164
|
+
for nb in neighbor_batches:
|
|
165
|
+
observed[batch_to_idx[nb]] += 1
|
|
166
|
+
|
|
167
|
+
expected = global_freq * k0
|
|
168
|
+
|
|
169
|
+
mask = expected > 0
|
|
170
|
+
if mask.sum() < 2:
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
|
|
174
|
+
df = max(1, mask.sum() - 1)
|
|
175
|
+
p_val = 1 - chi2.cdf(stat, df)
|
|
176
|
+
|
|
177
|
+
chi2_stats.append(stat)
|
|
178
|
+
p_values.append(p_val)
|
|
179
|
+
|
|
180
|
+
if len(p_values) == 0:
|
|
181
|
+
return 1.0, 0.0
|
|
182
|
+
|
|
183
|
+
acceptance_rate = np.mean(np.array(p_values) > alpha)
|
|
184
|
+
mean_stat = np.mean(chi2_stats)
|
|
185
|
+
|
|
186
|
+
return acceptance_rate, mean_stat
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
|
|
190
|
+
"""
|
|
191
|
+
Binary search for sigma to achieve target perplexity.
|
|
192
|
+
|
|
193
|
+
Used in LISI computation.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
distances : np.ndarray
|
|
198
|
+
Distances to neighbors.
|
|
199
|
+
target_perplexity : float
|
|
200
|
+
Target perplexity value.
|
|
201
|
+
tol : float
|
|
202
|
+
Tolerance for convergence.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
float
|
|
207
|
+
Sigma value.
|
|
208
|
+
"""
|
|
209
|
+
target_H = np.log2(target_perplexity + 1e-10)
|
|
210
|
+
|
|
211
|
+
sigma_min, sigma_max = 1e-10, 1e10
|
|
212
|
+
sigma = 1.0
|
|
213
|
+
|
|
214
|
+
for _ in range(50):
|
|
215
|
+
P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
|
|
216
|
+
P_sum = P.sum()
|
|
217
|
+
if P_sum < 1e-10:
|
|
218
|
+
sigma = (sigma + sigma_max) / 2
|
|
219
|
+
continue
|
|
220
|
+
P = P / P_sum
|
|
221
|
+
P = np.clip(P, 1e-10, 1.0)
|
|
222
|
+
H = -np.sum(P * np.log2(P))
|
|
223
|
+
|
|
224
|
+
if abs(H - target_H) < tol:
|
|
225
|
+
break
|
|
226
|
+
elif target_H > H:
|
|
227
|
+
sigma_min = sigma
|
|
228
|
+
else:
|
|
229
|
+
sigma_max = sigma
|
|
230
|
+
sigma = (sigma_min + sigma_max) / 2
|
|
231
|
+
|
|
232
|
+
return sigma
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _lisi_score(
|
|
236
|
+
X: np.ndarray,
|
|
237
|
+
batch_labels: np.ndarray,
|
|
238
|
+
perplexity: int = 30,
|
|
239
|
+
) -> float:
|
|
240
|
+
"""
|
|
241
|
+
Compute mean Local Inverse Simpson's Index (LISI).
|
|
242
|
+
|
|
243
|
+
Range: [1, n_batches], where n_batches = perfect mixing.
|
|
244
|
+
Higher = better batch mixing.
|
|
245
|
+
|
|
246
|
+
Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
X : np.ndarray
|
|
251
|
+
Data matrix.
|
|
252
|
+
batch_labels : np.ndarray
|
|
253
|
+
Batch labels for each sample.
|
|
254
|
+
perplexity : int
|
|
255
|
+
Perplexity for Gaussian kernel.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
float
|
|
260
|
+
Mean LISI score.
|
|
261
|
+
"""
|
|
262
|
+
n_samples = X.shape[0]
|
|
263
|
+
unique_batches = np.unique(batch_labels)
|
|
264
|
+
n_batches = len(unique_batches)
|
|
265
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
266
|
+
|
|
267
|
+
if n_batches < 2:
|
|
268
|
+
return 1.0
|
|
269
|
+
|
|
270
|
+
k = min(3 * perplexity, n_samples - 1)
|
|
271
|
+
|
|
272
|
+
nn = NearestNeighbors(n_neighbors=k + 1, algorithm="auto")
|
|
273
|
+
nn.fit(X)
|
|
274
|
+
distances, indices = nn.kneighbors(X)
|
|
275
|
+
|
|
276
|
+
distances = distances[:, 1:]
|
|
277
|
+
indices = indices[:, 1:]
|
|
278
|
+
|
|
279
|
+
lisi_values = []
|
|
280
|
+
|
|
281
|
+
for i in range(n_samples):
|
|
282
|
+
sigma = _find_sigma(distances[i], perplexity)
|
|
283
|
+
|
|
284
|
+
P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
|
|
285
|
+
P_sum = P.sum()
|
|
286
|
+
if P_sum < 1e-10:
|
|
287
|
+
lisi_values.append(1.0)
|
|
288
|
+
continue
|
|
289
|
+
P = P / P_sum
|
|
290
|
+
|
|
291
|
+
neighbor_batches = batch_labels[indices[i]]
|
|
292
|
+
batch_probs = np.zeros(n_batches)
|
|
293
|
+
for j, nb in enumerate(neighbor_batches):
|
|
294
|
+
batch_probs[batch_to_idx[nb]] += P[j]
|
|
295
|
+
|
|
296
|
+
simpson = np.sum(batch_probs**2)
|
|
297
|
+
lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
|
|
298
|
+
lisi_values.append(lisi)
|
|
299
|
+
|
|
300
|
+
return np.mean(lisi_values)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
304
|
+
"""
|
|
305
|
+
Compute between-batch to within-batch variance ratio.
|
|
306
|
+
|
|
307
|
+
Similar to F-statistic in one-way ANOVA.
|
|
308
|
+
Lower ratio after correction = better batch effect removal.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
X : np.ndarray
|
|
313
|
+
Data matrix.
|
|
314
|
+
batch_labels : np.ndarray
|
|
315
|
+
Batch labels for each sample.
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
float
|
|
320
|
+
Variance ratio (between/within).
|
|
321
|
+
"""
|
|
322
|
+
unique_batches = np.unique(batch_labels)
|
|
323
|
+
n_batches = len(unique_batches)
|
|
324
|
+
n_samples = X.shape[0]
|
|
325
|
+
|
|
326
|
+
if n_batches < 2:
|
|
327
|
+
return 0.0
|
|
328
|
+
|
|
329
|
+
grand_mean = np.mean(X, axis=0)
|
|
330
|
+
|
|
331
|
+
between_var = 0.0
|
|
332
|
+
within_var = 0.0
|
|
333
|
+
|
|
334
|
+
for batch in unique_batches:
|
|
335
|
+
mask = batch_labels == batch
|
|
336
|
+
n_b = np.sum(mask)
|
|
337
|
+
X_batch = X[mask]
|
|
338
|
+
batch_mean = np.mean(X_batch, axis=0)
|
|
339
|
+
|
|
340
|
+
between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
|
|
341
|
+
within_var += np.sum((X_batch - batch_mean) ** 2)
|
|
342
|
+
|
|
343
|
+
between_var /= n_batches - 1
|
|
344
|
+
within_var /= n_samples - n_batches
|
|
345
|
+
|
|
346
|
+
if within_var < 1e-10:
|
|
347
|
+
return 0.0
|
|
348
|
+
|
|
349
|
+
return between_var / within_var
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _knn_preservation(
|
|
353
|
+
X_before: np.ndarray,
|
|
354
|
+
X_after: np.ndarray,
|
|
355
|
+
k_values: list[int],
|
|
356
|
+
n_jobs: int = 1,
|
|
357
|
+
) -> dict[int, float]:
|
|
358
|
+
"""
|
|
359
|
+
Compute fraction of k-nearest neighbors preserved after correction.
|
|
360
|
+
|
|
361
|
+
Range: [0, 1], where 1 = perfect preservation.
|
|
362
|
+
Higher = better biological structure preservation.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
X_before : np.ndarray
|
|
367
|
+
Original data.
|
|
368
|
+
X_after : np.ndarray
|
|
369
|
+
Corrected data.
|
|
370
|
+
k_values : list of int
|
|
371
|
+
Values of k for k-NN.
|
|
372
|
+
n_jobs : int
|
|
373
|
+
Number of parallel jobs.
|
|
374
|
+
|
|
375
|
+
Returns
|
|
376
|
+
-------
|
|
377
|
+
dict
|
|
378
|
+
Mapping from k to preservation fraction.
|
|
379
|
+
"""
|
|
380
|
+
results = {}
|
|
381
|
+
max_k = max(k_values)
|
|
382
|
+
max_k = min(max_k, X_before.shape[0] - 1)
|
|
383
|
+
|
|
384
|
+
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
385
|
+
nn_before.fit(X_before)
|
|
386
|
+
_, indices_before = nn_before.kneighbors(X_before)
|
|
387
|
+
|
|
388
|
+
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm="auto", n_jobs=n_jobs)
|
|
389
|
+
nn_after.fit(X_after)
|
|
390
|
+
_, indices_after = nn_after.kneighbors(X_after)
|
|
391
|
+
|
|
392
|
+
for k in k_values:
|
|
393
|
+
if k > max_k:
|
|
394
|
+
results[k] = 0.0
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
overlaps = []
|
|
398
|
+
for i in range(X_before.shape[0]):
|
|
399
|
+
neighbors_before = set(indices_before[i, 1 : k + 1])
|
|
400
|
+
neighbors_after = set(indices_after[i, 1 : k + 1])
|
|
401
|
+
overlap = len(neighbors_before & neighbors_after) / k
|
|
402
|
+
overlaps.append(overlap)
|
|
403
|
+
|
|
404
|
+
results[k] = np.mean(overlaps)
|
|
405
|
+
|
|
406
|
+
return results
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _pairwise_distance_correlation(
|
|
410
|
+
X_before: np.ndarray,
|
|
411
|
+
X_after: np.ndarray,
|
|
412
|
+
subsample: int = 1000,
|
|
413
|
+
random_state: int = 42,
|
|
414
|
+
) -> float:
|
|
415
|
+
"""
|
|
416
|
+
Compute Spearman correlation of pairwise distances.
|
|
417
|
+
|
|
418
|
+
Range: [-1, 1], where 1 = perfect rank preservation.
|
|
419
|
+
Higher = better relative relationship preservation.
|
|
420
|
+
|
|
421
|
+
Parameters
|
|
422
|
+
----------
|
|
423
|
+
X_before : np.ndarray
|
|
424
|
+
Original data.
|
|
425
|
+
X_after : np.ndarray
|
|
426
|
+
Corrected data.
|
|
427
|
+
subsample : int
|
|
428
|
+
Maximum samples to use (for efficiency).
|
|
429
|
+
random_state : int
|
|
430
|
+
Random seed for subsampling.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
float
|
|
435
|
+
Spearman correlation coefficient.
|
|
436
|
+
"""
|
|
437
|
+
n_samples = X_before.shape[0]
|
|
438
|
+
|
|
439
|
+
if n_samples > subsample:
|
|
440
|
+
rng = np.random.default_rng(random_state)
|
|
441
|
+
idx = rng.choice(n_samples, subsample, replace=False)
|
|
442
|
+
X_before = X_before[idx]
|
|
443
|
+
X_after = X_after[idx]
|
|
444
|
+
|
|
445
|
+
dist_before = pdist(X_before, metric="euclidean")
|
|
446
|
+
dist_after = pdist(X_after, metric="euclidean")
|
|
447
|
+
|
|
448
|
+
if len(dist_before) == 0:
|
|
449
|
+
return 1.0
|
|
450
|
+
|
|
451
|
+
corr, _ = spearmanr(dist_before, dist_after)
|
|
452
|
+
|
|
453
|
+
if np.isnan(corr):
|
|
454
|
+
return 1.0
|
|
455
|
+
|
|
456
|
+
return corr
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
460
|
+
"""
|
|
461
|
+
Compute mean pairwise Euclidean distance between batch centroids.
|
|
462
|
+
|
|
463
|
+
Lower after correction = better batch alignment.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
X : np.ndarray
|
|
468
|
+
Data matrix.
|
|
469
|
+
batch_labels : np.ndarray
|
|
470
|
+
Batch labels for each sample.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
float
|
|
475
|
+
Mean pairwise distance between centroids.
|
|
476
|
+
"""
|
|
477
|
+
unique_batches = np.unique(batch_labels)
|
|
478
|
+
n_batches = len(unique_batches)
|
|
479
|
+
|
|
480
|
+
if n_batches < 2:
|
|
481
|
+
return 0.0
|
|
482
|
+
|
|
483
|
+
centroids = []
|
|
484
|
+
for batch in unique_batches:
|
|
485
|
+
mask = batch_labels == batch
|
|
486
|
+
centroid = np.mean(X[mask], axis=0)
|
|
487
|
+
centroids.append(centroid)
|
|
488
|
+
|
|
489
|
+
centroids = np.array(centroids)
|
|
490
|
+
distances = pdist(centroids, metric="euclidean")
|
|
491
|
+
|
|
492
|
+
return np.mean(distances)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
496
|
+
"""
|
|
497
|
+
Compute median Levene test statistic across features.
|
|
498
|
+
|
|
499
|
+
Lower statistic = more homogeneous variances across batches.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
X : np.ndarray
|
|
504
|
+
Data matrix.
|
|
505
|
+
batch_labels : np.ndarray
|
|
506
|
+
Batch labels for each sample.
|
|
507
|
+
|
|
508
|
+
Returns
|
|
509
|
+
-------
|
|
510
|
+
float
|
|
511
|
+
Median Levene test statistic.
|
|
512
|
+
"""
|
|
513
|
+
unique_batches = np.unique(batch_labels)
|
|
514
|
+
if len(unique_batches) < 2:
|
|
515
|
+
return 0.0
|
|
516
|
+
|
|
517
|
+
levene_stats = []
|
|
518
|
+
for j in range(X.shape[1]):
|
|
519
|
+
groups = [X[batch_labels == b, j] for b in unique_batches]
|
|
520
|
+
groups = [g for g in groups if len(g) > 0]
|
|
521
|
+
if len(groups) < 2:
|
|
522
|
+
continue
|
|
523
|
+
try:
|
|
524
|
+
stat, _ = levene(*groups, center="median")
|
|
525
|
+
if not np.isnan(stat):
|
|
526
|
+
levene_stats.append(stat)
|
|
527
|
+
except Exception:
|
|
528
|
+
continue
|
|
529
|
+
|
|
530
|
+
if len(levene_stats) == 0:
|
|
531
|
+
return 0.0
|
|
532
|
+
|
|
533
|
+
return np.median(levene_stats)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
class ComBatMetricsMixin:
|
|
537
|
+
"""Mixin providing batch effect metrics for the ComBat wrapper."""
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def metrics_(self) -> dict[str, Any] | None:
|
|
541
|
+
"""Return cached metrics from last fit_transform with compute_metrics=True.
|
|
542
|
+
|
|
543
|
+
Returns
|
|
544
|
+
-------
|
|
545
|
+
dict or None
|
|
546
|
+
Cached metrics dictionary, or None if no metrics have been computed.
|
|
547
|
+
"""
|
|
548
|
+
return getattr(self, "_metrics_cache", None)
|
|
549
|
+
|
|
550
|
+
def compute_batch_metrics(
|
|
551
|
+
self,
|
|
552
|
+
X: ArrayLike,
|
|
553
|
+
batch: ArrayLike | None = None,
|
|
554
|
+
*,
|
|
555
|
+
pca_components: int | None = None,
|
|
556
|
+
k_neighbors: list[int] | None = None,
|
|
557
|
+
kbet_k0: int | None = None,
|
|
558
|
+
lisi_perplexity: int = 30,
|
|
559
|
+
n_jobs: int = 1,
|
|
560
|
+
) -> dict[str, Any]:
|
|
561
|
+
"""
|
|
562
|
+
Compute batch effect metrics before and after ComBat correction.
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
----------
|
|
566
|
+
X : array-like of shape (n_samples, n_features)
|
|
567
|
+
Input data to evaluate.
|
|
568
|
+
batch : array-like of shape (n_samples,), optional
|
|
569
|
+
Batch labels. If None, uses the batch stored at construction.
|
|
570
|
+
pca_components : int, optional
|
|
571
|
+
Number of PCA components for dimensionality reduction before
|
|
572
|
+
computing metrics. If None (default), metrics are computed in
|
|
573
|
+
the original feature space. Must be less than min(n_samples, n_features).
|
|
574
|
+
k_neighbors : list of int, default=[5, 10, 50]
|
|
575
|
+
Values of k for k-NN preservation metric.
|
|
576
|
+
kbet_k0 : int, optional
|
|
577
|
+
Neighborhood size for kBET. Default is 10% of samples.
|
|
578
|
+
lisi_perplexity : int, default=30
|
|
579
|
+
Perplexity for LISI computation.
|
|
580
|
+
n_jobs : int, default=1
|
|
581
|
+
Number of parallel jobs for neighbor computations.
|
|
582
|
+
|
|
583
|
+
Returns
|
|
584
|
+
-------
|
|
585
|
+
dict
|
|
586
|
+
Dictionary with three main keys:
|
|
587
|
+
|
|
588
|
+
- ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
|
|
589
|
+
(each with 'before' and 'after' values)
|
|
590
|
+
- ``preservation``: k-NN preservation fractions, distance correlation
|
|
591
|
+
- ``alignment``: Centroid distance, Levene statistic (each with
|
|
592
|
+
'before' and 'after' values)
|
|
593
|
+
|
|
594
|
+
Raises
|
|
595
|
+
------
|
|
596
|
+
ValueError
|
|
597
|
+
If the model is not fitted or if pca_components is invalid.
|
|
598
|
+
"""
|
|
599
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
600
|
+
raise ValueError(
|
|
601
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
if not isinstance(X, pd.DataFrame):
|
|
605
|
+
X = pd.DataFrame(X)
|
|
606
|
+
|
|
607
|
+
idx = X.index
|
|
608
|
+
|
|
609
|
+
if batch is None:
|
|
610
|
+
batch_vec = self._subset(self.batch, idx)
|
|
611
|
+
else:
|
|
612
|
+
if isinstance(batch, (pd.Series, pd.DataFrame)):
|
|
613
|
+
batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
|
|
614
|
+
elif isinstance(batch, np.ndarray):
|
|
615
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
616
|
+
else:
|
|
617
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
618
|
+
|
|
619
|
+
batch_labels = np.array(batch_vec)
|
|
620
|
+
|
|
621
|
+
X_before = X.values
|
|
622
|
+
X_after = self.transform(X).values
|
|
623
|
+
|
|
624
|
+
n_samples, n_features = X_before.shape
|
|
625
|
+
if kbet_k0 is None:
|
|
626
|
+
kbet_k0 = max(10, int(0.10 * n_samples))
|
|
627
|
+
if k_neighbors is None:
|
|
628
|
+
k_neighbors = [5, 10, 50]
|
|
629
|
+
|
|
630
|
+
# Validate and apply PCA if requested
|
|
631
|
+
if pca_components is not None:
|
|
632
|
+
max_components = min(n_samples, n_features)
|
|
633
|
+
if pca_components >= max_components:
|
|
634
|
+
raise ValueError(
|
|
635
|
+
f"pca_components={pca_components} must be less than "
|
|
636
|
+
f"min(n_samples, n_features)={max_components}."
|
|
637
|
+
)
|
|
638
|
+
X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
|
|
639
|
+
else:
|
|
640
|
+
X_before_pca = X_before
|
|
641
|
+
X_after_pca = X_after
|
|
642
|
+
|
|
643
|
+
silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
|
|
644
|
+
silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
|
|
645
|
+
|
|
646
|
+
db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
|
|
647
|
+
db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
|
|
648
|
+
|
|
649
|
+
kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
|
|
650
|
+
kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
|
|
651
|
+
|
|
652
|
+
lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
|
|
653
|
+
lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
|
|
654
|
+
|
|
655
|
+
var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
|
|
656
|
+
var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
|
|
657
|
+
|
|
658
|
+
knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
|
|
659
|
+
dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
|
|
660
|
+
|
|
661
|
+
centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
|
|
662
|
+
centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
|
|
663
|
+
|
|
664
|
+
levene_before = _levene_median_statistic(X_before, batch_labels)
|
|
665
|
+
levene_after = _levene_median_statistic(X_after, batch_labels)
|
|
666
|
+
|
|
667
|
+
n_batches = len(np.unique(batch_labels))
|
|
668
|
+
|
|
669
|
+
metrics = {
|
|
670
|
+
"batch_effect": {
|
|
671
|
+
"silhouette": {
|
|
672
|
+
"before": silhouette_before,
|
|
673
|
+
"after": silhouette_after,
|
|
674
|
+
},
|
|
675
|
+
"davies_bouldin": {
|
|
676
|
+
"before": db_before,
|
|
677
|
+
"after": db_after,
|
|
678
|
+
},
|
|
679
|
+
"kbet": {
|
|
680
|
+
"before": kbet_before,
|
|
681
|
+
"after": kbet_after,
|
|
682
|
+
},
|
|
683
|
+
"lisi": {
|
|
684
|
+
"before": lisi_before,
|
|
685
|
+
"after": lisi_after,
|
|
686
|
+
"max_value": n_batches,
|
|
687
|
+
},
|
|
688
|
+
"variance_ratio": {
|
|
689
|
+
"before": var_ratio_before,
|
|
690
|
+
"after": var_ratio_after,
|
|
691
|
+
},
|
|
692
|
+
},
|
|
693
|
+
"preservation": {
|
|
694
|
+
"knn": knn_results,
|
|
695
|
+
"distance_correlation": dist_corr,
|
|
696
|
+
},
|
|
697
|
+
"alignment": {
|
|
698
|
+
"centroid_distance": {
|
|
699
|
+
"before": centroid_before,
|
|
700
|
+
"after": centroid_after,
|
|
701
|
+
},
|
|
702
|
+
"levene_statistic": {
|
|
703
|
+
"before": levene_before,
|
|
704
|
+
"after": levene_after,
|
|
705
|
+
},
|
|
706
|
+
},
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
return metrics
|
|
710
|
+
|
|
711
|
+
def feature_batch_importance(
|
|
712
|
+
self,
|
|
713
|
+
mode: Literal["magnitude", "distribution"] = "magnitude",
|
|
714
|
+
) -> pd.DataFrame:
|
|
715
|
+
"""Compute per-feature batch effect magnitude.
|
|
716
|
+
|
|
717
|
+
Returns a DataFrame with columns ``location``, ``scale``, and
|
|
718
|
+
``combined``. Location is the RMS of gamma across batches
|
|
719
|
+
(standardized mean shifts). Scale is the RMS of log-delta across
|
|
720
|
+
batches (log-fold variance change). Combined is the Euclidean norm
|
|
721
|
+
sqrt(location**2 + scale**2). Using RMS provides L2-consistent
|
|
722
|
+
aggregation; using log(delta) ensures symmetry.
|
|
723
|
+
|
|
724
|
+
Parameters
|
|
725
|
+
----------
|
|
726
|
+
mode : {'magnitude', 'distribution'}, default='magnitude'
|
|
727
|
+
- 'magnitude': Returns L2-consistent absolute batch effect magnitudes.
|
|
728
|
+
Suitable for ranking, thresholding, and cross-dataset comparison.
|
|
729
|
+
- 'distribution': Returns column-wise normalized proportions (each column
|
|
730
|
+
sums to 1, values in range [0, 1]), representing the relative contribution
|
|
731
|
+
of each feature to the total location, scale, or combined batch effect.
|
|
732
|
+
Note: normalization is applied independently to each column, so the
|
|
733
|
+
Euclidean relationship (combined**2 = location**2 + scale**2) no longer holds.
|
|
734
|
+
|
|
735
|
+
Returns
|
|
736
|
+
-------
|
|
737
|
+
pd.DataFrame
|
|
738
|
+
DataFrame with index=feature names, columns=['location', 'scale', 'combined'],
|
|
739
|
+
sorted by 'combined' descending.
|
|
740
|
+
|
|
741
|
+
Raises
|
|
742
|
+
------
|
|
743
|
+
ValueError
|
|
744
|
+
If the model is not fitted or if mode is invalid.
|
|
745
|
+
"""
|
|
746
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
747
|
+
raise ValueError(
|
|
748
|
+
"This ComBat instance is not fitted yet. "
|
|
749
|
+
"Call 'fit' before 'feature_batch_importance'."
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
if mode not in ["magnitude", "distribution"]:
|
|
753
|
+
raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
|
|
754
|
+
|
|
755
|
+
feature_names = self._model._grand_mean.index
|
|
756
|
+
gamma_star = self._model._gamma_star
|
|
757
|
+
delta_star = self._model._delta_star
|
|
758
|
+
|
|
759
|
+
# Location effect: RMS of gamma across batches (L2 aggregation)
|
|
760
|
+
location = np.sqrt((gamma_star**2).mean(axis=0))
|
|
761
|
+
|
|
762
|
+
# Scale effect: RMS of log(delta) across batches
|
|
763
|
+
if not self.mean_only:
|
|
764
|
+
scale = np.sqrt((np.log(delta_star) ** 2).mean(axis=0))
|
|
765
|
+
else:
|
|
766
|
+
scale = np.zeros_like(location)
|
|
767
|
+
|
|
768
|
+
# Euclidean to treat location and scale as orthogonal dimensions
|
|
769
|
+
combined = np.sqrt(location**2 + scale**2)
|
|
770
|
+
|
|
771
|
+
if mode == "distribution":
|
|
772
|
+
# Normalize each column independently to sum to 1
|
|
773
|
+
location_sum = location.sum()
|
|
774
|
+
scale_sum = scale.sum()
|
|
775
|
+
combined_sum = combined.sum()
|
|
776
|
+
|
|
777
|
+
location = location / location_sum if location_sum > 0 else location
|
|
778
|
+
scale = scale / scale_sum if scale_sum > 0 else scale
|
|
779
|
+
combined = combined / combined_sum if combined_sum > 0 else combined
|
|
780
|
+
|
|
781
|
+
return pd.DataFrame(
|
|
782
|
+
{
|
|
783
|
+
"location": location,
|
|
784
|
+
"scale": scale,
|
|
785
|
+
"combined": combined,
|
|
786
|
+
},
|
|
787
|
+
index=feature_names,
|
|
788
|
+
).sort_values("combined", ascending=False)
|