combatlearn 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.0
4
- Summary: Batch-effect harmonisation for machine learning frameworks.
3
+ Version: 0.1.1
4
+ Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
7
7
 
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
31
31
  Classifier: License :: OSI Approved :: MIT License
32
32
  Classifier: Operating System :: OS Independent
33
33
  Classifier: Programming Language :: Python :: 3
34
- Requires-Python: >=3.9
34
+ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
@@ -42,6 +42,12 @@ Dynamic: license-file
42
42
 
43
43
  # **combatlearn**
44
44
 
45
+ [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
46
+ [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
47
+ [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
48
+ [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
49
+ [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
50
+
45
51
  <div align="center">
46
52
  <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
47
53
  </div>
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
107
113
 
108
114
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
109
115
 
116
+ ## `ComBat` parameters
117
+
118
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
119
+
120
+ ### Main Parameters
121
+
122
+ | Parameter | Type | Default | Description |
123
+ | --- | --- | --- | --- |
124
+ | `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
125
+ | `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
126
+ | `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
127
+
128
+ ### Algorithm Options
129
+
130
+ | Parameter | Type | Default | Description |
131
+ | --- | --- | --- | --- |
132
+ | `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
133
+ | `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
134
+ | `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
135
+ | `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
136
+ | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
+ | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
+
110
139
  ## Contributing
111
140
 
112
141
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
113
142
 
143
+ ## Author
144
+
145
+ [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
+
147
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
+
114
149
  ## Acknowledgements
115
150
 
116
151
  This project builds on the excellent work of the ComBat family of harmonisation methods.
@@ -1,5 +1,11 @@
1
1
  # **combatlearn**
2
2
 
3
+ [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
4
+ [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
5
+ [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
6
+ [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
7
+ [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
8
+
3
9
  <div align="center">
4
10
  <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
5
11
  </div>
@@ -65,10 +71,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
65
71
 
66
72
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
67
73
 
74
+ ## `ComBat` parameters
75
+
76
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
77
+
78
+ ### Main Parameters
79
+
80
+ | Parameter | Type | Default | Description |
81
+ | --- | --- | --- | --- |
82
+ | `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
83
+ | `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
84
+ | `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
85
+
86
+ ### Algorithm Options
87
+
88
+ | Parameter | Type | Default | Description |
89
+ | --- | --- | --- | --- |
90
+ | `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
91
+ | `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
92
+ | `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
93
+ | `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
94
+ | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
95
+ | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
96
+
68
97
  ## Contributing
69
98
 
70
99
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
71
100
 
101
+ ## Author
102
+
103
+ [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
104
+
105
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
106
+
72
107
  ## Acknowledgements
73
108
 
74
109
  This project builds on the excellent work of the ComBat family of harmonisation methods.
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "combatlearn"
7
- version = "0.1.0"
8
- description = "Batch-effect harmonisation for machine learning frameworks."
7
+ version = "0.1.1"
8
+ description = "Batch-effect harmonization for machine learning frameworks."
9
9
  authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
10
- requires-python = ">=3.9"
10
+ requires-python = ">=3.10"
11
11
  dependencies = [
12
12
  "pandas>=1.3",
13
13
  "numpy>=1.21",
@@ -1,15 +1,14 @@
1
- __author__ = "Ettore Rocchi"
2
-
3
1
  """ComBat algorithm.
4
2
 
5
3
  `ComBatModel` implements both:
6
- * Johnson et al. (2007) vanilla ComBat (method="johnson")
7
- * Fortin et al. (2018) extension with covariates (method="fortin")
8
- * Chen et al. (2022) CovBat (method="chen")
4
+ * Johnson et al. (2007) vanilla ComBat (method="johnson")
5
+ * Fortin et al. (2018) extension with covariates (method="fortin")
6
+ * Chen et al. (2022) CovBat (method="chen")
9
7
 
10
8
  `ComBat` makes the model compatible with scikit-learn by stashing
11
9
  the batch (and optional covariates) at construction.
12
10
  """
11
+ from __future__ import annotations
13
12
 
14
13
  import numpy as np
15
14
  import numpy.linalg as la
@@ -17,9 +16,15 @@ import pandas as pd
17
16
  from sklearn.base import BaseEstimator, TransformerMixin
18
17
  from sklearn.utils.validation import check_is_fitted
19
18
  from sklearn.decomposition import PCA
20
- from typing import Literal
19
+ from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
20
+ import numpy.typing as npt
21
21
  import warnings
22
22
 
23
+ __author__ = "Ettore Rocchi"
24
+
25
+ ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
26
+ FloatArray = npt.NDArray[np.float64]
27
+
23
28
 
24
29
  class ComBatModel:
25
30
  """ComBat algorithm.
@@ -27,9 +32,9 @@ class ComBatModel:
27
32
  Parameters
28
33
  ----------
29
34
  method : {'johnson', 'fortin', 'chen'}, default='johnson'
30
- * 'johnson' classic ComBat.
31
- * 'fortin' covariateaware ComBat.
32
- * 'chen' CovBat, PCAbased ComBat.
35
+ * 'johnson' - classic ComBat.
36
+ * 'fortin' - covariate-aware ComBat.
37
+ * 'chen' - CovBat, PCA-based ComBat.
33
38
  parametric : bool, default=True
34
39
  Use the parametric empirical Bayes variant.
35
40
  mean_only : bool, default=False
@@ -40,7 +45,7 @@ class ComBatModel:
40
45
  covbat_cov_thresh : float, default=0.9
41
46
  CovBat: cumulative explained variance threshold for PCA.
42
47
  eps : float, default=1e-8
43
- Numerical jitter to avoid divisionbyzero.
48
+ Numerical jitter to avoid division-by-zero.
44
49
  """
45
50
 
46
51
  def __init__(
@@ -49,21 +54,43 @@ class ComBatModel:
49
54
  method: Literal["johnson", "fortin", "chen"] = "johnson",
50
55
  parametric: bool = True,
51
56
  mean_only: bool = False,
52
- reference_batch=None,
57
+ reference_batch: Optional[str] = None,
53
58
  eps: float = 1e-8,
54
59
  covbat_cov_thresh: float = 0.9,
55
60
  ) -> None:
56
- self.method = method
57
- self.parametric = parametric
58
- self.mean_only = bool(mean_only)
59
- self.reference_batch = reference_batch
60
- self.eps = float(eps)
61
- self.covbat_cov_thresh = float(covbat_cov_thresh)
61
+ self.method: str = method
62
+ self.parametric: bool = parametric
63
+ self.mean_only: bool = bool(mean_only)
64
+ self.reference_batch: Optional[str] = reference_batch
65
+ self.eps: float = float(eps)
66
+ self.covbat_cov_thresh: float = float(covbat_cov_thresh)
67
+
68
+ self._batch_levels: pd.Index
69
+ self._grand_mean: pd.Series
70
+ self._pooled_var: pd.Series
71
+ self._gamma_star: FloatArray
72
+ self._delta_star: FloatArray
73
+ self._n_per_batch: Dict[str, int]
74
+ self._reference_batch_idx: Optional[int]
75
+ self._beta_hat_nonbatch: FloatArray
76
+ self._n_batch: int
77
+ self._p_design: int
78
+ self._covbat_pca: PCA
79
+ self._covbat_n_pc: int
80
+ self._batch_levels_pc: pd.Index
81
+ self._pc_gamma_star: FloatArray
82
+ self._pc_delta_star: FloatArray
83
+
62
84
  if not (0.0 < self.covbat_cov_thresh <= 1.0):
63
85
  raise ValueError("covbat_cov_thresh must be in (0, 1].")
64
86
 
65
87
  @staticmethod
66
- def _as_series(arr, index, name):
88
+ def _as_series(
89
+ arr: ArrayLike,
90
+ index: pd.Index,
91
+ name: str
92
+ ) -> pd.Series:
93
+ """Convert array-like to categorical Series with validation."""
67
94
  if isinstance(arr, pd.Series):
68
95
  ser = arr.copy()
69
96
  else:
@@ -73,7 +100,12 @@ class ComBatModel:
73
100
  return ser.astype("category")
74
101
 
75
102
  @staticmethod
76
- def _to_df(arr, index, name):
103
+ def _to_df(
104
+ arr: Optional[ArrayLike],
105
+ index: pd.Index,
106
+ name: str
107
+ ) -> Optional[pd.DataFrame]:
108
+ """Convert array-like to DataFrame."""
77
109
  if arr is None:
78
110
  return None
79
111
  if isinstance(arr, pd.Series):
@@ -86,13 +118,14 @@ class ComBatModel:
86
118
 
87
119
  def fit(
88
120
  self,
89
- X,
90
- y=None,
121
+ X: ArrayLike,
122
+ y: Optional[ArrayLike] = None,
91
123
  *,
92
- batch,
93
- discrete_covariates=None,
94
- continuous_covariates=None,
95
- ):
124
+ batch: ArrayLike,
125
+ discrete_covariates: Optional[ArrayLike] = None,
126
+ continuous_covariates: Optional[ArrayLike] = None,
127
+ ) -> ComBatModel:
128
+ """Fit the ComBat model."""
96
129
  method = self.method.lower()
97
130
  if method not in {"johnson", "fortin", "chen"}:
98
131
  raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
@@ -104,7 +137,6 @@ class ComBatModel:
104
137
  disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
105
138
  cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
106
139
 
107
-
108
140
  if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
109
141
  raise ValueError(
110
142
  f"reference_batch={self.reference_batch!r} not present in the data batches "
@@ -127,38 +159,39 @@ class ComBatModel:
127
159
  self,
128
160
  X: pd.DataFrame,
129
161
  batch: pd.Series
130
- ):
131
- """
132
- Johnson et al. (2007) ComBat.
133
- """
162
+ ) -> None:
163
+ """Johnson et al. (2007) ComBat."""
134
164
  self._batch_levels = batch.cat.categories
135
165
  pooled_var = X.var(axis=0, ddof=1) + self.eps
136
166
  grand_mean = X.mean(axis=0)
137
167
 
138
168
  Xs = (X - grand_mean) / np.sqrt(pooled_var)
139
169
 
140
- n_per_batch: dict[str, int] = {}
141
- gamma_hat, delta_hat = [], []
170
+ n_per_batch: Dict[str, int] = {}
171
+ gamma_hat: list[npt.NDArray[np.float64]] = []
172
+ delta_hat: list[npt.NDArray[np.float64]] = []
173
+
142
174
  for lvl in self._batch_levels:
143
175
  idx = batch == lvl
144
- n_b = idx.sum()
176
+ n_b = int(idx.sum())
145
177
  if n_b < 2:
146
178
  raise ValueError(f"Batch '{lvl}' has <2 samples.")
147
- n_per_batch[lvl] = n_b
179
+ n_per_batch[str(lvl)] = n_b
148
180
  xb = Xs.loc[idx]
149
181
  gamma_hat.append(xb.mean(axis=0).values)
150
182
  delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
151
- gamma_hat = np.vstack(gamma_hat)
152
- delta_hat = np.vstack(delta_hat)
183
+
184
+ gamma_hat_arr = np.vstack(gamma_hat)
185
+ delta_hat_arr = np.vstack(delta_hat)
153
186
 
154
187
  if self.mean_only:
155
188
  gamma_star = self._shrink_gamma(
156
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
189
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
157
190
  )
158
- delta_star = np.ones_like(delta_hat)
191
+ delta_star = np.ones_like(delta_hat_arr)
159
192
  else:
160
193
  gamma_star, delta_star = self._shrink_gamma_delta(
161
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
194
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
162
195
  )
163
196
 
164
197
  if self.reference_batch is not None:
@@ -182,18 +215,16 @@ class ComBatModel:
182
215
  self,
183
216
  X: pd.DataFrame,
184
217
  batch: pd.Series,
185
- disc: pd.DataFrame | None,
186
- cont: pd.DataFrame | None,
187
- ):
188
- """
189
- Fortin et al. (2018) ComBat.
190
- """
218
+ disc: Optional[pd.DataFrame],
219
+ cont: Optional[pd.DataFrame],
220
+ ) -> None:
221
+ """Fortin et al. (2018) ComBat."""
191
222
  batch_levels = batch.cat.categories
192
223
  n_batch = len(batch_levels)
193
224
  n_samples = len(X)
194
225
 
195
226
  batch_dummies = pd.get_dummies(batch, drop_first=False)
196
- parts = [batch_dummies]
227
+ parts: list[pd.DataFrame] = [batch_dummies]
197
228
  if disc is not None:
198
229
  parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
199
230
  if cont is not None:
@@ -202,20 +233,20 @@ class ComBatModel:
202
233
  p_design = design.shape[1]
203
234
 
204
235
  X_np = X.values
205
- beta_hat = la.inv(design.T @ design) @ design.T @ X_np
236
+ beta_hat = la.lstsq(design, X_np, rcond=None)[0]
206
237
 
207
238
  gamma_hat = beta_hat[:n_batch]
208
239
  self._beta_hat_nonbatch = beta_hat[n_batch:]
209
240
 
210
- n_per_batch = batch.value_counts().sort_index().values
211
- self._n_per_batch = dict(zip(batch_levels, n_per_batch))
241
+ n_per_batch_arr = batch.value_counts().sort_index().values
242
+ self._n_per_batch = dict(zip(batch_levels, n_per_batch_arr))
212
243
 
213
- grand_mean = (n_per_batch / n_samples) @ gamma_hat
214
- self._grand_mean = grand_mean
244
+ grand_mean = (n_per_batch_arr / n_samples) @ gamma_hat
245
+ self._grand_mean = pd.Series(grand_mean, index=X.columns)
215
246
 
216
247
  resid = X_np - design @ beta_hat
217
248
  var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
218
- self._pooled_var = var_pooled
249
+ self._pooled_var = pd.Series(var_pooled, index=X.columns)
219
250
 
220
251
  stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
221
252
  Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
@@ -227,12 +258,12 @@ class ComBatModel:
227
258
 
228
259
  if self.mean_only:
229
260
  gamma_star = self._shrink_gamma(
230
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
261
+ gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
231
262
  )
232
263
  delta_star = np.ones_like(delta_hat)
233
264
  else:
234
265
  gamma_star, delta_star = self._shrink_gamma_delta(
235
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
266
+ gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
236
267
  )
237
268
 
238
269
  if self.reference_batch is not None:
@@ -256,9 +287,10 @@ class ComBatModel:
256
287
  self,
257
288
  X: pd.DataFrame,
258
289
  batch: pd.Series,
259
- disc: pd.DataFrame | None,
260
- cont: pd.DataFrame | None,
261
- ):
290
+ disc: Optional[pd.DataFrame],
291
+ cont: Optional[pd.DataFrame],
292
+ ) -> None:
293
+ """Chen et al. (2022) CovBat."""
262
294
  self._fit_fortin(X, batch, disc, cont)
263
295
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
264
296
  X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
@@ -273,23 +305,24 @@ class ComBatModel:
273
305
  self._batch_levels_pc = self._batch_levels
274
306
  n_per_batch = self._n_per_batch
275
307
 
276
- gamma_hat, delta_hat = [], []
308
+ gamma_hat: list[npt.NDArray[np.float64]] = []
309
+ delta_hat: list[npt.NDArray[np.float64]] = []
277
310
  for lvl in self._batch_levels_pc:
278
311
  idx = batch == lvl
279
312
  xb = scores_df.loc[idx]
280
313
  gamma_hat.append(xb.mean(axis=0).values)
281
314
  delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
282
- gamma_hat = np.vstack(gamma_hat)
283
- delta_hat = np.vstack(delta_hat)
315
+ gamma_hat_arr = np.vstack(gamma_hat)
316
+ delta_hat_arr = np.vstack(delta_hat)
284
317
 
285
318
  if self.mean_only:
286
319
  gamma_star = self._shrink_gamma(
287
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
320
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
288
321
  )
289
- delta_star = np.ones_like(delta_hat)
322
+ delta_star = np.ones_like(delta_hat_arr)
290
323
  else:
291
324
  gamma_star, delta_star = self._shrink_gamma_delta(
292
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
325
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
293
326
  )
294
327
 
295
328
  if self.reference_batch is not None:
@@ -305,14 +338,15 @@ class ComBatModel:
305
338
 
306
339
  def _shrink_gamma_delta(
307
340
  self,
308
- gamma_hat: np.ndarray,
309
- delta_hat: np.ndarray,
310
- n_per_batch: dict | np.ndarray,
341
+ gamma_hat: FloatArray,
342
+ delta_hat: FloatArray,
343
+ n_per_batch: Union[Dict[str, int], FloatArray],
311
344
  *,
312
345
  parametric: bool,
313
346
  max_iter: int = 100,
314
347
  tol: float = 1e-4,
315
- ):
348
+ ) -> Tuple[FloatArray, FloatArray]:
349
+ """Empirical Bayes shrinkage estimation."""
316
350
  if parametric:
317
351
  gamma_bar = gamma_hat.mean(axis=0)
318
352
  t2 = gamma_hat.var(axis=0, ddof=1)
@@ -323,6 +357,7 @@ class ComBatModel:
323
357
  gamma_star = np.empty_like(gamma_hat)
324
358
  delta_star = np.empty_like(delta_hat)
325
359
  n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
360
+
326
361
  for i in range(B):
327
362
  n_i = n_vec[i]
328
363
  g, d = gamma_hat[i], delta_hat[i]
@@ -340,18 +375,29 @@ class ComBatModel:
340
375
  gamma_bar = gamma_hat.mean(axis=0)
341
376
  t2 = gamma_hat.var(axis=0, ddof=1)
342
377
 
343
- def postmean(g_hat, g_bar, n, d_star, t2_):
378
+ def postmean(
379
+ g_hat: FloatArray,
380
+ g_bar: FloatArray,
381
+ n: float,
382
+ d_star: FloatArray,
383
+ t2_: FloatArray
384
+ ) -> FloatArray:
344
385
  return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
345
386
 
346
- def postvar(sum2, n, a, b):
387
+ def postvar(
388
+ sum2: FloatArray,
389
+ n: float,
390
+ a: FloatArray,
391
+ b: FloatArray
392
+ ) -> FloatArray:
347
393
  return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
