combatlearn 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
@@ -37,15 +37,18 @@ License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
38
38
  Requires-Dist: numpy>=1.21
39
39
  Requires-Dist: scikit-learn>=1.2
40
+ Requires-Dist: plotly>=5.0
41
+ Requires-Dist: nbformat>=4.2
42
+ Requires-Dist: umap-learn>=0.5
40
43
  Requires-Dist: pytest>=7
41
44
  Dynamic: license-file
42
45
 
43
46
  # **combatlearn**
44
47
 
45
- [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
48
+ [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
46
49
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
47
50
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
48
- [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
51
+ [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
49
52
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
50
53
 
51
54
  <div align="center">
@@ -56,7 +59,7 @@ Dynamic: license-file
56
59
 
57
60
  **Three methods**:
58
61
  - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
59
- - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
62
+ - `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
60
63
  - `method="chen"` - CovBat (Chen _et al._, 2022)
61
64
 
62
65
  ## Installation
@@ -111,7 +114,7 @@ print("Best parameters:", grid.best_params_)
111
114
  print(f"Best CV AUROC: {grid.best_score_:.3f}")
112
115
  ```
113
116
 
114
- For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
117
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
115
118
 
116
119
  ## `ComBat` parameters
117
120
 
@@ -136,6 +139,13 @@ The following section provides a detailed explanation of all parameters availabl
136
139
  | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
140
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
141
 
142
+
143
+ ### Batch Effect Correction Visualization
144
+
145
+ The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
146
+
147
+ For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
148
+
139
149
  ## Contributing
140
150
 
141
151
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
@@ -144,7 +154,7 @@ Pull requests, bug reports, and feature ideas are welcome: feel free to open a P
144
154
 
145
155
  [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
156
 
147
- [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
157
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) | [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
158
 
149
159
  ## Acknowledgements
150
160
 
@@ -1,9 +1,9 @@
1
1
  # **combatlearn**
2
2
 
3
- [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
3
+ [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
4
4
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
5
5
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
6
- [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
6
+ [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
7
7
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
8
8
 
9
9
  <div align="center">
@@ -14,7 +14,7 @@
14
14
 
15
15
  **Three methods**:
16
16
  - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
17
- - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
17
+ - `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
18
18
  - `method="chen"` - CovBat (Chen _et al._, 2022)
19
19
 
20
20
  ## Installation
@@ -69,7 +69,7 @@ print("Best parameters:", grid.best_params_)
69
69
  print(f"Best CV AUROC: {grid.best_score_:.3f}")
70
70
  ```
71
71
 
72
- For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
72
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
73
73
 
74
74
  ## `ComBat` parameters
75
75
 
@@ -94,6 +94,13 @@ The following section provides a detailed explanation of all parameters availabl
94
94
  | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
95
95
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
96
96
 
97
+
98
+ ### Batch Effect Correction Visualization
99
+
100
+ The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
101
+
102
+ For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
103
+
97
104
  ## Contributing
98
105
 
99
106
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
@@ -102,7 +109,7 @@ Pull requests, bug reports, and feature ideas are welcome: feel free to open a P
102
109
 
103
110
  [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
104
111
 
105
- [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
112
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) | [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
106
113
 
107
114
  ## Acknowledgements
108
115
 
@@ -1,4 +1,4 @@
1
1
  from .combat import ComBatModel, ComBat
2
2
 
3
3
  __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.1.0"
4
+ __version__ = "0.2.0"
@@ -16,10 +16,25 @@ import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
17
  from sklearn.utils.validation import check_is_fitted
18
18
  from sklearn.decomposition import PCA
19
+ from sklearn.manifold import TSNE
20
+ import matplotlib.pyplot as plt
19
21
  from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
20
22
  import numpy.typing as npt
21
23
  import warnings
22
24
 
25
+ try:
26
+ import umap
27
+ UMAP_AVAILABLE = True
28
+ except ImportError:
29
+ UMAP_AVAILABLE = False
30
+
31
+ try:
32
+ import plotly.graph_objects as go
33
+ from plotly.subplots import make_subplots
34
+ PLOTLY_AVAILABLE = True
35
+ except ImportError:
36
+ PLOTLY_AVAILABLE = False
37
+
23
38
  __author__ = "Ettore Rocchi"
24
39
 
25
40
  ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
@@ -139,7 +154,7 @@ class ComBatModel:
139
154
 
140
155
  if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
141
156
  raise ValueError(
142
- f"reference_batch={self.reference_batch!r} not present in the data batches "
157
+ f"reference_batch={self.reference_batch!r} not present in the data batches."
143
158
  f"{list(batch.cat.categories)}"
144
159
  )
145
160
 
@@ -218,69 +233,94 @@ class ComBatModel:
218
233
  disc: Optional[pd.DataFrame],
219
234
  cont: Optional[pd.DataFrame],
220
235
  ) -> None:
221
- """Fortin et al. (2018) ComBat."""
222
- batch_levels = batch.cat.categories
223
- n_batch = len(batch_levels)
236
+ """Fortin et al. (2018) neuroComBat."""
237
+ self._batch_levels = batch.cat.categories
238
+ n_batch = len(self._batch_levels)
224
239
  n_samples = len(X)
225
240
 
226
- batch_dummies = pd.get_dummies(batch, drop_first=False)
241
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
242
+ if self.reference_batch is not None:
243
+ if self.reference_batch not in self._batch_levels:
244
+ raise ValueError(
245
+ f"reference_batch={self.reference_batch!r} not present in batches."
246
+ f"{list(self._batch_levels)}"
247
+ )
248
+ batch_dummies.loc[:, self.reference_batch] = 1.0
249
+
227
250
  parts: list[pd.DataFrame] = [batch_dummies]
228
251
  if disc is not None:
229
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
252
+ parts.append(
253
+ pd.get_dummies(
254
+ disc.astype("category"), drop_first=True
255
+ ).astype(float)
256
+ )
257
+
230
258
  if cont is not None:
231
- parts.append(cont)
232
- design = pd.concat(parts, axis=1).astype(float).values
259
+ parts.append(cont.astype(float))
260
+
261
+ design = pd.concat(parts, axis=1).values
233
262
  p_design = design.shape[1]
234
263
 
235
264
  X_np = X.values
236
265
  beta_hat = la.lstsq(design, X_np, rcond=None)[0]
237
266
 
238
- gamma_hat = beta_hat[:n_batch]
267
+ beta_hat_batch = beta_hat[:n_batch]
239
268
  self._beta_hat_nonbatch = beta_hat[n_batch:]
240
269
 
241
- n_per_batch_arr = batch.value_counts().sort_index().values
242
- self._n_per_batch = dict(zip(batch_levels, n_per_batch_arr))
270
+ n_per_batch = batch.value_counts().sort_index().astype(int).values
271
+ self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
272
+
273
+ if self.reference_batch is not None:
274
+ ref_idx = list(self._batch_levels).index(self.reference_batch)
275
+ grand_mean = beta_hat_batch[ref_idx]
276
+ else:
277
+ grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
278
+ ref_idx = None
243
279
 
244
- grand_mean = (n_per_batch_arr / n_samples) @ gamma_hat
245
280
  self._grand_mean = pd.Series(grand_mean, index=X.columns)
246
281
 
247
- resid = X_np - design @ beta_hat
248
- var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
282
+ if self.reference_batch is not None:
283
+ ref_mask = (batch == self.reference_batch).values
284
+ resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
285
+ denom = int(ref_mask.sum())
286
+ else:
287
+ resid = X_np - design @ beta_hat
288
+ denom = n_samples
289
+ var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
249
290
  self._pooled_var = pd.Series(var_pooled, index=X.columns)
250
291
 
251
292
  stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
252
293
  Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
253
294
 
254
- delta_hat = np.empty_like(gamma_hat)
255
- for i, lvl in enumerate(batch_levels):
256
- idx = batch == lvl
257
- delta_hat[i] = Xs[idx].var(axis=0, ddof=1) + self.eps
295
+ gamma_hat = np.vstack(
296
+ [Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
297
+ )
298
+ delta_hat = np.vstack(
299
+ [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
300
+ for lvl in self._batch_levels]
301
+ )
258
302
 
259
303
  if self.mean_only:
260
304
  gamma_star = self._shrink_gamma(
261
- gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
305
+ gamma_hat, delta_hat, n_per_batch,
306
+ parametric = self.parametric
262
307
  )
263
308
  delta_star = np.ones_like(delta_hat)
264
309
  else:
265
310
  gamma_star, delta_star = self._shrink_gamma_delta(
266
- gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
311
+ gamma_hat, delta_hat, n_per_batch,
312
+ parametric = self.parametric
267
313
  )
268
314
 
269
- if self.reference_batch is not None:
270
- ref_idx = list(batch_levels).index(self.reference_batch)
271
- gamma_ref = gamma_star[ref_idx]
272
- delta_ref = delta_star[ref_idx]
273
- gamma_star = gamma_star - gamma_ref
315
+ if ref_idx is not None:
316
+ gamma_star[ref_idx] = 0.0
274
317
  if not self.mean_only:
275
- delta_star = delta_star / delta_ref
276
- self._reference_batch_idx = ref_idx
277
- else:
278
- self._reference_batch_idx = None
318
+ delta_star[ref_idx] = 1.0
319
+ self._reference_batch_idx = ref_idx
279
320
 
280
- self._batch_levels = batch_levels
281
321
  self._gamma_star = gamma_star
282
322
  self._delta_star = delta_star
283
- self._n_batch = n_batch
323
+ self._n_batch = n_batch
284
324
  self._p_design = p_design
285
325
 
286
326
  def _fit_chen(
@@ -434,7 +474,7 @@ class ComBatModel:
434
474
  *,
435
475
  parametric: bool,
436
476
  ) -> FloatArray:
437
- """Convenience wrapper that returns only γ⋆ (for *meanonly* mode)."""
477
+ """Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
438
478
  gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
439
479
  return gamma
440
480
 
@@ -454,7 +494,7 @@ class ComBatModel:
454
494
  batch = self._as_series(batch, idx, "batch")
455
495
  unseen = set(batch.cat.categories) - set(self._batch_levels)
456
496
  if unseen:
457
- raise ValueError(f"Unseen batch levels during transform: {unseen}")
497
+ raise ValueError(f"Unseen batch levels during transform: {unseen}.")
458
498
  disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
459
499
  cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
460
500
 
@@ -466,7 +506,7 @@ class ComBatModel:
466
506
  elif method == "chen":
467
507
  return self._transform_chen(X, batch, disc, cont)
468
508
  else:
469
- raise ValueError(f"Unknown method: {method}")
509
+ raise ValueError(f"Unknown method: {method}.")
470
510
 
471
511
  def _transform_johnson(
472
512
  self,
@@ -485,7 +525,7 @@ class ComBatModel:
485
525
  if not idx.any():
486
526
  continue
487
527
  if self.reference_batch is not None and lvl == self.reference_batch:
488
- X_adj.loc[idx] = X.loc[idx].values # untouched
528
+ X_adj.loc[idx] = X.loc[idx].values
489
529
  continue
490
530
 
491
531
  g = self._gamma_star[i]
@@ -505,18 +545,28 @@ class ComBatModel:
505
545
  cont: Optional[pd.DataFrame],
506
546
  ) -> pd.DataFrame:
507
547
  """Fortin transform implementation."""
508
- batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
509
- parts: list[pd.DataFrame] = [batch_dummies]
548
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
549
+ if self.reference_batch is not None:
550
+ batch_dummies.loc[:, self.reference_batch] = 1.0
551
+
552
+ parts = [batch_dummies]
510
553
  if disc is not None:
511
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
554
+ parts.append(
555
+ pd.get_dummies(
556
+ disc.astype("category"), drop_first=True
557
+ ).astype(float)
558
+ )
512
559
  if cont is not None:
513
- parts.append(cont)
560
+ parts.append(cont.astype(float))
514
561
 
515
- design = pd.concat(parts, axis=1).astype(float).values
562
+ design = pd.concat(parts, axis=1).values
516
563
 
517
564
  X_np = X.values
518
- stand_mean = self._grand_mean.values + design[:, self._n_batch:] @ self._beta_hat_nonbatch
519
- Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var.values)
565
+ stand_mu = (
566
+ self._grand_mean.values +
567
+ design[:, self._n_batch:] @ self._beta_hat_nonbatch
568
+ )
569
+ Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
520
570
 
521
571
  for i, lvl in enumerate(self._batch_levels):
522
572
  idx = batch == lvl
@@ -533,8 +583,11 @@ class ComBatModel:
533
583
  else:
534
584
  Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
535
585
 
536
- X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mean
537
- return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
586
+ X_adj = (
587
+ Xs * np.sqrt(self._pooled_var.values) +
588
+ stand_mu
589
+ )
590
+ return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
538
591
 
539
592
  def _transform_chen(
540
593
  self,
@@ -568,7 +621,7 @@ class ComBatModel:
568
621
 
569
622
 
570
623
  class ComBat(BaseEstimator, TransformerMixin):
571
- """Pipelinefriendly wrapper around `ComBatModel`.
624
+ """Pipeline-friendly wrapper around `ComBatModel`.
572
625
 
573
626
  Stores batch (and optional covariates) passed at construction and
574
627
  appropriately uses them for separate `fit` and `transform`.
@@ -621,6 +674,7 @@ class ComBat(BaseEstimator, TransformerMixin):
621
674
  discrete_covariates=disc,
622
675
  continuous_covariates=cont,
623
676
  )
677
+ self._fitted_batch = batch_vec
624
678
  return self
625
679
 
626
680
  def transform(self, X: ArrayLike) -> pd.DataFrame:
@@ -651,3 +705,315 @@ class ComBat(BaseEstimator, TransformerMixin):
651
705
  return pd.Series(obj, index=idx)
652
706
  else:
653
707
  return pd.DataFrame(obj, index=idx)
708
+
709
+ def plot_transformation(
710
+ self,
711
+ X: ArrayLike, *,
712
+ reduction_method: Literal['pca', 'tsne', 'umap'] = 'pca',
713
+ n_components: Literal[2, 3] = 2,
714
+ plot_type: Literal['static', 'interactive'] = 'static',
715
+ figsize: Tuple[int, int] = (12, 5),
716
+ alpha: float = 0.7,
717
+ point_size: int = 50,
718
+ cmap: str = 'Set1',
719
+ title: Optional[str] = None,
720
+ show_legend: bool = True,
721
+ return_embeddings: bool = False,
722
+ **reduction_kwargs) -> Union[Any, Tuple[Any, Dict[str, FloatArray]]]:
723
+ """
724
+ Visualize the ComBat transformation effect using dimensionality reduction.
725
+
726
+ It shows a before/after comparison of data transformed by `ComBat` using
727
+ PCA, t-SNE, or UMAP to reduce dimensions for visualization.
728
+
729
+ Parameters
730
+ ----------
731
+ X : array-like of shape (n_samples, n_features)
732
+ Input data to transform and visualize.
733
+
734
+ reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
735
+ Dimensionality reduction method.
736
+
737
+ n_components : {2, 3}, default=2
738
+ Number of components for dimensionality reduction.
739
+
740
+ plot_type : {`'static'`, `'interactive'`}, default=`'static'`
741
+ Visualization type:
742
+ - `'static'`: matplotlib plots (can be saved as images)
743
+ - `'interactive'`: plotly plots (explorable, requires plotly)
744
+
745
+ return_embeddings : bool, default=False
746
+ If `True`, return embeddings along with the plot.
747
+
748
+ **reduction_kwargs : dict
749
+ Additional parameters for reduction methods.
750
+
751
+ Returns
752
+ -------
753
+ fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
754
+ The figure object containing the plots.
755
+
756
+ embeddings : dict, optional
757
+ If `return_embeddings=True`, dictionary with:
758
+ - `'original'`: embedding of original data
759
+ - `'transformed'`: embedding of ComBat-transformed data
760
+ """
761
+ check_is_fitted(self._model, ["_gamma_star"])
762
+
763
+ if n_components not in [2, 3]:
764
+ raise ValueError(f"n_components must be 2 or 3, got {n_components}")
765
+ if reduction_method not in ['pca', 'tsne', 'umap']:
766
+ raise ValueError(f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'")
767
+ if plot_type not in ['static', 'interactive']:
768
+ raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
769
+
770
+ if reduction_method == 'umap' and not UMAP_AVAILABLE:
771
+ raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
772
+ if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
773
+ raise ImportError("Plotly is not installed. Install with: pip install plotly")
774
+
775
+ if not isinstance(X, pd.DataFrame):
776
+ X = pd.DataFrame(X)
777
+
778
+ idx = X.index
779
+ batch_vec = self._subset(self.batch, idx)
780
+ if batch_vec is None:
781
+ raise ValueError("Batch information is required for visualization")
782
+
783
+ X_transformed = self.transform(X)
784
+
785
+ X_np = X.values
786
+ X_trans_np = X_transformed.values
787
+
788
+ if reduction_method == 'pca':
789
+ reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
790
+ reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
791
+ elif reduction_method == 'tsne':
792
+ tsne_params = {'perplexity': 30, 'max_iter': 1000, 'random_state': 42}
793
+ tsne_params.update(reduction_kwargs)
794
+ reducer_orig = TSNE(n_components=n_components, **tsne_params)
795
+ reducer_trans = TSNE(n_components=n_components, **tsne_params)
796
+ else:
797
+ umap_params = {'random_state': 42}
798
+ umap_params.update(reduction_kwargs)
799
+ reducer_orig = umap.UMAP(n_components=n_components, **reduction_kwargs)
800
+ reducer_trans = umap.UMAP(n_components=n_components, **reduction_kwargs)
801
+
802
+ X_embedded_orig = reducer_orig.fit_transform(X_np)
803
+ X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
804
+
805
+ if plot_type == 'static':
806
+ fig = self._create_static_plot(
807
+ X_embedded_orig, X_embedded_trans, batch_vec,
808
+ reduction_method, n_components, figsize, alpha,
809
+ point_size, cmap, title, show_legend
810
+ )
811
+ else:
812
+ fig = self._create_interactive_plot(
813
+ X_embedded_orig, X_embedded_trans, batch_vec,
814
+ reduction_method, n_components, title, show_legend
815
+ )
816
+
817
+ if return_embeddings:
818
+ embeddings = {
819
+ 'original': X_embedded_orig,
820
+ 'transformed': X_embedded_trans
821
+ }
822
+ return fig, embeddings
823
+ else:
824
+ return fig
825
+
826
+ def _create_static_plot(
827
+ self,
828
+ X_orig: FloatArray,
829
+ X_trans: FloatArray,
830
+ batch_labels: pd.Series,
831
+ method: str,
832
+ n_components: int,
833
+ figsize: Tuple[int, int],
834
+ alpha: float,
835
+ point_size: int,
836
+ cmap: str,
837
+ title: Optional[str],
838
+ show_legend: bool) -> Any:
839
+ """Create static plots using matplotlib."""
840
+
841
+ fig = plt.figure(figsize=figsize)
842
+
843
+ unique_batches = batch_labels.drop_duplicates()
844
+ n_batches = len(unique_batches)
845
+
846
+ if n_batches <= 10:
847
+ colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_batches))
848
+ else:
849
+ colors = plt.cm.get_cmap('tab20')(np.linspace(0, 1, n_batches))
850
+
851
+ if n_components == 2:
852
+ ax1 = plt.subplot(1, 2, 1)
853
+ ax2 = plt.subplot(1, 2, 2)
854
+ else:
855
+ ax1 = fig.add_subplot(121, projection='3d')
856
+ ax2 = fig.add_subplot(122, projection='3d')
857
+
858
+ for i, batch in enumerate(unique_batches):
859
+ mask = batch_labels == batch
860
+ if n_components == 2:
861
+ ax1.scatter(
862
+ X_orig[mask, 0], X_orig[mask, 1],
863
+ c=[colors[i]],
864
+ s=point_size,
865
+ alpha=alpha,
866
+ label=f'Batch {batch}',
867
+ edgecolors='black',
868
+ linewidth=0.5
869
+ )
870
+ else:
871
+ ax1.scatter(
872
+ X_orig[mask, 0], X_orig[mask, 1], X_orig[mask, 2],
873
+ c=[colors[i]],
874
+ s=point_size,
875
+ alpha=alpha,
876
+ label=f'Batch {batch}',
877
+ edgecolors='black',
878
+ linewidth=0.5
879
+ )
880
+
881
+ ax1.set_title(f'Before ComBat correction\n({method.upper()})')
882
+ ax1.set_xlabel(f'{method.upper()}1')
883
+ ax1.set_ylabel(f'{method.upper()}2')
884
+ if n_components == 3:
885
+ ax1.set_zlabel(f'{method.upper()}3')
886
+
887
+ for i, batch in enumerate(unique_batches):
888
+ mask = batch_labels == batch
889
+ if n_components == 2:
890
+ ax2.scatter(
891
+ X_trans[mask, 0], X_trans[mask, 1],
892
+ c=[colors[i]],
893
+ s=point_size,
894
+ alpha=alpha,
895
+ label=f'Batch {batch}',
896
+ edgecolors='black',
897
+ linewidth=0.5
898
+ )
899
+ else:
900
+ ax2.scatter(
901
+ X_trans[mask, 0], X_trans[mask, 1], X_trans[mask, 2],
902
+ c=[colors[i]],
903
+ s=point_size,
904
+ alpha=alpha,
905
+ label=f'Batch {batch}',
906
+ edgecolors='black',
907
+ linewidth=0.5
908
+ )
909
+
910
+ ax2.set_title(f'After ComBat correction\n({method.upper()})')
911
+ ax2.set_xlabel(f'{method.upper()}1')
912
+ ax2.set_ylabel(f'{method.upper()}2')
913
+ if n_components == 3:
914
+ ax2.set_zlabel(f'{method.upper()}3')
915
+
916
+ if show_legend and n_batches <= 20:
917
+ ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
918
+
919
+ if title is None:
920
+ title = f'ComBat correction effect visualized with {method.upper()}'
921
+ fig.suptitle(title, fontsize=14, fontweight='bold')
922
+
923
+ plt.tight_layout()
924
+ return fig
925
+
926
+ def _create_interactive_plot(
927
+ self,
928
+ X_orig: FloatArray,
929
+ X_trans: FloatArray,
930
+ batch_labels: pd.Series,
931
+ method: str,
932
+ n_components: int,
933
+ title: Optional[str],
934
+ show_legend: bool) -> Any:
935
+ """Create interactive plots using plotly."""
936
+ if n_components == 2:
937
+ fig = make_subplots(
938
+ rows=1, cols=2,
939
+ subplot_titles=(
940
+ f'Before ComBat correction ({method.upper()})',
941
+ f'After ComBat correction ({method.upper()})'
942
+ )
943
+ )
944
+ else:
945
+ fig = make_subplots(
946
+ rows=1, cols=2,
947
+ specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
948
+ subplot_titles=(
949
+ f'Before ComBat correction ({method.upper()})',
950
+ f'After ComBat correction ({method.upper()})'
951
+ )
952
+ )
953
+
954
+ unique_batches = batch_labels.drop_duplicates()
955
+
956
+ for batch in unique_batches:
957
+ mask = batch_labels == batch
958
+
959
+ if n_components == 2:
960
+ fig.add_trace(
961
+ go.Scatter(x=X_orig[mask, 0], y=X_orig[mask, 1],
962
+ mode='markers',
963
+ name=f'Batch {batch}',
964
+ marker=dict(size=8, line=dict(width=1, color='black')),
965
+ showlegend=False),
966
+ row=1, col=1
967
+ )
968
+
969
+ fig.add_trace(
970
+ go.Scatter(x=X_trans[mask, 0], y=X_trans[mask, 1],
971
+ mode='markers',
972
+ name=f'Batch {batch}',
973
+ marker=dict(size=8, line=dict(width=1, color='black')),
974
+ showlegend=show_legend),
975
+ row=1, col=2
976
+ )
977
+ else:
978
+ fig.add_trace(
979
+ go.Scatter3d(x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
980
+ mode='markers',
981
+ name=f'Batch {batch}',
982
+ marker=dict(size=5, line=dict(width=0.5, color='black')),
983
+ showlegend=False),
984
+ row=1, col=1
985
+ )
986
+
987
+ fig.add_trace(
988
+ go.Scatter3d(x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
989
+ mode='markers',
990
+ name=f'Batch {batch}',
991
+ marker=dict(size=5, line=dict(width=0.5, color='black')),
992
+ showlegend=show_legend),
993
+ row=1, col=2
994
+ )
995
+
996
+ if title is None:
997
+ title = f'ComBat correction effect visualized with {method.upper()}'
998
+
999
+ fig.update_layout(
1000
+ title=title,
1001
+ title_font_size=16,
1002
+ height=600,
1003
+ showlegend=show_legend,
1004
+ hovermode='closest'
1005
+ )
1006
+
1007
+ axis_labels = [f'{method.upper()}{i+1}' for i in range(n_components)]
1008
+
1009
+ if n_components == 2:
1010
+ fig.update_xaxes(title_text=axis_labels[0])
1011
+ fig.update_yaxes(title_text=axis_labels[1])
1012
+ else:
1013
+ fig.update_scenes(
1014
+ xaxis_title=axis_labels[0],
1015
+ yaxis_title=axis_labels[1],
1016
+ zaxis_title=axis_labels[2]
1017
+ )
1018
+
1019
+ return fig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
@@ -37,15 +37,18 @@ License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
38
38
  Requires-Dist: numpy>=1.21
39
39
  Requires-Dist: scikit-learn>=1.2
40
+ Requires-Dist: plotly>=5.0
41
+ Requires-Dist: nbformat>=4.2
42
+ Requires-Dist: umap-learn>=0.5
40
43
  Requires-Dist: pytest>=7
41
44
  Dynamic: license-file
42
45
 
43
46
  # **combatlearn**
44
47
 
45
- [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
48
+ [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
46
49
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
47
50
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
48
- [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
51
+ [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
49
52
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
50
53
 
51
54
  <div align="center">
@@ -56,7 +59,7 @@ Dynamic: license-file
56
59
 
57
60
  **Three methods**:
58
61
  - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
59
- - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
62
+ - `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
60
63
  - `method="chen"` - CovBat (Chen _et al._, 2022)
61
64
 
62
65
  ## Installation
@@ -111,7 +114,7 @@ print("Best parameters:", grid.best_params_)
111
114
  print(f"Best CV AUROC: {grid.best_score_:.3f}")
112
115
  ```
113
116
 
114
- For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
117
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
115
118
 
116
119
  ## `ComBat` parameters
117
120
 
@@ -136,6 +139,13 @@ The following section provides a detailed explanation of all parameters availabl
136
139
  | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
140
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
141
 
142
+
143
+ ### Batch Effect Correction Visualization
144
+
145
+ The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
146
+
147
+ For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
148
+
139
149
  ## Contributing
140
150
 
141
151
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
@@ -144,7 +154,7 @@ Pull requests, bug reports, and feature ideas are welcome: feel free to open a P
144
154
 
145
155
  [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
156
 
147
- [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
157
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) | [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
158
 
149
159
  ## Acknowledgements
150
160
 
@@ -0,0 +1,11 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ combatlearn/__init__.py
5
+ combatlearn/combat.py
6
+ combatlearn.egg-info/PKG-INFO
7
+ combatlearn.egg-info/SOURCES.txt
8
+ combatlearn.egg-info/dependency_links.txt
9
+ combatlearn.egg-info/requires.txt
10
+ combatlearn.egg-info/top_level.txt
11
+ tests/test_combat.py
@@ -1,4 +1,7 @@
1
1
  pandas>=1.3
2
2
  numpy>=1.21
3
3
  scikit-learn>=1.2
4
+ plotly>=5.0
5
+ nbformat>=4.2
6
+ umap-learn>=0.5
4
7
  pytest>=7
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "combatlearn"
7
- version = "0.1.1"
7
+ version = "0.2.0"
8
8
  description = "Batch-effect harmonization for machine learning frameworks."
9
9
  authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
10
10
  requires-python = ">=3.10"
@@ -12,6 +12,9 @@ dependencies = [
12
12
  "pandas>=1.3",
13
13
  "numpy>=1.21",
14
14
  "scikit-learn>=1.2",
15
+ "plotly>=5.0",
16
+ "nbformat>=4.2",
17
+ "umap-learn>=0.5",
15
18
  "pytest>=7"
16
19
  ]
17
20
  license = {file="LICENSE"}
@@ -31,5 +34,5 @@ classifiers = [
31
34
  ]
32
35
 
33
36
  [tool.setuptools.packages.find]
34
- where = ["src"]
37
+ where = ["."]
35
38
  include = ["combatlearn*"]
@@ -1,11 +0,0 @@
1
- LICENSE
2
- README.md
3
- pyproject.toml
4
- src/combatlearn/__init__.py
5
- src/combatlearn/combat.py
6
- src/combatlearn.egg-info/PKG-INFO
7
- src/combatlearn.egg-info/SOURCES.txt
8
- src/combatlearn.egg-info/dependency_links.txt
9
- src/combatlearn.egg-info/requires.txt
10
- src/combatlearn.egg-info/top_level.txt
11
- tests/test_combat.py
File without changes
File without changes