combatlearn 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from .combat import ComBat
2
2
 
3
3
  __all__ = ["ComBat"]
4
- __version__ = "1.0.0"
4
+ __version__ = "1.1.0"
5
5
  __author__ = "Ettore Rocchi"
combatlearn/combat.py CHANGED
@@ -16,10 +16,14 @@ import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
17
  from sklearn.decomposition import PCA
18
18
  from sklearn.manifold import TSNE
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from sklearn.metrics import silhouette_score, davies_bouldin_score
21
+ from scipy.stats import levene, spearmanr, chi2
22
+ from scipy.spatial.distance import pdist
19
23
  import matplotlib
20
24
  import matplotlib.pyplot as plt
21
25
  import matplotlib.colors as mcolors
22
- from typing import Literal, Optional, Union, Dict, Tuple, Any
26
+ from typing import Literal, Optional, Union, Dict, Tuple, Any, List
23
27
  import numpy.typing as npt
24
28
  import warnings
25
29
  import umap
@@ -29,6 +33,526 @@ from plotly.subplots import make_subplots
29
33
  ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
30
34
  FloatArray = npt.NDArray[np.float64]
31
35
 
36
+ def _compute_pca_embedding(
37
+ X_before: np.ndarray,
38
+ X_after: np.ndarray,
39
+ n_components: int,
40
+ ) -> Tuple[np.ndarray, np.ndarray, PCA]:
41
+ """
42
+ Compute PCA embeddings for both datasets.
43
+
44
+ Fits PCA on X_before and applies to both datasets.
45
+
46
+ Parameters
47
+ ----------
48
+ X_before : np.ndarray
49
+ Original data before correction.
50
+ X_after : np.ndarray
51
+ Corrected data.
52
+ n_components : int
53
+ Number of PCA components.
54
+
55
+ Returns
56
+ -------
57
+ X_before_pca : np.ndarray
58
+ PCA-transformed original data.
59
+ X_after_pca : np.ndarray
60
+ PCA-transformed corrected data.
61
+ pca : PCA
62
+ Fitted PCA model.
63
+ """
64
+ n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
65
+ pca = PCA(n_components=n_components, random_state=42)
66
+ X_before_pca = pca.fit_transform(X_before)
67
+ X_after_pca = pca.transform(X_after)
68
+ return X_before_pca, X_after_pca, pca
69
+
70
+
71
+ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
72
+ """
73
+ Compute silhouette coefficient using batch as cluster labels.
74
+
75
+ Lower values after correction indicate better batch mixing.
76
+ Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
77
+
78
+ Parameters
79
+ ----------
80
+ X : np.ndarray
81
+ Data matrix.
82
+ batch_labels : np.ndarray
83
+ Batch labels for each sample.
84
+
85
+ Returns
86
+ -------
87
+ float
88
+ Silhouette coefficient.
89
+ """
90
+ unique_batches = np.unique(batch_labels)
91
+ if len(unique_batches) < 2:
92
+ return 0.0
93
+ try:
94
+ return silhouette_score(X, batch_labels, metric='euclidean')
95
+ except Exception:
96
+ return 0.0
97
+
98
+
99
+ def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
100
+ """
101
+ Compute Davies-Bouldin index using batch labels.
102
+
103
+ Lower values indicate better batch mixing.
104
+ Range: [0, inf), 0 = perfect batch overlap.
105
+
106
+ Parameters
107
+ ----------
108
+ X : np.ndarray
109
+ Data matrix.
110
+ batch_labels : np.ndarray
111
+ Batch labels for each sample.
112
+
113
+ Returns
114
+ -------
115
+ float
116
+ Davies-Bouldin index.
117
+ """
118
+ unique_batches = np.unique(batch_labels)
119
+ if len(unique_batches) < 2:
120
+ return 0.0
121
+ try:
122
+ return davies_bouldin_score(X, batch_labels)
123
+ except Exception:
124
+ return 0.0
125
+
126
+
127
+ def _kbet_score(
128
+ X: np.ndarray,
129
+ batch_labels: np.ndarray,
130
+ k0: int,
131
+ alpha: float = 0.05,
132
+ ) -> Tuple[float, float]:
133
+ """
134
+ Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
135
+
136
+ Tests if local batch proportions match global batch proportions.
137
+ Higher acceptance rate = better batch mixing.
138
+
139
+ Reference: Buttner et al. (2019) Nature Methods
140
+
141
+ Parameters
142
+ ----------
143
+ X : np.ndarray
144
+ Data matrix.
145
+ batch_labels : np.ndarray
146
+ Batch labels for each sample.
147
+ k0 : int
148
+ Neighborhood size.
149
+ alpha : float
150
+ Significance level for chi-squared test.
151
+
152
+ Returns
153
+ -------
154
+ acceptance_rate : float
155
+ Fraction of samples where H0 (uniform mixing) is accepted.
156
+ mean_stat : float
157
+ Mean chi-squared statistic across samples.
158
+ """
159
+ n_samples = X.shape[0]
160
+ unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
161
+ n_batches = len(unique_batches)
162
+
163
+ if n_batches < 2:
164
+ return 1.0, 0.0
165
+
166
+ global_freq = batch_counts / n_samples
167
+ k0 = min(k0, n_samples - 1)
168
+
169
+ nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm='auto')
170
+ nn.fit(X)
171
+ _, indices = nn.kneighbors(X)
172
+
173
+ chi2_stats = []
174
+ p_values = []
175
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
176
+
177
+ for i in range(n_samples):
178
+ neighbors = indices[i, 1:k0+1]
179
+ neighbor_batches = batch_labels[neighbors]
180
+
181
+ observed = np.zeros(n_batches)
182
+ for nb in neighbor_batches:
183
+ observed[batch_to_idx[nb]] += 1
184
+
185
+ expected = global_freq * k0
186
+
187
+ mask = expected > 0
188
+ if mask.sum() < 2:
189
+ continue
190
+
191
+ stat = np.sum((observed[mask] - expected[mask])**2 / expected[mask])
192
+ df = max(1, mask.sum() - 1)
193
+ p_val = 1 - chi2.cdf(stat, df)
194
+
195
+ chi2_stats.append(stat)
196
+ p_values.append(p_val)
197
+
198
+ if len(p_values) == 0:
199
+ return 1.0, 0.0
200
+
201
+ acceptance_rate = np.mean(np.array(p_values) > alpha)
202
+ mean_stat = np.mean(chi2_stats)
203
+
204
+ return acceptance_rate, mean_stat
205
+
206
+
207
+ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
208
+ """
209
+ Binary search for sigma to achieve target perplexity.
210
+
211
+ Used in LISI computation.
212
+
213
+ Parameters
214
+ ----------
215
+ distances : np.ndarray
216
+ Distances to neighbors.
217
+ target_perplexity : float
218
+ Target perplexity value.
219
+ tol : float
220
+ Tolerance for convergence.
221
+
222
+ Returns
223
+ -------
224
+ float
225
+ Sigma value.
226
+ """
227
+ target_H = np.log2(target_perplexity + 1e-10)
228
+
229
+ sigma_min, sigma_max = 1e-10, 1e10
230
+ sigma = 1.0
231
+
232
+ for _ in range(50):
233
+ P = np.exp(-distances**2 / (2 * sigma**2 + 1e-10))
234
+ P_sum = P.sum()
235
+ if P_sum < 1e-10:
236
+ sigma = (sigma + sigma_max) / 2
237
+ continue
238
+ P = P / P_sum
239
+ P = np.clip(P, 1e-10, 1.0)
240
+ H = -np.sum(P * np.log2(P))
241
+
242
+ if abs(H - target_H) < tol:
243
+ break
244
+ elif H < target_H:
245
+ sigma_min = sigma
246
+ else:
247
+ sigma_max = sigma
248
+ sigma = (sigma_min + sigma_max) / 2
249
+
250
+ return sigma
251
+
252
+
253
+ def _lisi_score(
254
+ X: np.ndarray,
255
+ batch_labels: np.ndarray,
256
+ perplexity: int = 30,
257
+ ) -> float:
258
+ """
259
+ Compute mean Local Inverse Simpson's Index (LISI).
260
+
261
+ Range: [1, n_batches], where n_batches = perfect mixing.
262
+ Higher = better batch mixing.
263
+
264
+ Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
265
+
266
+ Parameters
267
+ ----------
268
+ X : np.ndarray
269
+ Data matrix.
270
+ batch_labels : np.ndarray
271
+ Batch labels for each sample.
272
+ perplexity : int
273
+ Perplexity for Gaussian kernel.
274
+
275
+ Returns
276
+ -------
277
+ float
278
+ Mean LISI score.
279
+ """
280
+ n_samples = X.shape[0]
281
+ unique_batches = np.unique(batch_labels)
282
+ n_batches = len(unique_batches)
283
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
284
+
285
+ if n_batches < 2:
286
+ return 1.0
287
+
288
+ k = min(3 * perplexity, n_samples - 1)
289
+
290
+ nn = NearestNeighbors(n_neighbors=k + 1, algorithm='auto')
291
+ nn.fit(X)
292
+ distances, indices = nn.kneighbors(X)
293
+
294
+ distances = distances[:, 1:]
295
+ indices = indices[:, 1:]
296
+
297
+ lisi_values = []
298
+
299
+ for i in range(n_samples):
300
+ sigma = _find_sigma(distances[i], perplexity)
301
+
302
+ P = np.exp(-distances[i]**2 / (2 * sigma**2 + 1e-10))
303
+ P_sum = P.sum()
304
+ if P_sum < 1e-10:
305
+ lisi_values.append(1.0)
306
+ continue
307
+ P = P / P_sum
308
+
309
+ neighbor_batches = batch_labels[indices[i]]
310
+ batch_probs = np.zeros(n_batches)
311
+ for j, nb in enumerate(neighbor_batches):
312
+ batch_probs[batch_to_idx[nb]] += P[j]
313
+
314
+ simpson = np.sum(batch_probs**2)
315
+ if simpson < 1e-10:
316
+ lisi = n_batches
317
+ else:
318
+ lisi = 1.0 / simpson
319
+ lisi_values.append(lisi)
320
+
321
+ return np.mean(lisi_values)
322
+
323
+
324
+ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
325
+ """
326
+ Compute between-batch to within-batch variance ratio.
327
+
328
+ Similar to F-statistic in one-way ANOVA.
329
+ Lower ratio after correction = better batch effect removal.
330
+
331
+ Parameters
332
+ ----------
333
+ X : np.ndarray
334
+ Data matrix.
335
+ batch_labels : np.ndarray
336
+ Batch labels for each sample.
337
+
338
+ Returns
339
+ -------
340
+ float
341
+ Variance ratio (between/within).
342
+ """
343
+ unique_batches = np.unique(batch_labels)
344
+ n_batches = len(unique_batches)
345
+ n_samples = X.shape[0]
346
+
347
+ if n_batches < 2:
348
+ return 0.0
349
+
350
+ grand_mean = np.mean(X, axis=0)
351
+
352
+ between_var = 0.0
353
+ within_var = 0.0
354
+
355
+ for batch in unique_batches:
356
+ mask = batch_labels == batch
357
+ n_b = np.sum(mask)
358
+ X_batch = X[mask]
359
+ batch_mean = np.mean(X_batch, axis=0)
360
+
361
+ between_var += n_b * np.sum((batch_mean - grand_mean)**2)
362
+ within_var += np.sum((X_batch - batch_mean)**2)
363
+
364
+ between_var /= (n_batches - 1)
365
+ within_var /= (n_samples - n_batches)
366
+
367
+ if within_var < 1e-10:
368
+ return 0.0
369
+
370
+ return between_var / within_var
371
+
372
+
373
+ def _knn_preservation(
374
+ X_before: np.ndarray,
375
+ X_after: np.ndarray,
376
+ k_values: List[int],
377
+ n_jobs: int = 1,
378
+ ) -> Dict[int, float]:
379
+ """
380
+ Compute fraction of k-nearest neighbors preserved after correction.
381
+
382
+ Range: [0, 1], where 1 = perfect preservation.
383
+ Higher = better biological structure preservation.
384
+
385
+ Parameters
386
+ ----------
387
+ X_before : np.ndarray
388
+ Original data.
389
+ X_after : np.ndarray
390
+ Corrected data.
391
+ k_values : list of int
392
+ Values of k for k-NN.
393
+ n_jobs : int
394
+ Number of parallel jobs.
395
+
396
+ Returns
397
+ -------
398
+ dict
399
+ Mapping from k to preservation fraction.
400
+ """
401
+ results = {}
402
+ max_k = max(k_values)
403
+ max_k = min(max_k, X_before.shape[0] - 1)
404
+
405
+ nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
406
+ nn_before.fit(X_before)
407
+ _, indices_before = nn_before.kneighbors(X_before)
408
+
409
+ nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
410
+ nn_after.fit(X_after)
411
+ _, indices_after = nn_after.kneighbors(X_after)
412
+
413
+ for k in k_values:
414
+ if k > max_k:
415
+ results[k] = 0.0
416
+ continue
417
+
418
+ overlaps = []
419
+ for i in range(X_before.shape[0]):
420
+ neighbors_before = set(indices_before[i, 1:k+1])
421
+ neighbors_after = set(indices_after[i, 1:k+1])
422
+ overlap = len(neighbors_before & neighbors_after) / k
423
+ overlaps.append(overlap)
424
+
425
+ results[k] = np.mean(overlaps)
426
+
427
+ return results
428
+
429
+
430
+ def _pairwise_distance_correlation(
431
+ X_before: np.ndarray,
432
+ X_after: np.ndarray,
433
+ subsample: int = 1000,
434
+ random_state: int = 42,
435
+ ) -> float:
436
+ """
437
+ Compute Spearman correlation of pairwise distances.
438
+
439
+ Range: [-1, 1], where 1 = perfect rank preservation.
440
+ Higher = better relative relationship preservation.
441
+
442
+ Parameters
443
+ ----------
444
+ X_before : np.ndarray
445
+ Original data.
446
+ X_after : np.ndarray
447
+ Corrected data.
448
+ subsample : int
449
+ Maximum samples to use (for efficiency).
450
+ random_state : int
451
+ Random seed for subsampling.
452
+
453
+ Returns
454
+ -------
455
+ float
456
+ Spearman correlation coefficient.
457
+ """
458
+ n_samples = X_before.shape[0]
459
+
460
+ if n_samples > subsample:
461
+ rng = np.random.default_rng(random_state)
462
+ idx = rng.choice(n_samples, subsample, replace=False)
463
+ X_before = X_before[idx]
464
+ X_after = X_after[idx]
465
+
466
+ dist_before = pdist(X_before, metric='euclidean')
467
+ dist_after = pdist(X_after, metric='euclidean')
468
+
469
+ if len(dist_before) == 0:
470
+ return 1.0
471
+
472
+ corr, _ = spearmanr(dist_before, dist_after)
473
+
474
+ if np.isnan(corr):
475
+ return 1.0
476
+
477
+ return corr
478
+
479
+
480
+ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
481
+ """
482
+ Compute mean pairwise Euclidean distance between batch centroids.
483
+
484
+ Lower after correction = better batch alignment.
485
+
486
+ Parameters
487
+ ----------
488
+ X : np.ndarray
489
+ Data matrix.
490
+ batch_labels : np.ndarray
491
+ Batch labels for each sample.
492
+
493
+ Returns
494
+ -------
495
+ float
496
+ Mean pairwise distance between centroids.
497
+ """
498
+ unique_batches = np.unique(batch_labels)
499
+ n_batches = len(unique_batches)
500
+
501
+ if n_batches < 2:
502
+ return 0.0
503
+
504
+ centroids = []
505
+ for batch in unique_batches:
506
+ mask = batch_labels == batch
507
+ centroid = np.mean(X[mask], axis=0)
508
+ centroids.append(centroid)
509
+
510
+ centroids = np.array(centroids)
511
+ distances = pdist(centroids, metric='euclidean')
512
+
513
+ return np.mean(distances)
514
+
515
+
516
+ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
517
+ """
518
+ Compute median Levene test statistic across features.
519
+
520
+ Lower statistic = more homogeneous variances across batches.
521
+
522
+ Parameters
523
+ ----------
524
+ X : np.ndarray
525
+ Data matrix.
526
+ batch_labels : np.ndarray
527
+ Batch labels for each sample.
528
+
529
+ Returns
530
+ -------
531
+ float
532
+ Median Levene test statistic.
533
+ """
534
+ unique_batches = np.unique(batch_labels)
535
+ if len(unique_batches) < 2:
536
+ return 0.0
537
+
538
+ levene_stats = []
539
+ for j in range(X.shape[1]):
540
+ groups = [X[batch_labels == b, j] for b in unique_batches]
541
+ groups = [g for g in groups if len(g) > 0]
542
+ if len(groups) < 2:
543
+ continue
544
+ try:
545
+ stat, _ = levene(*groups, center='median')
546
+ if not np.isnan(stat):
547
+ levene_stats.append(stat)
548
+ except Exception:
549
+ continue
550
+
551
+ if len(levene_stats) == 0:
552
+ return 0.0
553
+
554
+ return np.median(levene_stats)
555
+
32
556
 
