combatlearn 0.2.2__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from .combat import ComBatModel, ComBat
1
+ from .combat import ComBat
2
2
 
3
- __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.2.2"
3
+ __all__ = ["ComBat"]
4
+ __version__ = "1.1.0"
5
+ __author__ = "Ettore Rocchi"
combatlearn/combat.py CHANGED
@@ -14,32 +14,544 @@ import numpy as np
14
14
  import numpy.linalg as la
15
15
  import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
- from sklearn.utils.validation import check_is_fitted
18
17
  from sklearn.decomposition import PCA
19
18
  from sklearn.manifold import TSNE
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from sklearn.metrics import silhouette_score, davies_bouldin_score
21
+ from scipy.stats import levene, spearmanr, chi2
22
+ from scipy.spatial.distance import pdist
23
+ import matplotlib
20
24
  import matplotlib.pyplot as plt
21
25
  import matplotlib.colors as mcolors
22
- from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
26
+ from typing import Literal, Optional, Union, Dict, Tuple, Any, List
23
27
  import numpy.typing as npt
24
28
  import warnings
29
+ import umap
30
+ import plotly.graph_objects as go
31
+ from plotly.subplots import make_subplots
25
32
 
26
- try:
27
- import umap
28
- UMAP_AVAILABLE = True
29
- except ImportError:
30
- UMAP_AVAILABLE = False
33
+ ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
34
+ FloatArray = npt.NDArray[np.float64]
31
35
 
