combatlearn 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/combat.py +172 -103
- {combatlearn-0.1.0.dist-info → combatlearn-0.1.1.dist-info}/METADATA +38 -3
- combatlearn-0.1.1.dist-info/RECORD +7 -0
- combatlearn-0.1.0.dist-info/RECORD +0 -7
- {combatlearn-0.1.0.dist-info → combatlearn-0.1.1.dist-info}/WHEEL +0 -0
- {combatlearn-0.1.0.dist-info → combatlearn-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-0.1.0.dist-info → combatlearn-0.1.1.dist-info}/top_level.txt +0 -0
combatlearn/combat.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
|
-
__author__ = "Ettore Rocchi"
|
|
2
|
-
|
|
3
1
|
"""ComBat algorithm.
|
|
4
2
|
|
|
5
3
|
`ComBatModel` implements both:
|
|
6
|
-
* Johnson et
|
|
7
|
-
* Fortin et
|
|
8
|
-
* Chen et
|
|
4
|
+
* Johnson et al. (2007) vanilla ComBat (method="johnson")
|
|
5
|
+
* Fortin et al. (2018) extension with covariates (method="fortin")
|
|
6
|
+
* Chen et al. (2022) CovBat (method="chen")
|
|
9
7
|
|
|
10
8
|
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
11
9
|
the batch (and optional covariates) at construction.
|
|
12
10
|
"""
|
|
11
|
+
from __future__ import annotations
|
|
13
12
|
|
|
14
13
|
import numpy as np
|
|
15
14
|
import numpy.linalg as la
|
|
@@ -17,9 +16,15 @@ import pandas as pd
|
|
|
17
16
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
18
17
|
from sklearn.utils.validation import check_is_fitted
|
|
19
18
|
from sklearn.decomposition import PCA
|
|
20
|
-
from typing import Literal
|
|
19
|
+
from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
|
|
20
|
+
import numpy.typing as npt
|
|
21
21
|
import warnings
|
|
22
22
|
|
|
23
|
+
__author__ = "Ettore Rocchi"
|
|
24
|
+
|
|
25
|
+
ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
|
|
26
|
+
FloatArray = npt.NDArray[np.float64]
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
class ComBatModel:
|
|
25
30
|
"""ComBat algorithm.
|
|
@@ -27,9 +32,9 @@ class ComBatModel:
|
|
|
27
32
|
Parameters
|
|
28
33
|
----------
|
|
29
34
|
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
30
|
-
* 'johnson'
|
|
31
|
-
* 'fortin'
|
|
32
|
-
* 'chen'
|
|
35
|
+
* 'johnson' - classic ComBat.
|
|
36
|
+
* 'fortin' - covariate-aware ComBat.
|
|
37
|
+
* 'chen' - CovBat, PCA-based ComBat.
|
|
33
38
|
parametric : bool, default=True
|
|
34
39
|
Use the parametric empirical Bayes variant.
|
|
35
40
|
mean_only : bool, default=False
|
|
@@ -40,7 +45,7 @@ class ComBatModel:
|
|
|
40
45
|
covbat_cov_thresh : float, default=0.9
|
|
41
46
|
CovBat: cumulative explained variance threshold for PCA.
|
|
42
47
|
eps : float, default=1e-8
|
|
43
|
-
Numerical jitter to avoid division
|
|
48
|
+
Numerical jitter to avoid division-by-zero.
|
|
44
49
|
"""
|
|
45
50
|
|
|
46
51
|
def __init__(
|
|
@@ -49,21 +54,43 @@ class ComBatModel:
|
|
|
49
54
|
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
50
55
|
parametric: bool = True,
|
|
51
56
|
mean_only: bool = False,
|
|
52
|
-
reference_batch=None,
|
|
57
|
+
reference_batch: Optional[str] = None,
|
|
53
58
|
eps: float = 1e-8,
|
|
54
59
|
covbat_cov_thresh: float = 0.9,
|
|
55
60
|
) -> None:
|
|
56
|
-
self.method = method
|
|
57
|
-
self.parametric = parametric
|
|
58
|
-
self.mean_only = bool(mean_only)
|
|
59
|
-
self.reference_batch = reference_batch
|
|
60
|
-
self.eps = float(eps)
|
|
61
|
-
self.covbat_cov_thresh = float(covbat_cov_thresh)
|
|
61
|
+
self.method: str = method
|
|
62
|
+
self.parametric: bool = parametric
|
|
63
|
+
self.mean_only: bool = bool(mean_only)
|
|
64
|
+
self.reference_batch: Optional[str] = reference_batch
|
|
65
|
+
self.eps: float = float(eps)
|
|
66
|
+
self.covbat_cov_thresh: float = float(covbat_cov_thresh)
|
|
67
|
+
|
|
68
|
+
self._batch_levels: pd.Index
|
|
69
|
+
self._grand_mean: pd.Series
|
|
70
|
+
self._pooled_var: pd.Series
|
|
71
|
+
self._gamma_star: FloatArray
|
|
72
|
+
self._delta_star: FloatArray
|
|
73
|
+
self._n_per_batch: Dict[str, int]
|
|
74
|
+
self._reference_batch_idx: Optional[int]
|
|
75
|
+
self._beta_hat_nonbatch: FloatArray
|
|
76
|
+
self._n_batch: int
|
|
77
|
+
self._p_design: int
|
|
78
|
+
self._covbat_pca: PCA
|
|
79
|
+
self._covbat_n_pc: int
|
|
80
|
+
self._batch_levels_pc: pd.Index
|
|
81
|
+
self._pc_gamma_star: FloatArray
|
|
82
|
+
self._pc_delta_star: FloatArray
|
|
83
|
+
|
|
62
84
|
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
63
85
|
raise ValueError("covbat_cov_thresh must be in (0, 1].")
|
|
64
86
|
|
|
65
87
|
@staticmethod
|
|
66
|
-
def _as_series(
|
|
88
|
+
def _as_series(
|
|
89
|
+
arr: ArrayLike,
|
|
90
|
+
index: pd.Index,
|
|
91
|
+
name: str
|
|
92
|
+
) -> pd.Series:
|
|
93
|
+
"""Convert array-like to categorical Series with validation."""
|
|
67
94
|
if isinstance(arr, pd.Series):
|
|
68
95
|
ser = arr.copy()
|
|
69
96
|
else:
|
|
@@ -73,7 +100,12 @@ class ComBatModel:
|
|
|
73
100
|
return ser.astype("category")
|
|
74
101
|
|
|
75
102
|
@staticmethod
|
|
76
|
-
def _to_df(
|
|
103
|
+
def _to_df(
|
|
104
|
+
arr: Optional[ArrayLike],
|
|
105
|
+
index: pd.Index,
|
|
106
|
+
name: str
|
|
107
|
+
) -> Optional[pd.DataFrame]:
|
|
108
|
+
"""Convert array-like to DataFrame."""
|
|
77
109
|
if arr is None:
|
|
78
110
|
return None
|
|
79
111
|
if isinstance(arr, pd.Series):
|
|
@@ -86,13 +118,14 @@ class ComBatModel:
|
|
|
86
118
|
|
|
87
119
|
def fit(
|
|
88
120
|
self,
|
|
89
|
-
X,
|
|
90
|
-
y=None,
|
|
121
|
+
X: ArrayLike,
|
|
122
|
+
y: Optional[ArrayLike] = None,
|
|
91
123
|
*,
|
|
92
|
-
batch,
|
|
93
|
-
discrete_covariates=None,
|
|
94
|
-
continuous_covariates=None,
|
|
95
|
-
):
|
|
124
|
+
batch: ArrayLike,
|
|
125
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
126
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
127
|
+
) -> ComBatModel:
|
|
128
|
+
"""Fit the ComBat model."""
|
|
96
129
|
method = self.method.lower()
|
|
97
130
|
if method not in {"johnson", "fortin", "chen"}:
|
|
98
131
|
raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
|
|
@@ -104,7 +137,6 @@ class ComBatModel:
|
|
|
104
137
|
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
105
138
|
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
106
139
|
|
|
107
|
-
|
|
108
140
|
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
109
141
|
raise ValueError(
|
|
110
142
|
f"reference_batch={self.reference_batch!r} not present in the data batches "
|
|
@@ -127,38 +159,39 @@ class ComBatModel:
|
|
|
127
159
|
self,
|
|
128
160
|
X: pd.DataFrame,
|
|
129
161
|
batch: pd.Series
|
|
130
|
-
):
|
|
131
|
-
"""
|
|
132
|
-
Johnson et al. (2007) ComBat.
|
|
133
|
-
"""
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Johnson et al. (2007) ComBat."""
|
|
134
164
|
self._batch_levels = batch.cat.categories
|
|
135
165
|
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
136
166
|
grand_mean = X.mean(axis=0)
|
|
137
167
|
|
|
138
168
|
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
139
169
|
|
|
140
|
-
n_per_batch:
|
|
141
|
-
gamma_hat
|
|
170
|
+
n_per_batch: Dict[str, int] = {}
|
|
171
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
172
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
173
|
+
|
|
142
174
|
for lvl in self._batch_levels:
|
|
143
175
|
idx = batch == lvl
|
|
144
|
-
n_b = idx.sum()
|
|
176
|
+
n_b = int(idx.sum())
|
|
145
177
|
if n_b < 2:
|
|
146
178
|
raise ValueError(f"Batch '{lvl}' has <2 samples.")
|
|
147
|
-
n_per_batch[lvl] = n_b
|
|
179
|
+
n_per_batch[str(lvl)] = n_b
|
|
148
180
|
xb = Xs.loc[idx]
|
|
149
181
|
gamma_hat.append(xb.mean(axis=0).values)
|
|
150
182
|
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
151
|
-
|
|
152
|
-
|
|
183
|
+
|
|
184
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
185
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
153
186
|
|
|
154
187
|
if self.mean_only:
|
|
155
188
|
gamma_star = self._shrink_gamma(
|
|
156
|
-
|
|
189
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
157
190
|
)
|
|
158
|
-
delta_star = np.ones_like(
|
|
191
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
159
192
|
else:
|
|
160
193
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
161
|
-
|
|
194
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
162
195
|
)
|
|
163
196
|
|
|
164
197
|
if self.reference_batch is not None:
|
|
@@ -182,18 +215,16 @@ class ComBatModel:
|
|
|
182
215
|
self,
|
|
183
216
|
X: pd.DataFrame,
|
|
184
217
|
batch: pd.Series,
|
|
185
|
-
disc: pd.DataFrame
|
|
186
|
-
cont: pd.DataFrame
|
|
187
|
-
):
|
|
188
|
-
"""
|
|
189
|
-
Fortin et al. (2018) ComBat.
|
|
190
|
-
"""
|
|
218
|
+
disc: Optional[pd.DataFrame],
|
|
219
|
+
cont: Optional[pd.DataFrame],
|
|
220
|
+
) -> None:
|
|
221
|
+
"""Fortin et al. (2018) ComBat."""
|
|
191
222
|
batch_levels = batch.cat.categories
|
|
192
223
|
n_batch = len(batch_levels)
|
|
193
224
|
n_samples = len(X)
|
|
194
225
|
|
|
195
226
|
batch_dummies = pd.get_dummies(batch, drop_first=False)
|
|
196
|
-
parts = [batch_dummies]
|
|
227
|
+
parts: list[pd.DataFrame] = [batch_dummies]
|
|
197
228
|
if disc is not None:
|
|
198
229
|
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
|
|
199
230
|
if cont is not None:
|
|
@@ -202,20 +233,20 @@ class ComBatModel:
|
|
|
202
233
|
p_design = design.shape[1]
|
|
203
234
|
|
|
204
235
|
X_np = X.values
|
|
205
|
-
beta_hat = la.
|
|
236
|
+
beta_hat = la.lstsq(design, X_np, rcond=None)[0]
|
|
206
237
|
|
|
207
238
|
gamma_hat = beta_hat[:n_batch]
|
|
208
239
|
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
209
240
|
|
|
210
|
-
|
|
211
|
-
self._n_per_batch = dict(zip(batch_levels,
|
|
241
|
+
n_per_batch_arr = batch.value_counts().sort_index().values
|
|
242
|
+
self._n_per_batch = dict(zip(batch_levels, n_per_batch_arr))
|
|
212
243
|
|
|
213
|
-
grand_mean = (
|
|
214
|
-
self._grand_mean = grand_mean
|
|
244
|
+
grand_mean = (n_per_batch_arr / n_samples) @ gamma_hat
|
|
245
|
+
self._grand_mean = pd.Series(grand_mean, index=X.columns)
|
|
215
246
|
|
|
216
247
|
resid = X_np - design @ beta_hat
|
|
217
248
|
var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
|
|
218
|
-
self._pooled_var = var_pooled
|
|
249
|
+
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
219
250
|
|
|
220
251
|
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
221
252
|
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
@@ -227,12 +258,12 @@ class ComBatModel:
|
|
|
227
258
|
|
|
228
259
|
if self.mean_only:
|
|
229
260
|
gamma_star = self._shrink_gamma(
|
|
230
|
-
gamma_hat, delta_hat,
|
|
261
|
+
gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
|
|
231
262
|
)
|
|
232
263
|
delta_star = np.ones_like(delta_hat)
|
|
233
264
|
else:
|
|
234
265
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
235
|
-
gamma_hat, delta_hat,
|
|
266
|
+
gamma_hat, delta_hat, n_per_batch_arr, parametric=self.parametric
|
|
236
267
|
)
|
|
237
268
|
|
|
238
269
|
if self.reference_batch is not None:
|
|
@@ -256,9 +287,10 @@ class ComBatModel:
|
|
|
256
287
|
self,
|
|
257
288
|
X: pd.DataFrame,
|
|
258
289
|
batch: pd.Series,
|
|
259
|
-
disc: pd.DataFrame
|
|
260
|
-
cont: pd.DataFrame
|
|
261
|
-
):
|
|
290
|
+
disc: Optional[pd.DataFrame],
|
|
291
|
+
cont: Optional[pd.DataFrame],
|
|
292
|
+
) -> None:
|
|
293
|
+
"""Chen et al. (2022) CovBat."""
|
|
262
294
|
self._fit_fortin(X, batch, disc, cont)
|
|
263
295
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
264
296
|
X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
|
|
@@ -273,23 +305,24 @@ class ComBatModel:
|
|
|
273
305
|
self._batch_levels_pc = self._batch_levels
|
|
274
306
|
n_per_batch = self._n_per_batch
|
|
275
307
|
|
|
276
|
-
gamma_hat
|
|
308
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
309
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
277
310
|
for lvl in self._batch_levels_pc:
|
|
278
311
|
idx = batch == lvl
|
|
279
312
|
xb = scores_df.loc[idx]
|
|
280
313
|
gamma_hat.append(xb.mean(axis=0).values)
|
|
281
314
|
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
282
|
-
|
|
283
|
-
|
|
315
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
316
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
284
317
|
|
|
285
318
|
if self.mean_only:
|
|
286
319
|
gamma_star = self._shrink_gamma(
|
|
287
|
-
|
|
320
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
288
321
|
)
|
|
289
|
-
delta_star = np.ones_like(
|
|
322
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
290
323
|
else:
|
|
291
324
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
292
|
-
|
|
325
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
293
326
|
)
|
|
294
327
|
|
|
295
328
|
if self.reference_batch is not None:
|
|
@@ -305,14 +338,15 @@ class ComBatModel:
|
|
|
305
338
|
|
|
306
339
|
def _shrink_gamma_delta(
|
|
307
340
|
self,
|
|
308
|
-
gamma_hat:
|
|
309
|
-
delta_hat:
|
|
310
|
-
n_per_batch:
|
|
341
|
+
gamma_hat: FloatArray,
|
|
342
|
+
delta_hat: FloatArray,
|
|
343
|
+
n_per_batch: Union[Dict[str, int], FloatArray],
|
|
311
344
|
*,
|
|
312
345
|
parametric: bool,
|
|
313
346
|
max_iter: int = 100,
|
|
314
347
|
tol: float = 1e-4,
|
|
315
|
-
):
|
|
348
|
+
) -> Tuple[FloatArray, FloatArray]:
|
|
349
|
+
"""Empirical Bayes shrinkage estimation."""
|
|
316
350
|
if parametric:
|
|
317
351
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
318
352
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
@@ -323,6 +357,7 @@ class ComBatModel:
|
|
|
323
357
|
gamma_star = np.empty_like(gamma_hat)
|
|
324
358
|
delta_star = np.empty_like(delta_hat)
|
|
325
359
|
n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
|
|
360
|
+
|
|
326
361
|
for i in range(B):
|
|
327
362
|
n_i = n_vec[i]
|
|
328
363
|
g, d = gamma_hat[i], delta_hat[i]
|
|
@@ -340,18 +375,29 @@ class ComBatModel:
|
|
|
340
375
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
341
376
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
342
377
|
|
|
343
|
-
def postmean(
|
|
378
|
+
def postmean(
|
|
379
|
+
g_hat: FloatArray,
|
|
380
|
+
g_bar: FloatArray,
|
|
381
|
+
n: float,
|
|
382
|
+
d_star: FloatArray,
|
|
383
|
+
t2_: FloatArray
|
|
384
|
+
) -> FloatArray:
|
|
344
385
|
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
345
386
|
|
|
346
|
-
def postvar(
|
|
387
|
+
def postvar(
|
|
388
|
+
sum2: FloatArray,
|
|
389
|
+
n: float,
|
|
390
|
+
a: FloatArray,
|
|
391
|
+
b: FloatArray
|
|
392
|
+
) -> FloatArray:
|
|
347
393
|
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
348
394
|
|
|
349
|
-
def aprior(delta):
|
|
395
|
+
def aprior(delta: FloatArray) -> FloatArray:
|
|
350
396
|
m, s2 = delta.mean(), delta.var()
|
|
351
397
|
s2 = max(s2, self.eps)
|
|
352
398
|
return (2 * s2 + m ** 2) / s2
|
|
353
399
|
|
|
354
|
-
def bprior(delta):
|
|
400
|
+
def bprior(delta: FloatArray) -> FloatArray:
|
|
355
401
|
m, s2 = delta.mean(), delta.var()
|
|
356
402
|
s2 = max(s2, self.eps)
|
|
357
403
|
return (m * s2 + m ** 3) / s2
|
|
@@ -382,24 +428,25 @@ class ComBatModel:
|
|
|
382
428
|
|
|
383
429
|
def _shrink_gamma(
|
|
384
430
|
self,
|
|
385
|
-
gamma_hat:
|
|
386
|
-
delta_hat:
|
|
387
|
-
n_per_batch:
|
|
431
|
+
gamma_hat: FloatArray,
|
|
432
|
+
delta_hat: FloatArray,
|
|
433
|
+
n_per_batch: Union[Dict[str, int], FloatArray],
|
|
388
434
|
*,
|
|
389
435
|
parametric: bool,
|
|
390
|
-
) ->
|
|
436
|
+
) -> FloatArray:
|
|
391
437
|
"""Convenience wrapper that returns only γ⋆ (for *mean‑only* mode)."""
|
|
392
438
|
gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
|
|
393
439
|
return gamma
|
|
394
440
|
|
|
395
441
|
def transform(
|
|
396
442
|
self,
|
|
397
|
-
X,
|
|
443
|
+
X: ArrayLike,
|
|
398
444
|
*,
|
|
399
|
-
batch,
|
|
400
|
-
discrete_covariates=None,
|
|
401
|
-
continuous_covariates=None,
|
|
402
|
-
):
|
|
445
|
+
batch: ArrayLike,
|
|
446
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
447
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
448
|
+
) -> pd.DataFrame:
|
|
449
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
403
450
|
check_is_fitted(self, ["_gamma_star"])
|
|
404
451
|
if not isinstance(X, pd.DataFrame):
|
|
405
452
|
X = pd.DataFrame(X)
|
|
@@ -418,8 +465,15 @@ class ComBatModel:
|
|
|
418
465
|
return self._transform_fortin(X, batch, disc, cont)
|
|
419
466
|
elif method == "chen":
|
|
420
467
|
return self._transform_chen(X, batch, disc, cont)
|
|
468
|
+
else:
|
|
469
|
+
raise ValueError(f"Unknown method: {method}")
|
|
421
470
|
|
|
422
|
-
def _transform_johnson(
|
|
471
|
+
def _transform_johnson(
|
|
472
|
+
self,
|
|
473
|
+
X: pd.DataFrame,
|
|
474
|
+
batch: pd.Series
|
|
475
|
+
) -> pd.DataFrame:
|
|
476
|
+
"""Johnson transform implementation."""
|
|
423
477
|
pooled = self._pooled_var
|
|
424
478
|
grand = self._grand_mean
|
|
425
479
|
|
|
@@ -447,11 +501,12 @@ class ComBatModel:
|
|
|
447
501
|
self,
|
|
448
502
|
X: pd.DataFrame,
|
|
449
503
|
batch: pd.Series,
|
|
450
|
-
disc: pd.DataFrame
|
|
451
|
-
cont: pd.DataFrame
|
|
452
|
-
):
|
|
504
|
+
disc: Optional[pd.DataFrame],
|
|
505
|
+
cont: Optional[pd.DataFrame],
|
|
506
|
+
) -> pd.DataFrame:
|
|
507
|
+
"""Fortin transform implementation."""
|
|
453
508
|
batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
|
|
454
|
-
parts = [batch_dummies]
|
|
509
|
+
parts: list[pd.DataFrame] = [batch_dummies]
|
|
455
510
|
if disc is not None:
|
|
456
511
|
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
|
|
457
512
|
if cont is not None:
|
|
@@ -460,8 +515,8 @@ class ComBatModel:
|
|
|
460
515
|
design = pd.concat(parts, axis=1).astype(float).values
|
|
461
516
|
|
|
462
517
|
X_np = X.values
|
|
463
|
-
stand_mean = self._grand_mean + design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
464
|
-
Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var)
|
|
518
|
+
stand_mean = self._grand_mean.values + design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
519
|
+
Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var.values)
|
|
465
520
|
|
|
466
521
|
for i, lvl in enumerate(self._batch_levels):
|
|
467
522
|
idx = batch == lvl
|
|
@@ -478,19 +533,20 @@ class ComBatModel:
|
|
|
478
533
|
else:
|
|
479
534
|
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
480
535
|
|
|
481
|
-
X_adj = Xs * np.sqrt(self._pooled_var) + stand_mean
|
|
536
|
+
X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mean
|
|
482
537
|
return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
|
|
483
538
|
|
|
484
539
|
def _transform_chen(
|
|
485
540
|
self,
|
|
486
541
|
X: pd.DataFrame,
|
|
487
542
|
batch: pd.Series,
|
|
488
|
-
disc: pd.DataFrame
|
|
489
|
-
cont: pd.DataFrame
|
|
490
|
-
):
|
|
543
|
+
disc: Optional[pd.DataFrame],
|
|
544
|
+
cont: Optional[pd.DataFrame],
|
|
545
|
+
) -> pd.DataFrame:
|
|
546
|
+
"""Chen transform implementation."""
|
|
491
547
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
492
548
|
X_centered = X_meanvar_adj - self._covbat_pca.mean_
|
|
493
|
-
scores = self._covbat_pca.transform(X_centered)
|
|
549
|
+
scores = self._covbat_pca.transform(X_centered.values)
|
|
494
550
|
n_pc = self._covbat_n_pc
|
|
495
551
|
scores_adj = scores.copy()
|
|
496
552
|
|
|
@@ -515,19 +571,19 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
515
571
|
"""Pipeline‑friendly wrapper around `ComBatModel`.
|
|
516
572
|
|
|
517
573
|
Stores batch (and optional covariates) passed at construction and
|
|
518
|
-
appropriately
|
|
574
|
+
appropriately uses them for separate `fit` and `transform`.
|
|
519
575
|
"""
|
|
520
576
|
|
|
521
577
|
def __init__(
|
|
522
578
|
self,
|
|
523
|
-
batch,
|
|
579
|
+
batch: ArrayLike,
|
|
524
580
|
*,
|
|
525
|
-
discrete_covariates=None,
|
|
526
|
-
continuous_covariates=None,
|
|
581
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
582
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
527
583
|
method: str = "johnson",
|
|
528
584
|
parametric: bool = True,
|
|
529
585
|
mean_only: bool = False,
|
|
530
|
-
reference_batch=None,
|
|
586
|
+
reference_batch: Optional[str] = None,
|
|
531
587
|
eps: float = 1e-8,
|
|
532
588
|
covbat_cov_thresh: float = 0.9,
|
|
533
589
|
) -> None:
|
|
@@ -549,8 +605,13 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
549
605
|
covbat_cov_thresh=covbat_cov_thresh,
|
|
550
606
|
)
|
|
551
607
|
|
|
552
|
-
def fit(
|
|
553
|
-
|
|
608
|
+
def fit(
|
|
609
|
+
self,
|
|
610
|
+
X: ArrayLike,
|
|
611
|
+
y: Optional[ArrayLike] = None
|
|
612
|
+
) -> "ComBat":
|
|
613
|
+
"""Fit the ComBat model."""
|
|
614
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
554
615
|
batch_vec = self._subset(self.batch, idx)
|
|
555
616
|
disc = self._subset(self.discrete_covariates, idx)
|
|
556
617
|
cont = self._subset(self.continuous_covariates, idx)
|
|
@@ -562,8 +623,9 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
562
623
|
)
|
|
563
624
|
return self
|
|
564
625
|
|
|
565
|
-
def transform(self, X):
|
|
566
|
-
|
|
626
|
+
def transform(self, X: ArrayLike) -> pd.DataFrame:
|
|
627
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
628
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
567
629
|
batch_vec = self._subset(self.batch, idx)
|
|
568
630
|
disc = self._subset(self.discrete_covariates, idx)
|
|
569
631
|
cont = self._subset(self.continuous_covariates, idx)
|
|
@@ -575,10 +637,17 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
575
637
|
)
|
|
576
638
|
|
|
577
639
|
@staticmethod
|
|
578
|
-
def _subset(
|
|
640
|
+
def _subset(
|
|
641
|
+
obj: Optional[ArrayLike],
|
|
642
|
+
idx: pd.Index
|
|
643
|
+
) -> Optional[Union[pd.DataFrame, pd.Series]]:
|
|
644
|
+
"""Subset array-like object by index."""
|
|
579
645
|
if obj is None:
|
|
580
646
|
return None
|
|
581
647
|
if isinstance(obj, (pd.Series, pd.DataFrame)):
|
|
582
648
|
return obj.loc[idx]
|
|
583
649
|
else:
|
|
584
|
-
|
|
650
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 1:
|
|
651
|
+
return pd.Series(obj, index=idx)
|
|
652
|
+
else:
|
|
653
|
+
return pd.DataFrame(obj, index=idx)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: Batch-effect
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT License
|
|
7
7
|
|
|
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
31
31
|
Classifier: License :: OSI Approved :: MIT License
|
|
32
32
|
Classifier: Operating System :: OS Independent
|
|
33
33
|
Classifier: Programming Language :: Python :: 3
|
|
34
|
-
Requires-Python: >=3.
|
|
34
|
+
Requires-Python: >=3.10
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: pandas>=1.3
|
|
@@ -42,6 +42,12 @@ Dynamic: license-file
|
|
|
42
42
|
|
|
43
43
|
# **combatlearn**
|
|
44
44
|
|
|
45
|
+
[](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue?logo=python)
|
|
46
|
+
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
47
|
+
[](https://pepy.tech/projects/combatlearn)
|
|
48
|
+
[](https://pypi.org/project/combatlearn/)
|
|
49
|
+
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
50
|
+
|
|
45
51
|
<div align="center">
|
|
46
52
|
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
47
53
|
</div>
|
|
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
107
113
|
|
|
108
114
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
109
115
|
|
|
116
|
+
## `ComBat` parameters
|
|
117
|
+
|
|
118
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
119
|
+
|
|
120
|
+
### Main Parameters
|
|
121
|
+
|
|
122
|
+
| Parameter | Type | Default | Description |
|
|
123
|
+
| --- | --- | --- | --- |
|
|
124
|
+
| `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
|
|
125
|
+
| `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
|
|
126
|
+
| `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
|
|
127
|
+
|
|
128
|
+
### Algorithm Options
|
|
129
|
+
|
|
130
|
+
| Parameter | Type | Default | Description |
|
|
131
|
+
| --- | --- | --- | --- |
|
|
132
|
+
| `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
|
|
133
|
+
| `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
|
|
134
|
+
| `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
|
|
135
|
+
| `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
|
|
136
|
+
| `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
|
|
137
|
+
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
138
|
+
|
|
110
139
|
## Contributing
|
|
111
140
|
|
|
112
141
|
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
113
142
|
|
|
143
|
+
## Author
|
|
144
|
+
|
|
145
|
+
[**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
|
|
146
|
+
|
|
147
|
+
[Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
|
|
148
|
+
|
|
114
149
|
## Acknowledgements
|
|
115
150
|
|
|
116
151
|
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
|
|
2
|
+
combatlearn/combat.py,sha256=3tWZDCtXcJtrv8QECD0OFGTpdo-zNRW1YflMXi7rU0c,24022
|
|
3
|
+
combatlearn-0.1.1.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
+
combatlearn-0.1.1.dist-info/METADATA,sha256=82rLKgrrISPGM0LPer9tlyfRR6dRyrAFTwNzdkKQDa8,8286
|
|
5
|
+
combatlearn-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
combatlearn-0.1.1.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
+
combatlearn-0.1.1.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
|
|
2
|
-
combatlearn/combat.py,sha256=Ro9ap_bpWFYoxHAvHpRjdoMl63TOVBglNo7Er6digKg,20914
|
|
3
|
-
combatlearn-0.1.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
-
combatlearn-0.1.0.dist-info/METADATA,sha256=2-z46YJJ4SeoGhqyMiueufOZ2xm8TINco2pyleEEWJ0,5376
|
|
5
|
-
combatlearn-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
combatlearn-0.1.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
-
combatlearn-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|