348
394
 
349
- def aprior(delta):
395
+ def aprior(delta: FloatArray) -> FloatArray:
350
396
  m, s2 = delta.mean(), delta.var()
351
397
  s2 = max(s2, self.eps)
352
398
  return (2 * s2 + m ** 2) / s2
353
399
 
354
- def bprior(delta):
400
+ def bprior(delta: FloatArray) -> FloatArray:
355
401
  m, s2 = delta.mean(), delta.var()
356
402
  s2 = max(s2, self.eps)
357
403
  return (m * s2 + m ** 3) / s2
@@ -382,24 +428,25 @@ class ComBatModel:
382
428
 
383
429
  def _shrink_gamma(
384
430
  self,
385
- gamma_hat: np.ndarray,
386
- delta_hat: np.ndarray,
387
- n_per_batch: dict | np.ndarray,
431
+ gamma_hat: FloatArray,
432
+ delta_hat: FloatArray,
433
+ n_per_batch: Union[Dict[str, int], FloatArray],
388
434
  *,
389
435
  parametric: bool,
390
- ) -> np.ndarray:
436
+ ) -> FloatArray:
391
437
  """Convenience wrapper that returns only γ⋆ (for *mean‑only* mode)."""
392
438
  gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
393
439
  return gamma
394
440
 