32
- try:
33
- import plotly.graph_objects as go
34
- from plotly.subplots import make_subplots
35
- PLOTLY_AVAILABLE = True
36
- except ImportError:
37
- PLOTLY_AVAILABLE = False
36
+ def _compute_pca_embedding(
37
+ X_before: np.ndarray,
38
+ X_after: np.ndarray,
39
+ n_components: int,
40
+ ) -> Tuple[np.ndarray, np.ndarray, PCA]:
41
+ """
42
+ Compute PCA embeddings for both datasets.
38
43
 
39
- __author__ = "Ettore Rocchi"
44
+ Fits PCA on X_before and applies to both datasets.
40
45
 
41
- ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
42
- FloatArray = npt.NDArray[np.float64]
46
+ Parameters
47
+ ----------
48
+ X_before : np.ndarray
49
+ Original data before correction.
50
+ X_after : np.ndarray
51
+ Corrected data.
52
+ n_components : int
53
+ Number of PCA components.
54
+
55
+ Returns
56
+ -------
57
+ X_before_pca : np.ndarray
58
+ PCA-transformed original data.
59
+ X_after_pca : np.ndarray
60
+ PCA-transformed corrected data.
61
+ pca : PCA
62
+ Fitted PCA model.
63
+ """
64
+ n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
65
+ pca = PCA(n_components=n_components, random_state=42)
66
+ X_before_pca = pca.fit_transform(X_before)
67
+ X_after_pca = pca.transform(X_after)
68
+ return X_before_pca, X_after_pca, pca
69
+
70
+
71
+ def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
72
+ """
73
+ Compute silhouette coefficient using batch as cluster labels.
74
+
75
+ Lower values after correction indicate better batch mixing.
76
+ Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
77
+
78
+ Parameters
79
+ ----------
80
+ X : np.ndarray
81
+ Data matrix.
82
+ batch_labels : np.ndarray
83
+ Batch labels for each sample.
84
+
85
+ Returns
86
+ -------
87
+ float
88
+ Silhouette coefficient.
89
+ """
90
+ unique_batches = np.unique(batch_labels)
91
+ if len(unique_batches) < 2:
92
+ return 0.0
93
+ try:
94
+ return silhouette_score(X, batch_labels, metric='euclidean')
95
+ except Exception:
96
+ return 0.0
97
+
98
+
99
+ def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
100
+ """
101
+ Compute Davies-Bouldin index using batch labels.
102
+
103
+ Lower values indicate better batch mixing.
104
+ Range: [0, inf), 0 = perfect batch overlap.
105
+
106
+ Parameters
107
+ ----------
108
+ X : np.ndarray
109
+ Data matrix.
110
+ batch_labels : np.ndarray
111
+ Batch labels for each sample.
112
+
113
+ Returns
114
+ -------
115
+ float
116
+ Davies-Bouldin index.
117
+ """
118
+ unique_batches = np.unique(batch_labels)
119
+ if len(unique_batches) < 2:
120
+ return 0.0
121
+ try:
122
+ return davies_bouldin_score(X, batch_labels)
123
+ except Exception:
124
+ return 0.0
125
+
126
+
127
+ def _kbet_score(
128
+ X: np.ndarray,
129
+ batch_labels: np.ndarray,
130
+ k0: int,
131
+ alpha: float = 0.05,
132
+ ) -> Tuple[float, float]:
133
+ """
134
+ Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
135
+
136
+ Tests if local batch proportions match global batch proportions.
137
+ Higher acceptance rate = better batch mixing.
138
+
139
+ Reference: Buttner et al. (2019) Nature Methods
140
+
141
+ Parameters
142
+ ----------
143
+ X : np.ndarray
144
+ Data matrix.
145
+ batch_labels : np.ndarray
146
+ Batch labels for each sample.
147
+ k0 : int
148
+ Neighborhood size.
149
+ alpha : float
150
+ Significance level for chi-squared test.
151
+
152
+ Returns
153
+ -------
154
+ acceptance_rate : float
155
+ Fraction of samples where H0 (uniform mixing) is accepted.
156
+ mean_stat : float
157
+ Mean chi-squared statistic across samples.
158
+ """
159
+ n_samples = X.shape[0]
160
+ unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
161
+ n_batches = len(unique_batches)
162
+
163
+ if n_batches < 2:
164
+ return 1.0, 0.0
165
+
166
+ global_freq = batch_counts / n_samples
167
+ k0 = min(k0, n_samples - 1)
168
+
169
+ nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm='auto')
170
+ nn.fit(X)
171
+ _, indices = nn.kneighbors(X)
172
+
173
+ chi2_stats = []
174
+ p_values = []
175
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
176
+
177
+ for i in range(n_samples):
178
+ neighbors = indices[i, 1:k0+1]
179
+ neighbor_batches = batch_labels[neighbors]
180
+
181
+ observed = np.zeros(n_batches)
182
+ for nb in neighbor_batches:
183
+ observed[batch_to_idx[nb]] += 1
184
+
185
+ expected = global_freq * k0
186
+
187
+ mask = expected > 0
188
+ if mask.sum() < 2:
189
+ continue
190
+
191
+ stat = np.sum((observed[mask] - expected[mask])**2 / expected[mask])
192
+ df = max(1, mask.sum() - 1)
193
+ p_val = 1 - chi2.cdf(stat, df)
194
+
195
+ chi2_stats.append(stat)
196
+ p_values.append(p_val)
197
+
198
+ if len(p_values) == 0:
199
+ return 1.0, 0.0
200
+
201
+ acceptance_rate = np.mean(np.array(p_values) > alpha)
202
+ mean_stat = np.mean(chi2_stats)
203
+
204
+ return acceptance_rate, mean_stat
205
+
206
+
207
+ def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
208
+ """
209
+ Binary search for sigma to achieve target perplexity.
210
+
211
+ Used in LISI computation.
212
+
213
+ Parameters
214
+ ----------
215
+ distances : np.ndarray
216
+ Distances to neighbors.
217
+ target_perplexity : float
218
+ Target perplexity value.
219
+ tol : float
220
+ Tolerance for convergence.
221
+
222
+ Returns
223
+ -------
224
+ float
225
+ Sigma value.
226
+ """
227
+ target_H = np.log2(target_perplexity + 1e-10)
228
+
229
+ sigma_min, sigma_max = 1e-10, 1e10
230
+ sigma = 1.0
231
+
232
+ for _ in range(50):
233
+ P = np.exp(-distances**2 / (2 * sigma**2 + 1e-10))
234
+ P_sum = P.sum()
235
+ if P_sum < 1e-10:
236
+ sigma = (sigma + sigma_max) / 2
237
+ continue
238
+ P = P / P_sum
239
+ P = np.clip(P, 1e-10, 1.0)
240
+ H = -np.sum(P * np.log2(P))
241
+
242
+ if abs(H - target_H) < tol:
243
+ break
244
+ elif H < target_H:
245
+ sigma_min = sigma
246
+ else:
247
+ sigma_max = sigma
248
+ sigma = (sigma_min + sigma_max) / 2
249
+
250
+ return sigma
251
+
252
+
253
+ def _lisi_score(
254
+ X: np.ndarray,
255
+ batch_labels: np.ndarray,
256
+ perplexity: int = 30,
257
+ ) -> float:
258
+ """
259
+ Compute mean Local Inverse Simpson's Index (LISI).
260
+
261
+ Range: [1, n_batches], where n_batches = perfect mixing.
262
+ Higher = better batch mixing.
263
+
264
+ Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
265
+
266
+ Parameters
267
+ ----------
268
+ X : np.ndarray
269
+ Data matrix.
270
+ batch_labels : np.ndarray
271
+ Batch labels for each sample.
272
+ perplexity : int
273
+ Perplexity for Gaussian kernel.
274
+
275
+ Returns
276
+ -------
277
+ float
278
+ Mean LISI score.
279
+ """
280
+ n_samples = X.shape[0]
281
+ unique_batches = np.unique(batch_labels)
282
+ n_batches = len(unique_batches)
283
+ batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
284
+
285
+ if n_batches < 2:
286
+ return 1.0
287
+
288
+ k = min(3 * perplexity, n_samples - 1)
289
+
290
+ nn = NearestNeighbors(n_neighbors=k + 1, algorithm='auto')
291
+ nn.fit(X)
292
+ distances, indices = nn.kneighbors(X)
293
+
294
+ distances = distances[:, 1:]
295
+ indices = indices[:, 1:]
296
+
297
+ lisi_values = []
298
+
299
+ for i in range(n_samples):
300
+ sigma = _find_sigma(distances[i], perplexity)
301
+
302
+ P = np.exp(-distances[i]**2 / (2 * sigma**2 + 1e-10))
303
+ P_sum = P.sum()
304
+ if P_sum < 1e-10:
305
+ lisi_values.append(1.0)
306
+ continue
307
+ P = P / P_sum
308
+
309
+ neighbor_batches = batch_labels[indices[i]]
310
+ batch_probs = np.zeros(n_batches)
311
+ for j, nb in enumerate(neighbor_batches):
312
+ batch_probs[batch_to_idx[nb]] += P[j]
313
+
314
+ simpson = np.sum(batch_probs**2)
315
+ if simpson < 1e-10:
316
+ lisi = n_batches
317
+ else:
318
+ lisi = 1.0 / simpson
319
+ lisi_values.append(lisi)
320
+
321
+ return np.mean(lisi_values)
322
+
323
+
324
+ def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
325
+ """
326
+ Compute between-batch to within-batch variance ratio.
327
+
328
+ Similar to F-statistic in one-way ANOVA.
329
+ Lower ratio after correction = better batch effect removal.
330
+
331
+ Parameters
332
+ ----------
333
+ X : np.ndarray
334
+ Data matrix.
335
+ batch_labels : np.ndarray
336
+ Batch labels for each sample.
337
+
338
+ Returns
339
+ -------
340
+ float
341
+ Variance ratio (between/within).
342
+ """
343
+ unique_batches = np.unique(batch_labels)
344
+ n_batches = len(unique_batches)
345
+ n_samples = X.shape[0]
346
+
347
+ if n_batches < 2:
348
+ return 0.0
349
+
350
+ grand_mean = np.mean(X, axis=0)
351
+
352
+ between_var = 0.0
353
+ within_var = 0.0
354
+
355
+ for batch in unique_batches:
356
+ mask = batch_labels == batch
357
+ n_b = np.sum(mask)
358
+ X_batch = X[mask]
359
+ batch_mean = np.mean(X_batch, axis=0)
360
+
361
+ between_var += n_b * np.sum((batch_mean - grand_mean)**2)
362
+ within_var += np.sum((X_batch - batch_mean)**2)
363
+
364
+ between_var /= (n_batches - 1)
365
+ within_var /= (n_samples - n_batches)
366
+
367
+ if within_var < 1e-10:
368
+ return 0.0
369
+
370
+ return between_var / within_var
371
+
372
+
373
+ def _knn_preservation(
374
+ X_before: np.ndarray,
375
+ X_after: np.ndarray,
376
+ k_values: List[int],
377
+ n_jobs: int = 1,
378
+ ) -> Dict[int, float]:
379
+ """
380
+ Compute fraction of k-nearest neighbors preserved after correction.
381
+
382
+ Range: [0, 1], where 1 = perfect preservation.
383
+ Higher = better biological structure preservation.
384
+
385
+ Parameters
386
+ ----------
387
+ X_before : np.ndarray
388
+ Original data.
389
+ X_after : np.ndarray
390
+ Corrected data.
391
+ k_values : list of int
392
+ Values of k for k-NN.
393
+ n_jobs : int
394
+ Number of parallel jobs.
395
+
396
+ Returns
397
+ -------
398
+ dict
399
+ Mapping from k to preservation fraction.
400
+ """
401
+ results = {}
402
+ max_k = max(k_values)
403
+ max_k = min(max_k, X_before.shape[0] - 1)
404
+
405
+ nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
406
+ nn_before.fit(X_before)
407
+ _, indices_before = nn_before.kneighbors(X_before)
408
+
409
+ nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
410
+ nn_after.fit(X_after)
411
+ _, indices_after = nn_after.kneighbors(X_after)
412
+
413
+ for k in k_values:
414
+ if k > max_k:
415
+ results[k] = 0.0
416
+ continue
417
+
418
+ overlaps = []
419
+ for i in range(X_before.shape[0]):
420
+ neighbors_before = set(indices_before[i, 1:k+1])
421
+ neighbors_after = set(indices_after[i, 1:k+1])
422
+ overlap = len(neighbors_before & neighbors_after) / k
423
+ overlaps.append(overlap)
424
+
425
+ results[k] = np.mean(overlaps)
426
+
427
+ return results
428
+
429
+
430
+ def _pairwise_distance_correlation(
431
+ X_before: np.ndarray,
432
+ X_after: np.ndarray,
433
+ subsample: int = 1000,
434
+ random_state: int = 42,
435
+ ) -> float:
436
+ """
437
+ Compute Spearman correlation of pairwise distances.
438
+
439
+ Range: [-1, 1], where 1 = perfect rank preservation.
440
+ Higher = better relative relationship preservation.
441
+
442
+ Parameters
443
+ ----------
444
+ X_before : np.ndarray
445
+ Original data.
446
+ X_after : np.ndarray
447
+ Corrected data.
448
+ subsample : int
449
+ Maximum samples to use (for efficiency).
450
+ random_state : int
451
+ Random seed for subsampling.
452
+
453
+ Returns
454
+ -------
455
+ float
456
+ Spearman correlation coefficient.
457
+ """
458
+ n_samples = X_before.shape[0]
459
+
460
+ if n_samples > subsample:
461
+ rng = np.random.default_rng(random_state)
462
+ idx = rng.choice(n_samples, subsample, replace=False)
463
+ X_before = X_before[idx]
464
+ X_after = X_after[idx]
465
+
466
+ dist_before = pdist(X_before, metric='euclidean')
467
+ dist_after = pdist(X_after, metric='euclidean')
468
+
469
+ if len(dist_before) == 0:
470
+ return 1.0
471
+
472
+ corr, _ = spearmanr(dist_before, dist_after)
473
+
474
+ if np.isnan(corr):
475
+ return 1.0
476
+
477
+ return corr
478
+
479
+
480
+ def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
481
+ """
482
+ Compute mean pairwise Euclidean distance between batch centroids.
483
+
484
+ Lower after correction = better batch alignment.
485
+
486
+ Parameters
487
+ ----------
488
+ X : np.ndarray
489
+ Data matrix.
490
+ batch_labels : np.ndarray
491
+ Batch labels for each sample.
492
+
493
+ Returns
494
+ -------
495
+ float
496
+ Mean pairwise distance between centroids.
497
+ """
498
+ unique_batches = np.unique(batch_labels)
499
+ n_batches = len(unique_batches)
500
+
501
+ if n_batches < 2:
502
+ return 0.0
503
+
504
+ centroids = []
505
+ for batch in unique_batches:
506
+ mask = batch_labels == batch
507
+ centroid = np.mean(X[mask], axis=0)
508
+ centroids.append(centroid)
509
+
510
+ centroids = np.array(centroids)
511
+ distances = pdist(centroids, metric='euclidean')
512
+
513
+ return np.mean(distances)
514
+
515
+
516
+ def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
517
+ """
518
+ Compute median Levene test statistic across features.
519
+
520
+ Lower statistic = more homogeneous variances across batches.
521
+
522
+ Parameters
523
+ ----------
524
+ X : np.ndarray
525
+ Data matrix.
526
+ batch_labels : np.ndarray
527
+ Batch labels for each sample.
528
+
529
+ Returns
530
+ -------
531
+ float
532
+ Median Levene test statistic.
533
+ """
534
+ unique_batches = np.unique(batch_labels)
535
+ if len(unique_batches) < 2:
536
+ return 0.0
537
+
538
+ levene_stats = []
539
+ for j in range(X.shape[1]):
540
+ groups = [X[batch_labels == b, j] for b in unique_batches]
541
+ groups = [g for g in groups if len(g) > 0]
542
+ if len(groups) < 2:
543
+ continue
544
+ try:
545
+ stat, _ = levene(*groups, center='median')
546
+ if not np.isnan(stat):
547
+ levene_stats.append(stat)
548
+ except Exception:
549
+ continue
550
+
551
+ if len(levene_stats) == 0:
552
+ return 0.0
553
+
554
+ return np.median(levene_stats)
43
555
 
