combatlearn 0.2.1__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {combatlearn-0.2.1 → combatlearn-1.0.0}/PKG-INFO +24 -3
- {combatlearn-0.2.1 → combatlearn-1.0.0}/README.md +12 -0
- combatlearn-1.0.0/combatlearn/__init__.py +5 -0
- {combatlearn-0.2.1 → combatlearn-1.0.0}/combatlearn/combat.py +88 -61
- {combatlearn-0.2.1 → combatlearn-1.0.0}/combatlearn.egg-info/PKG-INFO +24 -3
- {combatlearn-0.2.1 → combatlearn-1.0.0}/combatlearn.egg-info/SOURCES.txt +1 -0
- combatlearn-1.0.0/combatlearn.egg-info/requires.txt +19 -0
- {combatlearn-0.2.1 → combatlearn-1.0.0}/pyproject.toml +20 -12
- combatlearn-0.2.1/combatlearn.egg-info/requires.txt → combatlearn-1.0.0/requirements.txt +1 -2
- combatlearn-1.0.0/tests/test_combat.py +379 -0
- combatlearn-0.2.1/combatlearn/__init__.py +0 -4
- combatlearn-0.2.1/tests/test_combat.py +0 -150
- {combatlearn-0.2.1 → combatlearn-1.0.0}/LICENSE +0 -0
- {combatlearn-0.2.1 → combatlearn-1.0.0}/combatlearn.egg-info/dependency_links.txt +0 -0
- {combatlearn-0.2.1 → combatlearn-1.0.0}/combatlearn.egg-info/top_level.txt +0 -0
- {combatlearn-0.2.1 → combatlearn-1.0.0}/setup.cfg +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
|
-
License
|
|
6
|
+
License: MIT
|
|
7
7
|
Keywords: machine-learning,harmonization,combat,preprocessing
|
|
8
8
|
Classifier: Development Status :: 3 - Alpha
|
|
9
9
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
|
|
|
19
19
|
Requires-Dist: plotly>=5.0
|
|
20
20
|
Requires-Dist: nbformat>=4.2
|
|
21
21
|
Requires-Dist: umap-learn>=0.5
|
|
22
|
-
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
24
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
25
|
+
Requires-Dist: ruff>=0.1; extra == "dev"
|
|
26
|
+
Requires-Dist: mypy>=1.0; extra == "dev"
|
|
27
|
+
Provides-Extra: docs
|
|
28
|
+
Requires-Dist: mkdocs>=1.5.0; extra == "docs"
|
|
29
|
+
Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
|
|
30
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
|
|
31
|
+
Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
|
|
23
32
|
Dynamic: license-file
|
|
24
33
|
|
|
25
34
|
# **combatlearn**
|
|
26
35
|
|
|
27
36
|
[](https://www.python.org/)
|
|
28
37
|
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
38
|
+
[](https://combatlearn.readthedocs.io)
|
|
29
39
|
[](https://pepy.tech/projects/combatlearn)
|
|
30
40
|
[](https://pypi.org/project/combatlearn/)
|
|
31
41
|
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
95
105
|
|
|
96
106
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
97
107
|
|
|
108
|
+
## Documentation
|
|
109
|
+
|
|
110
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
111
|
+
|
|
112
|
+
The documentation includes:
|
|
113
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
114
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
115
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
116
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
117
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
118
|
+
|
|
98
119
|
## `ComBat` parameters
|
|
99
120
|
|
|
100
121
|
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
[](https://www.python.org/)
|
|
4
4
|
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
5
|
+
[](https://combatlearn.readthedocs.io)
|
|
5
6
|
[](https://pepy.tech/projects/combatlearn)
|
|
6
7
|
[](https://pypi.org/project/combatlearn/)
|
|
7
8
|
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
@@ -71,6 +72,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
71
72
|
|
|
72
73
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
73
74
|
|
|
75
|
+
## Documentation
|
|
76
|
+
|
|
77
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
78
|
+
|
|
79
|
+
The documentation includes:
|
|
80
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
81
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
82
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
83
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
84
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
85
|
+
|
|
74
86
|
## `ComBat` parameters
|
|
75
87
|
|
|
76
88
|
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
@@ -14,28 +14,17 @@ import numpy as np
|
|
|
14
14
|
import numpy.linalg as la
|
|
15
15
|
import pandas as pd
|
|
16
16
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
17
|
-
from sklearn.utils.validation import check_is_fitted
|
|
18
17
|
from sklearn.decomposition import PCA
|
|
19
18
|
from sklearn.manifold import TSNE
|
|
19
|
+
import matplotlib
|
|
20
20
|
import matplotlib.pyplot as plt
|
|
21
|
-
|
|
21
|
+
import matplotlib.colors as mcolors
|
|
22
|
+
from typing import Literal, Optional, Union, Dict, Tuple, Any
|
|
22
23
|
import numpy.typing as npt
|
|
23
24
|
import warnings
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
UMAP_AVAILABLE = True
|
|
28
|
-
except ImportError:
|
|
29
|
-
UMAP_AVAILABLE = False
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
import plotly.graph_objects as go
|
|
33
|
-
from plotly.subplots import make_subplots
|
|
34
|
-
PLOTLY_AVAILABLE = True
|
|
35
|
-
except ImportError:
|
|
36
|
-
PLOTLY_AVAILABLE = False
|
|
37
|
-
|
|
38
|
-
__author__ = "Ettore Rocchi"
|
|
25
|
+
import umap
|
|
26
|
+
import plotly.graph_objects as go
|
|
27
|
+
from plotly.subplots import make_subplots
|
|
39
28
|
|
|
40
29
|
ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
|
|
41
30
|
FloatArray = npt.NDArray[np.float64]
|
|
@@ -57,8 +46,9 @@ class ComBatModel:
|
|
|
57
46
|
ignoring the variance (`delta_star`).
|
|
58
47
|
reference_batch : str, optional
|
|
59
48
|
If specified, the batch level to use as reference.
|
|
60
|
-
covbat_cov_thresh : float, default=0.9
|
|
61
|
-
CovBat: cumulative
|
|
49
|
+
covbat_cov_thresh : float or int, default=0.9
|
|
50
|
+
CovBat: cumulative variance threshold (0, 1] to retain PCs, or
|
|
51
|
+
integer >= 1 specifying the number of components directly.
|
|
62
52
|
eps : float, default=1e-8
|
|
63
53
|
Numerical jitter to avoid division-by-zero.
|
|
64
54
|
"""
|
|
@@ -66,19 +56,19 @@ class ComBatModel:
|
|
|
66
56
|
def __init__(
|
|
67
57
|
self,
|
|
68
58
|
*,
|
|
69
|
-
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
59
|
+
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
70
60
|
parametric: bool = True,
|
|
71
61
|
mean_only: bool = False,
|
|
72
62
|
reference_batch: Optional[str] = None,
|
|
73
63
|
eps: float = 1e-8,
|
|
74
|
-
covbat_cov_thresh: float = 0.9,
|
|
64
|
+
covbat_cov_thresh: Union[float, int] = 0.9,
|
|
75
65
|
) -> None:
|
|
76
66
|
self.method: str = method
|
|
77
67
|
self.parametric: bool = parametric
|
|
78
68
|
self.mean_only: bool = bool(mean_only)
|
|
79
69
|
self.reference_batch: Optional[str] = reference_batch
|
|
80
70
|
self.eps: float = float(eps)
|
|
81
|
-
self.covbat_cov_thresh: float =
|
|
71
|
+
self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
|
|
82
72
|
|
|
83
73
|
self._batch_levels: pd.Index
|
|
84
74
|
self._grand_mean: pd.Series
|
|
@@ -95,9 +85,16 @@ class ComBatModel:
|
|
|
95
85
|
self._batch_levels_pc: pd.Index
|
|
96
86
|
self._pc_gamma_star: FloatArray
|
|
97
87
|
self._pc_delta_star: FloatArray
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
88
|
+
|
|
89
|
+
# Validate covbat_cov_thresh
|
|
90
|
+
if isinstance(self.covbat_cov_thresh, float):
|
|
91
|
+
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
92
|
+
raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
|
|
93
|
+
elif isinstance(self.covbat_cov_thresh, int):
|
|
94
|
+
if self.covbat_cov_thresh < 1:
|
|
95
|
+
raise ValueError("covbat_cov_thresh must be >= 1 when int.")
|
|
96
|
+
else:
|
|
97
|
+
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
101
98
|
|
|
102
99
|
@staticmethod
|
|
103
100
|
def _as_series(
|
|
@@ -335,8 +332,14 @@ class ComBatModel:
|
|
|
335
332
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
336
333
|
X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
|
|
337
334
|
pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
|
|
338
|
-
|
|
339
|
-
|
|
335
|
+
|
|
336
|
+
# Determine number of components based on threshold type
|
|
337
|
+
if isinstance(self.covbat_cov_thresh, int):
|
|
338
|
+
n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
|
|
339
|
+
else:
|
|
340
|
+
cumulative = np.cumsum(pca.explained_variance_ratio_)
|
|
341
|
+
n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
|
|
342
|
+
|
|
340
343
|
self._covbat_pca = pca
|
|
341
344
|
self._covbat_n_pc = n_pc
|
|
342
345
|
|
|
@@ -487,7 +490,8 @@ class ComBatModel:
|
|
|
487
490
|
continuous_covariates: Optional[ArrayLike] = None,
|
|
488
491
|
) -> pd.DataFrame:
|
|
489
492
|
"""Transform the data using fitted ComBat parameters."""
|
|
490
|
-
|
|
493
|
+
if not hasattr(self, "_gamma_star"):
|
|
494
|
+
raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
|
|
491
495
|
if not isinstance(X, pd.DataFrame):
|
|
492
496
|
X = pd.DataFrame(X)
|
|
493
497
|
idx = X.index
|
|
@@ -599,7 +603,7 @@ class ComBatModel:
|
|
|
599
603
|
"""Chen transform implementation."""
|
|
600
604
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
601
605
|
X_centered = X_meanvar_adj - self._covbat_pca.mean_
|
|
602
|
-
scores = self._covbat_pca.transform(X_centered
|
|
606
|
+
scores = self._covbat_pca.transform(X_centered)
|
|
603
607
|
n_pc = self._covbat_n_pc
|
|
604
608
|
scores_adj = scores.copy()
|
|
605
609
|
|
|
@@ -638,7 +642,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
638
642
|
mean_only: bool = False,
|
|
639
643
|
reference_batch: Optional[str] = None,
|
|
640
644
|
eps: float = 1e-8,
|
|
641
|
-
covbat_cov_thresh: float = 0.9,
|
|
645
|
+
covbat_cov_thresh: Union[float, int] = 0.9,
|
|
642
646
|
) -> None:
|
|
643
647
|
self.batch = batch
|
|
644
648
|
self.discrete_covariates = discrete_covariates
|
|
@@ -758,7 +762,8 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
758
762
|
- `'original'`: embedding of original data
|
|
759
763
|
- `'transformed'`: embedding of ComBat-transformed data
|
|
760
764
|
"""
|
|
761
|
-
|
|
765
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
766
|
+
raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
|
|
762
767
|
|
|
763
768
|
if n_components not in [2, 3]:
|
|
764
769
|
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
@@ -767,11 +772,6 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
767
772
|
if plot_type not in ['static', 'interactive']:
|
|
768
773
|
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
769
774
|
|
|
770
|
-
if reduction_method == 'umap' and not UMAP_AVAILABLE:
|
|
771
|
-
raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
|
|
772
|
-
if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
|
|
773
|
-
raise ImportError("Plotly is not installed. Install with: pip install plotly")
|
|
774
|
-
|
|
775
775
|
if not isinstance(X, pd.DataFrame):
|
|
776
776
|
X = pd.DataFrame(X)
|
|
777
777
|
|
|
@@ -796,8 +796,8 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
796
796
|
else:
|
|
797
797
|
umap_params = {'random_state': 42}
|
|
798
798
|
umap_params.update(reduction_kwargs)
|
|
799
|
-
reducer_orig = umap.UMAP(n_components=n_components, **
|
|
800
|
-
reducer_trans = umap.UMAP(n_components=n_components, **
|
|
799
|
+
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
800
|
+
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
801
801
|
|
|
802
802
|
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
803
803
|
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
@@ -811,7 +811,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
811
811
|
else:
|
|
812
812
|
fig = self._create_interactive_plot(
|
|
813
813
|
X_embedded_orig, X_embedded_trans, batch_vec,
|
|
814
|
-
reduction_method, n_components, title, show_legend
|
|
814
|
+
reduction_method, n_components, cmap, title, show_legend
|
|
815
815
|
)
|
|
816
816
|
|
|
817
817
|
if return_embeddings:
|
|
@@ -844,9 +844,9 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
844
844
|
n_batches = len(unique_batches)
|
|
845
845
|
|
|
846
846
|
if n_batches <= 10:
|
|
847
|
-
colors =
|
|
847
|
+
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
848
848
|
else:
|
|
849
|
-
colors =
|
|
849
|
+
colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
|
|
850
850
|
|
|
851
851
|
if n_components == 2:
|
|
852
852
|
ax1 = plt.subplot(1, 2, 1)
|
|
@@ -930,6 +930,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
930
930
|
batch_labels: pd.Series,
|
|
931
931
|
method: str,
|
|
932
932
|
n_components: int,
|
|
933
|
+
cmap: str,
|
|
933
934
|
title: Optional[str],
|
|
934
935
|
show_legend: bool) -> Any:
|
|
935
936
|
"""Create interactive plots using plotly."""
|
|
@@ -953,43 +954,69 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
953
954
|
|
|
954
955
|
unique_batches = batch_labels.drop_duplicates()
|
|
955
956
|
|
|
957
|
+
n_batches = len(unique_batches)
|
|
958
|
+
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
959
|
+
color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
|
|
960
|
+
|
|
961
|
+
batch_to_color = dict(zip(unique_batches, color_list))
|
|
962
|
+
|
|
956
963
|
for batch in unique_batches:
|
|
957
964
|
mask = batch_labels == batch
|
|
958
965
|
|
|
959
966
|
if n_components == 2:
|
|
960
967
|
fig.add_trace(
|
|
961
|
-
go.Scatter(
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
968
|
+
go.Scatter(
|
|
969
|
+
x=X_orig[mask, 0], y=X_orig[mask, 1],
|
|
970
|
+
mode='markers',
|
|
971
|
+
name=f'Batch {batch}',
|
|
972
|
+
marker=dict(
|
|
973
|
+
size=8,
|
|
974
|
+
color=batch_to_color[batch],
|
|
975
|
+
line=dict(width=1, color='black')
|
|
976
|
+
),
|
|
977
|
+
showlegend=False),
|
|
966
978
|
row=1, col=1
|
|
967
979
|
)
|
|
968
980
|
|
|
969
981
|
fig.add_trace(
|
|
970
|
-
go.Scatter(
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
982
|
+
go.Scatter(
|
|
983
|
+
x=X_trans[mask, 0], y=X_trans[mask, 1],
|
|
984
|
+
mode='markers',
|
|
985
|
+
name=f'Batch {batch}',
|
|
986
|
+
marker=dict(
|
|
987
|
+
size=8,
|
|
988
|
+
color=batch_to_color[batch],
|
|
989
|
+
line=dict(width=1, color='black')
|
|
990
|
+
),
|
|
991
|
+
showlegend=show_legend),
|
|
975
992
|
row=1, col=2
|
|
976
993
|
)
|
|
977
994
|
else:
|
|
978
995
|
fig.add_trace(
|
|
979
|
-
go.Scatter3d(
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
996
|
+
go.Scatter3d(
|
|
997
|
+
x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
|
|
998
|
+
mode='markers',
|
|
999
|
+
name=f'Batch {batch}',
|
|
1000
|
+
marker=dict(
|
|
1001
|
+
size=5,
|
|
1002
|
+
color=batch_to_color[batch],
|
|
1003
|
+
line=dict(width=0.5, color='black')
|
|
1004
|
+
),
|
|
1005
|
+
showlegend=False),
|
|
984
1006
|
row=1, col=1
|
|
985
1007
|
)
|
|
986
1008
|
|
|
987
1009
|
fig.add_trace(
|
|
988
|
-
go.Scatter3d(
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
1010
|
+
go.Scatter3d(
|
|
1011
|
+
x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
|
|
1012
|
+
mode='markers',
|
|
1013
|
+
name=f'Batch {batch}',
|
|
1014
|
+
marker=dict(
|
|
1015
|
+
size=5,
|
|
1016
|
+
color=batch_to_color[batch],
|
|
1017
|
+
line=dict(width=0.5, color='black')
|
|
1018
|
+
),
|
|
1019
|
+
showlegend=show_legend),
|
|
993
1020
|
row=1, col=2
|
|
994
1021
|
)
|
|
995
1022
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
|
-
License
|
|
6
|
+
License: MIT
|
|
7
7
|
Keywords: machine-learning,harmonization,combat,preprocessing
|
|
8
8
|
Classifier: Development Status :: 3 - Alpha
|
|
9
9
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
|
|
|
19
19
|
Requires-Dist: plotly>=5.0
|
|
20
20
|
Requires-Dist: nbformat>=4.2
|
|
21
21
|
Requires-Dist: umap-learn>=0.5
|
|
22
|
-
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
24
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
25
|
+
Requires-Dist: ruff>=0.1; extra == "dev"
|
|
26
|
+
Requires-Dist: mypy>=1.0; extra == "dev"
|
|
27
|
+
Provides-Extra: docs
|
|
28
|
+
Requires-Dist: mkdocs>=1.5.0; extra == "docs"
|
|
29
|
+
Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
|
|
30
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
|
|
31
|
+
Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
|
|
23
32
|
Dynamic: license-file
|
|
24
33
|
|
|
25
34
|
# **combatlearn**
|
|
26
35
|
|
|
27
36
|
[](https://www.python.org/)
|
|
28
37
|
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
38
|
+
[](https://combatlearn.readthedocs.io)
|
|
29
39
|
[](https://pepy.tech/projects/combatlearn)
|
|
30
40
|
[](https://pypi.org/project/combatlearn/)
|
|
31
41
|
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
95
105
|
|
|
96
106
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
97
107
|
|
|
108
|
+
## Documentation
|
|
109
|
+
|
|
110
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
111
|
+
|
|
112
|
+
The documentation includes:
|
|
113
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
114
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
115
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
116
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
117
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
118
|
+
|
|
98
119
|
## `ComBat` parameters
|
|
99
120
|
|
|
100
121
|
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
pandas>=1.3
|
|
2
|
+
numpy>=1.21
|
|
3
|
+
scikit-learn>=1.2
|
|
4
|
+
matplotlib>=3.4
|
|
5
|
+
plotly>=5.0
|
|
6
|
+
nbformat>=4.2
|
|
7
|
+
umap-learn>=0.5
|
|
8
|
+
|
|
9
|
+
[dev]
|
|
10
|
+
pytest>=7
|
|
11
|
+
pytest-cov>=4.0
|
|
12
|
+
ruff>=0.1
|
|
13
|
+
mypy>=1.0
|
|
14
|
+
|
|
15
|
+
[docs]
|
|
16
|
+
mkdocs>=1.5.0
|
|
17
|
+
mkdocs-material>=9.0.0
|
|
18
|
+
mkdocstrings[python]>=0.24.0
|
|
19
|
+
pymdown-extensions>=10.0
|
|
@@ -4,21 +4,11 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "combatlearn"
|
|
7
|
-
|
|
7
|
+
dynamic = ["version", "dependencies"]
|
|
8
8
|
description = "Batch-effect harmonization for machine learning frameworks."
|
|
9
9
|
authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
|
|
10
10
|
requires-python = ">=3.10"
|
|
11
|
-
|
|
12
|
-
"pandas>=1.3",
|
|
13
|
-
"numpy>=1.21",
|
|
14
|
-
"scikit-learn>=1.2",
|
|
15
|
-
"matplotlib>=3.4",
|
|
16
|
-
"plotly>=5.0",
|
|
17
|
-
"nbformat>=4.2",
|
|
18
|
-
"umap-learn>=0.5",
|
|
19
|
-
"pytest>=7"
|
|
20
|
-
]
|
|
21
|
-
license = "MIT"
|
|
11
|
+
license = {text = "MIT"}
|
|
22
12
|
readme = {file="README.md", content-type="text/markdown"}
|
|
23
13
|
keywords = [
|
|
24
14
|
"machine-learning",
|
|
@@ -33,6 +23,24 @@ classifiers = [
|
|
|
33
23
|
"Programming Language :: Python :: 3",
|
|
34
24
|
]
|
|
35
25
|
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
dev = [
|
|
28
|
+
"pytest>=7",
|
|
29
|
+
"pytest-cov>=4.0",
|
|
30
|
+
"ruff>=0.1",
|
|
31
|
+
"mypy>=1.0",
|
|
32
|
+
]
|
|
33
|
+
docs = [
|
|
34
|
+
"mkdocs>=1.5.0",
|
|
35
|
+
"mkdocs-material>=9.0.0",
|
|
36
|
+
"mkdocstrings[python]>=0.24.0",
|
|
37
|
+
"pymdown-extensions>=10.0",
|
|
38
|
+
]
|
|
39
|
+
|
|
36
40
|
[tool.setuptools.packages.find]
|
|
37
41
|
where = ["."]
|
|
38
42
|
include = ["combatlearn*"]
|
|
43
|
+
|
|
44
|
+
[tool.setuptools.dynamic]
|
|
45
|
+
version = {attr = "combatlearn.__version__"}
|
|
46
|
+
dependencies = {file = ["requirements.txt"]}
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import pytest
|
|
4
|
+
from sklearn.pipeline import Pipeline
|
|
5
|
+
from sklearn.base import clone
|
|
6
|
+
from sklearn.preprocessing import StandardScaler
|
|
7
|
+
from combatlearn import ComBat
|
|
8
|
+
from combatlearn.combat import ComBatModel
|
|
9
|
+
from utils import simulate_data, simulate_covariate_data
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_transform_without_fit_raises():
|
|
13
|
+
"""
|
|
14
|
+
Test that `transform` raises a `ValueError` if not fitted.
|
|
15
|
+
"""
|
|
16
|
+
X, batch = simulate_data()
|
|
17
|
+
model = ComBatModel()
|
|
18
|
+
with pytest.raises(ValueError, match="not fitted"):
|
|
19
|
+
model.transform(X, batch=batch)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_unseen_batch_raises_value_error():
|
|
23
|
+
"""
|
|
24
|
+
Test that unseen batch raises a `ValueError`.
|
|
25
|
+
"""
|
|
26
|
+
X, batch = simulate_data()
|
|
27
|
+
model = ComBatModel().fit(X, batch=batch)
|
|
28
|
+
new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
|
|
29
|
+
with pytest.raises(ValueError):
|
|
30
|
+
model.transform(X, batch=new_batch)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_single_sample_batch_error():
|
|
34
|
+
"""
|
|
35
|
+
Test that a single sample batch raises a `ValueError`.
|
|
36
|
+
"""
|
|
37
|
+
X, batch = simulate_data()
|
|
38
|
+
batch.iloc[0] = "single"
|
|
39
|
+
with pytest.raises(ValueError):
|
|
40
|
+
ComBatModel().fit(X, batch=batch)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
44
|
+
def test_dtypes_preserved(method):
|
|
45
|
+
"""All output columns must remain floating dtypes after correction."""
|
|
46
|
+
if method == "johnson":
|
|
47
|
+
X, batch = simulate_data()
|
|
48
|
+
extra = {}
|
|
49
|
+
else: # fortin or chen
|
|
50
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
51
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
52
|
+
|
|
53
|
+
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
54
|
+
assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
|
|
55
|
+
|
|
56
|
+
def test_wrapper_clone_and_pipeline():
|
|
57
|
+
"""
|
|
58
|
+
Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
|
|
59
|
+
"""
|
|
60
|
+
X, batch = simulate_data()
|
|
61
|
+
wrapper = ComBat(batch=batch, parametric=True)
|
|
62
|
+
pipe = Pipeline([
|
|
63
|
+
("scaler", StandardScaler()),
|
|
64
|
+
("combat", wrapper),
|
|
65
|
+
])
|
|
66
|
+
X_corr = pipe.fit_transform(X)
|
|
67
|
+
pipe_clone: Pipeline = clone(pipe)
|
|
68
|
+
X_corr2 = pipe_clone.fit_transform(X)
|
|
69
|
+
np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
73
|
+
def test_no_nan_or_inf_in_output(method):
|
|
74
|
+
"""`ComBat` must not introduce NaN or Inf values, for any backend."""
|
|
75
|
+
if method == "johnson":
|
|
76
|
+
X, batch = simulate_data()
|
|
77
|
+
extra = {}
|
|
78
|
+
else: # fortin or chen
|
|
79
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
80
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
81
|
+
|
|
82
|
+
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
83
|
+
assert not np.isnan(X_corr.values).any()
|
|
84
|
+
assert not np.isinf(X_corr.values).any()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
88
|
+
def test_shape_preserved(method):
|
|
89
|
+
"""The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
|
|
90
|
+
if method == "johnson":
|
|
91
|
+
X, batch = simulate_data()
|
|
92
|
+
combat = ComBat(batch=batch, method=method).fit(X)
|
|
93
|
+
elif method in ["fortin", "chen"]:
|
|
94
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
95
|
+
combat = ComBat(
|
|
96
|
+
batch=batch,
|
|
97
|
+
discrete_covariates=disc,
|
|
98
|
+
continuous_covariates=cont,
|
|
99
|
+
method=method,
|
|
100
|
+
).fit(X)
|
|
101
|
+
|
|
102
|
+
X_corr = combat.transform(X)
|
|
103
|
+
assert X_corr.shape == X.shape
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_johnson_print_warning():
|
|
107
|
+
"""
|
|
108
|
+
Test that a warning is printed when using the Johnson method.
|
|
109
|
+
"""
|
|
110
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
111
|
+
with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
|
|
112
|
+
_ = ComBat(
|
|
113
|
+
batch=batch,
|
|
114
|
+
discrete_covariates=disc,
|
|
115
|
+
continuous_covariates=cont,
|
|
116
|
+
method="johnson",
|
|
117
|
+
).fit(X)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
121
|
+
def test_reference_batch_samples_unchanged(method):
|
|
122
|
+
"""
|
|
123
|
+
Samples belonging to the reference batch must come out *numerically identical*
|
|
124
|
+
(within floating-point jitter) after correction.
|
|
125
|
+
"""
|
|
126
|
+
if method == "johnson":
|
|
127
|
+
X, batch = simulate_data()
|
|
128
|
+
extra = {}
|
|
129
|
+
elif method in ["fortin", "chen"]:
|
|
130
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
131
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
132
|
+
|
|
133
|
+
ref_batch = batch.iloc[0]
|
|
134
|
+
combat = ComBat(batch=batch, method=method,
|
|
135
|
+
reference_batch=ref_batch, **extra).fit(X)
|
|
136
|
+
X_corr = combat.transform(X)
|
|
137
|
+
|
|
138
|
+
mask = batch == ref_batch
|
|
139
|
+
np.testing.assert_allclose(X_corr.loc[mask].values,
|
|
140
|
+
X.loc[mask].values,
|
|
141
|
+
rtol=0, atol=1e-10)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_reference_batch_missing_raises():
|
|
145
|
+
"""
|
|
146
|
+
Asking for a reference batch that doesn't exist should fail.
|
|
147
|
+
"""
|
|
148
|
+
X, batch = simulate_data()
|
|
149
|
+
with pytest.raises(ValueError, match="not present"):
|
|
150
|
+
ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.mark.parametrize("parametric", [True, False])
|
|
154
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
155
|
+
def test_parametric_vs_nonparametric(parametric, method):
|
|
156
|
+
"""
|
|
157
|
+
Test both parametric and non-parametric modes work without errors.
|
|
158
|
+
"""
|
|
159
|
+
if method == "johnson":
|
|
160
|
+
X, batch = simulate_data()
|
|
161
|
+
extra = {}
|
|
162
|
+
else:
|
|
163
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
164
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
165
|
+
|
|
166
|
+
combat = ComBat(batch=batch, method=method, parametric=parametric, **extra)
|
|
167
|
+
X_corr = combat.fit_transform(X)
|
|
168
|
+
assert X_corr.shape == X.shape
|
|
169
|
+
assert not np.isnan(X_corr.values).any()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@pytest.mark.parametrize("mean_only", [True, False])
|
|
173
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
174
|
+
def test_mean_only_mode(mean_only, method):
|
|
175
|
+
"""
|
|
176
|
+
Test mean_only mode works for all methods.
|
|
177
|
+
"""
|
|
178
|
+
if method == "johnson":
|
|
179
|
+
X, batch = simulate_data()
|
|
180
|
+
extra = {}
|
|
181
|
+
else:
|
|
182
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
183
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
184
|
+
|
|
185
|
+
combat = ComBat(batch=batch, method=method, mean_only=mean_only, **extra)
|
|
186
|
+
X_corr = combat.fit_transform(X)
|
|
187
|
+
assert X_corr.shape == X.shape
|
|
188
|
+
assert not np.isnan(X_corr.values).any()
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def test_covbat_cov_thresh_as_float():
|
|
192
|
+
"""
|
|
193
|
+
Test CovBat with covbat_cov_thresh as float (cumulative variance threshold).
|
|
194
|
+
"""
|
|
195
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
196
|
+
combat = ComBat(
|
|
197
|
+
batch=batch,
|
|
198
|
+
discrete_covariates=disc,
|
|
199
|
+
continuous_covariates=cont,
|
|
200
|
+
method="chen",
|
|
201
|
+
covbat_cov_thresh=0.95,
|
|
202
|
+
)
|
|
203
|
+
X_corr = combat.fit_transform(X)
|
|
204
|
+
assert X_corr.shape == X.shape
|
|
205
|
+
assert not np.isnan(X_corr.values).any()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def test_covbat_cov_thresh_as_int():
|
|
209
|
+
"""
|
|
210
|
+
Test CovBat with covbat_cov_thresh as int (number of components).
|
|
211
|
+
"""
|
|
212
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
213
|
+
n_components = 10
|
|
214
|
+
combat = ComBat(
|
|
215
|
+
batch=batch,
|
|
216
|
+
discrete_covariates=disc,
|
|
217
|
+
continuous_covariates=cont,
|
|
218
|
+
method="chen",
|
|
219
|
+
covbat_cov_thresh=n_components,
|
|
220
|
+
)
|
|
221
|
+
X_corr = combat.fit_transform(X)
|
|
222
|
+
assert X_corr.shape == X.shape
|
|
223
|
+
assert not np.isnan(X_corr.values).any()
|
|
224
|
+
assert combat._model._covbat_n_pc == n_components
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_covbat_cov_thresh_invalid_float_raises():
|
|
228
|
+
"""
|
|
229
|
+
Test that invalid float values for covbat_cov_thresh raise ValueError.
|
|
230
|
+
"""
|
|
231
|
+
with pytest.raises(ValueError, match="must be in \\(0, 1\\]"):
|
|
232
|
+
ComBatModel(covbat_cov_thresh=1.5)
|
|
233
|
+
|
|
234
|
+
with pytest.raises(ValueError, match="must be in \\(0, 1\\]"):
|
|
235
|
+
ComBatModel(covbat_cov_thresh=0.0)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_covbat_cov_thresh_invalid_int_raises():
|
|
239
|
+
"""
|
|
240
|
+
Test that invalid int values for covbat_cov_thresh raise ValueError.
|
|
241
|
+
"""
|
|
242
|
+
with pytest.raises(ValueError, match="must be >= 1"):
|
|
243
|
+
ComBatModel(covbat_cov_thresh=0)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_covbat_cov_thresh_invalid_type_raises():
|
|
247
|
+
"""
|
|
248
|
+
Test that invalid types for covbat_cov_thresh raise TypeError.
|
|
249
|
+
"""
|
|
250
|
+
with pytest.raises(TypeError, match="must be float or int"):
|
|
251
|
+
ComBatModel(covbat_cov_thresh="invalid")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
255
|
+
def test_index_preserved(method):
|
|
256
|
+
"""
|
|
257
|
+
Test that the index is preserved after transformation.
|
|
258
|
+
"""
|
|
259
|
+
if method == "johnson":
|
|
260
|
+
X, batch = simulate_data()
|
|
261
|
+
extra = {}
|
|
262
|
+
else:
|
|
263
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
264
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
265
|
+
|
|
266
|
+
custom_index = pd.Index([f"sample_{i}" for i in range(len(X))])
|
|
267
|
+
X.index = custom_index
|
|
268
|
+
batch.index = custom_index
|
|
269
|
+
if method != "johnson":
|
|
270
|
+
disc.index = custom_index
|
|
271
|
+
cont.index = custom_index
|
|
272
|
+
|
|
273
|
+
combat = ComBat(batch=batch, method=method, **extra)
|
|
274
|
+
X_corr = combat.fit_transform(X)
|
|
275
|
+
assert X_corr.index.equals(custom_index)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
279
|
+
def test_column_names_preserved(method):
|
|
280
|
+
"""
|
|
281
|
+
Test that column names are preserved after transformation.
|
|
282
|
+
"""
|
|
283
|
+
if method == "johnson":
|
|
284
|
+
X, batch = simulate_data()
|
|
285
|
+
extra = {}
|
|
286
|
+
else:
|
|
287
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
288
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
289
|
+
|
|
290
|
+
custom_columns = [f"feature_{i}" for i in range(X.shape[1])]
|
|
291
|
+
X.columns = custom_columns
|
|
292
|
+
|
|
293
|
+
combat = ComBat(batch=batch, method=method, **extra)
|
|
294
|
+
X_corr = combat.fit_transform(X)
|
|
295
|
+
assert list(X_corr.columns) == custom_columns
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def test_plot_transformation_static_2d():
|
|
299
|
+
"""
|
|
300
|
+
Test plot_transformation with static 2D PCA visualization.
|
|
301
|
+
"""
|
|
302
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
303
|
+
combat = ComBat(batch=batch, method="johnson").fit(X)
|
|
304
|
+
|
|
305
|
+
fig = combat.plot_transformation(
|
|
306
|
+
X,
|
|
307
|
+
reduction_method='pca',
|
|
308
|
+
n_components=2,
|
|
309
|
+
plot_type='static'
|
|
310
|
+
)
|
|
311
|
+
assert fig is not None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def test_plot_transformation_static_3d():
|
|
315
|
+
"""
|
|
316
|
+
Test plot_transformation with static 3D PCA visualization.
|
|
317
|
+
"""
|
|
318
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
319
|
+
combat = ComBat(batch=batch, method="johnson").fit(X)
|
|
320
|
+
|
|
321
|
+
fig = combat.plot_transformation(
|
|
322
|
+
X,
|
|
323
|
+
reduction_method='pca',
|
|
324
|
+
n_components=3,
|
|
325
|
+
plot_type='static'
|
|
326
|
+
)
|
|
327
|
+
assert fig is not None
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def test_plot_transformation_return_embeddings():
|
|
331
|
+
"""
|
|
332
|
+
Test that plot_transformation can return embeddings.
|
|
333
|
+
"""
|
|
334
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
335
|
+
combat = ComBat(batch=batch, method="johnson").fit(X)
|
|
336
|
+
|
|
337
|
+
fig, embeddings = combat.plot_transformation(
|
|
338
|
+
X,
|
|
339
|
+
reduction_method='pca',
|
|
340
|
+
n_components=2,
|
|
341
|
+
plot_type='static',
|
|
342
|
+
return_embeddings=True
|
|
343
|
+
)
|
|
344
|
+
assert fig is not None
|
|
345
|
+
assert 'original' in embeddings
|
|
346
|
+
assert 'transformed' in embeddings
|
|
347
|
+
assert embeddings['original'].shape == (100, 2)
|
|
348
|
+
assert embeddings['transformed'].shape == (100, 2)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def test_plot_transformation_invalid_method_raises():
|
|
352
|
+
"""
|
|
353
|
+
Test that invalid reduction_method raises ValueError.
|
|
354
|
+
"""
|
|
355
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
356
|
+
combat = ComBat(batch=batch, method="johnson").fit(X)
|
|
357
|
+
|
|
358
|
+
with pytest.raises(ValueError, match="reduction_method must be"):
|
|
359
|
+
combat.plot_transformation(X, reduction_method='invalid')
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def test_plot_transformation_invalid_n_components_raises():
|
|
363
|
+
"""
|
|
364
|
+
Test that invalid n_components raises ValueError.
|
|
365
|
+
"""
|
|
366
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
367
|
+
combat = ComBat(batch=batch, method="johnson").fit(X)
|
|
368
|
+
|
|
369
|
+
with pytest.raises(ValueError, match="n_components must be 2 or 3"):
|
|
370
|
+
combat.plot_transformation(X, n_components=4)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def test_invalid_method_raises():
|
|
374
|
+
"""
|
|
375
|
+
Test that an invalid method raises ValueError.
|
|
376
|
+
"""
|
|
377
|
+
X, batch = simulate_data()
|
|
378
|
+
with pytest.raises(ValueError, match="method must be"):
|
|
379
|
+
ComBatModel(method="invalid").fit(X, batch=batch)
|
|
@@ -1,150 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import pytest
|
|
4
|
-
from sklearn.pipeline import Pipeline
|
|
5
|
-
from sklearn.base import clone
|
|
6
|
-
from sklearn.preprocessing import StandardScaler
|
|
7
|
-
from sklearn.exceptions import NotFittedError
|
|
8
|
-
from combatlearn import ComBatModel, ComBat
|
|
9
|
-
from utils import simulate_data, simulate_covariate_data
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def test_transform_without_fit_raises():
|
|
13
|
-
"""
|
|
14
|
-
Test that `transform` raises a `NotFittedError` if not fitted.
|
|
15
|
-
"""
|
|
16
|
-
X, batch = simulate_data()
|
|
17
|
-
model = ComBatModel()
|
|
18
|
-
with pytest.raises(NotFittedError):
|
|
19
|
-
model.transform(X, batch=batch)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def test_unseen_batch_raises_value_error():
|
|
23
|
-
"""
|
|
24
|
-
Test that unseen batch raises a `ValueError`.
|
|
25
|
-
"""
|
|
26
|
-
X, batch = simulate_data()
|
|
27
|
-
model = ComBatModel().fit(X, batch=batch)
|
|
28
|
-
new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
|
|
29
|
-
with pytest.raises(ValueError):
|
|
30
|
-
model.transform(X, batch=new_batch)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def test_single_sample_batch_error():
|
|
34
|
-
"""
|
|
35
|
-
Test that a single sample batch raises a `ValueError`.
|
|
36
|
-
"""
|
|
37
|
-
X, batch = simulate_data()
|
|
38
|
-
batch.iloc[0] = "single"
|
|
39
|
-
with pytest.raises(ValueError):
|
|
40
|
-
ComBatModel().fit(X, batch=batch)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
44
|
-
def test_dtypes_preserved(method):
|
|
45
|
-
"""All output columns must remain floating dtypes after correction."""
|
|
46
|
-
if method == "johnson":
|
|
47
|
-
X, batch = simulate_data()
|
|
48
|
-
extra = {}
|
|
49
|
-
else: # fortin or chen
|
|
50
|
-
X, batch, disc, cont = simulate_covariate_data()
|
|
51
|
-
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
52
|
-
|
|
53
|
-
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
54
|
-
assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
|
|
55
|
-
|
|
56
|
-
def test_wrapper_clone_and_pipeline():
|
|
57
|
-
"""
|
|
58
|
-
Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
|
|
59
|
-
"""
|
|
60
|
-
X, batch = simulate_data()
|
|
61
|
-
wrapper = ComBat(batch=batch, parametric=True)
|
|
62
|
-
pipe = Pipeline([
|
|
63
|
-
("scaler", StandardScaler()),
|
|
64
|
-
("combat", wrapper),
|
|
65
|
-
])
|
|
66
|
-
X_corr = pipe.fit_transform(X)
|
|
67
|
-
pipe_clone: Pipeline = clone(pipe)
|
|
68
|
-
X_corr2 = pipe_clone.fit_transform(X)
|
|
69
|
-
np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
73
|
-
def test_no_nan_or_inf_in_output(method):
|
|
74
|
-
"""`ComBat` must not introduce NaN or Inf values, for any backend."""
|
|
75
|
-
if method == "johnson":
|
|
76
|
-
X, batch = simulate_data()
|
|
77
|
-
extra = {}
|
|
78
|
-
else: # fortin or chen
|
|
79
|
-
X, batch, disc, cont = simulate_covariate_data()
|
|
80
|
-
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
81
|
-
|
|
82
|
-
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
83
|
-
assert not np.isnan(X_corr.values).any()
|
|
84
|
-
assert not np.isinf(X_corr.values).any()
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
88
|
-
def test_shape_preserved(method):
|
|
89
|
-
"""The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
|
|
90
|
-
if method == "johnson":
|
|
91
|
-
X, batch = simulate_data()
|
|
92
|
-
combat = ComBat(batch=batch, method=method).fit(X)
|
|
93
|
-
elif method in ["fortin", "chen"]:
|
|
94
|
-
X, batch, disc, cont = simulate_covariate_data()
|
|
95
|
-
combat = ComBat(
|
|
96
|
-
batch=batch,
|
|
97
|
-
discrete_covariates=disc,
|
|
98
|
-
continuous_covariates=cont,
|
|
99
|
-
method=method,
|
|
100
|
-
).fit(X)
|
|
101
|
-
|
|
102
|
-
X_corr = combat.transform(X)
|
|
103
|
-
assert X_corr.shape == X.shape
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def test_johnson_print_warning():
|
|
107
|
-
"""
|
|
108
|
-
Test that a warning is printed when using the Johnson method.
|
|
109
|
-
"""
|
|
110
|
-
X, batch, disc, cont = simulate_covariate_data()
|
|
111
|
-
with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
|
|
112
|
-
_ = ComBat(
|
|
113
|
-
batch=batch,
|
|
114
|
-
discrete_covariates=disc,
|
|
115
|
-
continuous_covariates=cont,
|
|
116
|
-
method="johnson",
|
|
117
|
-
).fit(X)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
121
|
-
def test_reference_batch_samples_unchanged(method):
|
|
122
|
-
"""
|
|
123
|
-
Samples belonging to the reference batch must come out *numerically identical*
|
|
124
|
-
(within floating-point jitter) after correction.
|
|
125
|
-
"""
|
|
126
|
-
if method == "johnson":
|
|
127
|
-
X, batch = simulate_data()
|
|
128
|
-
extra = {}
|
|
129
|
-
elif method in ["fortin", "chen"]:
|
|
130
|
-
X, batch, disc, cont = simulate_covariate_data()
|
|
131
|
-
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
132
|
-
|
|
133
|
-
ref_batch = batch.iloc[0]
|
|
134
|
-
combat = ComBat(batch=batch, method=method,
|
|
135
|
-
reference_batch=ref_batch, **extra).fit(X)
|
|
136
|
-
X_corr = combat.transform(X)
|
|
137
|
-
|
|
138
|
-
mask = batch == ref_batch
|
|
139
|
-
np.testing.assert_allclose(X_corr.loc[mask].values,
|
|
140
|
-
X.loc[mask].values,
|
|
141
|
-
rtol=0, atol=1e-10)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def test_reference_batch_missing_raises():
|
|
145
|
-
"""
|
|
146
|
-
Asking for a reference batch that doesn't exist should fail.
|
|
147
|
-
"""
|
|
148
|
-
X, batch = simulate_data()
|
|
149
|
-
with pytest.raises(ValueError, match="not present"):
|
|
150
|
-
ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|