395
441
  def transform(
396
442
  self,
397
- X,
443
+ X: ArrayLike,
398
444
  *,
399
- batch,
400
- discrete_covariates=None,
401
- continuous_covariates=None,
402
- ):
445
+ batch: ArrayLike,
446
+ discrete_covariates: Optional[ArrayLike] = None,
447
+ continuous_covariates: Optional[ArrayLike] = None,
448
+ ) -> pd.DataFrame:
449
+ """Transform the data using fitted ComBat parameters."""
403
450
  check_is_fitted(self, ["_gamma_star"])
404
451
  if not isinstance(X, pd.DataFrame):
405
452
  X = pd.DataFrame(X)
@@ -418,8 +465,15 @@ class ComBatModel:
418
465
  return self._transform_fortin(X, batch, disc, cont)
419
466
  elif method == "chen":
420
467
  return self._transform_chen(X, batch, disc, cont)
468
+ else:
469
+ raise ValueError(f"Unknown method: {method}")
421
470
 
422
- def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series):
471
+ def _transform_johnson(
472
+ self,
473
+ X: pd.DataFrame,
474
+ batch: pd.Series
475
+ ) -> pd.DataFrame:
476
+ """Johnson transform implementation."""
423
477
  pooled = self._pooled_var
424
478
  grand = self._grand_mean
