combatlearn 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn-0.1.0/LICENSE +21 -0
- combatlearn-0.1.0/PKG-INFO +132 -0
- combatlearn-0.1.0/README.md +90 -0
- combatlearn-0.1.0/pyproject.toml +35 -0
- combatlearn-0.1.0/setup.cfg +4 -0
- combatlearn-0.1.0/src/combatlearn/__init__.py +4 -0
- combatlearn-0.1.0/src/combatlearn/combat.py +584 -0
- combatlearn-0.1.0/src/combatlearn.egg-info/PKG-INFO +132 -0
- combatlearn-0.1.0/src/combatlearn.egg-info/SOURCES.txt +11 -0
- combatlearn-0.1.0/src/combatlearn.egg-info/dependency_links.txt +1 -0
- combatlearn-0.1.0/src/combatlearn.egg-info/requires.txt +4 -0
- combatlearn-0.1.0/src/combatlearn.egg-info/top_level.txt +1 -0
- combatlearn-0.1.0/tests/test_combat.py +150 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ettore Rocchi
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: combatlearn
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Batch-effect harmonisation for machine learning frameworks.
|
|
5
|
+
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Ettore Rocchi
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Keywords: machine-learning,harmonization,combat,preprocessing
|
|
29
|
+
Classifier: Development Status :: 3 - Alpha
|
|
30
|
+
Classifier: Intended Audience :: Science/Research
|
|
31
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
32
|
+
Classifier: Operating System :: OS Independent
|
|
33
|
+
Classifier: Programming Language :: Python :: 3
|
|
34
|
+
Requires-Python: >=3.9
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
License-File: LICENSE
|
|
37
|
+
Requires-Dist: pandas>=1.3
|
|
38
|
+
Requires-Dist: numpy>=1.21
|
|
39
|
+
Requires-Dist: scikit-learn>=1.2
|
|
40
|
+
Requires-Dist: pytest>=7
|
|
41
|
+
Dynamic: license-file
|
|
42
|
+
|
|
43
|
+
# **combatlearn**
|
|
44
|
+
|
|
45
|
+
<div align="center">
|
|
46
|
+
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
47
|
+
</div>
|
|
48
|
+
|
|
49
|
+
**combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
|
|
50
|
+
|
|
51
|
+
**Three methods**:
|
|
52
|
+
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
53
|
+
- `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
|
|
54
|
+
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install combatlearn
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## Quick start
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
import pandas as pd
|
|
66
|
+
from sklearn.pipeline import Pipeline
|
|
67
|
+
from sklearn.preprocessing import StandardScaler
|
|
68
|
+
from sklearn.linear_model import LogisticRegression
|
|
69
|
+
from combatlearn import ComBat
|
|
70
|
+
|
|
71
|
+
df = pd.read_csv("data.csv", index_col=0)
|
|
72
|
+
X, y = df.drop(columns="y"), df["y"]
|
|
73
|
+
|
|
74
|
+
batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
|
|
75
|
+
diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
|
|
76
|
+
age = pd.read_csv("age.csv", index_col=0) # continuous
|
|
77
|
+
|
|
78
|
+
pipe = Pipeline([
|
|
79
|
+
("combat", ComBat(
|
|
80
|
+
batch=batch,
|
|
81
|
+
discrete_covariates=diag,
|
|
82
|
+
continuous_covariates=age,
|
|
83
|
+
method="fortin", # or "johnson" or "chen"
|
|
84
|
+
parametric=True
|
|
85
|
+
)),
|
|
86
|
+
("scaler", StandardScaler()),
|
|
87
|
+
("clf", LogisticRegression())
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
param_grid = {
|
|
91
|
+
"combat__mean_only": [True, False],
|
|
92
|
+
"clf__C": [0.01, 0.1, 1, 10],
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
grid = GridSearchCV(
|
|
96
|
+
estimator=pipe,
|
|
97
|
+
param_grid=param_grid,
|
|
98
|
+
cv=5,
|
|
99
|
+
scoring="roc_auc",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
grid.fit(X, y)
|
|
103
|
+
|
|
104
|
+
print("Best parameters:", grid.best_params_)
|
|
105
|
+
print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
109
|
+
|
|
110
|
+
## Contributing
|
|
111
|
+
|
|
112
|
+
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
113
|
+
|
|
114
|
+
## Acknowledgements
|
|
115
|
+
|
|
116
|
+
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
117
|
+
We gratefully acknowledge:
|
|
118
|
+
|
|
119
|
+
- [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
|
|
120
|
+
- [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
|
|
121
|
+
- [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
|
|
122
|
+
|
|
123
|
+
## Citation
|
|
124
|
+
|
|
125
|
+
If **combatlearn** is useful in your research, please cite the original
|
|
126
|
+
papers:
|
|
127
|
+
|
|
128
|
+
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
129
|
+
|
|
130
|
+
- Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
|
|
131
|
+
|
|
132
|
+
- Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# **combatlearn**
|
|
2
|
+
|
|
3
|
+
<div align="center">
|
|
4
|
+
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
5
|
+
</div>
|
|
6
|
+
|
|
7
|
+
**combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
|
|
8
|
+
|
|
9
|
+
**Three methods**:
|
|
10
|
+
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
11
|
+
- `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
|
|
12
|
+
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
13
|
+
|
|
14
|
+
## Installation
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install combatlearn
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Quick start
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import pandas as pd
|
|
24
|
+
from sklearn.pipeline import Pipeline
|
|
25
|
+
from sklearn.preprocessing import StandardScaler
|
|
26
|
+
from sklearn.linear_model import LogisticRegression
|
|
27
|
+
from combatlearn import ComBat
|
|
28
|
+
|
|
29
|
+
df = pd.read_csv("data.csv", index_col=0)
|
|
30
|
+
X, y = df.drop(columns="y"), df["y"]
|
|
31
|
+
|
|
32
|
+
batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
|
|
33
|
+
diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
|
|
34
|
+
age = pd.read_csv("age.csv", index_col=0) # continuous
|
|
35
|
+
|
|
36
|
+
pipe = Pipeline([
|
|
37
|
+
("combat", ComBat(
|
|
38
|
+
batch=batch,
|
|
39
|
+
discrete_covariates=diag,
|
|
40
|
+
continuous_covariates=age,
|
|
41
|
+
method="fortin", # or "johnson" or "chen"
|
|
42
|
+
parametric=True
|
|
43
|
+
)),
|
|
44
|
+
("scaler", StandardScaler()),
|
|
45
|
+
("clf", LogisticRegression())
|
|
46
|
+
])
|
|
47
|
+
|
|
48
|
+
param_grid = {
|
|
49
|
+
"combat__mean_only": [True, False],
|
|
50
|
+
"clf__C": [0.01, 0.1, 1, 10],
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
grid = GridSearchCV(
|
|
54
|
+
estimator=pipe,
|
|
55
|
+
param_grid=param_grid,
|
|
56
|
+
cv=5,
|
|
57
|
+
scoring="roc_auc",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
grid.fit(X, y)
|
|
61
|
+
|
|
62
|
+
print("Best parameters:", grid.best_params_)
|
|
63
|
+
print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
67
|
+
|
|
68
|
+
## Contributing
|
|
69
|
+
|
|
70
|
+
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
71
|
+
|
|
72
|
+
## Acknowledgements
|
|
73
|
+
|
|
74
|
+
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
75
|
+
We gratefully acknowledge:
|
|
76
|
+
|
|
77
|
+
- [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
|
|
78
|
+
- [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
|
|
79
|
+
- [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
|
|
80
|
+
|
|
81
|
+
## Citation
|
|
82
|
+
|
|
83
|
+
If **combatlearn** is useful in your research, please cite the original
|
|
84
|
+
papers:
|
|
85
|
+
|
|
86
|
+
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
87
|
+
|
|
88
|
+
- Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
|
|
89
|
+
|
|
90
|
+
- Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "combatlearn"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Batch-effect harmonisation for machine learning frameworks."
|
|
9
|
+
authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"pandas>=1.3",
|
|
13
|
+
"numpy>=1.21",
|
|
14
|
+
"scikit-learn>=1.2",
|
|
15
|
+
"pytest>=7"
|
|
16
|
+
]
|
|
17
|
+
license = {file="LICENSE"}
|
|
18
|
+
readme = {file="README.md", content-type="text/markdown"}
|
|
19
|
+
keywords = [
|
|
20
|
+
"machine-learning",
|
|
21
|
+
"harmonization",
|
|
22
|
+
"combat",
|
|
23
|
+
"preprocessing",
|
|
24
|
+
]
|
|
25
|
+
classifiers = [
|
|
26
|
+
"Development Status :: 3 - Alpha",
|
|
27
|
+
"Intended Audience :: Science/Research",
|
|
28
|
+
"License :: OSI Approved :: MIT License",
|
|
29
|
+
"Operating System :: OS Independent",
|
|
30
|
+
"Programming Language :: Python :: 3",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[tool.setuptools.packages.find]
|
|
34
|
+
where = ["src"]
|
|
35
|
+
include = ["combatlearn*"]
|
|
@@ -0,0 +1,584 @@
|
|
|
1
|
+
__author__ = "Ettore Rocchi"
|
|
2
|
+
|
|
3
|
+
"""ComBat algorithm.
|
|
4
|
+
|
|
5
|
+
`ComBatModel` implements both:
|
|
6
|
+
* Johnson et al. (2007) vanilla ComBat (method="johnson")
|
|
7
|
+
* Fortin et al. (2018) extension with covariates (method="fortin")
|
|
8
|
+
* Chen et al. (2022) CovBat (method="chen")
|
|
9
|
+
|
|
10
|
+
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
11
|
+
the batch (and optional covariates) at construction.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.linalg as la
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
18
|
+
from sklearn.utils.validation import check_is_fitted
|
|
19
|
+
from sklearn.decomposition import PCA
|
|
20
|
+
from typing import Literal
|
|
21
|
+
import warnings
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ComBatModel:
|
|
25
|
+
"""ComBat algorithm.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
30
|
+
* 'johnson' – classic ComBat.
|
|
31
|
+
* 'fortin' – covariate‑aware ComBat.
|
|
32
|
+
* 'chen' – CovBat, PCA‑based ComBat.
|
|
33
|
+
parametric : bool, default=True
|
|
34
|
+
Use the parametric empirical Bayes variant.
|
|
35
|
+
mean_only : bool, default=False
|
|
36
|
+
If True, only the mean is adjusted (`gamma_star`),
|
|
37
|
+
ignoring the variance (`delta_star`).
|
|
38
|
+
reference_batch : str, optional
|
|
39
|
+
If specified, the batch level to use as reference.
|
|
40
|
+
covbat_cov_thresh : float, default=0.9
|
|
41
|
+
CovBat: cumulative explained variance threshold for PCA.
|
|
42
|
+
eps : float, default=1e-8
|
|
43
|
+
Numerical jitter to avoid division‑by‑zero.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
*,
|
|
49
|
+
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
50
|
+
parametric: bool = True,
|
|
51
|
+
mean_only: bool = False,
|
|
52
|
+
reference_batch=None,
|
|
53
|
+
eps: float = 1e-8,
|
|
54
|
+
covbat_cov_thresh: float = 0.9,
|
|
55
|
+
) -> None:
|
|
56
|
+
self.method = method
|
|
57
|
+
self.parametric = parametric
|
|
58
|
+
self.mean_only = bool(mean_only)
|
|
59
|
+
self.reference_batch = reference_batch
|
|
60
|
+
self.eps = float(eps)
|
|
61
|
+
self.covbat_cov_thresh = float(covbat_cov_thresh)
|
|
62
|
+
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
63
|
+
raise ValueError("covbat_cov_thresh must be in (0, 1].")
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _as_series(arr, index, name):
|
|
67
|
+
if isinstance(arr, pd.Series):
|
|
68
|
+
ser = arr.copy()
|
|
69
|
+
else:
|
|
70
|
+
ser = pd.Series(arr, index=index, name=name)
|
|
71
|
+
if not ser.index.equals(index):
|
|
72
|
+
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
73
|
+
return ser.astype("category")
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _to_df(arr, index, name):
|
|
77
|
+
if arr is None:
|
|
78
|
+
return None
|
|
79
|
+
if isinstance(arr, pd.Series):
|
|
80
|
+
arr = arr.to_frame()
|
|
81
|
+
if not isinstance(arr, pd.DataFrame):
|
|
82
|
+
arr = pd.DataFrame(arr, index=index)
|
|
83
|
+
if not arr.index.equals(index):
|
|
84
|
+
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
85
|
+
return arr
|
|
86
|
+
|
|
87
|
+
def fit(
|
|
88
|
+
self,
|
|
89
|
+
X,
|
|
90
|
+
y=None,
|
|
91
|
+
*,
|
|
92
|
+
batch,
|
|
93
|
+
discrete_covariates=None,
|
|
94
|
+
continuous_covariates=None,
|
|
95
|
+
):
|
|
96
|
+
method = self.method.lower()
|
|
97
|
+
if method not in {"johnson", "fortin", "chen"}:
|
|
98
|
+
raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
|
|
99
|
+
if not isinstance(X, pd.DataFrame):
|
|
100
|
+
X = pd.DataFrame(X)
|
|
101
|
+
idx = X.index
|
|
102
|
+
batch = self._as_series(batch, idx, "batch")
|
|
103
|
+
|
|
104
|
+
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
105
|
+
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"reference_batch={self.reference_batch!r} not present in the data batches "
|
|
111
|
+
f"{list(batch.cat.categories)}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if method == "johnson":
|
|
115
|
+
if disc is not None or cont is not None:
|
|
116
|
+
warnings.warn(
|
|
117
|
+
"Covariates are ignored when using method='johnson'."
|
|
118
|
+
)
|
|
119
|
+
self._fit_johnson(X, batch)
|
|
120
|
+
elif method == "fortin":
|
|
121
|
+
self._fit_fortin(X, batch, disc, cont)
|
|
122
|
+
elif method == "chen":
|
|
123
|
+
self._fit_chen(X, batch, disc, cont)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def _fit_johnson(
|
|
127
|
+
self,
|
|
128
|
+
X: pd.DataFrame,
|
|
129
|
+
batch: pd.Series
|
|
130
|
+
):
|
|
131
|
+
"""
|
|
132
|
+
Johnson et al. (2007) ComBat.
|
|
133
|
+
"""
|
|
134
|
+
self._batch_levels = batch.cat.categories
|
|
135
|
+
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
136
|
+
grand_mean = X.mean(axis=0)
|
|
137
|
+
|
|
138
|
+
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
139
|
+
|
|
140
|
+
n_per_batch: dict[str, int] = {}
|
|
141
|
+
gamma_hat, delta_hat = [], []
|
|
142
|
+
for lvl in self._batch_levels:
|
|
143
|
+
idx = batch == lvl
|
|
144
|
+
n_b = idx.sum()
|
|
145
|
+
if n_b < 2:
|
|
146
|
+
raise ValueError(f"Batch '{lvl}' has <2 samples.")
|
|
147
|
+
n_per_batch[lvl] = n_b
|
|
148
|
+
xb = Xs.loc[idx]
|
|
149
|
+
gamma_hat.append(xb.mean(axis=0).values)
|
|
150
|
+
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
151
|
+
gamma_hat = np.vstack(gamma_hat)
|
|
152
|
+
delta_hat = np.vstack(delta_hat)
|
|
153
|
+
|
|
154
|
+
if self.mean_only:
|
|
155
|
+
gamma_star = self._shrink_gamma(
|
|
156
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
157
|
+
)
|
|
158
|
+
delta_star = np.ones_like(delta_hat)
|
|
159
|
+
else:
|
|
160
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
161
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if self.reference_batch is not None:
|
|
165
|
+
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
166
|
+
gamma_ref = gamma_star[ref_idx]
|
|
167
|
+
delta_ref = delta_star[ref_idx]
|
|
168
|
+
gamma_star = gamma_star - gamma_ref
|
|
169
|
+
if not self.mean_only:
|
|
170
|
+
delta_star = delta_star / delta_ref
|
|
171
|
+
self._reference_batch_idx = ref_idx
|
|
172
|
+
else:
|
|
173
|
+
self._reference_batch_idx = None
|
|
174
|
+
|
|
175
|
+
self._grand_mean = grand_mean
|
|
176
|
+
self._pooled_var = pooled_var
|
|
177
|
+
self._gamma_star = gamma_star
|
|
178
|
+
self._delta_star = delta_star
|
|
179
|
+
self._n_per_batch = n_per_batch
|
|
180
|
+
|
|
181
|
+
def _fit_fortin(
|
|
182
|
+
self,
|
|
183
|
+
X: pd.DataFrame,
|
|
184
|
+
batch: pd.Series,
|
|
185
|
+
disc: pd.DataFrame | None,
|
|
186
|
+
cont: pd.DataFrame | None,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Fortin et al. (2018) ComBat.
|
|
190
|
+
"""
|
|
191
|
+
batch_levels = batch.cat.categories
|
|
192
|
+
n_batch = len(batch_levels)
|
|
193
|
+
n_samples = len(X)
|
|
194
|
+
|
|
195
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False)
|
|
196
|
+
parts = [batch_dummies]
|
|
197
|
+
if disc is not None:
|
|
198
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
|
|
199
|
+
if cont is not None:
|
|
200
|
+
parts.append(cont)
|
|
201
|
+
design = pd.concat(parts, axis=1).astype(float).values
|
|
202
|
+
p_design = design.shape[1]
|
|
203
|
+
|
|
204
|
+
X_np = X.values
|
|
205
|
+
beta_hat = la.inv(design.T @ design) @ design.T @ X_np
|
|
206
|
+
|
|
207
|
+
gamma_hat = beta_hat[:n_batch]
|
|
208
|
+
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
209
|
+
|
|
210
|
+
n_per_batch = batch.value_counts().sort_index().values
|
|
211
|
+
self._n_per_batch = dict(zip(batch_levels, n_per_batch))
|
|
212
|
+
|
|
213
|
+
grand_mean = (n_per_batch / n_samples) @ gamma_hat
|
|
214
|
+
self._grand_mean = grand_mean
|
|
215
|
+
|
|
216
|
+
resid = X_np - design @ beta_hat
|
|
217
|
+
var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
|
|
218
|
+
self._pooled_var = var_pooled
|
|
219
|
+
|
|
220
|
+
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
221
|
+
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
222
|
+
|
|
223
|
+
delta_hat = np.empty_like(gamma_hat)
|
|
224
|
+
for i, lvl in enumerate(batch_levels):
|
|
225
|
+
idx = batch == lvl
|
|
226
|
+
delta_hat[i] = Xs[idx].var(axis=0, ddof=1) + self.eps
|
|
227
|
+
|
|
228
|
+
if self.mean_only:
|
|
229
|
+
gamma_star = self._shrink_gamma(
|
|
230
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
231
|
+
)
|
|
232
|
+
delta_star = np.ones_like(delta_hat)
|
|
233
|
+
else:
|
|
234
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
235
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if self.reference_batch is not None:
|
|
239
|
+
ref_idx = list(batch_levels).index(self.reference_batch)
|
|
240
|
+
gamma_ref = gamma_star[ref_idx]
|
|
241
|
+
delta_ref = delta_star[ref_idx]
|
|
242
|
+
gamma_star = gamma_star - gamma_ref
|
|
243
|
+
if not self.mean_only:
|
|
244
|
+
delta_star = delta_star / delta_ref
|
|
245
|
+
self._reference_batch_idx = ref_idx
|
|
246
|
+
else:
|
|
247
|
+
self._reference_batch_idx = None
|
|
248
|
+
|
|
249
|
+
self._batch_levels = batch_levels
|
|
250
|
+
self._gamma_star = gamma_star
|
|
251
|
+
self._delta_star = delta_star
|
|
252
|
+
self._n_batch = n_batch
|
|
253
|
+
self._p_design = p_design
|
|
254
|
+
|
|
255
|
+
def _fit_chen(
|
|
256
|
+
self,
|
|
257
|
+
X: pd.DataFrame,
|
|
258
|
+
batch: pd.Series,
|
|
259
|
+
disc: pd.DataFrame | None,
|
|
260
|
+
cont: pd.DataFrame | None,
|
|
261
|
+
):
|
|
262
|
+
self._fit_fortin(X, batch, disc, cont)
|
|
263
|
+
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
264
|
+
X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
|
|
265
|
+
pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
|
|
266
|
+
cumulative = np.cumsum(pca.explained_variance_ratio_)
|
|
267
|
+
n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
|
|
268
|
+
self._covbat_pca = pca
|
|
269
|
+
self._covbat_n_pc = n_pc
|
|
270
|
+
|
|
271
|
+
scores = pca.transform(X_centered)[:, :n_pc]
|
|
272
|
+
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
|
|
273
|
+
self._batch_levels_pc = self._batch_levels
|
|
274
|
+
n_per_batch = self._n_per_batch
|
|
275
|
+
|
|
276
|
+
gamma_hat, delta_hat = [], []
|
|
277
|
+
for lvl in self._batch_levels_pc:
|
|
278
|
+
idx = batch == lvl
|
|
279
|
+
xb = scores_df.loc[idx]
|
|
280
|
+
gamma_hat.append(xb.mean(axis=0).values)
|
|
281
|
+
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
282
|
+
gamma_hat = np.vstack(gamma_hat)
|
|
283
|
+
delta_hat = np.vstack(delta_hat)
|
|
284
|
+
|
|
285
|
+
if self.mean_only:
|
|
286
|
+
gamma_star = self._shrink_gamma(
|
|
287
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
288
|
+
)
|
|
289
|
+
delta_star = np.ones_like(delta_hat)
|
|
290
|
+
else:
|
|
291
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
292
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if self.reference_batch is not None:
|
|
296
|
+
ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
|
|
297
|
+
gamma_ref = gamma_star[ref_idx]
|
|
298
|
+
delta_ref = delta_star[ref_idx]
|
|
299
|
+
gamma_star = gamma_star - gamma_ref
|
|
300
|
+
if not self.mean_only:
|
|
301
|
+
delta_star = delta_star / delta_ref
|
|
302
|
+
|
|
303
|
+
self._pc_gamma_star = gamma_star
|
|
304
|
+
self._pc_delta_star = delta_star
|
|
305
|
+
|
|
306
|
+
def _shrink_gamma_delta(
|
|
307
|
+
self,
|
|
308
|
+
gamma_hat: np.ndarray,
|
|
309
|
+
delta_hat: np.ndarray,
|
|
310
|
+
n_per_batch: dict | np.ndarray,
|
|
311
|
+
*,
|
|
312
|
+
parametric: bool,
|
|
313
|
+
max_iter: int = 100,
|
|
314
|
+
tol: float = 1e-4,
|
|
315
|
+
):
|
|
316
|
+
if parametric:
|
|
317
|
+
gamma_bar = gamma_hat.mean(axis=0)
|
|
318
|
+
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
319
|
+
a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
|
|
320
|
+
b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
|
|
321
|
+
|
|
322
|
+
B, p = gamma_hat.shape
|
|
323
|
+
gamma_star = np.empty_like(gamma_hat)
|
|
324
|
+
delta_star = np.empty_like(delta_hat)
|
|
325
|
+
n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
|
|
326
|
+
for i in range(B):
|
|
327
|
+
n_i = n_vec[i]
|
|
328
|
+
g, d = gamma_hat[i], delta_hat[i]
|
|
329
|
+
gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
|
|
330
|
+
gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
|
|
331
|
+
|
|
332
|
+
a_post = a_prior + n_i / 2.0
|
|
333
|
+
b_post = b_prior + 0.5 * n_i * d
|
|
334
|
+
delta_star[i] = b_post / (a_post - 1)
|
|
335
|
+
return gamma_star, delta_star
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
B, p = gamma_hat.shape
|
|
339
|
+
n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
|
|
340
|
+
gamma_bar = gamma_hat.mean(axis=0)
|
|
341
|
+
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
342
|
+
|
|
343
|
+
def postmean(g_hat, g_bar, n, d_star, t2_):
|
|
344
|
+
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
345
|
+
|
|
346
|
+
def postvar(sum2, n, a, b):
|
|
347
|
+
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
348
|
+
|
|
349
|
+
def aprior(delta):
|
|
350
|
+
m, s2 = delta.mean(), delta.var()
|
|
351
|
+
s2 = max(s2, self.eps)
|
|
352
|
+
return (2 * s2 + m ** 2) / s2
|
|
353
|
+
|
|
354
|
+
def bprior(delta):
|
|
355
|
+
m, s2 = delta.mean(), delta.var()
|
|
356
|
+
s2 = max(s2, self.eps)
|
|
357
|
+
return (m * s2 + m ** 3) / s2
|
|
358
|
+
|
|
359
|
+
gamma_star = np.empty_like(gamma_hat)
|
|
360
|
+
delta_star = np.empty_like(delta_hat)
|
|
361
|
+
|
|
362
|
+
for i in range(B):
|
|
363
|
+
n_i = n_vec[i]
|
|
364
|
+
g_hat_i = gamma_hat[i]
|
|
365
|
+
d_hat_i = delta_hat[i]
|
|
366
|
+
a_i = aprior(d_hat_i)
|
|
367
|
+
b_i = bprior(d_hat_i)
|
|
368
|
+
|
|
369
|
+
g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
|
|
370
|
+
for _ in range(max_iter):
|
|
371
|
+
g_prev, d_prev = g_new, d_new
|
|
372
|
+
g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
|
|
373
|
+
sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
|
|
374
|
+
d_new = postvar(sum2, n_i, a_i, b_i)
|
|
375
|
+
if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
|
|
376
|
+
self.mean_only or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
|
|
377
|
+
):
|
|
378
|
+
break
|
|
379
|
+
gamma_star[i] = g_new
|
|
380
|
+
delta_star[i] = 1.0 if self.mean_only else d_new
|
|
381
|
+
return gamma_star, delta_star
|
|
382
|
+
|
|
383
|
+
def _shrink_gamma(
|
|
384
|
+
self,
|
|
385
|
+
gamma_hat: np.ndarray,
|
|
386
|
+
delta_hat: np.ndarray,
|
|
387
|
+
n_per_batch: dict | np.ndarray,
|
|
388
|
+
*,
|
|
389
|
+
parametric: bool,
|
|
390
|
+
) -> np.ndarray:
|
|
391
|
+
"""Convenience wrapper that returns only γ⋆ (for *mean‑only* mode)."""
|
|
392
|
+
gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
|
|
393
|
+
return gamma
|
|
394
|
+
|
|
395
|
+
def transform(
|
|
396
|
+
self,
|
|
397
|
+
X,
|
|
398
|
+
*,
|
|
399
|
+
batch,
|
|
400
|
+
discrete_covariates=None,
|
|
401
|
+
continuous_covariates=None,
|
|
402
|
+
):
|
|
403
|
+
check_is_fitted(self, ["_gamma_star"])
|
|
404
|
+
if not isinstance(X, pd.DataFrame):
|
|
405
|
+
X = pd.DataFrame(X)
|
|
406
|
+
idx = X.index
|
|
407
|
+
batch = self._as_series(batch, idx, "batch")
|
|
408
|
+
unseen = set(batch.cat.categories) - set(self._batch_levels)
|
|
409
|
+
if unseen:
|
|
410
|
+
raise ValueError(f"Unseen batch levels during transform: {unseen}")
|
|
411
|
+
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
412
|
+
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
413
|
+
|
|
414
|
+
method = self.method.lower()
|
|
415
|
+
if method == "johnson":
|
|
416
|
+
return self._transform_johnson(X, batch)
|
|
417
|
+
elif method == "fortin":
|
|
418
|
+
return self._transform_fortin(X, batch, disc, cont)
|
|
419
|
+
elif method == "chen":
|
|
420
|
+
return self._transform_chen(X, batch, disc, cont)
|
|
421
|
+
|
|
422
|
+
def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series):
|
|
423
|
+
pooled = self._pooled_var
|
|
424
|
+
grand = self._grand_mean
|
|
425
|
+
|
|
426
|
+
Xs = (X - grand) / np.sqrt(pooled)
|
|
427
|
+
X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
|
|
428
|
+
|
|
429
|
+
for i, lvl in enumerate(self._batch_levels):
|
|
430
|
+
idx = batch == lvl
|
|
431
|
+
if not idx.any():
|
|
432
|
+
continue
|
|
433
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
434
|
+
X_adj.loc[idx] = X.loc[idx].values # untouched
|
|
435
|
+
continue
|
|
436
|
+
|
|
437
|
+
g = self._gamma_star[i]
|
|
438
|
+
d = self._delta_star[i]
|
|
439
|
+
if self.mean_only:
|
|
440
|
+
Xb = Xs.loc[idx] - g
|
|
441
|
+
else:
|
|
442
|
+
Xb = (Xs.loc[idx] - g) / np.sqrt(d)
|
|
443
|
+
X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
|
|
444
|
+
return X_adj
|
|
445
|
+
|
|
446
|
+
def _transform_fortin(
|
|
447
|
+
self,
|
|
448
|
+
X: pd.DataFrame,
|
|
449
|
+
batch: pd.Series,
|
|
450
|
+
disc: pd.DataFrame | None,
|
|
451
|
+
cont: pd.DataFrame | None,
|
|
452
|
+
):
|
|
453
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
|
|
454
|
+
parts = [batch_dummies]
|
|
455
|
+
if disc is not None:
|
|
456
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
|
|
457
|
+
if cont is not None:
|
|
458
|
+
parts.append(cont)
|
|
459
|
+
|
|
460
|
+
design = pd.concat(parts, axis=1).astype(float).values
|
|
461
|
+
|
|
462
|
+
X_np = X.values
|
|
463
|
+
stand_mean = self._grand_mean + design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
464
|
+
Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var)
|
|
465
|
+
|
|
466
|
+
for i, lvl in enumerate(self._batch_levels):
|
|
467
|
+
idx = batch == lvl
|
|
468
|
+
if not idx.any():
|
|
469
|
+
continue
|
|
470
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
471
|
+
# leave reference samples unchanged
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
g = self._gamma_star[i]
|
|
475
|
+
d = self._delta_star[i]
|
|
476
|
+
if self.mean_only:
|
|
477
|
+
Xs[idx] = Xs[idx] - g
|
|
478
|
+
else:
|
|
479
|
+
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
480
|
+
|
|
481
|
+
X_adj = Xs * np.sqrt(self._pooled_var) + stand_mean
|
|
482
|
+
return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
|
|
483
|
+
|
|
484
|
+
def _transform_chen(
|
|
485
|
+
self,
|
|
486
|
+
X: pd.DataFrame,
|
|
487
|
+
batch: pd.Series,
|
|
488
|
+
disc: pd.DataFrame | None,
|
|
489
|
+
cont: pd.DataFrame | None,
|
|
490
|
+
):
|
|
491
|
+
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
492
|
+
X_centered = X_meanvar_adj - self._covbat_pca.mean_
|
|
493
|
+
scores = self._covbat_pca.transform(X_centered)
|
|
494
|
+
n_pc = self._covbat_n_pc
|
|
495
|
+
scores_adj = scores.copy()
|
|
496
|
+
|
|
497
|
+
for i, lvl in enumerate(self._batch_levels_pc):
|
|
498
|
+
idx = batch == lvl
|
|
499
|
+
if not idx.any():
|
|
500
|
+
continue
|
|
501
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
502
|
+
continue
|
|
503
|
+
g = self._pc_gamma_star[i]
|
|
504
|
+
d = self._pc_delta_star[i]
|
|
505
|
+
if self.mean_only:
|
|
506
|
+
scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
|
|
507
|
+
else:
|
|
508
|
+
scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
|
|
509
|
+
|
|
510
|
+
X_recon = self._covbat_pca.inverse_transform(scores_adj) + self._covbat_pca.mean_
|
|
511
|
+
return pd.DataFrame(X_recon, index=X.index, columns=X.columns)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class ComBat(BaseEstimator, TransformerMixin):
|
|
515
|
+
"""Pipeline‑friendly wrapper around `ComBatModel`.
|
|
516
|
+
|
|
517
|
+
Stores batch (and optional covariates) passed at construction and
|
|
518
|
+
appropriately used them also for separate `fit` and `transform`.
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
def __init__(
|
|
522
|
+
self,
|
|
523
|
+
batch,
|
|
524
|
+
*,
|
|
525
|
+
discrete_covariates=None,
|
|
526
|
+
continuous_covariates=None,
|
|
527
|
+
method: str = "johnson",
|
|
528
|
+
parametric: bool = True,
|
|
529
|
+
mean_only: bool = False,
|
|
530
|
+
reference_batch=None,
|
|
531
|
+
eps: float = 1e-8,
|
|
532
|
+
covbat_cov_thresh: float = 0.9,
|
|
533
|
+
) -> None:
|
|
534
|
+
self.batch = batch
|
|
535
|
+
self.discrete_covariates = discrete_covariates
|
|
536
|
+
self.continuous_covariates = continuous_covariates
|
|
537
|
+
self.method = method
|
|
538
|
+
self.parametric = parametric
|
|
539
|
+
self.mean_only = mean_only
|
|
540
|
+
self.reference_batch = reference_batch
|
|
541
|
+
self.eps = eps
|
|
542
|
+
self.covbat_cov_thresh = covbat_cov_thresh
|
|
543
|
+
self._model = ComBatModel(
|
|
544
|
+
method=method,
|
|
545
|
+
parametric=parametric,
|
|
546
|
+
mean_only=mean_only,
|
|
547
|
+
reference_batch=reference_batch,
|
|
548
|
+
eps=eps,
|
|
549
|
+
covbat_cov_thresh=covbat_cov_thresh,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
def fit(self, X, y=None):
|
|
553
|
+
idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
|
|
554
|
+
batch_vec = self._subset(self.batch, idx)
|
|
555
|
+
disc = self._subset(self.discrete_covariates, idx)
|
|
556
|
+
cont = self._subset(self.continuous_covariates, idx)
|
|
557
|
+
self._model.fit(
|
|
558
|
+
X,
|
|
559
|
+
batch=batch_vec,
|
|
560
|
+
discrete_covariates=disc,
|
|
561
|
+
continuous_covariates=cont,
|
|
562
|
+
)
|
|
563
|
+
return self
|
|
564
|
+
|
|
565
|
+
def transform(self, X):
|
|
566
|
+
idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
|
|
567
|
+
batch_vec = self._subset(self.batch, idx)
|
|
568
|
+
disc = self._subset(self.discrete_covariates, idx)
|
|
569
|
+
cont = self._subset(self.continuous_covariates, idx)
|
|
570
|
+
return self._model.transform(
|
|
571
|
+
X,
|
|
572
|
+
batch=batch_vec,
|
|
573
|
+
discrete_covariates=disc,
|
|
574
|
+
continuous_covariates=cont,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
@staticmethod
|
|
578
|
+
def _subset(obj, idx):
|
|
579
|
+
if obj is None:
|
|
580
|
+
return None
|
|
581
|
+
if isinstance(obj, (pd.Series, pd.DataFrame)):
|
|
582
|
+
return obj.loc[idx]
|
|
583
|
+
else:
|
|
584
|
+
return pd.DataFrame(obj).iloc[idx]
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: combatlearn
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Batch-effect harmonisation for machine learning frameworks.
|
|
5
|
+
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Ettore Rocchi
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Keywords: machine-learning,harmonization,combat,preprocessing
|
|
29
|
+
Classifier: Development Status :: 3 - Alpha
|
|
30
|
+
Classifier: Intended Audience :: Science/Research
|
|
31
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
32
|
+
Classifier: Operating System :: OS Independent
|
|
33
|
+
Classifier: Programming Language :: Python :: 3
|
|
34
|
+
Requires-Python: >=3.9
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
License-File: LICENSE
|
|
37
|
+
Requires-Dist: pandas>=1.3
|
|
38
|
+
Requires-Dist: numpy>=1.21
|
|
39
|
+
Requires-Dist: scikit-learn>=1.2
|
|
40
|
+
Requires-Dist: pytest>=7
|
|
41
|
+
Dynamic: license-file
|
|
42
|
+
|
|
43
|
+
# **combatlearn**
|
|
44
|
+
|
|
45
|
+
<div align="center">
|
|
46
|
+
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
47
|
+
</div>
|
|
48
|
+
|
|
49
|
+
**combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
|
|
50
|
+
|
|
51
|
+
**Three methods**:
|
|
52
|
+
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
53
|
+
- `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
|
|
54
|
+
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install combatlearn
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## Quick start
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
import pandas as pd
|
|
66
|
+
from sklearn.pipeline import Pipeline
|
|
67
|
+
from sklearn.preprocessing import StandardScaler
|
|
68
|
+
from sklearn.linear_model import LogisticRegression
|
|
69
|
+
from combatlearn import ComBat
|
|
70
|
+
|
|
71
|
+
df = pd.read_csv("data.csv", index_col=0)
|
|
72
|
+
X, y = df.drop(columns="y"), df["y"]
|
|
73
|
+
|
|
74
|
+
batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
|
|
75
|
+
diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
|
|
76
|
+
age = pd.read_csv("age.csv", index_col=0) # continuous
|
|
77
|
+
|
|
78
|
+
pipe = Pipeline([
|
|
79
|
+
("combat", ComBat(
|
|
80
|
+
batch=batch,
|
|
81
|
+
discrete_covariates=diag,
|
|
82
|
+
continuous_covariates=age,
|
|
83
|
+
method="fortin", # or "johnson" or "chen"
|
|
84
|
+
parametric=True
|
|
85
|
+
)),
|
|
86
|
+
("scaler", StandardScaler()),
|
|
87
|
+
("clf", LogisticRegression())
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
param_grid = {
|
|
91
|
+
"combat__mean_only": [True, False],
|
|
92
|
+
"clf__C": [0.01, 0.1, 1, 10],
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
grid = GridSearchCV(
|
|
96
|
+
estimator=pipe,
|
|
97
|
+
param_grid=param_grid,
|
|
98
|
+
cv=5,
|
|
99
|
+
scoring="roc_auc",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
grid.fit(X, y)
|
|
103
|
+
|
|
104
|
+
print("Best parameters:", grid.best_params_)
|
|
105
|
+
print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
109
|
+
|
|
110
|
+
## Contributing
|
|
111
|
+
|
|
112
|
+
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
113
|
+
|
|
114
|
+
## Acknowledgements
|
|
115
|
+
|
|
116
|
+
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
117
|
+
We gratefully acknowledge:
|
|
118
|
+
|
|
119
|
+
- [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
|
|
120
|
+
- [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
|
|
121
|
+
- [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
|
|
122
|
+
|
|
123
|
+
## Citation
|
|
124
|
+
|
|
125
|
+
If **combatlearn** is useful in your research, please cite the original
|
|
126
|
+
papers:
|
|
127
|
+
|
|
128
|
+
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
129
|
+
|
|
130
|
+
- Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
|
|
131
|
+
|
|
132
|
+
- Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/combatlearn/__init__.py
|
|
5
|
+
src/combatlearn/combat.py
|
|
6
|
+
src/combatlearn.egg-info/PKG-INFO
|
|
7
|
+
src/combatlearn.egg-info/SOURCES.txt
|
|
8
|
+
src/combatlearn.egg-info/dependency_links.txt
|
|
9
|
+
src/combatlearn.egg-info/requires.txt
|
|
10
|
+
src/combatlearn.egg-info/top_level.txt
|
|
11
|
+
tests/test_combat.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
combatlearn
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import pytest
|
|
4
|
+
from sklearn.pipeline import Pipeline
|
|
5
|
+
from sklearn.base import clone
|
|
6
|
+
from sklearn.preprocessing import StandardScaler
|
|
7
|
+
from sklearn.exceptions import NotFittedError
|
|
8
|
+
from combatlearn import ComBatModel, ComBat
|
|
9
|
+
from utils import simulate_data, simulate_covariate_data
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_transform_without_fit_raises():
|
|
13
|
+
"""
|
|
14
|
+
Test that `transform` raises a `NotFittedError` if not fitted.
|
|
15
|
+
"""
|
|
16
|
+
X, batch = simulate_data()
|
|
17
|
+
model = ComBatModel()
|
|
18
|
+
with pytest.raises(NotFittedError):
|
|
19
|
+
model.transform(X, batch=batch)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_unseen_batch_raises_value_error():
|
|
23
|
+
"""
|
|
24
|
+
Test that unseen batch raises a `ValueError`.
|
|
25
|
+
"""
|
|
26
|
+
X, batch = simulate_data()
|
|
27
|
+
model = ComBatModel().fit(X, batch=batch)
|
|
28
|
+
new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
|
|
29
|
+
with pytest.raises(ValueError):
|
|
30
|
+
model.transform(X, batch=new_batch)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_single_sample_batch_error():
|
|
34
|
+
"""
|
|
35
|
+
Test that a single sample batch raises a `ValueError`.
|
|
36
|
+
"""
|
|
37
|
+
X, batch = simulate_data()
|
|
38
|
+
batch.iloc[0] = "single"
|
|
39
|
+
with pytest.raises(ValueError):
|
|
40
|
+
ComBatModel().fit(X, batch=batch)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
44
|
+
def test_dtypes_preserved(method):
|
|
45
|
+
"""All output columns must remain floating dtypes after correction."""
|
|
46
|
+
if method == "johnson":
|
|
47
|
+
X, batch = simulate_data()
|
|
48
|
+
extra = {}
|
|
49
|
+
else: # fortin or chen
|
|
50
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
51
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
52
|
+
|
|
53
|
+
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
54
|
+
assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
|
|
55
|
+
|
|
56
|
+
def test_wrapper_clone_and_pipeline():
|
|
57
|
+
"""
|
|
58
|
+
Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
|
|
59
|
+
"""
|
|
60
|
+
X, batch = simulate_data()
|
|
61
|
+
wrapper = ComBat(batch=batch, parametric=True)
|
|
62
|
+
pipe = Pipeline([
|
|
63
|
+
("scaler", StandardScaler()),
|
|
64
|
+
("combat", wrapper),
|
|
65
|
+
])
|
|
66
|
+
X_corr = pipe.fit_transform(X)
|
|
67
|
+
pipe_clone: Pipeline = clone(pipe)
|
|
68
|
+
X_corr2 = pipe_clone.fit_transform(X)
|
|
69
|
+
np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
73
|
+
def test_no_nan_or_inf_in_output(method):
|
|
74
|
+
"""`ComBat` must not introduce NaN or Inf values, for any backend."""
|
|
75
|
+
if method == "johnson":
|
|
76
|
+
X, batch = simulate_data()
|
|
77
|
+
extra = {}
|
|
78
|
+
else: # fortin or chen
|
|
79
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
80
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
81
|
+
|
|
82
|
+
X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
|
|
83
|
+
assert not np.isnan(X_corr.values).any()
|
|
84
|
+
assert not np.isinf(X_corr.values).any()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
88
|
+
def test_shape_preserved(method):
|
|
89
|
+
"""The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
|
|
90
|
+
if method == "johnson":
|
|
91
|
+
X, batch = simulate_data()
|
|
92
|
+
combat = ComBat(batch=batch, method=method).fit(X)
|
|
93
|
+
elif method in ["fortin", "chen"]:
|
|
94
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
95
|
+
combat = ComBat(
|
|
96
|
+
batch=batch,
|
|
97
|
+
discrete_covariates=disc,
|
|
98
|
+
continuous_covariates=cont,
|
|
99
|
+
method=method,
|
|
100
|
+
).fit(X)
|
|
101
|
+
|
|
102
|
+
X_corr = combat.transform(X)
|
|
103
|
+
assert X_corr.shape == X.shape
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_johnson_print_warning():
|
|
107
|
+
"""
|
|
108
|
+
Test that a warning is printed when using the Johnson method.
|
|
109
|
+
"""
|
|
110
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
111
|
+
with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
|
|
112
|
+
_ = ComBat(
|
|
113
|
+
batch=batch,
|
|
114
|
+
discrete_covariates=disc,
|
|
115
|
+
continuous_covariates=cont,
|
|
116
|
+
method="johnson",
|
|
117
|
+
).fit(X)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
|
|
121
|
+
def test_reference_batch_samples_unchanged(method):
|
|
122
|
+
"""
|
|
123
|
+
Samples belonging to the reference batch must come out *numerically identical*
|
|
124
|
+
(within floating-point jitter) after correction.
|
|
125
|
+
"""
|
|
126
|
+
if method == "johnson":
|
|
127
|
+
X, batch = simulate_data()
|
|
128
|
+
extra = {}
|
|
129
|
+
elif method in ["fortin", "chen"]:
|
|
130
|
+
X, batch, disc, cont = simulate_covariate_data()
|
|
131
|
+
extra = dict(discrete_covariates=disc, continuous_covariates=cont)
|
|
132
|
+
|
|
133
|
+
ref_batch = batch.iloc[0]
|
|
134
|
+
combat = ComBat(batch=batch, method=method,
|
|
135
|
+
reference_batch=ref_batch, **extra).fit(X)
|
|
136
|
+
X_corr = combat.transform(X)
|
|
137
|
+
|
|
138
|
+
mask = batch == ref_batch
|
|
139
|
+
np.testing.assert_allclose(X_corr.loc[mask].values,
|
|
140
|
+
X.loc[mask].values,
|
|
141
|
+
rtol=0, atol=1e-10)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_reference_batch_missing_raises():
|
|
145
|
+
"""
|
|
146
|
+
Asking for a reference batch that doesn't exist should fail.
|
|
147
|
+
"""
|
|
148
|
+
X, batch = simulate_data()
|
|
149
|
+
with pytest.raises(ValueError, match="not present"):
|
|
150
|
+
ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)
|