44
556
 
45
557
  class ComBatModel:
@@ -58,8 +570,9 @@ class ComBatModel:
58
570
  ignoring the variance (`delta_star`).
59
571
  reference_batch : str, optional
60
572
  If specified, the batch level to use as reference.
61
- covbat_cov_thresh : float, default=0.9
62
- CovBat: cumulative explained variance threshold for PCA.
573
+ covbat_cov_thresh : float or int, default=0.9
574
+ CovBat: cumulative variance threshold (0, 1] to retain PCs, or
575
+ integer >= 1 specifying the number of components directly.
63
576
  eps : float, default=1e-8
64
577
  Numerical jitter to avoid division-by-zero.
65
578
  """
@@ -67,19 +580,19 @@ class ComBatModel:
67
580
  def __init__(
68
581
  self,
69
582
  *,
70
- method: Literal["johnson", "fortin", "chen"] = "johnson",
583
+ method: Literal["johnson", "fortin", "chen"] = "johnson",
71
584
  parametric: bool = True,
72
585
  mean_only: bool = False,
73
586
  reference_batch: Optional[str] = None,
74
587
  eps: float = 1e-8,
75
- covbat_cov_thresh: float = 0.9,
588
+ covbat_cov_thresh: Union[float, int] = 0.9,
76
589
  ) -> None:
77
590
  self.method: str = method
78
591
  self.parametric: bool = parametric
79
592
  self.mean_only: bool = bool(mean_only)
80
593
  self.reference_batch: Optional[str] = reference_batch
81
594
  self.eps: float = float(eps)
82
- self.covbat_cov_thresh: float = float(covbat_cov_thresh)
595
+ self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
83
596
 
84
597
  self._batch_levels: pd.Index
85
598
  self._grand_mean: pd.Series
@@ -96,9 +609,16 @@ class ComBatModel:
96
609
  self._batch_levels_pc: pd.Index
97
610
  self._pc_gamma_star: FloatArray
98
611
  self._pc_delta_star: FloatArray
99
-
100
- if not (0.0 < self.covbat_cov_thresh <= 1.0):
101
- raise ValueError("covbat_cov_thresh must be in (0, 1].")
612
+
613
+ # Validate covbat_cov_thresh
614
+ if isinstance(self.covbat_cov_thresh, float):
615
+ if not (0.0 < self.covbat_cov_thresh <= 1.0):
616
+ raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
617
+ elif isinstance(self.covbat_cov_thresh, int):
618
+ if self.covbat_cov_thresh < 1:
619
+ raise ValueError("covbat_cov_thresh must be >= 1 when int.")
620
+ else:
621
+ raise TypeError("covbat_cov_thresh must be float or int.")
102
622
 
103
623
  @staticmethod
104
624
  def _as_series(
@@ -336,8 +856,14 @@ class ComBatModel:
336
856
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
337
857
  X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
338
858
  pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
339
- cumulative = np.cumsum(pca.explained_variance_ratio_)
340
- n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
859
+
860
+ # Determine number of components based on threshold type
861
+ if isinstance(self.covbat_cov_thresh, int):
862
+ n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
863
+ else:
864
+ cumulative = np.cumsum(pca.explained_variance_ratio_)
865
+ n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
866
+
341
867
  self._covbat_pca = pca
342
868
  self._covbat_n_pc = n_pc
343
869
 
@@ -488,7 +1014,8 @@ class ComBatModel:
488
1014
  continuous_covariates: Optional[ArrayLike] = None,
489
1015
  ) -> pd.DataFrame:
490
1016
  """Transform the data using fitted ComBat parameters."""
491
- check_is_fitted(self, ["_gamma_star"])
1017
+ if not hasattr(self, "_gamma_star"):
1018
+ raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
492
1019
  if not isinstance(X, pd.DataFrame):
493
1020
  X = pd.DataFrame(X)
494
1021
  idx = X.index
@@ -600,7 +1127,7 @@ class ComBatModel:
600
1127
  """Chen transform implementation."""
601
1128
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
602
1129
  X_centered = X_meanvar_adj - self._covbat_pca.mean_
603
- scores = self._covbat_pca.transform(X_centered.values)
1130
+ scores = self._covbat_pca.transform(X_centered)
604
1131
  n_pc = self._covbat_n_pc
605
1132
  scores_adj = scores.copy()
606
1133
 
@@ -639,7 +1166,8 @@ class ComBat(BaseEstimator, TransformerMixin):
639
1166
  mean_only: bool = False,
640
1167
  reference_batch: Optional[str] = None,
641
1168
  eps: float = 1e-8,
642
- covbat_cov_thresh: float = 0.9,
1169
+ covbat_cov_thresh: Union[float, int] = 0.9,
1170
+ compute_metrics: bool = False,
643
1171
  ) -> None:
644
1172
  self.batch = batch
645
1173
  self.discrete_covariates = discrete_covariates
@@ -650,6 +1178,7 @@ class ComBat(BaseEstimator, TransformerMixin):
650
1178
  self.reference_batch = reference_batch
651
1179
  self.eps = eps
652
1180
  self.covbat_cov_thresh = covbat_cov_thresh
1181
+ self.compute_metrics = compute_metrics
653
1182
  self._model = ComBatModel(
654
1183
  method=method,
655
1184
  parametric=parametric,
@@ -707,6 +1236,221 @@ class ComBat(BaseEstimator, TransformerMixin):
707
1236
  else:
708
1237
  return pd.DataFrame(obj, index=idx)
709
1238
 
1239
+ @property
1240
+ def metrics_(self) -> Optional[Dict[str, Any]]:
1241
+ """Return cached metrics from last fit_transform with compute_metrics=True.
1242
+
1243
+ Returns
1244
+ -------
1245
+ dict or None
1246
+ Cached metrics dictionary, or None if no metrics have been computed.
1247
+ """
1248
+ return getattr(self, '_metrics_cache', None)
1249
+
1250
+ def compute_batch_metrics(
1251
+ self,
1252
+ X: ArrayLike,
1253
+ batch: Optional[ArrayLike] = None,
1254
+ *,
1255
+ pca_components: Optional[int] = None,
1256
+ k_neighbors: List[int] = [5, 10, 50],
1257
+ kbet_k0: Optional[int] = None,
1258
+ lisi_perplexity: int = 30,
1259
+ n_jobs: int = 1,
1260
+ ) -> Dict[str, Any]:
1261
+ """
1262
+ Compute batch effect metrics before and after ComBat correction.
1263
+
1264
+ Parameters
1265
+ ----------
1266
+ X : array-like of shape (n_samples, n_features)
1267
+ Input data to evaluate.
1268
+ batch : array-like of shape (n_samples,), optional
1269
+ Batch labels. If None, uses the batch stored at construction.
1270
+ pca_components : int, optional
1271
+ Number of PCA components for dimensionality reduction before
1272
+ computing metrics. If None (default), metrics are computed in
1273
+ the original feature space. Must be less than min(n_samples, n_features).
1274
+ k_neighbors : list of int, default=[5, 10, 50]
1275
+ Values of k for k-NN preservation metric.
1276
+ kbet_k0 : int, optional
1277
+ Neighborhood size for kBET. Default is 10% of samples.
1278
+ lisi_perplexity : int, default=30
1279
+ Perplexity for LISI computation.
1280
+ n_jobs : int, default=1
1281
+ Number of parallel jobs for neighbor computations.
1282
+
1283
+ Returns
1284
+ -------
1285
+ metrics : dict
1286
+ Dictionary with structure:
1287
+ {
1288
+ 'batch_effect': {
1289
+ 'silhouette': {'before': float, 'after': float},
1290
+ 'davies_bouldin': {...},
1291
+ 'kbet': {...},
1292
+ 'lisi': {..., 'max_value': n_batches},
1293
+ 'variance_ratio': {...},
1294
+ },
1295
+ 'preservation': {
1296
+ 'knn': {k: fraction for k in k_neighbors},
1297
+ 'distance_correlation': float,
1298
+ },
1299
+ 'alignment': {
1300
+ 'centroid_distance': {...},
1301
+ 'levene_statistic': {...},
1302
+ },
1303
+ }
1304
+
1305
+ Raises
1306
+ ------
1307
+ ValueError
1308
+ If the model is not fitted or if pca_components is invalid.
1309
+ """
1310
+ if not hasattr(self._model, "_gamma_star"):
1311
+ raise ValueError(
1312
+ "This ComBat instance is not fitted yet. "
1313
+ "Call 'fit' before 'compute_batch_metrics'."
1314
+ )
1315
+
1316
+ if not isinstance(X, pd.DataFrame):
1317
+ X = pd.DataFrame(X)
1318
+
1319
+ idx = X.index
1320
+
1321
+ if batch is None:
1322
+ batch_vec = self._subset(self.batch, idx)
1323
+ else:
1324
+ if isinstance(batch, (pd.Series, pd.DataFrame)):
1325
+ batch_vec = batch.loc[idx] if hasattr(batch, 'loc') else batch
1326
+ elif isinstance(batch, np.ndarray):
1327
+ batch_vec = pd.Series(batch, index=idx)
1328
+ else:
1329
+ batch_vec = pd.Series(batch, index=idx)
1330
+
1331
+ batch_labels = np.array(batch_vec)
1332
+
1333
+ X_before = X.values
1334
+ X_after = self.transform(X).values
1335
+
1336
+ n_samples, n_features = X_before.shape
1337
+ if kbet_k0 is None:
1338
+ kbet_k0 = max(10, int(0.10 * n_samples))
1339
+
1340
+ # Validate and apply PCA if requested
1341
+ if pca_components is not None:
1342
+ max_components = min(n_samples, n_features)
1343
+ if pca_components >= max_components:
1344
+ raise ValueError(
1345
+ f"pca_components={pca_components} must be less than "
1346
+ f"min(n_samples, n_features)={max_components}."
1347
+ )
1348
+ X_before_pca, X_after_pca, _ = _compute_pca_embedding(
1349
+ X_before, X_after, pca_components
1350
+ )
1351
+ else:
1352
+ X_before_pca = X_before
1353
+ X_after_pca = X_after
1354
+
1355
+ silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
1356
+ silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
1357
+
1358
+ db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
1359
+ db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
1360
+
1361
+ kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
1362
+ kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
1363
+
1364
+ lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
1365
+ lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
1366
+
1367
+ var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
1368
+ var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
1369
+
1370
+ knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
1371
+ dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
1372
+
1373
+ centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
1374
+ centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
1375
+
1376
+ levene_before = _levene_median_statistic(X_before, batch_labels)
1377
+ levene_after = _levene_median_statistic(X_after, batch_labels)
1378
+
1379
+ n_batches = len(np.unique(batch_labels))
1380
+
1381
+ metrics = {
1382
+ 'batch_effect': {
1383
+ 'silhouette': {
1384
+ 'before': silhouette_before,
1385
+ 'after': silhouette_after,
1386
+ },
1387
+ 'davies_bouldin': {
1388
+ 'before': db_before,
1389
+ 'after': db_after,
1390
+ },
1391
+ 'kbet': {
1392
+ 'before': kbet_before,
1393
+ 'after': kbet_after,
1394
+ },
1395
+ 'lisi': {
1396
+ 'before': lisi_before,
1397
+ 'after': lisi_after,
1398
+ 'max_value': n_batches,
1399
+ },
1400
+ 'variance_ratio': {
1401
+ 'before': var_ratio_before,
1402
+ 'after': var_ratio_after,
1403
+ },
1404
+ },
1405
+ 'preservation': {
1406
+ 'knn': knn_results,
1407
+ 'distance_correlation': dist_corr,
1408
+ },
1409
+ 'alignment': {
1410
+ 'centroid_distance': {
1411
+ 'before': centroid_before,
1412
+ 'after': centroid_after,
1413
+ },
1414
+ 'levene_statistic': {
1415
+ 'before': levene_before,
1416
+ 'after': levene_after,
1417
+ },
1418
+ },
1419
+ }
1420
+
1421
+ return metrics
1422
+
1423
+ def fit_transform(
1424
+ self,
1425
+ X: ArrayLike,
1426
+ y: Optional[ArrayLike] = None
1427
+ ) -> pd.DataFrame:
1428
+ """
1429
+ Fit and transform the data, optionally computing metrics.
1430
+
1431
+ If compute_metrics=True was set at construction, batch effect
1432
+ metrics are computed and cached in metrics_ property.
1433
+
1434
+ Parameters
1435
+ ----------
1436
+ X : array-like of shape (n_samples, n_features)
1437
+ Input data to fit and transform.
1438
+ y : None
1439
+ Ignored. Present for API compatibility.
1440
+
1441
+ Returns
1442
+ -------
1443
+ X_transformed : pd.DataFrame
1444
+ Batch-corrected data.
1445
+ """
1446
+ self.fit(X, y)
1447
+ X_transformed = self.transform(X)
1448
+
1449
+ if self.compute_metrics:
1450
+ self._metrics_cache = self.compute_batch_metrics(X)
1451
+
1452
+ return X_transformed
1453
+
710
1454
  def plot_transformation(
711
1455
  self,
712
1456
  X: ArrayLike, *,
@@ -759,7 +1503,8 @@ class ComBat(BaseEstimator, TransformerMixin):
759
1503
  - `'original'`: embedding of original data
760
1504
  - `'transformed'`: embedding of ComBat-transformed data
761
1505
  """
