combatlearn 0.2.1__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.1
3
+ Version: 1.0.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
- License-Expression: MIT
6
+ License: MIT
7
7
  Keywords: machine-learning,harmonization,combat,preprocessing
8
8
  Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Intended Audience :: Science/Research
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
19
19
  Requires-Dist: plotly>=5.0
20
20
  Requires-Dist: nbformat>=4.2
21
21
  Requires-Dist: umap-learn>=0.5
22
- Requires-Dist: pytest>=7
22
+ Provides-Extra: dev
23
+ Requires-Dist: pytest>=7; extra == "dev"
24
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
25
+ Requires-Dist: ruff>=0.1; extra == "dev"
26
+ Requires-Dist: mypy>=1.0; extra == "dev"
27
+ Provides-Extra: docs
28
+ Requires-Dist: mkdocs>=1.5.0; extra == "docs"
29
+ Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
30
+ Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
31
+ Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
23
32
  Dynamic: license-file
24
33
 
25
34
  # **combatlearn**
26
35
 
27
36
  [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
28
37
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
38
+ [![Documentation](https://readthedocs.org/projects/combatlearn/badge/?version=latest)](https://combatlearn.readthedocs.io)
29
39
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
30
40
  [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
31
41
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
95
105
 
96
106
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
97
107
 
108
+ ## Documentation
109
+
110
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
111
+
112
+ The documentation includes:
113
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
114
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
115
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
116
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
117
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
118
+
98
119
  ## `ComBat` parameters
99
120
 
100
121
  The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
@@ -2,6 +2,7 @@
2
2
 
3
3
  [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
4
4
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
5
+ [![Documentation](https://readthedocs.org/projects/combatlearn/badge/?version=latest)](https://combatlearn.readthedocs.io)
5
6
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
6
7
  [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
7
8
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
@@ -71,6 +72,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
71
72
 
72
73
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
73
74
 
75
+ ## Documentation
76
+
77
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
78
+
79
+ The documentation includes:
80
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
81
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
82
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
83
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
84
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
85
+
74
86
  ## `ComBat` parameters
75
87
 
76
88
  The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
@@ -0,0 +1,5 @@
1
+ from .combat import ComBat
2
+
3
+ __all__ = ["ComBat"]
4
+ __version__ = "1.0.0"
5
+ __author__ = "Ettore Rocchi"
@@ -14,28 +14,17 @@ import numpy as np
14
14
  import numpy.linalg as la
15
15
  import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
- from sklearn.utils.validation import check_is_fitted
18
17
  from sklearn.decomposition import PCA
19
18
  from sklearn.manifold import TSNE
19
+ import matplotlib
20
20
  import matplotlib.pyplot as plt
21
- from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
21
+ import matplotlib.colors as mcolors
22
+ from typing import Literal, Optional, Union, Dict, Tuple, Any
22
23
  import numpy.typing as npt
23
24
  import warnings
24
-
25
- try:
26
- import umap
27
- UMAP_AVAILABLE = True
28
- except ImportError:
29
- UMAP_AVAILABLE = False
30
-
31
- try:
32
- import plotly.graph_objects as go
33
- from plotly.subplots import make_subplots
34
- PLOTLY_AVAILABLE = True
35
- except ImportError:
36
- PLOTLY_AVAILABLE = False
37
-
38
- __author__ = "Ettore Rocchi"
25
+ import umap
26
+ import plotly.graph_objects as go
27
+ from plotly.subplots import make_subplots
39
28
 
40
29
  ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
41
30
  FloatArray = npt.NDArray[np.float64]
@@ -57,8 +46,9 @@ class ComBatModel:
57
46
  ignoring the variance (`delta_star`).
58
47
  reference_batch : str, optional
59
48
  If specified, the batch level to use as reference.
60
- covbat_cov_thresh : float, default=0.9
61
- CovBat: cumulative explained variance threshold for PCA.
49
+ covbat_cov_thresh : float or int, default=0.9
50
+ CovBat: cumulative variance threshold (0, 1] to retain PCs, or
51
+ integer >= 1 specifying the number of components directly.
62
52
  eps : float, default=1e-8
63
53
  Numerical jitter to avoid division-by-zero.
64
54
  """
@@ -66,19 +56,19 @@ class ComBatModel:
66
56
  def __init__(
67
57
  self,
68
58
  *,
69
- method: Literal["johnson", "fortin", "chen"] = "johnson",
59
+ method: Literal["johnson", "fortin", "chen"] = "johnson",
70
60
  parametric: bool = True,
71
61
  mean_only: bool = False,
72
62
  reference_batch: Optional[str] = None,
73
63
  eps: float = 1e-8,
74
- covbat_cov_thresh: float = 0.9,
64
+ covbat_cov_thresh: Union[float, int] = 0.9,
75
65
  ) -> None:
76
66
  self.method: str = method
77
67
  self.parametric: bool = parametric
78
68
  self.mean_only: bool = bool(mean_only)
79
69
  self.reference_batch: Optional[str] = reference_batch
80
70
  self.eps: float = float(eps)
81
- self.covbat_cov_thresh: float = float(covbat_cov_thresh)
71
+ self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
82
72
 
83
73
  self._batch_levels: pd.Index
84
74
  self._grand_mean: pd.Series
@@ -95,9 +85,16 @@ class ComBatModel:
95
85
  self._batch_levels_pc: pd.Index
96
86
  self._pc_gamma_star: FloatArray
97
87
  self._pc_delta_star: FloatArray
98
-
99
- if not (0.0 < self.covbat_cov_thresh <= 1.0):
100
- raise ValueError("covbat_cov_thresh must be in (0, 1].")
88
+
89
+ # Validate covbat_cov_thresh
90
+ if isinstance(self.covbat_cov_thresh, float):
91
+ if not (0.0 < self.covbat_cov_thresh <= 1.0):
92
+ raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
93
+ elif isinstance(self.covbat_cov_thresh, int):
94
+ if self.covbat_cov_thresh < 1:
95
+ raise ValueError("covbat_cov_thresh must be >= 1 when int.")
96
+ else:
97
+ raise TypeError("covbat_cov_thresh must be float or int.")
101
98
 
102
99
  @staticmethod
103
100
  def _as_series(
@@ -335,8 +332,14 @@ class ComBatModel:
335
332
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
336
333
  X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
337
334
  pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
338
- cumulative = np.cumsum(pca.explained_variance_ratio_)
339
- n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
335
+
336
+ # Determine number of components based on threshold type
337
+ if isinstance(self.covbat_cov_thresh, int):
338
+ n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
339
+ else:
340
+ cumulative = np.cumsum(pca.explained_variance_ratio_)
341
+ n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
342
+
340
343
  self._covbat_pca = pca
341
344
  self._covbat_n_pc = n_pc
342
345
 
@@ -487,7 +490,8 @@ class ComBatModel:
487
490
  continuous_covariates: Optional[ArrayLike] = None,
488
491
  ) -> pd.DataFrame:
489
492
  """Transform the data using fitted ComBat parameters."""
490
- check_is_fitted(self, ["_gamma_star"])
493
+ if not hasattr(self, "_gamma_star"):
494
+ raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
491
495
  if not isinstance(X, pd.DataFrame):
492
496
  X = pd.DataFrame(X)
493
497
  idx = X.index
@@ -599,7 +603,7 @@ class ComBatModel:
599
603
  """Chen transform implementation."""
600
604
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
601
605
  X_centered = X_meanvar_adj - self._covbat_pca.mean_
602
- scores = self._covbat_pca.transform(X_centered.values)
606
+ scores = self._covbat_pca.transform(X_centered)
603
607
  n_pc = self._covbat_n_pc
604
608
  scores_adj = scores.copy()
605
609
 
@@ -638,7 +642,7 @@ class ComBat(BaseEstimator, TransformerMixin):
638
642
  mean_only: bool = False,
639
643
  reference_batch: Optional[str] = None,
640
644
  eps: float = 1e-8,
641
- covbat_cov_thresh: float = 0.9,
645
+ covbat_cov_thresh: Union[float, int] = 0.9,
642
646
  ) -> None:
643
647
  self.batch = batch
644
648
  self.discrete_covariates = discrete_covariates
@@ -758,7 +762,8 @@ class ComBat(BaseEstimator, TransformerMixin):
758
762
  - `'original'`: embedding of original data
759
763
  - `'transformed'`: embedding of ComBat-transformed data
760
764
  """
761
- check_is_fitted(self._model, ["_gamma_star"])
765
+ if not hasattr(self._model, "_gamma_star"):
766
+ raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
762
767
 
763
768
  if n_components not in [2, 3]:
764
769
  raise ValueError(f"n_components must be 2 or 3, got {n_components}")
@@ -767,11 +772,6 @@ class ComBat(BaseEstimator, TransformerMixin):
767
772
  if plot_type not in ['static', 'interactive']:
768
773
  raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
769
774
 
770
- if reduction_method == 'umap' and not UMAP_AVAILABLE:
771
- raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
772
- if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
773
- raise ImportError("Plotly is not installed. Install with: pip install plotly")
774
-
775
775
  if not isinstance(X, pd.DataFrame):
776
776
  X = pd.DataFrame(X)
777
777
 
@@ -796,8 +796,8 @@ class ComBat(BaseEstimator, TransformerMixin):
796
796
  else:
797
797
  umap_params = {'random_state': 42}
798
798
  umap_params.update(reduction_kwargs)
799
- reducer_orig = umap.UMAP(n_components=n_components, **reduction_kwargs)
800
- reducer_trans = umap.UMAP(n_components=n_components, **reduction_kwargs)
799
+ reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
800
+ reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
801
801
 
802
802
  X_embedded_orig = reducer_orig.fit_transform(X_np)
803
803
  X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
@@ -811,7 +811,7 @@ class ComBat(BaseEstimator, TransformerMixin):
811
811
  else:
812
812
  fig = self._create_interactive_plot(
813
813
  X_embedded_orig, X_embedded_trans, batch_vec,
814
- reduction_method, n_components, title, show_legend
814
+ reduction_method, n_components, cmap, title, show_legend
815
815
  )
816
816
 
817
817
  if return_embeddings:
@@ -844,9 +844,9 @@ class ComBat(BaseEstimator, TransformerMixin):
844
844
  n_batches = len(unique_batches)
845
845
 
846
846
  if n_batches <= 10:
847
- colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_batches))
847
+ colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
848
848
  else:
849
- colors = plt.cm.get_cmap('tab20')(np.linspace(0, 1, n_batches))
849
+ colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
850
850
 
851
851
  if n_components == 2:
852
852
  ax1 = plt.subplot(1, 2, 1)
@@ -930,6 +930,7 @@ class ComBat(BaseEstimator, TransformerMixin):
930
930
  batch_labels: pd.Series,
931
931
  method: str,
932
932
  n_components: int,
933
+ cmap: str,
933
934
  title: Optional[str],
934
935
  show_legend: bool) -> Any:
935
936
  """Create interactive plots using plotly."""
@@ -953,43 +954,69 @@ class ComBat(BaseEstimator, TransformerMixin):
953
954
 
954
955
  unique_batches = batch_labels.drop_duplicates()
955
956
 
957
+ n_batches = len(unique_batches)
958
+ cmap_func = matplotlib.colormaps.get_cmap(cmap)
959
+ color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
960
+
961
+ batch_to_color = dict(zip(unique_batches, color_list))
962
+
956
963
  for batch in unique_batches:
957
964
  mask = batch_labels == batch
958
965
 
959
966
  if n_components == 2:
960
967
  fig.add_trace(
961
- go.Scatter(x=X_orig[mask, 0], y=X_orig[mask, 1],
962
- mode='markers',
963
- name=f'Batch {batch}',
964
- marker=dict(size=8, line=dict(width=1, color='black')),
965
- showlegend=False),
968
+ go.Scatter(
969
+ x=X_orig[mask, 0], y=X_orig[mask, 1],
970
+ mode='markers',
971
+ name=f'Batch {batch}',
972
+ marker=dict(
973
+ size=8,
974
+ color=batch_to_color[batch],
975
+ line=dict(width=1, color='black')
976
+ ),
977
+ showlegend=False),
966
978
  row=1, col=1
967
979
  )
968
980
 
969
981
  fig.add_trace(
970
- go.Scatter(x=X_trans[mask, 0], y=X_trans[mask, 1],
971
- mode='markers',
972
- name=f'Batch {batch}',
973
- marker=dict(size=8, line=dict(width=1, color='black')),
974
- showlegend=show_legend),
982
+ go.Scatter(
983
+ x=X_trans[mask, 0], y=X_trans[mask, 1],
984
+ mode='markers',
985
+ name=f'Batch {batch}',
986
+ marker=dict(
987
+ size=8,
988
+ color=batch_to_color[batch],
989
+ line=dict(width=1, color='black')
990
+ ),
991
+ showlegend=show_legend),
975
992
  row=1, col=2
976
993
  )
977
994
  else:
978
995
  fig.add_trace(
979
- go.Scatter3d(x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
980
- mode='markers',
981
- name=f'Batch {batch}',
982
- marker=dict(size=5, line=dict(width=0.5, color='black')),
983
- showlegend=False),
996
+ go.Scatter3d(
997
+ x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
998
+ mode='markers',
999
+ name=f'Batch {batch}',
1000
+ marker=dict(
1001
+ size=5,
1002
+ color=batch_to_color[batch],
1003
+ line=dict(width=0.5, color='black')
1004
+ ),
1005
+ showlegend=False),
984
1006
  row=1, col=1
985
1007
  )
986
1008
 
987
1009
  fig.add_trace(
988
- go.Scatter3d(x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
989
- mode='markers',
990
- name=f'Batch {batch}',
991
- marker=dict(size=5, line=dict(width=0.5, color='black')),
992
- showlegend=show_legend),
1010
+ go.Scatter3d(
1011
+ x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
1012
+ mode='markers',
1013
+ name=f'Batch {batch}',
1014
+ marker=dict(
1015
+ size=5,
1016
+ color=batch_to_color[batch],
1017
+ line=dict(width=0.5, color='black')
1018
+ ),
1019
+ showlegend=show_legend),
993
1020
  row=1, col=2
994
1021
  )
995
1022
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.1
3
+ Version: 1.0.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
- License-Expression: MIT
6
+ License: MIT
7
7
  Keywords: machine-learning,harmonization,combat,preprocessing
8
8
  Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Intended Audience :: Science/Research
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
19
19
  Requires-Dist: plotly>=5.0
20
20
  Requires-Dist: nbformat>=4.2
21
21
  Requires-Dist: umap-learn>=0.5
22
- Requires-Dist: pytest>=7
22
+ Provides-Extra: dev
23
+ Requires-Dist: pytest>=7; extra == "dev"
24
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
25
+ Requires-Dist: ruff>=0.1; extra == "dev"
26
+ Requires-Dist: mypy>=1.0; extra == "dev"
27
+ Provides-Extra: docs
28
+ Requires-Dist: mkdocs>=1.5.0; extra == "docs"
29
+ Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
30
+ Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
31
+ Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
23
32
  Dynamic: license-file
24
33
 
25
34
  # **combatlearn**
26
35
 
27
36
  [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
28
37
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
38
+ [![Documentation](https://readthedocs.org/projects/combatlearn/badge/?version=latest)](https://combatlearn.readthedocs.io)
29
39
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
30
40
  [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
31
41
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
95
105
 
96
106
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
97
107
 
108
+ ## Documentation
109
+
110
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
111
+
112
+ The documentation includes:
113
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
114
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
115
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
116
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
117
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
118
+
98
119
  ## `ComBat` parameters
99
120
 
100
121
  The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
@@ -1,6 +1,7 @@
1
1
  LICENSE
2
2
  README.md
3
3
  pyproject.toml
4
+ requirements.txt
4
5
  combatlearn/__init__.py
5
6
  combatlearn/combat.py
6
7
  combatlearn.egg-info/PKG-INFO
@@ -0,0 +1,19 @@
1
+ pandas>=1.3
2
+ numpy>=1.21
3
+ scikit-learn>=1.2
4
+ matplotlib>=3.4
5
+ plotly>=5.0
6
+ nbformat>=4.2
7
+ umap-learn>=0.5
8
+
9
+ [dev]
10
+ pytest>=7
11
+ pytest-cov>=4.0
12
+ ruff>=0.1
13
+ mypy>=1.0
14
+
15
+ [docs]
16
+ mkdocs>=1.5.0
17
+ mkdocs-material>=9.0.0
18
+ mkdocstrings[python]>=0.24.0
19
+ pymdown-extensions>=10.0
@@ -4,21 +4,11 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "combatlearn"
7
- version = "0.2.1"
7
+ dynamic = ["version", "dependencies"]
8
8
  description = "Batch-effect harmonization for machine learning frameworks."
9
9
  authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
10
10
  requires-python = ">=3.10"
11
- dependencies = [
12
- "pandas>=1.3",
13
- "numpy>=1.21",
14
- "scikit-learn>=1.2",
15
- "matplotlib>=3.4",
16
- "plotly>=5.0",
17
- "nbformat>=4.2",
18
- "umap-learn>=0.5",
19
- "pytest>=7"
20
- ]
21
- license = "MIT"
11
+ license = {text = "MIT"}
22
12
  readme = {file="README.md", content-type="text/markdown"}
23
13
  keywords = [
24
14
  "machine-learning",
@@ -33,6 +23,24 @@ classifiers = [
33
23
  "Programming Language :: Python :: 3",
34
24
  ]
35
25
 
26
+ [project.optional-dependencies]
27
+ dev = [
28
+ "pytest>=7",
29
+ "pytest-cov>=4.0",
30
+ "ruff>=0.1",
31
+ "mypy>=1.0",
32
+ ]
33
+ docs = [
34
+ "mkdocs>=1.5.0",
35
+ "mkdocs-material>=9.0.0",
36
+ "mkdocstrings[python]>=0.24.0",
37
+ "pymdown-extensions>=10.0",
38
+ ]
39
+
36
40
  [tool.setuptools.packages.find]
37
41
  where = ["."]
38
42
  include = ["combatlearn*"]
43
+
44
+ [tool.setuptools.dynamic]
45
+ version = {attr = "combatlearn.__version__"}
46
+ dependencies = {file = ["requirements.txt"]}
@@ -4,5 +4,4 @@ scikit-learn>=1.2
4
4
  matplotlib>=3.4
5
5
  plotly>=5.0
6
6
  nbformat>=4.2
7
- umap-learn>=0.5
8
- pytest>=7
7
+ umap-learn>=0.5
@@ -0,0 +1,379 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pytest
4
+ from sklearn.pipeline import Pipeline
5
+ from sklearn.base import clone
6
+ from sklearn.preprocessing import StandardScaler
7
+ from combatlearn import ComBat
8
+ from combatlearn.combat import ComBatModel
9
+ from utils import simulate_data, simulate_covariate_data
10
+
11
+
12
+ def test_transform_without_fit_raises():
13
+ """
14
+ Test that `transform` raises a `ValueError` if not fitted.
15
+ """
16
+ X, batch = simulate_data()
17
+ model = ComBatModel()
18
+ with pytest.raises(ValueError, match="not fitted"):
19
+ model.transform(X, batch=batch)
20
+
21
+
22
+ def test_unseen_batch_raises_value_error():
23
+ """
24
+ Test that unseen batch raises a `ValueError`.
25
+ """
26
+ X, batch = simulate_data()
27
+ model = ComBatModel().fit(X, batch=batch)
28
+ new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
29
+ with pytest.raises(ValueError):
30
+ model.transform(X, batch=new_batch)
31
+
32
+
33
+ def test_single_sample_batch_error():
34
+ """
35
+ Test that a single sample batch raises a `ValueError`.
36
+ """
37
+ X, batch = simulate_data()
38
+ batch.iloc[0] = "single"
39
+ with pytest.raises(ValueError):
40
+ ComBatModel().fit(X, batch=batch)
41
+
42
+
43
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
44
+ def test_dtypes_preserved(method):
45
+ """All output columns must remain floating dtypes after correction."""
46
+ if method == "johnson":
47
+ X, batch = simulate_data()
48
+ extra = {}
49
+ else: # fortin or chen
50
+ X, batch, disc, cont = simulate_covariate_data()
51
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
52
+
53
+ X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
54
+ assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
55
+
56
+ def test_wrapper_clone_and_pipeline():
57
+ """
58
+ Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
59
+ """
60
+ X, batch = simulate_data()
61
+ wrapper = ComBat(batch=batch, parametric=True)
62
+ pipe = Pipeline([
63
+ ("scaler", StandardScaler()),
64
+ ("combat", wrapper),
65
+ ])
66
+ X_corr = pipe.fit_transform(X)
67
+ pipe_clone: Pipeline = clone(pipe)
68
+ X_corr2 = pipe_clone.fit_transform(X)
69
+ np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
70
+
71
+
72
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
73
+ def test_no_nan_or_inf_in_output(method):
74
+ """`ComBat` must not introduce NaN or Inf values, for any backend."""
75
+ if method == "johnson":
76
+ X, batch = simulate_data()
77
+ extra = {}
78
+ else: # fortin or chen
79
+ X, batch, disc, cont = simulate_covariate_data()
80
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
81
+
82
+ X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
83
+ assert not np.isnan(X_corr.values).any()
84
+ assert not np.isinf(X_corr.values).any()
85
+
86
+
87
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
88
+ def test_shape_preserved(method):
89
+ """The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
90
+ if method == "johnson":
91
+ X, batch = simulate_data()
92
+ combat = ComBat(batch=batch, method=method).fit(X)
93
+ elif method in ["fortin", "chen"]:
94
+ X, batch, disc, cont = simulate_covariate_data()
95
+ combat = ComBat(
96
+ batch=batch,
97
+ discrete_covariates=disc,
98
+ continuous_covariates=cont,
99
+ method=method,
100
+ ).fit(X)
101
+
102
+ X_corr = combat.transform(X)
103
+ assert X_corr.shape == X.shape
104
+
105
+
106
+ def test_johnson_print_warning():
107
+ """
108
+ Test that a warning is printed when using the Johnson method.
109
+ """
110
+ X, batch, disc, cont = simulate_covariate_data()
111
+ with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
112
+ _ = ComBat(
113
+ batch=batch,
114
+ discrete_covariates=disc,
115
+ continuous_covariates=cont,
116
+ method="johnson",
117
+ ).fit(X)
118
+
119
+
120
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
121
+ def test_reference_batch_samples_unchanged(method):
122
+ """
123
+ Samples belonging to the reference batch must come out *numerically identical*
124
+ (within floating-point jitter) after correction.
125
+ """
126
+ if method == "johnson":
127
+ X, batch = simulate_data()
128
+ extra = {}
129
+ elif method in ["fortin", "chen"]:
130
+ X, batch, disc, cont = simulate_covariate_data()
131
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
132
+
133
+ ref_batch = batch.iloc[0]
134
+ combat = ComBat(batch=batch, method=method,
135
+ reference_batch=ref_batch, **extra).fit(X)
136
+ X_corr = combat.transform(X)
137
+
138
+ mask = batch == ref_batch
139
+ np.testing.assert_allclose(X_corr.loc[mask].values,
140
+ X.loc[mask].values,
141
+ rtol=0, atol=1e-10)
142
+
143
+
144
+ def test_reference_batch_missing_raises():
145
+ """
146
+ Asking for a reference batch that doesn't exist should fail.
147
+ """
148
+ X, batch = simulate_data()
149
+ with pytest.raises(ValueError, match="not present"):
150
+ ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)
151
+
152
+
153
+ @pytest.mark.parametrize("parametric", [True, False])
154
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
155
+ def test_parametric_vs_nonparametric(parametric, method):
156
+ """
157
+ Test both parametric and non-parametric modes work without errors.
158
+ """
159
+ if method == "johnson":
160
+ X, batch = simulate_data()
161
+ extra = {}
162
+ else:
163
+ X, batch, disc, cont = simulate_covariate_data()
164
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
165
+
166
+ combat = ComBat(batch=batch, method=method, parametric=parametric, **extra)
167
+ X_corr = combat.fit_transform(X)
168
+ assert X_corr.shape == X.shape
169
+ assert not np.isnan(X_corr.values).any()
170
+
171
+
172
+ @pytest.mark.parametrize("mean_only", [True, False])
173
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
174
+ def test_mean_only_mode(mean_only, method):
175
+ """
176
+ Test mean_only mode works for all methods.
177
+ """
178
+ if method == "johnson":
179
+ X, batch = simulate_data()
180
+ extra = {}
181
+ else:
182
+ X, batch, disc, cont = simulate_covariate_data()
183
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
184
+
185
+ combat = ComBat(batch=batch, method=method, mean_only=mean_only, **extra)
186
+ X_corr = combat.fit_transform(X)
187
+ assert X_corr.shape == X.shape
188
+ assert not np.isnan(X_corr.values).any()
189
+
190
+
191
+ def test_covbat_cov_thresh_as_float():
192
+ """
193
+ Test CovBat with covbat_cov_thresh as float (cumulative variance threshold).
194
+ """
195
+ X, batch, disc, cont = simulate_covariate_data()
196
+ combat = ComBat(
197
+ batch=batch,
198
+ discrete_covariates=disc,
199
+ continuous_covariates=cont,
200
+ method="chen",
201
+ covbat_cov_thresh=0.95,
202
+ )
203
+ X_corr = combat.fit_transform(X)
204
+ assert X_corr.shape == X.shape
205
+ assert not np.isnan(X_corr.values).any()
206
+
207
+
208
+ def test_covbat_cov_thresh_as_int():
209
+ """
210
+ Test CovBat with covbat_cov_thresh as int (number of components).
211
+ """
212
+ X, batch, disc, cont = simulate_covariate_data()
213
+ n_components = 10
214
+ combat = ComBat(
215
+ batch=batch,
216
+ discrete_covariates=disc,
217
+ continuous_covariates=cont,
218
+ method="chen",
219
+ covbat_cov_thresh=n_components,
220
+ )
221
+ X_corr = combat.fit_transform(X)
222
+ assert X_corr.shape == X.shape
223
+ assert not np.isnan(X_corr.values).any()
224
+ assert combat._model._covbat_n_pc == n_components
225
+
226
+
227
+ def test_covbat_cov_thresh_invalid_float_raises():
228
+ """
229
+ Test that invalid float values for covbat_cov_thresh raise ValueError.
230
+ """
231
+ with pytest.raises(ValueError, match="must be in \\(0, 1\\]"):
232
+ ComBatModel(covbat_cov_thresh=1.5)
233
+
234
+ with pytest.raises(ValueError, match="must be in \\(0, 1\\]"):
235
+ ComBatModel(covbat_cov_thresh=0.0)
236
+
237
+
238
+ def test_covbat_cov_thresh_invalid_int_raises():
239
+ """
240
+ Test that invalid int values for covbat_cov_thresh raise ValueError.
241
+ """
242
+ with pytest.raises(ValueError, match="must be >= 1"):
243
+ ComBatModel(covbat_cov_thresh=0)
244
+
245
+
246
+ def test_covbat_cov_thresh_invalid_type_raises():
247
+ """
248
+ Test that invalid types for covbat_cov_thresh raise TypeError.
249
+ """
250
+ with pytest.raises(TypeError, match="must be float or int"):
251
+ ComBatModel(covbat_cov_thresh="invalid")
252
+
253
+
254
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
255
+ def test_index_preserved(method):
256
+ """
257
+ Test that the index is preserved after transformation.
258
+ """
259
+ if method == "johnson":
260
+ X, batch = simulate_data()
261
+ extra = {}
262
+ else:
263
+ X, batch, disc, cont = simulate_covariate_data()
264
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
265
+
266
+ custom_index = pd.Index([f"sample_{i}" for i in range(len(X))])
267
+ X.index = custom_index
268
+ batch.index = custom_index
269
+ if method != "johnson":
270
+ disc.index = custom_index
271
+ cont.index = custom_index
272
+
273
+ combat = ComBat(batch=batch, method=method, **extra)
274
+ X_corr = combat.fit_transform(X)
275
+ assert X_corr.index.equals(custom_index)
276
+
277
+
278
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
279
+ def test_column_names_preserved(method):
280
+ """
281
+ Test that column names are preserved after transformation.
282
+ """
283
+ if method == "johnson":
284
+ X, batch = simulate_data()
285
+ extra = {}
286
+ else:
287
+ X, batch, disc, cont = simulate_covariate_data()
288
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
289
+
290
+ custom_columns = [f"feature_{i}" for i in range(X.shape[1])]
291
+ X.columns = custom_columns
292
+
293
+ combat = ComBat(batch=batch, method=method, **extra)
294
+ X_corr = combat.fit_transform(X)
295
+ assert list(X_corr.columns) == custom_columns
296
+
297
+
298
+ def test_plot_transformation_static_2d():
299
+ """
300
+ Test plot_transformation with static 2D PCA visualization.
301
+ """
302
+ X, batch = simulate_data(n_samples=100, n_features=20)
303
+ combat = ComBat(batch=batch, method="johnson").fit(X)
304
+
305
+ fig = combat.plot_transformation(
306
+ X,
307
+ reduction_method='pca',
308
+ n_components=2,
309
+ plot_type='static'
310
+ )
311
+ assert fig is not None
312
+
313
+
314
+ def test_plot_transformation_static_3d():
315
+ """
316
+ Test plot_transformation with static 3D PCA visualization.
317
+ """
318
+ X, batch = simulate_data(n_samples=100, n_features=20)
319
+ combat = ComBat(batch=batch, method="johnson").fit(X)
320
+
321
+ fig = combat.plot_transformation(
322
+ X,
323
+ reduction_method='pca',
324
+ n_components=3,
325
+ plot_type='static'
326
+ )
327
+ assert fig is not None
328
+
329
+
330
+ def test_plot_transformation_return_embeddings():
331
+ """
332
+ Test that plot_transformation can return embeddings.
333
+ """
334
+ X, batch = simulate_data(n_samples=100, n_features=20)
335
+ combat = ComBat(batch=batch, method="johnson").fit(X)
336
+
337
+ fig, embeddings = combat.plot_transformation(
338
+ X,
339
+ reduction_method='pca',
340
+ n_components=2,
341
+ plot_type='static',
342
+ return_embeddings=True
343
+ )
344
+ assert fig is not None
345
+ assert 'original' in embeddings
346
+ assert 'transformed' in embeddings
347
+ assert embeddings['original'].shape == (100, 2)
348
+ assert embeddings['transformed'].shape == (100, 2)
349
+
350
+
351
+ def test_plot_transformation_invalid_method_raises():
352
+ """
353
+ Test that invalid reduction_method raises ValueError.
354
+ """
355
+ X, batch = simulate_data(n_samples=100, n_features=20)
356
+ combat = ComBat(batch=batch, method="johnson").fit(X)
357
+
358
+ with pytest.raises(ValueError, match="reduction_method must be"):
359
+ combat.plot_transformation(X, reduction_method='invalid')
360
+
361
+
362
+ def test_plot_transformation_invalid_n_components_raises():
363
+ """
364
+ Test that invalid n_components raises ValueError.
365
+ """
366
+ X, batch = simulate_data(n_samples=100, n_features=20)
367
+ combat = ComBat(batch=batch, method="johnson").fit(X)
368
+
369
+ with pytest.raises(ValueError, match="n_components must be 2 or 3"):
370
+ combat.plot_transformation(X, n_components=4)
371
+
372
+
373
+ def test_invalid_method_raises():
374
+ """
375
+ Test that an invalid method raises ValueError.
376
+ """
377
+ X, batch = simulate_data()
378
+ with pytest.raises(ValueError, match="method must be"):
379
+ ComBatModel(method="invalid").fit(X, batch=batch)
@@ -1,4 +0,0 @@
1
- from .combat import ComBatModel, ComBat
2
-
3
- __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.2.1"
@@ -1,150 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import pytest
4
- from sklearn.pipeline import Pipeline
5
- from sklearn.base import clone
6
- from sklearn.preprocessing import StandardScaler
7
- from sklearn.exceptions import NotFittedError
8
- from combatlearn import ComBatModel, ComBat
9
- from utils import simulate_data, simulate_covariate_data
10
-
11
-
12
- def test_transform_without_fit_raises():
13
- """
14
- Test that `transform` raises a `NotFittedError` if not fitted.
15
- """
16
- X, batch = simulate_data()
17
- model = ComBatModel()
18
- with pytest.raises(NotFittedError):
19
- model.transform(X, batch=batch)
20
-
21
-
22
- def test_unseen_batch_raises_value_error():
23
- """
24
- Test that unseen batch raises a `ValueError`.
25
- """
26
- X, batch = simulate_data()
27
- model = ComBatModel().fit(X, batch=batch)
28
- new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
29
- with pytest.raises(ValueError):
30
- model.transform(X, batch=new_batch)
31
-
32
-
33
- def test_single_sample_batch_error():
34
- """
35
- Test that a single sample batch raises a `ValueError`.
36
- """
37
- X, batch = simulate_data()
38
- batch.iloc[0] = "single"
39
- with pytest.raises(ValueError):
40
- ComBatModel().fit(X, batch=batch)
41
-
42
-
43
- @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
44
- def test_dtypes_preserved(method):
45
- """All output columns must remain floating dtypes after correction."""
46
- if method == "johnson":
47
- X, batch = simulate_data()
48
- extra = {}
49
- else: # fortin or chen
50
- X, batch, disc, cont = simulate_covariate_data()
51
- extra = dict(discrete_covariates=disc, continuous_covariates=cont)
52
-
53
- X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
54
- assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
55
-
56
- def test_wrapper_clone_and_pipeline():
57
- """
58
- Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
59
- """
60
- X, batch = simulate_data()
61
- wrapper = ComBat(batch=batch, parametric=True)
62
- pipe = Pipeline([
63
- ("scaler", StandardScaler()),
64
- ("combat", wrapper),
65
- ])
66
- X_corr = pipe.fit_transform(X)
67
- pipe_clone: Pipeline = clone(pipe)
68
- X_corr2 = pipe_clone.fit_transform(X)
69
- np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
70
-
71
-
72
- @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
73
- def test_no_nan_or_inf_in_output(method):
74
- """`ComBat` must not introduce NaN or Inf values, for any backend."""
75
- if method == "johnson":
76
- X, batch = simulate_data()
77
- extra = {}
78
- else: # fortin or chen
79
- X, batch, disc, cont = simulate_covariate_data()
80
- extra = dict(discrete_covariates=disc, continuous_covariates=cont)
81
-
82
- X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
83
- assert not np.isnan(X_corr.values).any()
84
- assert not np.isinf(X_corr.values).any()
85
-
86
-
87
- @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
88
- def test_shape_preserved(method):
89
- """The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
90
- if method == "johnson":
91
- X, batch = simulate_data()
92
- combat = ComBat(batch=batch, method=method).fit(X)
93
- elif method in ["fortin", "chen"]:
94
- X, batch, disc, cont = simulate_covariate_data()
95
- combat = ComBat(
96
- batch=batch,
97
- discrete_covariates=disc,
98
- continuous_covariates=cont,
99
- method=method,
100
- ).fit(X)
101
-
102
- X_corr = combat.transform(X)
103
- assert X_corr.shape == X.shape
104
-
105
-
106
- def test_johnson_print_warning():
107
- """
108
- Test that a warning is printed when using the Johnson method.
109
- """
110
- X, batch, disc, cont = simulate_covariate_data()
111
- with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
112
- _ = ComBat(
113
- batch=batch,
114
- discrete_covariates=disc,
115
- continuous_covariates=cont,
116
- method="johnson",
117
- ).fit(X)
118
-
119
-
120
- @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
121
- def test_reference_batch_samples_unchanged(method):
122
- """
123
- Samples belonging to the reference batch must come out *numerically identical*
124
- (within floating-point jitter) after correction.
125
- """
126
- if method == "johnson":
127
- X, batch = simulate_data()
128
- extra = {}
129
- elif method in ["fortin", "chen"]:
130
- X, batch, disc, cont = simulate_covariate_data()
131
- extra = dict(discrete_covariates=disc, continuous_covariates=cont)
132
-
133
- ref_batch = batch.iloc[0]
134
- combat = ComBat(batch=batch, method=method,
135
- reference_batch=ref_batch, **extra).fit(X)
136
- X_corr = combat.transform(X)
137
-
138
- mask = batch == ref_batch
139
- np.testing.assert_allclose(X_corr.loc[mask].values,
140
- X.loc[mask].values,
141
- rtol=0, atol=1e-10)
142
-
143
-
144
- def test_reference_batch_missing_raises():
145
- """
146
- Asking for a reference batch that doesn't exist should fail.
147
- """
148
- X, batch = simulate_data()
149
- with pytest.raises(ValueError, match="not present"):
150
- ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)
File without changes
File without changes