425
479
 
@@ -447,11 +501,12 @@ class ComBatModel:
447
501
  self,
448
502
  X: pd.DataFrame,
449
503
  batch: pd.Series,
450
- disc: pd.DataFrame | None,
451
- cont: pd.DataFrame | None,
452
- ):
504
+ disc: Optional[pd.DataFrame],
505
+ cont: Optional[pd.DataFrame],
506
+ ) -> pd.DataFrame:
507
+ """Fortin transform implementation."""
453
508
  batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
454
- parts = [batch_dummies]
509
+ parts: list[pd.DataFrame] = [batch_dummies]
455
510
  if disc is not None:
456
511
  parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
457
512
  if cont is not None:
@@ -460,8 +515,8 @@ class ComBatModel:
460
515
  design = pd.concat(parts, axis=1).astype(float).values
461
516
 
462
517
  X_np = X.values
463
- stand_mean = self._grand_mean + design[:, self._n_batch:] @ self._beta_hat_nonbatch
464
- Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var)
518
+ stand_mean = self._grand_mean.values + design[:, self._n_batch:] @ self._beta_hat_nonbatch
519
+ Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var.values)
465
520
 
466
521
  for i, lvl in enumerate(self._batch_levels):
467
522
  idx = batch == lvl
@@ -478,19 +533,20 @@ class ComBatModel:
478
533
  else:
479
534
  Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
480
535
 
481
- X_adj = Xs * np.sqrt(self._pooled_var) + stand_mean
536
+ X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mean
482
537
  return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
483
538
 
484
539
  def _transform_chen(
485
540
  self,
486
541
  X: pd.DataFrame,
487
542
  batch: pd.Series,
488
- disc: pd.DataFrame | None,
489
- cont: pd.DataFrame | None,
490
- ):
543
+ disc: Optional[pd.DataFrame],
544
+ cont: Optional[pd.DataFrame],
545
+ ) -> pd.DataFrame:
546
+ """Chen transform implementation."""
491
547
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
492
548
  X_centered = X_meanvar_adj - self._covbat_pca.mean_
493
- scores = self._covbat_pca.transform(X_centered)
549
+ scores = self._covbat_pca.transform(X_centered.values)
494
550
  n_pc = self._covbat_n_pc
495
551
  scores_adj = scores.copy()
496
552
 
@@ -515,19 +571,19 @@ class ComBat(BaseEstimator, TransformerMixin):
515
571
  """Pipeline‑friendly wrapper around `ComBatModel`.
516
572
 
517
573
  Stores batch (and optional covariates) passed at construction and
518
- appropriately used them also for separate `fit` and `transform`.
574
+ appropriately uses them for separate `fit` and `transform`.
519
575
  """
520
576
 
521
577
  def __init__(
522
578
  self,
523
- batch,
579
+ batch: ArrayLike,
524
580
  *,
525
- discrete_covariates=None,
526
- continuous_covariates=None,
581
+ discrete_covariates: Optional[ArrayLike] = None,
582
+ continuous_covariates: Optional[ArrayLike] = None,
527
583
  method: str = "johnson",
528
584
  parametric: bool = True,
529
585
  mean_only: bool = False,
530
- reference_batch=None,
586
+ reference_batch: Optional[str] = None,
531
587
  eps: float = 1e-8,
532
588
  covbat_cov_thresh: float = 0.9,
533
589
  ) -> None:
@@ -549,8 +605,13 @@ class ComBat(BaseEstimator, TransformerMixin):
549
605
  covbat_cov_thresh=covbat_cov_thresh,
550
606
  )
551
607
 
552
- def fit(self, X, y=None):
553
- idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
608
+ def fit(
609
+ self,
610
+ X: ArrayLike,
611
+ y: Optional[ArrayLike] = None
612
+ ) -> "ComBat":
613
+ """Fit the ComBat model."""
614
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
554
615
  batch_vec = self._subset(self.batch, idx)
555
616
  disc = self._subset(self.discrete_covariates, idx)
556
617
  cont = self._subset(self.continuous_covariates, idx)
@@ -562,8 +623,9 @@ class ComBat(BaseEstimator, TransformerMixin):
562
623
  )
563
624
  return self
564
625
 
565
- def transform(self, X):
566
- idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
626
+ def transform(self, X: ArrayLike) -> pd.DataFrame:
627
+ """Transform the data using fitted ComBat parameters."""
628
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
567
629
  batch_vec = self._subset(self.batch, idx)
568
630
  disc = self._subset(self.discrete_covariates, idx)
569
631
  cont = self._subset(self.continuous_covariates, idx)
@@ -575,10 +637,17 @@ class ComBat(BaseEstimator, TransformerMixin):
575
637
  )
576
638
 
577
639
  @staticmethod
578
- def _subset(obj, idx):
640
+ def _subset(
641
+ obj: Optional[ArrayLike],
642
+ idx: pd.Index
643
+ ) -> Optional[Union[pd.DataFrame, pd.Series]]:
644
+ """Subset array-like object by index."""
579
645
  if obj is None:
580
646
  return None
581
647
  if isinstance(obj, (pd.Series, pd.DataFrame)):
582
648
  return obj.loc[idx]
583
649
  else:
584
- return pd.DataFrame(obj).iloc[idx]
650
+ if isinstance(obj, np.ndarray) and obj.ndim == 1:
651
+ return pd.Series(obj, index=idx)
652
+ else:
653
+ return pd.DataFrame(obj, index=idx)
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.0
4
- Summary: Batch-effect harmonisation for machine learning frameworks.
3
+ Version: 0.1.1
4
+ Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
7
7
 
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
31
31
  Classifier: License :: OSI Approved :: MIT License
32
32
  Classifier: Operating System :: OS Independent
33
33
  Classifier: Programming Language :: Python :: 3
34
- Requires-Python: >=3.9
34
+ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
@@ -42,6 +42,12 @@ Dynamic: license-file
42
42
 
43
43
  # **combatlearn**
44
44
 
45
+ [![Python versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
46
+ [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
47
+ [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
48
+ [![PyPI version](https://badge.fury.io/py/combatlearn.svg)](https://pypi.org/project/combatlearn/)
49
+ [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
50
+
45
51
  <div align="center">
46
52
  <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
47
53
  </div>
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
107
113
 
108
114
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
109
115
 
116
+ ## `ComBat` parameters
117
+
118
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
119
+
120
+ ### Main Parameters
121
+
122
+ | Parameter | Type | Default | Description |
123
+ | --- | --- | --- | --- |
124
+ | `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
125
+ | `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
126
+ | `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
127
+
128
+ ### Algorithm Options
129
+
130
+ | Parameter | Type | Default | Description |
131
+ | --- | --- | --- | --- |
132
+ | `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
133
+ | `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
134
+ | `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
135
+ | `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
136
+ | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
+ | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
+
110
139
  ## Contributing
111
140
 
112
141
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
113
142
 
143
+ ## Author
144
+
145
+ [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
+
147
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
+
114
149
  ## Acknowledgements
115
150
 
116
151
  This project builds on the excellent work of the ComBat family of harmonisation methods.
File without changes
File without changes