762
- check_is_fitted(self._model, ["_gamma_star"])
1506
+ if not hasattr(self._model, "_gamma_star"):
1507
+ raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
763
1508
 
764
1509
  if n_components not in [2, 3]:
765
1510
  raise ValueError(f"n_components must be 2 or 3, got {n_components}")
@@ -768,11 +1513,6 @@ class ComBat(BaseEstimator, TransformerMixin):
768
1513
  if plot_type not in ['static', 'interactive']:
769
1514
  raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
770
1515
 
771
- if reduction_method == 'umap' and not UMAP_AVAILABLE:
772
- raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
773
- if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
774
- raise ImportError("Plotly is not installed. Install with: pip install plotly")
775
-
776
1516
  if not isinstance(X, pd.DataFrame):
777
1517
  X = pd.DataFrame(X)
778
1518
 
@@ -797,8 +1537,8 @@ class ComBat(BaseEstimator, TransformerMixin):
797
1537
  else:
798
1538
  umap_params = {'random_state': 42}
799
1539
  umap_params.update(reduction_kwargs)
800
- reducer_orig = umap.UMAP(n_components=n_components, **reduction_kwargs)
801
- reducer_trans = umap.UMAP(n_components=n_components, **reduction_kwargs)
1540
+ reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
1541
+ reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
802
1542
 
