combatlearn 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ettore Rocchi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,132 @@
1
+ Metadata-Version: 2.4
2
+ Name: combatlearn
3
+ Version: 0.1.0
4
+ Summary: Batch-effect harmonisation for machine learning frameworks.
5
+ Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Ettore Rocchi
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Keywords: machine-learning,harmonization,combat,preprocessing
29
+ Classifier: Development Status :: 3 - Alpha
30
+ Classifier: Intended Audience :: Science/Research
31
+ Classifier: License :: OSI Approved :: MIT License
32
+ Classifier: Operating System :: OS Independent
33
+ Classifier: Programming Language :: Python :: 3
34
+ Requires-Python: >=3.9
35
+ Description-Content-Type: text/markdown
36
+ License-File: LICENSE
37
+ Requires-Dist: pandas>=1.3
38
+ Requires-Dist: numpy>=1.21
39
+ Requires-Dist: scikit-learn>=1.2
40
+ Requires-Dist: pytest>=7
41
+ Dynamic: license-file
42
+
43
+ # **combatlearn**
44
+
45
+ <div align="center">
46
+ <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
47
+ </div>
48
+
49
+ **combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
50
+
51
+ **Three methods**:
52
+ - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
53
+ - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
54
+ - `method="chen"` - CovBat (Chen _et al._, 2022)
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ pip install combatlearn
60
+ ```
61
+
62
+ ## Quick start
63
+
64
+ ```python
65
+ import pandas as pd
66
+ from sklearn.pipeline import Pipeline
67
+ from sklearn.preprocessing import StandardScaler
68
+ from sklearn.linear_model import LogisticRegression
69
+ from combatlearn import ComBat
70
+
71
+ df = pd.read_csv("data.csv", index_col=0)
72
+ X, y = df.drop(columns="y"), df["y"]
73
+
74
+ batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
75
+ diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
76
+ age = pd.read_csv("age.csv", index_col=0) # continuous
77
+
78
+ pipe = Pipeline([
79
+ ("combat", ComBat(
80
+ batch=batch,
81
+ discrete_covariates=diag,
82
+ continuous_covariates=age,
83
+ method="fortin", # or "johnson" or "chen"
84
+ parametric=True
85
+ )),
86
+ ("scaler", StandardScaler()),
87
+ ("clf", LogisticRegression())
88
+ ])
89
+
90
+ param_grid = {
91
+ "combat__mean_only": [True, False],
92
+ "clf__C": [0.01, 0.1, 1, 10],
93
+ }
94
+
95
+ grid = GridSearchCV(
96
+ estimator=pipe,
97
+ param_grid=param_grid,
98
+ cv=5,
99
+ scoring="roc_auc",
100
+ )
101
+
102
+ grid.fit(X, y)
103
+
104
+ print("Best parameters:", grid.best_params_)
105
+ print(f"Best CV AUROC: {grid.best_score_:.3f}")
106
+ ```
107
+
108
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
109
+
110
+ ## Contributing
111
+
112
+ Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
113
+
114
+ ## Acknowledgements
115
+
116
+ This project builds on the excellent work of the ComBat family of harmonisation methods.
117
+ We gratefully acknowledge:
118
+
119
+ - [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
120
+ - [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
121
+ - [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
122
+
123
+ ## Citation
124
+
125
+ If **combatlearn** is useful in your research, please cite the original
126
+ papers:
127
+
128
+ - Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
129
+
130
+ - Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
131
+
132
+ - Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
@@ -0,0 +1,90 @@
1
+ # **combatlearn**
2
+
3
+ <div align="center">
4
+ <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
5
+ </div>
6
+
7
+ **combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
8
+
9
+ **Three methods**:
10
+ - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
11
+ - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
12
+ - `method="chen"` - CovBat (Chen _et al._, 2022)
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ pip install combatlearn
18
+ ```
19
+
20
+ ## Quick start
21
+
22
+ ```python
23
+ import pandas as pd
24
+ from sklearn.pipeline import Pipeline
25
+ from sklearn.preprocessing import StandardScaler
26
+ from sklearn.linear_model import LogisticRegression
27
+ from combatlearn import ComBat
28
+
29
+ df = pd.read_csv("data.csv", index_col=0)
30
+ X, y = df.drop(columns="y"), df["y"]
31
+
32
+ batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
33
+ diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
34
+ age = pd.read_csv("age.csv", index_col=0) # continuous
35
+
36
+ pipe = Pipeline([
37
+ ("combat", ComBat(
38
+ batch=batch,
39
+ discrete_covariates=diag,
40
+ continuous_covariates=age,
41
+ method="fortin", # or "johnson" or "chen"
42
+ parametric=True
43
+ )),
44
+ ("scaler", StandardScaler()),
45
+ ("clf", LogisticRegression())
46
+ ])
47
+
48
+ param_grid = {
49
+ "combat__mean_only": [True, False],
50
+ "clf__C": [0.01, 0.1, 1, 10],
51
+ }
52
+
53
+ grid = GridSearchCV(
54
+ estimator=pipe,
55
+ param_grid=param_grid,
56
+ cv=5,
57
+ scoring="roc_auc",
58
+ )
59
+
60
+ grid.fit(X, y)
61
+
62
+ print("Best parameters:", grid.best_params_)
63
+ print(f"Best CV AUROC: {grid.best_score_:.3f}")
64
+ ```
65
+
66
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
67
+
68
+ ## Contributing
69
+
70
+ Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
71
+
72
+ ## Acknowledgements
73
+
74
+ This project builds on the excellent work of the ComBat family of harmonisation methods.
75
+ We gratefully acknowledge:
76
+
77
+ - [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
78
+ - [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
79
+ - [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
80
+
81
+ ## Citation
82
+
83
+ If **combatlearn** is useful in your research, please cite the original
84
+ papers:
85
+
86
+ - Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
87
+
88
+ - Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
89
+
90
+ - Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
@@ -0,0 +1,35 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "combatlearn"
7
+ version = "0.1.0"
8
+ description = "Batch-effect harmonisation for machine learning frameworks."
9
+ authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
10
+ requires-python = ">=3.9"
11
+ dependencies = [
12
+ "pandas>=1.3",
13
+ "numpy>=1.21",
14
+ "scikit-learn>=1.2",
15
+ "pytest>=7"
16
+ ]
17
+ license = {file="LICENSE"}
18
+ readme = {file="README.md", content-type="text/markdown"}
19
+ keywords = [
20
+ "machine-learning",
21
+ "harmonization",
22
+ "combat",
23
+ "preprocessing",
24
+ ]
25
+ classifiers = [
26
+ "Development Status :: 3 - Alpha",
27
+ "Intended Audience :: Science/Research",
28
+ "License :: OSI Approved :: MIT License",
29
+ "Operating System :: OS Independent",
30
+ "Programming Language :: Python :: 3",
31
+ ]
32
+
33
+ [tool.setuptools.packages.find]
34
+ where = ["src"]
35
+ include = ["combatlearn*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,4 @@
1
+ from .combat import ComBatModel, ComBat
2
+
3
+ __all__ = ["ComBatModel", "ComBat"]
4
+ __version__ = "0.1.0"
@@ -0,0 +1,584 @@
1
+ __author__ = "Ettore Rocchi"
2
+
3
+ """ComBat algorithm.
4
+
5
+ `ComBatModel` implements both:
6
+ * Johnson et al. (2007) vanilla ComBat (method="johnson")
7
+ * Fortin et al. (2018) extension with covariates (method="fortin")
8
+ * Chen et al. (2022) CovBat (method="chen")
9
+
10
+ `ComBat` makes the model compatible with scikit-learn by stashing
11
+ the batch (and optional covariates) at construction.
12
+ """
13
+
14
+ import numpy as np
15
+ import numpy.linalg as la
16
+ import pandas as pd
17
+ from sklearn.base import BaseEstimator, TransformerMixin
18
+ from sklearn.utils.validation import check_is_fitted
19
+ from sklearn.decomposition import PCA
20
+ from typing import Literal
21
+ import warnings
22
+
23
+
24
+ class ComBatModel:
25
+ """ComBat algorithm.
26
+
27
+ Parameters
28
+ ----------
29
+ method : {'johnson', 'fortin', 'chen'}, default='johnson'
30
+ * 'johnson' – classic ComBat.
31
+ * 'fortin' – covariate‑aware ComBat.
32
+ * 'chen' – CovBat, PCA‑based ComBat.
33
+ parametric : bool, default=True
34
+ Use the parametric empirical Bayes variant.
35
+ mean_only : bool, default=False
36
+ If True, only the mean is adjusted (`gamma_star`),
37
+ ignoring the variance (`delta_star`).
38
+ reference_batch : str, optional
39
+ If specified, the batch level to use as reference.
40
+ covbat_cov_thresh : float, default=0.9
41
+ CovBat: cumulative explained variance threshold for PCA.
42
+ eps : float, default=1e-8
43
+ Numerical jitter to avoid division‑by‑zero.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ *,
49
+ method: Literal["johnson", "fortin", "chen"] = "johnson",
50
+ parametric: bool = True,
51
+ mean_only: bool = False,
52
+ reference_batch=None,
53
+ eps: float = 1e-8,
54
+ covbat_cov_thresh: float = 0.9,
55
+ ) -> None:
56
+ self.method = method
57
+ self.parametric = parametric
58
+ self.mean_only = bool(mean_only)
59
+ self.reference_batch = reference_batch
60
+ self.eps = float(eps)
61
+ self.covbat_cov_thresh = float(covbat_cov_thresh)
62
+ if not (0.0 < self.covbat_cov_thresh <= 1.0):
63
+ raise ValueError("covbat_cov_thresh must be in (0, 1].")
64
+
65
+ @staticmethod
66
+ def _as_series(arr, index, name):
67
+ if isinstance(arr, pd.Series):
68
+ ser = arr.copy()
69
+ else:
70
+ ser = pd.Series(arr, index=index, name=name)
71
+ if not ser.index.equals(index):
72
+ raise ValueError(f"`{name}` index mismatch with `X`.")
73
+ return ser.astype("category")
74
+
75
+ @staticmethod
76
+ def _to_df(arr, index, name):
77
+ if arr is None:
78
+ return None
79
+ if isinstance(arr, pd.Series):
80
+ arr = arr.to_frame()
81
+ if not isinstance(arr, pd.DataFrame):
82
+ arr = pd.DataFrame(arr, index=index)
83
+ if not arr.index.equals(index):
84
+ raise ValueError(f"`{name}` index mismatch with `X`.")
85
+ return arr
86
+
87
+ def fit(
88
+ self,
89
+ X,
90
+ y=None,
91
+ *,
92
+ batch,
93
+ discrete_covariates=None,
94
+ continuous_covariates=None,
95
+ ):
96
+ method = self.method.lower()
97
+ if method not in {"johnson", "fortin", "chen"}:
98
+ raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
99
+ if not isinstance(X, pd.DataFrame):
100
+ X = pd.DataFrame(X)
101
+ idx = X.index
102
+ batch = self._as_series(batch, idx, "batch")
103
+
104
+ disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
105
+ cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
106
+
107
+
108
+ if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
109
+ raise ValueError(
110
+ f"reference_batch={self.reference_batch!r} not present in the data batches "
111
+ f"{list(batch.cat.categories)}"
112
+ )
113
+
114
+ if method == "johnson":
115
+ if disc is not None or cont is not None:
116
+ warnings.warn(
117
+ "Covariates are ignored when using method='johnson'."
118
+ )
119
+ self._fit_johnson(X, batch)
120
+ elif method == "fortin":
121
+ self._fit_fortin(X, batch, disc, cont)
122
+ elif method == "chen":
123
+ self._fit_chen(X, batch, disc, cont)
124
+ return self
125
+
126
+ def _fit_johnson(
127
+ self,
128
+ X: pd.DataFrame,
129
+ batch: pd.Series
130
+ ):
131
+ """
132
+ Johnson et al. (2007) ComBat.
133
+ """
134
+ self._batch_levels = batch.cat.categories
135
+ pooled_var = X.var(axis=0, ddof=1) + self.eps
136
+ grand_mean = X.mean(axis=0)
137
+
138
+ Xs = (X - grand_mean) / np.sqrt(pooled_var)
139
+
140
+ n_per_batch: dict[str, int] = {}
141
+ gamma_hat, delta_hat = [], []
142
+ for lvl in self._batch_levels:
143
+ idx = batch == lvl
144
+ n_b = idx.sum()
145
+ if n_b < 2:
146
+ raise ValueError(f"Batch '{lvl}' has <2 samples.")
147
+ n_per_batch[lvl] = n_b
148
+ xb = Xs.loc[idx]
149
+ gamma_hat.append(xb.mean(axis=0).values)
150
+ delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
151
+ gamma_hat = np.vstack(gamma_hat)
152
+ delta_hat = np.vstack(delta_hat)
153
+
154
+ if self.mean_only:
155
+ gamma_star = self._shrink_gamma(
156
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
157
+ )
158
+ delta_star = np.ones_like(delta_hat)
159
+ else:
160
+ gamma_star, delta_star = self._shrink_gamma_delta(
161
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
162
+ )
163
+
164
+ if self.reference_batch is not None:
165
+ ref_idx = list(self._batch_levels).index(self.reference_batch)
166
+ gamma_ref = gamma_star[ref_idx]
167
+ delta_ref = delta_star[ref_idx]
168
+ gamma_star = gamma_star - gamma_ref
169
+ if not self.mean_only:
170
+ delta_star = delta_star / delta_ref
171
+ self._reference_batch_idx = ref_idx
172
+ else:
173
+ self._reference_batch_idx = None
174
+
175
+ self._grand_mean = grand_mean
176
+ self._pooled_var = pooled_var
177
+ self._gamma_star = gamma_star
178
+ self._delta_star = delta_star
179
+ self._n_per_batch = n_per_batch
180
+
181
+ def _fit_fortin(
182
+ self,
183
+ X: pd.DataFrame,
184
+ batch: pd.Series,
185
+ disc: pd.DataFrame | None,
186
+ cont: pd.DataFrame | None,
187
+ ):
188
+ """
189
+ Fortin et al. (2018) ComBat.
190
+ """
191
+ batch_levels = batch.cat.categories
192
+ n_batch = len(batch_levels)
193
+ n_samples = len(X)
194
+
195
+ batch_dummies = pd.get_dummies(batch, drop_first=False)
196
+ parts = [batch_dummies]
197
+ if disc is not None:
198
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
199
+ if cont is not None:
200
+ parts.append(cont)
201
+ design = pd.concat(parts, axis=1).astype(float).values
202
+ p_design = design.shape[1]
203
+
204
+ X_np = X.values
205
+ beta_hat = la.inv(design.T @ design) @ design.T @ X_np
206
+
207
+ gamma_hat = beta_hat[:n_batch]
208
+ self._beta_hat_nonbatch = beta_hat[n_batch:]
209
+
210
+ n_per_batch = batch.value_counts().sort_index().values
211
+ self._n_per_batch = dict(zip(batch_levels, n_per_batch))
212
+
213
+ grand_mean = (n_per_batch / n_samples) @ gamma_hat
214
+ self._grand_mean = grand_mean
215
+
216
+ resid = X_np - design @ beta_hat
217
+ var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
218
+ self._pooled_var = var_pooled
219
+
220
+ stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
221
+ Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
222
+
223
+ delta_hat = np.empty_like(gamma_hat)
224
+ for i, lvl in enumerate(batch_levels):
225
+ idx = batch == lvl
226
+ delta_hat[i] = Xs[idx].var(axis=0, ddof=1) + self.eps
227
+
228
+ if self.mean_only:
229
+ gamma_star = self._shrink_gamma(
230
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
231
+ )
232
+ delta_star = np.ones_like(delta_hat)
233
+ else:
234
+ gamma_star, delta_star = self._shrink_gamma_delta(
235
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
236
+ )
237
+
238
+ if self.reference_batch is not None:
239
+ ref_idx = list(batch_levels).index(self.reference_batch)
240
+ gamma_ref = gamma_star[ref_idx]
241
+ delta_ref = delta_star[ref_idx]
242
+ gamma_star = gamma_star - gamma_ref
243
+ if not self.mean_only:
244
+ delta_star = delta_star / delta_ref
245
+ self._reference_batch_idx = ref_idx
246
+ else:
247
+ self._reference_batch_idx = None
248
+
249
+ self._batch_levels = batch_levels
250
+ self._gamma_star = gamma_star
251
+ self._delta_star = delta_star
252
+ self._n_batch = n_batch
253
+ self._p_design = p_design
254
+
255
+ def _fit_chen(
256
+ self,
257
+ X: pd.DataFrame,
258
+ batch: pd.Series,
259
+ disc: pd.DataFrame | None,
260
+ cont: pd.DataFrame | None,
261
+ ):
262
+ self._fit_fortin(X, batch, disc, cont)
263
+ X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
264
+ X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
265
+ pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
266
+ cumulative = np.cumsum(pca.explained_variance_ratio_)
267
+ n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
268
+ self._covbat_pca = pca
269
+ self._covbat_n_pc = n_pc
270
+
271
+ scores = pca.transform(X_centered)[:, :n_pc]
272
+ scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i+1}" for i in range(n_pc)])
273
+ self._batch_levels_pc = self._batch_levels
274
+ n_per_batch = self._n_per_batch
275
+
276
+ gamma_hat, delta_hat = [], []
277
+ for lvl in self._batch_levels_pc:
278
+ idx = batch == lvl
279
+ xb = scores_df.loc[idx]
280
+ gamma_hat.append(xb.mean(axis=0).values)
281
+ delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
282
+ gamma_hat = np.vstack(gamma_hat)
283
+ delta_hat = np.vstack(delta_hat)
284
+
285
+ if self.mean_only:
286
+ gamma_star = self._shrink_gamma(
287
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
288
+ )
289
+ delta_star = np.ones_like(delta_hat)
290
+ else:
291
+ gamma_star, delta_star = self._shrink_gamma_delta(
292
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
293
+ )
294
+
295
+ if self.reference_batch is not None:
296
+ ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
297
+ gamma_ref = gamma_star[ref_idx]
298
+ delta_ref = delta_star[ref_idx]
299
+ gamma_star = gamma_star - gamma_ref
300
+ if not self.mean_only:
301
+ delta_star = delta_star / delta_ref
302
+
303
+ self._pc_gamma_star = gamma_star
304
+ self._pc_delta_star = delta_star
305
+
306
+ def _shrink_gamma_delta(
307
+ self,
308
+ gamma_hat: np.ndarray,
309
+ delta_hat: np.ndarray,
310
+ n_per_batch: dict | np.ndarray,
311
+ *,
312
+ parametric: bool,
313
+ max_iter: int = 100,
314
+ tol: float = 1e-4,
315
+ ):
316
+ if parametric:
317
+ gamma_bar = gamma_hat.mean(axis=0)
318
+ t2 = gamma_hat.var(axis=0, ddof=1)
319
+ a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
320
+ b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
321
+
322
+ B, p = gamma_hat.shape
323
+ gamma_star = np.empty_like(gamma_hat)
324
+ delta_star = np.empty_like(delta_hat)
325
+ n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
326
+ for i in range(B):
327
+ n_i = n_vec[i]
328
+ g, d = gamma_hat[i], delta_hat[i]
329
+ gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
330
+ gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
331
+
332
+ a_post = a_prior + n_i / 2.0
333
+ b_post = b_prior + 0.5 * n_i * d
334
+ delta_star[i] = b_post / (a_post - 1)
335
+ return gamma_star, delta_star
336
+
337
+ else:
338
+ B, p = gamma_hat.shape
339
+ n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
340
+ gamma_bar = gamma_hat.mean(axis=0)
341
+ t2 = gamma_hat.var(axis=0, ddof=1)
342
+
343
+ def postmean(g_hat, g_bar, n, d_star, t2_):
344
+ return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
345
+
346
+ def postvar(sum2, n, a, b):
347
+ return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
348
+
349
+ def aprior(delta):
350
+ m, s2 = delta.mean(), delta.var()
351
+ s2 = max(s2, self.eps)
352
+ return (2 * s2 + m ** 2) / s2
353
+
354
+ def bprior(delta):
355
+ m, s2 = delta.mean(), delta.var()
356
+ s2 = max(s2, self.eps)
357
+ return (m * s2 + m ** 3) / s2
358
+
359
+ gamma_star = np.empty_like(gamma_hat)
360
+ delta_star = np.empty_like(delta_hat)
361
+
362
+ for i in range(B):
363
+ n_i = n_vec[i]
364
+ g_hat_i = gamma_hat[i]
365
+ d_hat_i = delta_hat[i]
366
+ a_i = aprior(d_hat_i)
367
+ b_i = bprior(d_hat_i)
368
+
369
+ g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
370
+ for _ in range(max_iter):
371
+ g_prev, d_prev = g_new, d_new
372
+ g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
373
+ sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
374
+ d_new = postvar(sum2, n_i, a_i, b_i)
375
+ if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
376
+ self.mean_only or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
377
+ ):
378
+ break
379
+ gamma_star[i] = g_new
380
+ delta_star[i] = 1.0 if self.mean_only else d_new
381
+ return gamma_star, delta_star
382
+
383
+ def _shrink_gamma(
384
+ self,
385
+ gamma_hat: np.ndarray,
386
+ delta_hat: np.ndarray,
387
+ n_per_batch: dict | np.ndarray,
388
+ *,
389
+ parametric: bool,
390
+ ) -> np.ndarray:
391
+ """Convenience wrapper that returns only γ⋆ (for *mean‑only* mode)."""
392
+ gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
393
+ return gamma
394
+
395
+ def transform(
396
+ self,
397
+ X,
398
+ *,
399
+ batch,
400
+ discrete_covariates=None,
401
+ continuous_covariates=None,
402
+ ):
403
+ check_is_fitted(self, ["_gamma_star"])
404
+ if not isinstance(X, pd.DataFrame):
405
+ X = pd.DataFrame(X)
406
+ idx = X.index
407
+ batch = self._as_series(batch, idx, "batch")
408
+ unseen = set(batch.cat.categories) - set(self._batch_levels)
409
+ if unseen:
410
+ raise ValueError(f"Unseen batch levels during transform: {unseen}")
411
+ disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
412
+ cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
413
+
414
+ method = self.method.lower()
415
+ if method == "johnson":
416
+ return self._transform_johnson(X, batch)
417
+ elif method == "fortin":
418
+ return self._transform_fortin(X, batch, disc, cont)
419
+ elif method == "chen":
420
+ return self._transform_chen(X, batch, disc, cont)
421
+
422
+ def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series):
423
+ pooled = self._pooled_var
424
+ grand = self._grand_mean
425
+
426
+ Xs = (X - grand) / np.sqrt(pooled)
427
+ X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
428
+
429
+ for i, lvl in enumerate(self._batch_levels):
430
+ idx = batch == lvl
431
+ if not idx.any():
432
+ continue
433
+ if self.reference_batch is not None and lvl == self.reference_batch:
434
+ X_adj.loc[idx] = X.loc[idx].values # untouched
435
+ continue
436
+
437
+ g = self._gamma_star[i]
438
+ d = self._delta_star[i]
439
+ if self.mean_only:
440
+ Xb = Xs.loc[idx] - g
441
+ else:
442
+ Xb = (Xs.loc[idx] - g) / np.sqrt(d)
443
+ X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
444
+ return X_adj
445
+
446
+ def _transform_fortin(
447
+ self,
448
+ X: pd.DataFrame,
449
+ batch: pd.Series,
450
+ disc: pd.DataFrame | None,
451
+ cont: pd.DataFrame | None,
452
+ ):
453
+ batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
454
+ parts = [batch_dummies]
455
+ if disc is not None:
456
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
457
+ if cont is not None:
458
+ parts.append(cont)
459
+
460
+ design = pd.concat(parts, axis=1).astype(float).values
461
+
462
+ X_np = X.values
463
+ stand_mean = self._grand_mean + design[:, self._n_batch:] @ self._beta_hat_nonbatch
464
+ Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var)
465
+
466
+ for i, lvl in enumerate(self._batch_levels):
467
+ idx = batch == lvl
468
+ if not idx.any():
469
+ continue
470
+ if self.reference_batch is not None and lvl == self.reference_batch:
471
+ # leave reference samples unchanged
472
+ continue
473
+
474
+ g = self._gamma_star[i]
475
+ d = self._delta_star[i]
476
+ if self.mean_only:
477
+ Xs[idx] = Xs[idx] - g
478
+ else:
479
+ Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
480
+
481
+ X_adj = Xs * np.sqrt(self._pooled_var) + stand_mean
482
+ return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
483
+
484
+ def _transform_chen(
485
+ self,
486
+ X: pd.DataFrame,
487
+ batch: pd.Series,
488
+ disc: pd.DataFrame | None,
489
+ cont: pd.DataFrame | None,
490
+ ):
491
+ X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
492
+ X_centered = X_meanvar_adj - self._covbat_pca.mean_
493
+ scores = self._covbat_pca.transform(X_centered)
494
+ n_pc = self._covbat_n_pc
495
+ scores_adj = scores.copy()
496
+
497
+ for i, lvl in enumerate(self._batch_levels_pc):
498
+ idx = batch == lvl
499
+ if not idx.any():
500
+ continue
501
+ if self.reference_batch is not None and lvl == self.reference_batch:
502
+ continue
503
+ g = self._pc_gamma_star[i]
504
+ d = self._pc_delta_star[i]
505
+ if self.mean_only:
506
+ scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
507
+ else:
508
+ scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
509
+
510
+ X_recon = self._covbat_pca.inverse_transform(scores_adj) + self._covbat_pca.mean_
511
+ return pd.DataFrame(X_recon, index=X.index, columns=X.columns)
512
+
513
+
514
+ class ComBat(BaseEstimator, TransformerMixin):
515
+ """Pipeline‑friendly wrapper around `ComBatModel`.
516
+
517
+ Stores batch (and optional covariates) passed at construction and
518
+ appropriately used them also for separate `fit` and `transform`.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ batch,
524
+ *,
525
+ discrete_covariates=None,
526
+ continuous_covariates=None,
527
+ method: str = "johnson",
528
+ parametric: bool = True,
529
+ mean_only: bool = False,
530
+ reference_batch=None,
531
+ eps: float = 1e-8,
532
+ covbat_cov_thresh: float = 0.9,
533
+ ) -> None:
534
+ self.batch = batch
535
+ self.discrete_covariates = discrete_covariates
536
+ self.continuous_covariates = continuous_covariates
537
+ self.method = method
538
+ self.parametric = parametric
539
+ self.mean_only = mean_only
540
+ self.reference_batch = reference_batch
541
+ self.eps = eps
542
+ self.covbat_cov_thresh = covbat_cov_thresh
543
+ self._model = ComBatModel(
544
+ method=method,
545
+ parametric=parametric,
546
+ mean_only=mean_only,
547
+ reference_batch=reference_batch,
548
+ eps=eps,
549
+ covbat_cov_thresh=covbat_cov_thresh,
550
+ )
551
+
552
+ def fit(self, X, y=None):
553
+ idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
554
+ batch_vec = self._subset(self.batch, idx)
555
+ disc = self._subset(self.discrete_covariates, idx)
556
+ cont = self._subset(self.continuous_covariates, idx)
557
+ self._model.fit(
558
+ X,
559
+ batch=batch_vec,
560
+ discrete_covariates=disc,
561
+ continuous_covariates=cont,
562
+ )
563
+ return self
564
+
565
+ def transform(self, X):
566
+ idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
567
+ batch_vec = self._subset(self.batch, idx)
568
+ disc = self._subset(self.discrete_covariates, idx)
569
+ cont = self._subset(self.continuous_covariates, idx)
570
+ return self._model.transform(
571
+ X,
572
+ batch=batch_vec,
573
+ discrete_covariates=disc,
574
+ continuous_covariates=cont,
575
+ )
576
+
577
+ @staticmethod
578
+ def _subset(obj, idx):
579
+ if obj is None:
580
+ return None
581
+ if isinstance(obj, (pd.Series, pd.DataFrame)):
582
+ return obj.loc[idx]
583
+ else:
584
+ return pd.DataFrame(obj).iloc[idx]
@@ -0,0 +1,132 @@
1
+ Metadata-Version: 2.4
2
+ Name: combatlearn
3
+ Version: 0.1.0
4
+ Summary: Batch-effect harmonisation for machine learning frameworks.
5
+ Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Ettore Rocchi
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Keywords: machine-learning,harmonization,combat,preprocessing
29
+ Classifier: Development Status :: 3 - Alpha
30
+ Classifier: Intended Audience :: Science/Research
31
+ Classifier: License :: OSI Approved :: MIT License
32
+ Classifier: Operating System :: OS Independent
33
+ Classifier: Programming Language :: Python :: 3
34
+ Requires-Python: >=3.9
35
+ Description-Content-Type: text/markdown
36
+ License-File: LICENSE
37
+ Requires-Dist: pandas>=1.3
38
+ Requires-Dist: numpy>=1.21
39
+ Requires-Dist: scikit-learn>=1.2
40
+ Requires-Dist: pytest>=7
41
+ Dynamic: license-file
42
+
43
+ # **combatlearn**
44
+
45
+ <div align="center">
46
+ <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
47
+ </div>
48
+
49
+ **combatlearn** makes the popular _ComBat_ (and _CovBat_) batch-effect correction algorithm available for use into machine learning frameworks. It lets you harmonise high-dimensional data inside a scikit-learn `Pipeline`, so that cross-validation and grid-search automatically take batch structure into account, **without data leakage**.
50
+
51
+ **Three methods**:
52
+ - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
53
+ - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
54
+ - `method="chen"` - CovBat (Chen _et al._, 2022)
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ pip install combatlearn
60
+ ```
61
+
62
+ ## Quick start
63
+
64
+ ```python
65
+ import pandas as pd
66
+ from sklearn.pipeline import Pipeline
67
+ from sklearn.preprocessing import StandardScaler
68
+ from sklearn.linear_model import LogisticRegression
69
+ from combatlearn import ComBat
70
+
71
+ df = pd.read_csv("data.csv", index_col=0)
72
+ X, y = df.drop(columns="y"), df["y"]
73
+
74
+ batch = pd.read_csv("batch.csv", index_col=0, squeeze=True)
75
+ diag = pd.read_csv("diagnosis.csv", index_col=0) # categorical
76
+ age = pd.read_csv("age.csv", index_col=0) # continuous
77
+
78
+ pipe = Pipeline([
79
+ ("combat", ComBat(
80
+ batch=batch,
81
+ discrete_covariates=diag,
82
+ continuous_covariates=age,
83
+ method="fortin", # or "johnson" or "chen"
84
+ parametric=True
85
+ )),
86
+ ("scaler", StandardScaler()),
87
+ ("clf", LogisticRegression())
88
+ ])
89
+
90
+ param_grid = {
91
+ "combat__mean_only": [True, False],
92
+ "clf__C": [0.01, 0.1, 1, 10],
93
+ }
94
+
95
+ grid = GridSearchCV(
96
+ estimator=pipe,
97
+ param_grid=param_grid,
98
+ cv=5,
99
+ scoring="roc_auc",
100
+ )
101
+
102
+ grid.fit(X, y)
103
+
104
+ print("Best parameters:", grid.best_params_)
105
+ print(f"Best CV AUROC: {grid.best_score_:.3f}")
106
+ ```
107
+
108
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
109
+
110
+ ## Contributing
111
+
112
+ Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
113
+
114
+ ## Acknowledgements
115
+
116
+ This project builds on the excellent work of the ComBat family of harmonisation methods.
117
+ We gratefully acknowledge:
118
+
119
+ - [**ComBat**](https://rdrr.io/bioc/sva/man/ComBat.html)
120
+ - [**neuroCombat**](https://github.com/Jfortin1/neuroCombat)
121
+ - [**CovBat**](https://github.com/andy1764/CovBat_Harmonization)
122
+
123
+ ## Citation
124
+
125
+ If **combatlearn** is useful in your research, please cite the original
126
+ papers:
127
+
128
+ - Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
129
+
130
+ - Fortin JP, Cullen N, Sheline YI, Taylor WD, Aselcioglu I, Cook PA, Adams P, Cooper C, Fava M, McGrath PJ, McInnis M, Phillips ML, Trivedi MH, Weissman MM, Shinohara RT. Harmonization of cortical thickness measurements across scanners and sites. _Neuroimage_. 2018 Feb 15;167:104-120. doi: [10.1016/j.neuroimage.2017.11.024](https://doi.org/10.1016/j.neuroimage.2017.11.024)
131
+
132
+ - Chen AA, Beer JC, Tustison NJ, Cook PA, Shinohara RT, Shou H; Alzheimer's Disease Neuroimaging Initiative. Mitigating site effects in covariance for machine learning in neuroimaging data. _Hum Brain Mapp_. 2022 Mar;43(4):1179-1195. doi: [10.1002/hbm.25688](https://doi.org/10.1002/hbm.25688)
@@ -0,0 +1,11 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/combatlearn/__init__.py
5
+ src/combatlearn/combat.py
6
+ src/combatlearn.egg-info/PKG-INFO
7
+ src/combatlearn.egg-info/SOURCES.txt
8
+ src/combatlearn.egg-info/dependency_links.txt
9
+ src/combatlearn.egg-info/requires.txt
10
+ src/combatlearn.egg-info/top_level.txt
11
+ tests/test_combat.py
@@ -0,0 +1,4 @@
1
+ pandas>=1.3
2
+ numpy>=1.21
3
+ scikit-learn>=1.2
4
+ pytest>=7
@@ -0,0 +1 @@
1
+ combatlearn
@@ -0,0 +1,150 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pytest
4
+ from sklearn.pipeline import Pipeline
5
+ from sklearn.base import clone
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.exceptions import NotFittedError
8
+ from combatlearn import ComBatModel, ComBat
9
+ from utils import simulate_data, simulate_covariate_data
10
+
11
+
12
+ def test_transform_without_fit_raises():
13
+ """
14
+ Test that `transform` raises a `NotFittedError` if not fitted.
15
+ """
16
+ X, batch = simulate_data()
17
+ model = ComBatModel()
18
+ with pytest.raises(NotFittedError):
19
+ model.transform(X, batch=batch)
20
+
21
+
22
+ def test_unseen_batch_raises_value_error():
23
+ """
24
+ Test that unseen batch raises a `ValueError`.
25
+ """
26
+ X, batch = simulate_data()
27
+ model = ComBatModel().fit(X, batch=batch)
28
+ new_batch = pd.Series(["Z"] * len(batch), index=batch.index)
29
+ with pytest.raises(ValueError):
30
+ model.transform(X, batch=new_batch)
31
+
32
+
33
+ def test_single_sample_batch_error():
34
+ """
35
+ Test that a single sample batch raises a `ValueError`.
36
+ """
37
+ X, batch = simulate_data()
38
+ batch.iloc[0] = "single"
39
+ with pytest.raises(ValueError):
40
+ ComBatModel().fit(X, batch=batch)
41
+
42
+
43
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
44
+ def test_dtypes_preserved(method):
45
+ """All output columns must remain floating dtypes after correction."""
46
+ if method == "johnson":
47
+ X, batch = simulate_data()
48
+ extra = {}
49
+ else: # fortin or chen
50
+ X, batch, disc, cont = simulate_covariate_data()
51
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
52
+
53
+ X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
54
+ assert all(np.issubdtype(dt, np.floating) for dt in X_corr.dtypes)
55
+
56
+ def test_wrapper_clone_and_pipeline():
57
+ """
58
+ Test `ComBat` wrapper can be cloned and used in a `Pipeline`.
59
+ """
60
+ X, batch = simulate_data()
61
+ wrapper = ComBat(batch=batch, parametric=True)
62
+ pipe = Pipeline([
63
+ ("scaler", StandardScaler()),
64
+ ("combat", wrapper),
65
+ ])
66
+ X_corr = pipe.fit_transform(X)
67
+ pipe_clone: Pipeline = clone(pipe)
68
+ X_corr2 = pipe_clone.fit_transform(X)
69
+ np.testing.assert_allclose(X_corr, X_corr2, rtol=1e-5, atol=1e-5)
70
+
71
+
72
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
73
+ def test_no_nan_or_inf_in_output(method):
74
+ """`ComBat` must not introduce NaN or Inf values, for any backend."""
75
+ if method == "johnson":
76
+ X, batch = simulate_data()
77
+ extra = {}
78
+ else: # fortin or chen
79
+ X, batch, disc, cont = simulate_covariate_data()
80
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
81
+
82
+ X_corr = ComBat(batch=batch, method=method, **extra).fit_transform(X)
83
+ assert not np.isnan(X_corr.values).any()
84
+ assert not np.isinf(X_corr.values).any()
85
+
86
+
87
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
88
+ def test_shape_preserved(method):
89
+ """The (n_samples, n_features) shape must be identical pre- and post-ComBat."""
90
+ if method == "johnson":
91
+ X, batch = simulate_data()
92
+ combat = ComBat(batch=batch, method=method).fit(X)
93
+ elif method in ["fortin", "chen"]:
94
+ X, batch, disc, cont = simulate_covariate_data()
95
+ combat = ComBat(
96
+ batch=batch,
97
+ discrete_covariates=disc,
98
+ continuous_covariates=cont,
99
+ method=method,
100
+ ).fit(X)
101
+
102
+ X_corr = combat.transform(X)
103
+ assert X_corr.shape == X.shape
104
+
105
+
106
+ def test_johnson_print_warning():
107
+ """
108
+ Test that a warning is printed when using the Johnson method.
109
+ """
110
+ X, batch, disc, cont = simulate_covariate_data()
111
+ with pytest.warns(Warning, match="Covariates are ignored when using method='johnson'."):
112
+ _ = ComBat(
113
+ batch=batch,
114
+ discrete_covariates=disc,
115
+ continuous_covariates=cont,
116
+ method="johnson",
117
+ ).fit(X)
118
+
119
+
120
+ @pytest.mark.parametrize("method", ["johnson", "fortin", "chen"])
121
+ def test_reference_batch_samples_unchanged(method):
122
+ """
123
+ Samples belonging to the reference batch must come out *numerically identical*
124
+ (within floating-point jitter) after correction.
125
+ """
126
+ if method == "johnson":
127
+ X, batch = simulate_data()
128
+ extra = {}
129
+ elif method in ["fortin", "chen"]:
130
+ X, batch, disc, cont = simulate_covariate_data()
131
+ extra = dict(discrete_covariates=disc, continuous_covariates=cont)
132
+
133
+ ref_batch = batch.iloc[0]
134
+ combat = ComBat(batch=batch, method=method,
135
+ reference_batch=ref_batch, **extra).fit(X)
136
+ X_corr = combat.transform(X)
137
+
138
+ mask = batch == ref_batch
139
+ np.testing.assert_allclose(X_corr.loc[mask].values,
140
+ X.loc[mask].values,
141
+ rtol=0, atol=1e-10)
142
+
143
+
144
+ def test_reference_batch_missing_raises():
145
+ """
146
+ Asking for a reference batch that doesn't exist should fail.
147
+ """
148
+ X, batch = simulate_data()
149
+ with pytest.raises(ValueError, match="not present"):
150
+ ComBat(batch=batch, reference_batch="DOES_NOT_EXIST").fit(X)