33
557
  class ComBatModel:
34
558
  """ComBat algorithm.
@@ -643,6 +1167,7 @@ class ComBat(BaseEstimator, TransformerMixin):
643
1167
  reference_batch: Optional[str] = None,
644
1168
  eps: float = 1e-8,
645
1169
  covbat_cov_thresh: Union[float, int] = 0.9,
1170
+ compute_metrics: bool = False,
646
1171
  ) -> None:
647
1172
  self.batch = batch
648
1173
  self.discrete_covariates = discrete_covariates
@@ -653,6 +1178,7 @@ class ComBat(BaseEstimator, TransformerMixin):
653
1178
  self.reference_batch = reference_batch
654
1179
  self.eps = eps
655
1180
  self.covbat_cov_thresh = covbat_cov_thresh
1181
+ self.compute_metrics = compute_metrics
656
1182
  self._model = ComBatModel(
657
1183
  method=method,
658
1184
  parametric=parametric,
@@ -710,6 +1236,221 @@ class ComBat(BaseEstimator, TransformerMixin):
710
1236
  else:
711
1237
  return pd.DataFrame(obj, index=idx)
712
1238
 
1239
+ @property
1240
+ def metrics_(self) -> Optional[Dict[str, Any]]:
1241
+ """Return cached metrics from last fit_transform with compute_metrics=True.
1242
+
1243
+ Returns
1244
+ -------
1245
+ dict or None
1246
+ Cached metrics dictionary, or None if no metrics have been computed.
1247
+ """
1248
+ return getattr(self, '_metrics_cache', None)
1249
+
1250
+ def compute_batch_metrics(
1251
+ self,
1252
+ X: ArrayLike,
1253
+ batch: Optional[ArrayLike] = None,
1254
+ *,
1255
+ pca_components: Optional[int] = None,
1256
+ k_neighbors: List[int] = [5, 10, 50],
1257
+ kbet_k0: Optional[int] = None,
1258
+ lisi_perplexity: int = 30,
1259
+ n_jobs: int = 1,
1260
+ ) -> Dict[str, Any]:
1261
+ """
1262
+ Compute batch effect metrics before and after ComBat correction.
1263
+
1264
+ Parameters
1265
+ ----------
1266
+ X : array-like of shape (n_samples, n_features)
1267
+ Input data to evaluate.
1268
+ batch : array-like of shape (n_samples,), optional
1269
+ Batch labels. If None, uses the batch stored at construction.
1270
+ pca_components : int, optional
1271
+ Number of PCA components for dimensionality reduction before
1272
+ computing metrics. If None (default), metrics are computed in
1273
+ the original feature space. Must be less than min(n_samples, n_features).
1274
+ k_neighbors : list of int, default=[5, 10, 50]
1275
+ Values of k for k-NN preservation metric.
1276
+ kbet_k0 : int, optional
1277
+ Neighborhood size for kBET. Default is 10% of samples.
1278
+ lisi_perplexity : int, default=30
1279
+ Perplexity for LISI computation.
1280
+ n_jobs : int, default=1
1281
+ Number of parallel jobs for neighbor computations.
1282
+
1283
+ Returns
1284
+ -------
1285
+ metrics : dict
1286
+ Dictionary with structure:
1287
+ {
1288
+ 'batch_effect': {
1289
+ 'silhouette': {'before': float, 'after': float},
1290
+ 'davies_bouldin': {...},
1291
+ 'kbet': {...},
1292
+ 'lisi': {..., 'max_value': n_batches},
1293
+ 'variance_ratio': {...},
1294
+ },
1295
+ 'preservation': {
1296
+ 'knn': {k: fraction for k in k_neighbors},
1297
+ 'distance_correlation': float,
1298
+ },
1299
+ 'alignment': {
1300
+ 'centroid_distance': {...},
1301
+ 'levene_statistic': {...},
1302
+ },
1303
+ }
1304
+
1305
+ Raises
1306
+ ------
1307
+ ValueError
1308
+ If the model is not fitted or if pca_components is invalid.
1309
+ """
1310
+ if not hasattr(self._model, "_gamma_star"):
1311
+ raise ValueError(
1312
+ "This ComBat instance is not fitted yet. "
1313
+ "Call 'fit' before 'compute_batch_metrics'."
1314
+ )
1315
+
1316
+ if not isinstance(X, pd.DataFrame):
1317
+ X = pd.DataFrame(X)
1318
+
1319
+ idx = X.index
1320
+
1321
+ if batch is None:
1322
+ batch_vec = self._subset(self.batch, idx)
1323
+ else:
1324
+ if isinstance(batch, (pd.Series, pd.DataFrame)):
1325
+ batch_vec = batch.loc[idx] if hasattr(batch, 'loc') else batch
1326
+ elif isinstance(batch, np.ndarray):
1327
+ batch_vec = pd.Series(batch, index=idx)
1328
+ else:
1329
+ batch_vec = pd.Series(batch, index=idx)
1330
+
1331
+ batch_labels = np.array(batch_vec)
1332
+
1333
+ X_before = X.values
1334
+ X_after = self.transform(X).values
1335
+
1336
+ n_samples, n_features = X_before.shape
1337
+ if kbet_k0 is None:
1338
+ kbet_k0 = max(10, int(0.10 * n_samples))
1339
+
1340
+ # Validate and apply PCA if requested
1341
+ if pca_components is not None:
1342
+ max_components = min(n_samples, n_features)
1343
+ if pca_components >= max_components:
1344
+ raise ValueError(
1345
+ f"pca_components={pca_components} must be less than "
1346
+ f"min(n_samples, n_features)={max_components}."
1347
+ )
1348
+ X_before_pca, X_after_pca, _ = _compute_pca_embedding(
1349
+ X_before, X_after, pca_components
1350
+ )
1351
+ else:
1352
+ X_before_pca = X_before
1353
+ X_after_pca = X_after
1354
+
1355
+ silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
1356
+ silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
1357
+
1358
+ db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
1359
+ db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
1360
+
1361
+ kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
1362
+ kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
1363
+
1364
+ lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
1365
+ lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
1366
+
1367
+ var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
1368
+ var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
1369
+
1370
+ knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
1371
+ dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
1372
+
1373
+ centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
1374
+ centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
1375
+
1376
+ levene_before = _levene_median_statistic(X_before, batch_labels)
1377
+ levene_after = _levene_median_statistic(X_after, batch_labels)
1378
+
1379
+ n_batches = len(np.unique(batch_labels))
1380
+
1381
+ metrics = {
1382
+ 'batch_effect': {
1383
+ 'silhouette': {
1384
+ 'before': silhouette_before,
1385
+ 'after': silhouette_after,
1386
+ },
1387
+ 'davies_bouldin': {
1388
+ 'before': db_before,
1389
+ 'after': db_after,
1390
+ },
1391
+ 'kbet': {
1392
+ 'before': kbet_before,
1393
+ 'after': kbet_after,
1394
+ },
1395
+ 'lisi': {
1396
+ 'before': lisi_before,
1397
+ 'after': lisi_after,
1398
+ 'max_value': n_batches,
1399
+ },
1400
+ 'variance_ratio': {
1401
+ 'before': var_ratio_before,
1402
+ 'after': var_ratio_after,
1403
+ },
1404
+ },
1405
+ 'preservation': {
1406
+ 'knn': knn_results,
1407
+ 'distance_correlation': dist_corr,
1408
+ },
1409
+ 'alignment': {
1410
+ 'centroid_distance': {
1411
+ 'before': centroid_before,
1412
+ 'after': centroid_after,
1413
+ },
1414
+ 'levene_statistic': {
1415
+ 'before': levene_before,
1416
+ 'after': levene_after,
1417
+ },
1418
+ },
1419
+ }
1420
+
1421
+ return metrics
1422
+
1423
+ def fit_transform(
1424
+ self,
1425
+ X: ArrayLike,
1426
+ y: Optional[ArrayLike] = None
1427
+ ) -> pd.DataFrame:
1428
+ """
1429
+ Fit and transform the data, optionally computing metrics.
1430
+
1431
+ If compute_metrics=True was set at construction, batch effect
1432
+ metrics are computed and cached in metrics_ property.
1433
+
1434
+ Parameters
1435
+ ----------
1436
+ X : array-like of shape (n_samples, n_features)
1437
+ Input data to fit and transform.
1438
+ y : None
1439
+ Ignored. Present for API compatibility.
1440
+
1441
+ Returns
1442
+ -------
1443
+ X_transformed : pd.DataFrame
1444
+ Batch-corrected data.
1445
+ """
1446
+ self.fit(X, y)
1447
+ X_transformed = self.transform(X)
1448
+
1449
+ if self.compute_metrics:
1450
+ self._metrics_cache = self.compute_batch_metrics(X)
1451
+
1452
+ return X_transformed
1453
+
713
1454
  def plot_transformation(
714
1455
  self,
715
1456
  X: ArrayLike, *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT
@@ -57,8 +57,21 @@ Dynamic: license-file
57
57
  pip install combatlearn
58
58
  ```
59
59
 
60
+ ## Documentation
61
+
62
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
63
+
64
+ The documentation includes:
65
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
66
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
67
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
68
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
69
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
70
+
60
71
  ## Quick start
61
72
 
73
+ For more details, see the [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/).
74
+
62
75
  ```python
63
76
  import pandas as pd
64
77
  from sklearn.pipeline import Pipeline
@@ -105,20 +118,9 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
105
118
 
106
119
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
107
120
 
108
- ## Documentation
109
-
110
- **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
111
-
112
- The documentation includes:
113
- - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
114
- - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
115
- - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
116
- - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
117
- - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
118
-
119
121
  ## `ComBat` parameters
120
122
 
121
- The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
123
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class. For complete API documentation, see the [API Reference](https://combatlearn.readthedocs.io/en/latest/api/).
122
124
 
123
125
  ### Main Parameters
124
126
 
@@ -140,11 +142,17 @@ The following section provides a detailed explanation of all parameters availabl
140
142
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
141
143
 
142
144
 
143
- ### Batch Effect Correction Visualization
145
+ ### Batch Effect Correction Visualization
144
146
 
145
147
  The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
146
148
 
147
- For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
149
+ For further details see the [Visualization Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/visualization/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
150
+
151
+ ### Batch Effect Metrics
152
+
153
+ The `compute_batch_metrics` method provides quantitative assessment of batch correction quality. It computes metrics including Silhouette coefficient, Davies-Bouldin index, kBET, LISI, and variance ratio for batch effect quantification, as well as k-NN preservation and distance correlation for structure preservation.
154
+
155
+ For further details see the [Metrics Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/metrics/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
148
156
 
149
157
  ## Contributing
150
158
 
@@ -167,8 +175,7 @@ We gratefully acknowledge:
167
175
 
168
176
  ## Citation
169
177
 
170
- If **combatlearn** is useful in your research, please cite the original
171
- papers:
178
+ If **combatlearn** is useful in your research, please cite the original papers:
172
179
 
173
180
  - Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
174
181
 
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=L4sPJuJzLJIODAuSXdNQECVoFJXcmVss7SzoqX6MlYg,99
2
+ combatlearn/combat.py,sha256=yLoppVuLBvqO-0a01xWQve-8ZEMkEICabcGPf5u1goI,59309
3
+ combatlearn-1.1.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-1.1.0.dist-info/METADATA,sha256=s4Y3G0ou1TnNZJpu1eRC7VN2lDVan_BqEm_qpVQf2lk,9446
5
+ combatlearn-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-1.1.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-1.1.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=ck_EGW8iqLGUebg2wc-h794lwG3uAkHn9GaWjHgUIX4,99
2
- combatlearn/combat.py,sha256=Hri1XwnfSXWLoC1KD2VkqtNLkZpixI5ax0UrT1HtjyU,38505
3
- combatlearn-1.0.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-1.0.0.dist-info/METADATA,sha256=hJvZEiA_ekTq06wzfOf2p6M_4vwNXGOdoS-K5MvT4P0,8558
5
- combatlearn-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-1.0.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-1.0.0.dist-info/RECORD,,