803
1543
  X_embedded_orig = reducer_orig.fit_transform(X_np)
804
1544
  X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
@@ -845,9 +1585,9 @@ class ComBat(BaseEstimator, TransformerMixin):
845
1585
  n_batches = len(unique_batches)
846
1586
 
847
1587
  if n_batches <= 10:
848
- colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_batches))
1588
+ colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
849
1589
  else:
850
- colors = plt.cm.get_cmap('tab20')(np.linspace(0, 1, n_batches))
1590
+ colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
851
1591
 
852
1592
  if n_components == 2:
853
1593
  ax1 = plt.subplot(1, 2, 1)
@@ -956,7 +1696,7 @@ class ComBat(BaseEstimator, TransformerMixin):
956
1696
  unique_batches = batch_labels.drop_duplicates()
957
1697
 
958
1698
  n_batches = len(unique_batches)
959
- cmap_func = plt.cm.get_cmap(cmap)
1699
+ cmap_func = matplotlib.colormaps.get_cmap(cmap)
960
1700
  color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
961
1701
 
962
1702
  batch_to_color = dict(zip(unique_batches, color_list))
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.2
3
+ Version: 1.1.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
- License-Expression: MIT
6
+ License: MIT
7
7
  Keywords: machine-learning,harmonization,combat,preprocessing
8
8
  Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Intended Audience :: Science/Research
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
19
19
  Requires-Dist: plotly>=5.0
20
20
  Requires-Dist: nbformat>=4.2
21
21
  Requires-Dist: umap-learn>=0.5
22
- Requires-Dist: pytest>=7
22
+ Provides-Extra: dev
23
+ Requires-Dist: pytest>=7; extra == "dev"
24
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
25
+ Requires-Dist: ruff>=0.1; extra == "dev"
26
+ Requires-Dist: mypy>=1.0; extra == "dev"
27
+ Provides-Extra: docs
28
+ Requires-Dist: mkdocs>=1.5.0; extra == "docs"
29
+ Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
30
+ Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
31
+ Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
23
32
  Dynamic: license-file
24
33
 
25
34
  # **combatlearn**
26
35
 
27
36
  [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
28
37
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
38
+ [![Documentation](https://readthedocs.org/projects/combatlearn/badge/?version=latest)](https://combatlearn.readthedocs.io)
29
39
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
30
40
  [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
31
41
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
@@ -47,8 +57,21 @@ Dynamic: license-file
47
57
  pip install combatlearn
48
58
  ```
49
59
 
60
+ ## Documentation
61
+
62
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
63
+
64
+ The documentation includes:
65
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
66
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
67
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
68
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
69
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
70
+
50
71
  ## Quick start
51
72
 
73
+ For more details, see the [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/).
74
+
52
75
  ```python
53
76
  import pandas as pd
54
77
  from sklearn.pipeline import Pipeline
@@ -97,7 +120,7 @@ For a full example of how to use **combatlearn** see the [notebook demo](https:/
97
120
 
98
121
  ## `ComBat` parameters
99
122
 
100
- The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
123
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class. For complete API documentation, see the [API Reference](https://combatlearn.readthedocs.io/en/latest/api/).
101
124
 
102
125
  ### Main Parameters
103
126
 
@@ -119,11 +142,17 @@ The following section provides a detailed explanation of all parameters availabl
119
142
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
120
143
 
121
144
 
122
- ### Batch Effect Correction Visualization
145
+ ### Batch Effect Correction Visualization
123
146
 
124
147
  The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
125
148
 
126
- For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
149
+ For further details see the [Visualization Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/visualization/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
150
+
151
+ ### Batch Effect Metrics
152
+
153
+ The `compute_batch_metrics` method provides quantitative assessment of batch correction quality. It computes metrics including Silhouette coefficient, Davies-Bouldin index, kBET, LISI, and variance ratio for batch effect quantification, as well as k-NN preservation and distance correlation for structure preservation.
154
+
155
+ For further details see the [Metrics Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/metrics/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
127
156
 
128
157
  ## Contributing
129
158
 
@@ -146,8 +175,7 @@ We gratefully acknowledge:
146
175
 
147
176
  ## Citation
148
177
 
149
- If **combatlearn** is useful in your research, please cite the original
150
- papers:
178
+ If **combatlearn** is useful in your research, please cite the original papers:
151
179
 
152
180
  - Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
153
181
 
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=L4sPJuJzLJIODAuSXdNQECVoFJXcmVss7SzoqX6MlYg,99
2
+ combatlearn/combat.py,sha256=yLoppVuLBvqO-0a01xWQve-8ZEMkEICabcGPf5u1goI,59309
3
+ combatlearn-1.1.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-1.1.0.dist-info/METADATA,sha256=s4Y3G0ou1TnNZJpu1eRC7VN2lDVan_BqEm_qpVQf2lk,9446
5
+ combatlearn-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-1.1.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-1.1.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
2
- combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
3
- combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
5
- combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.2.2.dist-info/RECORD,,