combatlearn 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/combat.py CHANGED
@@ -1,15 +1,14 @@
1
- __author__ = "Ettore Rocchi"
2
-
3
1
  """ComBat algorithm.
4
2
 
5
3
  `ComBatModel` implements both:
6
- * Johnson et al. (2007) vanilla ComBat (method="johnson")
7
- * Fortin et al. (2018) extension with covariates (method="fortin")
8
- * Chen et al. (2022) CovBat (method="chen")
4
+ * Johnson et al. (2007) vanilla ComBat (method="johnson")
5
+ * Fortin et al. (2018) extension with covariates (method="fortin")
6
+ * Chen et al. (2022) CovBat (method="chen")
9
7
 
10
8
  `ComBat` makes the model compatible with scikit-learn by stashing
11
9
  the batch (and optional covariates) at construction.
12
10
  """
11
+ from __future__ import annotations
13
12
 
14
13
  import numpy as np
15
14
  import numpy.linalg as la
@@ -17,9 +16,15 @@ import pandas as pd
17
16
  from sklearn.base import BaseEstimator, TransformerMixin
18
17
  from sklearn.utils.validation import check_is_fitted
19
18
  from sklearn.decomposition import PCA
20
- from typing import Literal
19
+ from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
20
+ import numpy.typing as npt
21
21
  import warnings
22
22
 
23
+ __author__ = "Ettore Rocchi"
24
+
25
+ ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
26
+ FloatArray = npt.NDArray[np.float64]
27
+
23
28
 
24
29
  class ComBatModel:
25
30
  """ComBat algorithm.
@@ -27,9 +32,9 @@ class ComBatModel:
27
32
  Parameters
28
33
  ----------
29
34
  method : {'johnson', 'fortin', 'chen'}, default='johnson'
30
- * 'johnson' classic ComBat.
31
- * 'fortin' covariateaware ComBat.
32
- * 'chen' CovBat, PCAbased ComBat.
35
+ * 'johnson' - classic ComBat.
36
+ * 'fortin' - covariate-aware ComBat.
37
+ * 'chen' - CovBat, PCA-based ComBat.
33
38
  parametric : bool, default=True
34
39
  Use the parametric empirical Bayes variant.
35
40
  mean_only : bool, default=False
@@ -40,7 +45,7 @@ class ComBatModel:
40
45
  covbat_cov_thresh : float, default=0.9
41
46
  CovBat: cumulative explained variance threshold for PCA.
42
47
  eps : float, default=1e-8
43
- Numerical jitter to avoid divisionbyzero.
48
+ Numerical jitter to avoid division-by-zero.
44
49
  """
45
50
 
46
51
  def __init__(
@@ -49,21 +54,43 @@ class ComBatModel:
49
54
  method: Literal["johnson", "fortin", "chen"] = "johnson",
50
55
  parametric: bool = True,
51
56
  mean_only: bool = False,
52
- reference_batch=None,
57
+ reference_batch: Optional[str] = None,
53
58
  eps: float = 1e-8,
54
59
  covbat_cov_thresh: float = 0.9,
55
60
  ) -> None:
56
- self.method = method
57
- self.parametric = parametric
58
- self.mean_only = bool(mean_only)
59
- self.reference_batch = reference_batch
60
- self.eps = float(eps)
61
- self.covbat_cov_thresh = float(covbat_cov_thresh)
61
+ self.method: str = method
62
+ self.parametric: bool = parametric
63
+ self.mean_only: bool = bool(mean_only)
64
+ self.reference_batch: Optional[str] = reference_batch
65
+ self.eps: float = float(eps)
66
+ self.covbat_cov_thresh: float = float(covbat_cov_thresh)
67
+
68
+ self._batch_levels: pd.Index
69
+ self._grand_mean: pd.Series
70
+ self._pooled_var: pd.Series
71
+ self._gamma_star: FloatArray
72
+ self._delta_star: FloatArray
73
+ self._n_per_batch: Dict[str, int]
74
+ self._reference_batch_idx: Optional[int]
75
+ self._beta_hat_nonbatch: FloatArray
76
+ self._n_batch: int
77
+ self._p_design: int
78
+ self._covbat_pca: PCA
79
+ self._covbat_n_pc: int
80
+ self._batch_levels_pc: pd.Index
81
+ self._pc_gamma_star: FloatArray
82
+ self._pc_delta_star: FloatArray
83
+
62
84
  if not (0.0 < self.covbat_cov_thresh <= 1.0):
63
85
  raise ValueError("covbat_cov_thresh must be in (0, 1].")
64
86
 
65
87
  @staticmethod
66
- def _as_series(arr, index, name):
88
+ def _as_series(
89
+ arr: ArrayLike,
90
+ index: pd.Index,
91
+ name: str
92
+ ) -> pd.Series:
93
+ """Convert array-like to categorical Series with validation."""
67
94
  if isinstance(arr, pd.Series):
68
95
  ser = arr.copy()
69
96
  else:
@@ -73,7 +100,12 @@ class ComBatModel:
73
100
  return ser.astype("category")
74
101
 
75
102
  @staticmethod
76
- def _to_df(arr, index, name):
103
+ def _to_df(
104
+ arr: Optional[ArrayLike],
105
+ index: pd.Index,
106
+ name: str
107
+ ) -> Optional[pd.DataFrame]:
108
+ """Convert array-like to DataFrame."""
77
109
  if arr is None:
78
110
  return None
79
111
  if isinstance(arr, pd.Series):
@@ -86,13 +118,14 @@ class ComBatModel:
86
118
 
87
119
  def fit(
88
120
  self,
89
- X,
90
- y=None,
121
+ X: ArrayLike,
122
+ y: Optional[ArrayLike] = None,
91
123
  *,
92
- batch,
93
- discrete_covariates=None,
94
- continuous_covariates=None,
95
- ):
124
+ batch: ArrayLike,
125
+ discrete_covariates: Optional[ArrayLike] = None,
126
+ continuous_covariates: Optional[ArrayLike] = None,
127
+ ) -> ComBatModel:
128
+ """Fit the ComBat model."""
96
129
  method = self.method.lower()
97
130
  if method not in {"johnson", "fortin", "chen"}:
98
131
  raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
@@ -104,10 +137,9 @@ class ComBatModel:
104
137
  disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
105
138
  cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
106
139
 
107
-
108
140
  if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
109
141
  raise ValueError(
110
- f"reference_batch={self.reference_batch!r} not present in the data batches "
142
+ f"reference_batch={self.reference_batch!r} not present in the data batches."
111
143
  f"{list(batch.cat.categories)}"
112
144
  )
113
145
 
@@ -127,38 +159,39 @@ class ComBatModel:
127
159
  self,
128
160
  X: pd.DataFrame,
129
161
  batch: pd.Series
130
- ):
131
- """
132
- Johnson et al. (2007) ComBat.
133
- """
162
+ ) -> None:
163
+ """Johnson et al. (2007) ComBat."""
134
164
  self._batch_levels = batch.cat.categories
135
165
  pooled_var = X.var(axis=0, ddof=1) + self.eps
136
166
  grand_mean = X.mean(axis=0)
137
167
 
138
168
  Xs = (X - grand_mean) / np.sqrt(pooled_var)
139
169
 
140
- n_per_batch: dict[str, int] = {}
141
- gamma_hat, delta_hat = [], []
170
+ n_per_batch: Dict[str, int] = {}
171
+ gamma_hat: list[npt.NDArray[np.float64]] = []
172
+ delta_hat: list[npt.NDArray[np.float64]] = []
173
+
142
174
  for lvl in self._batch_levels:
143
175
  idx = batch == lvl
144
- n_b = idx.sum()
176
+ n_b = int(idx.sum())
145
177
  if n_b < 2:
146
178
  raise ValueError(f"Batch '{lvl}' has <2 samples.")
147
- n_per_batch[lvl] = n_b
179
+ n_per_batch[str(lvl)] = n_b
148
180
  xb = Xs.loc[idx]
149
181
  gamma_hat.append(xb.mean(axis=0).values)
150
182
  delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
151
- gamma_hat = np.vstack(gamma_hat)
152
- delta_hat = np.vstack(delta_hat)
183
+
184
+ gamma_hat_arr = np.vstack(gamma_hat)
185
+ delta_hat_arr = np.vstack(delta_hat)
153
186
 
154
187
  if self.mean_only:
155
188
  gamma_star = self._shrink_gamma(
156
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
189
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
157
190
  )
158
- delta_star = np.ones_like(delta_hat)
191
+ delta_star = np.ones_like(delta_hat_arr)
159
192
  else:
160
193
  gamma_star, delta_star = self._shrink_gamma_delta(
161
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
194
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
162
195
  )
163
196
 
164
197
  if self.reference_batch is not None:
@@ -182,83 +215,107 @@ class ComBatModel:
182
215
  self,
183
216
  X: pd.DataFrame,
184
217
  batch: pd.Series,
185
- disc: pd.DataFrame | None,
186
- cont: pd.DataFrame | None,
187
- ):
188
- """
189
- Fortin et al. (2018) ComBat.
190
- """
191
- batch_levels = batch.cat.categories
192
- n_batch = len(batch_levels)
218
+ disc: Optional[pd.DataFrame],
219
+ cont: Optional[pd.DataFrame],
220
+ ) -> None:
221
+ """Fortin et al. (2018) neuroComBat."""
222
+ self._batch_levels = batch.cat.categories
223
+ n_batch = len(self._batch_levels)
193
224
  n_samples = len(X)
194
225
 
195
- batch_dummies = pd.get_dummies(batch, drop_first=False)
196
- parts = [batch_dummies]
226
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
227
+ if self.reference_batch is not None:
228
+ if self.reference_batch not in self._batch_levels:
229
+ raise ValueError(
230
+ f"reference_batch={self.reference_batch!r} not present in batches."
231
+ f"{list(self._batch_levels)}"
232
+ )
233
+ batch_dummies.loc[:, self.reference_batch] = 1.0
234
+
235
+ parts: list[pd.DataFrame] = [batch_dummies]
197
236
  if disc is not None:
198
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
237
+ parts.append(
238
+ pd.get_dummies(
239
+ disc.astype("category"), drop_first=True
240
+ ).astype(float)
241
+ )
242
+
199
243
  if cont is not None:
200
- parts.append(cont)
201
- design = pd.concat(parts, axis=1).astype(float).values
244
+ parts.append(cont.astype(float))
245
+
246
+ design = pd.concat(parts, axis=1).values
202
247
  p_design = design.shape[1]
203
248
 
204
249
  X_np = X.values
205
- beta_hat = la.inv(design.T @ design) @ design.T @ X_np
250
+ beta_hat = la.lstsq(design, X_np, rcond=None)[0]
206
251
 
207
- gamma_hat = beta_hat[:n_batch]
252
+ beta_hat_batch = beta_hat[:n_batch]
208
253
  self._beta_hat_nonbatch = beta_hat[n_batch:]
209
254
 
210
- n_per_batch = batch.value_counts().sort_index().values
211
- self._n_per_batch = dict(zip(batch_levels, n_per_batch))
255
+ n_per_batch = batch.value_counts().sort_index().astype(int).values
256
+ self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
212
257
 
213
- grand_mean = (n_per_batch / n_samples) @ gamma_hat
214
- self._grand_mean = grand_mean
258
+ if self.reference_batch is not None:
259
+ ref_idx = list(self._batch_levels).index(self.reference_batch)
260
+ grand_mean = beta_hat_batch[ref_idx]
261
+ else:
262
+ grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
263
+ ref_idx = None
264
+
265
+ self._grand_mean = pd.Series(grand_mean, index=X.columns)
215
266
 
216
- resid = X_np - design @ beta_hat
217
- var_pooled = (resid ** 2).sum(axis=0) / (n_samples - p_design) + self.eps
218
- self._pooled_var = var_pooled
267
+ if self.reference_batch is not None:
268
+ ref_mask = (batch == self.reference_batch).values
269
+ resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
270
+ denom = int(ref_mask.sum())
271
+ else:
272
+ resid = X_np - design @ beta_hat
273
+ denom = n_samples
274
+ var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
275
+ self._pooled_var = pd.Series(var_pooled, index=X.columns)
219
276
 
220
277
  stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
221
278
  Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
222
279
 
223
- delta_hat = np.empty_like(gamma_hat)
224
- for i, lvl in enumerate(batch_levels):
225
- idx = batch == lvl
226
- delta_hat[i] = Xs[idx].var(axis=0, ddof=1) + self.eps
280
+ gamma_hat = np.vstack(
281
+ [Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
282
+ )
283
+ delta_hat = np.vstack(
284
+ [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
285
+ for lvl in self._batch_levels]
286
+ )
227
287
 
228
288
  if self.mean_only:
229
289
  gamma_star = self._shrink_gamma(
230
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
290
+ gamma_hat, delta_hat, n_per_batch,
291
+ parametric = self.parametric
231
292
  )
232
293
  delta_star = np.ones_like(delta_hat)
233
294
  else:
234
295
  gamma_star, delta_star = self._shrink_gamma_delta(
235
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
296
+ gamma_hat, delta_hat, n_per_batch,
297
+ parametric = self.parametric
236
298
  )
237
299
 
238
- if self.reference_batch is not None:
239
- ref_idx = list(batch_levels).index(self.reference_batch)
240
- gamma_ref = gamma_star[ref_idx]
241
- delta_ref = delta_star[ref_idx]
242
- gamma_star = gamma_star - gamma_ref
300
+ if ref_idx is not None:
301
+ gamma_star[ref_idx] = 0.0
243
302
  if not self.mean_only:
244
- delta_star = delta_star / delta_ref
245
- self._reference_batch_idx = ref_idx
246
- else:
247
- self._reference_batch_idx = None
303
+ delta_star[ref_idx] = 1.0
304
+ self._reference_batch_idx = ref_idx
248
305
 
249
- self._batch_levels = batch_levels
250
306
  self._gamma_star = gamma_star
251
307
  self._delta_star = delta_star
252
- self._n_batch = n_batch
308
+ self._n_batch = n_batch
253
309
  self._p_design = p_design
254
310
 
255
311
  def _fit_chen(
256
312
  self,
257
313
  X: pd.DataFrame,
258
314
  batch: pd.Series,
259
- disc: pd.DataFrame | None,
260
- cont: pd.DataFrame | None,
261
- ):
315
+ disc: Optional[pd.DataFrame],
316
+ cont: Optional[pd.DataFrame],
317
+ ) -> None:
318
+ """Chen et al. (2022) CovBat."""
262
319
  self._fit_fortin(X, batch, disc, cont)
263
320
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
264
321
  X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
@@ -273,23 +330,24 @@ class ComBatModel:
273
330
  self._batch_levels_pc = self._batch_levels
274
331
  n_per_batch = self._n_per_batch
275
332
 
276
- gamma_hat, delta_hat = [], []
333
+ gamma_hat: list[npt.NDArray[np.float64]] = []
334
+ delta_hat: list[npt.NDArray[np.float64]] = []
277
335
  for lvl in self._batch_levels_pc:
278
336
  idx = batch == lvl
279
337
  xb = scores_df.loc[idx]
280
338
  gamma_hat.append(xb.mean(axis=0).values)
281
339
  delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
282
- gamma_hat = np.vstack(gamma_hat)
283
- delta_hat = np.vstack(delta_hat)
340
+ gamma_hat_arr = np.vstack(gamma_hat)
341
+ delta_hat_arr = np.vstack(delta_hat)
284
342
 
285
343
  if self.mean_only:
286
344
  gamma_star = self._shrink_gamma(
287
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
345
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
288
346
  )
289
- delta_star = np.ones_like(delta_hat)
347
+ delta_star = np.ones_like(delta_hat_arr)
290
348
  else:
291
349
  gamma_star, delta_star = self._shrink_gamma_delta(
292
- gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
350
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
293
351
  )
294
352
 
295
353
  if self.reference_batch is not None:
@@ -305,14 +363,15 @@ class ComBatModel:
305
363
 
306
364
  def _shrink_gamma_delta(
307
365
  self,
308
- gamma_hat: np.ndarray,
309
- delta_hat: np.ndarray,
310
- n_per_batch: dict | np.ndarray,
366
+ gamma_hat: FloatArray,
367
+ delta_hat: FloatArray,
368
+ n_per_batch: Union[Dict[str, int], FloatArray],
311
369
  *,
312
370
  parametric: bool,
313
371
  max_iter: int = 100,
314
372
  tol: float = 1e-4,
315
- ):
373
+ ) -> Tuple[FloatArray, FloatArray]:
374
+ """Empirical Bayes shrinkage estimation."""
316
375
  if parametric:
317
376
  gamma_bar = gamma_hat.mean(axis=0)
318
377
  t2 = gamma_hat.var(axis=0, ddof=1)
@@ -323,6 +382,7 @@ class ComBatModel:
323
382
  gamma_star = np.empty_like(gamma_hat)
324
383
  delta_star = np.empty_like(delta_hat)
325
384
  n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
385
+
326
386
  for i in range(B):
327
387
  n_i = n_vec[i]
328
388
  g, d = gamma_hat[i], delta_hat[i]
@@ -340,18 +400,29 @@ class ComBatModel:
340
400
  gamma_bar = gamma_hat.mean(axis=0)
341
401
  t2 = gamma_hat.var(axis=0, ddof=1)
342
402
 
343
- def postmean(g_hat, g_bar, n, d_star, t2_):
403
+ def postmean(
404
+ g_hat: FloatArray,
405
+ g_bar: FloatArray,
406
+ n: float,
407
+ d_star: FloatArray,
408
+ t2_: FloatArray
409
+ ) -> FloatArray:
344
410
  return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
345
411
 
346
- def postvar(sum2, n, a, b):
412
+ def postvar(
413
+ sum2: FloatArray,
414
+ n: float,
415
+ a: FloatArray,
416
+ b: FloatArray
417
+ ) -> FloatArray:
347
418
  return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
348
419
 
349
- def aprior(delta):
420
+ def aprior(delta: FloatArray) -> FloatArray:
350
421
  m, s2 = delta.mean(), delta.var()
351
422
  s2 = max(s2, self.eps)
352
423
  return (2 * s2 + m ** 2) / s2
353
424
 
354
- def bprior(delta):
425
+ def bprior(delta: FloatArray) -> FloatArray:
355
426
  m, s2 = delta.mean(), delta.var()
356
427
  s2 = max(s2, self.eps)
357
428
  return (m * s2 + m ** 3) / s2
@@ -382,24 +453,25 @@ class ComBatModel:
382
453
 
383
454
  def _shrink_gamma(
384
455
  self,
385
- gamma_hat: np.ndarray,
386
- delta_hat: np.ndarray,
387
- n_per_batch: dict | np.ndarray,
456
+ gamma_hat: FloatArray,
457
+ delta_hat: FloatArray,
458
+ n_per_batch: Union[Dict[str, int], FloatArray],
388
459
  *,
389
460
  parametric: bool,
390
- ) -> np.ndarray:
391
- """Convenience wrapper that returns only γ⋆ (for *meanonly* mode)."""
461
+ ) -> FloatArray:
462
+ """Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
392
463
  gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
393
464
  return gamma
394
465
 
395
466
  def transform(
396
467
  self,
397
- X,
468
+ X: ArrayLike,
398
469
  *,
399
- batch,
400
- discrete_covariates=None,
401
- continuous_covariates=None,
402
- ):
470
+ batch: ArrayLike,
471
+ discrete_covariates: Optional[ArrayLike] = None,
472
+ continuous_covariates: Optional[ArrayLike] = None,
473
+ ) -> pd.DataFrame:
474
+ """Transform the data using fitted ComBat parameters."""
403
475
  check_is_fitted(self, ["_gamma_star"])
404
476
  if not isinstance(X, pd.DataFrame):
405
477
  X = pd.DataFrame(X)
@@ -407,7 +479,7 @@ class ComBatModel:
407
479
  batch = self._as_series(batch, idx, "batch")
408
480
  unseen = set(batch.cat.categories) - set(self._batch_levels)
409
481
  if unseen:
410
- raise ValueError(f"Unseen batch levels during transform: {unseen}")
482
+ raise ValueError(f"Unseen batch levels during transform: {unseen}.")
411
483
  disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
412
484
  cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
413
485
 
@@ -418,8 +490,15 @@ class ComBatModel:
418
490
  return self._transform_fortin(X, batch, disc, cont)
419
491
  elif method == "chen":
420
492
  return self._transform_chen(X, batch, disc, cont)
493
+ else:
494
+ raise ValueError(f"Unknown method: {method}.")
421
495
 
422
- def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series):
496
+ def _transform_johnson(
497
+ self,
498
+ X: pd.DataFrame,
499
+ batch: pd.Series
500
+ ) -> pd.DataFrame:
501
+ """Johnson transform implementation."""
423
502
  pooled = self._pooled_var
424
503
  grand = self._grand_mean
425
504
 
@@ -431,7 +510,7 @@ class ComBatModel:
431
510
  if not idx.any():
432
511
  continue
433
512
  if self.reference_batch is not None and lvl == self.reference_batch:
434
- X_adj.loc[idx] = X.loc[idx].values # untouched
513
+ X_adj.loc[idx] = X.loc[idx].values
435
514
  continue
436
515
 
437
516
  g = self._gamma_star[i]
@@ -447,21 +526,32 @@ class ComBatModel:
447
526
  self,
448
527
  X: pd.DataFrame,
449
528
  batch: pd.Series,
450
- disc: pd.DataFrame | None,
451
- cont: pd.DataFrame | None,
452
- ):
453
- batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
529
+ disc: Optional[pd.DataFrame],
530
+ cont: Optional[pd.DataFrame],
531
+ ) -> pd.DataFrame:
532
+ """Fortin transform implementation."""
533
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
534
+ if self.reference_batch is not None:
535
+ batch_dummies.loc[:, self.reference_batch] = 1.0
536
+
454
537
  parts = [batch_dummies]
455
538
  if disc is not None:
456
- parts.append(pd.get_dummies(disc.astype("category"), drop_first=True))
539
+ parts.append(
540
+ pd.get_dummies(
541
+ disc.astype("category"), drop_first=True
542
+ ).astype(float)
543
+ )
457
544
  if cont is not None:
458
- parts.append(cont)
545
+ parts.append(cont.astype(float))
459
546
 
460
- design = pd.concat(parts, axis=1).astype(float).values
547
+ design = pd.concat(parts, axis=1).values
461
548
 
462
549
  X_np = X.values
463
- stand_mean = self._grand_mean + design[:, self._n_batch:] @ self._beta_hat_nonbatch
464
- Xs = (X_np - stand_mean) / np.sqrt(self._pooled_var)
550
+ stand_mu = (
551
+ self._grand_mean.values +
552
+ design[:, self._n_batch:] @ self._beta_hat_nonbatch
553
+ )
554
+ Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
465
555
 
466
556
  for i, lvl in enumerate(self._batch_levels):
467
557
  idx = batch == lvl
@@ -478,19 +568,23 @@ class ComBatModel:
478
568
  else:
479
569
  Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
480
570
 
481
- X_adj = Xs * np.sqrt(self._pooled_var) + stand_mean
482
- return pd.DataFrame(X_adj, index=X.index, columns=X.columns)
571
+ X_adj = (
572
+ Xs * np.sqrt(self._pooled_var.values) +
573
+ stand_mu
574
+ )
575
+ return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
483
576
 
484
577
  def _transform_chen(
485
578
  self,
486
579
  X: pd.DataFrame,
487
580
  batch: pd.Series,
488
- disc: pd.DataFrame | None,
489
- cont: pd.DataFrame | None,
490
- ):
581
+ disc: Optional[pd.DataFrame],
582
+ cont: Optional[pd.DataFrame],
583
+ ) -> pd.DataFrame:
584
+ """Chen transform implementation."""
491
585
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
492
586
  X_centered = X_meanvar_adj - self._covbat_pca.mean_
493
- scores = self._covbat_pca.transform(X_centered)
587
+ scores = self._covbat_pca.transform(X_centered.values)
494
588
  n_pc = self._covbat_n_pc
495
589
  scores_adj = scores.copy()
496
590
 
@@ -512,22 +606,22 @@ class ComBatModel:
512
606
 
513
607
 
514
608
  class ComBat(BaseEstimator, TransformerMixin):
515
- """Pipelinefriendly wrapper around `ComBatModel`.
609
+ """Pipeline-friendly wrapper around `ComBatModel`.
516
610
 
517
611
  Stores batch (and optional covariates) passed at construction and
518
- appropriately used them also for separate `fit` and `transform`.
612
+ appropriately uses them for separate `fit` and `transform`.
519
613
  """
520
614
 
521
615
  def __init__(
522
616
  self,
523
- batch,
617
+ batch: ArrayLike,
524
618
  *,
525
- discrete_covariates=None,
526
- continuous_covariates=None,
619
+ discrete_covariates: Optional[ArrayLike] = None,
620
+ continuous_covariates: Optional[ArrayLike] = None,
527
621
  method: str = "johnson",
528
622
  parametric: bool = True,
529
623
  mean_only: bool = False,
530
- reference_batch=None,
624
+ reference_batch: Optional[str] = None,
531
625
  eps: float = 1e-8,
532
626
  covbat_cov_thresh: float = 0.9,
533
627
  ) -> None:
@@ -549,8 +643,13 @@ class ComBat(BaseEstimator, TransformerMixin):
549
643
  covbat_cov_thresh=covbat_cov_thresh,
550
644
  )
551
645
 
552
- def fit(self, X, y=None):
553
- idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
646
+ def fit(
647
+ self,
648
+ X: ArrayLike,
649
+ y: Optional[ArrayLike] = None
650
+ ) -> "ComBat":
651
+ """Fit the ComBat model."""
652
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
554
653
  batch_vec = self._subset(self.batch, idx)
555
654
  disc = self._subset(self.discrete_covariates, idx)
556
655
  cont = self._subset(self.continuous_covariates, idx)
@@ -562,8 +661,9 @@ class ComBat(BaseEstimator, TransformerMixin):
562
661
  )
563
662
  return self
564
663
 
565
- def transform(self, X):
566
- idx = X.index if isinstance(X, pd.DataFrame) else np.arange(len(X))
664
+ def transform(self, X: ArrayLike) -> pd.DataFrame:
665
+ """Transform the data using fitted ComBat parameters."""
666
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
567
667
  batch_vec = self._subset(self.batch, idx)
568
668
  disc = self._subset(self.discrete_covariates, idx)
569
669
  cont = self._subset(self.continuous_covariates, idx)
@@ -575,10 +675,17 @@ class ComBat(BaseEstimator, TransformerMixin):
575
675
  )
576
676
 
577
677
  @staticmethod
578
- def _subset(obj, idx):
678
+ def _subset(
679
+ obj: Optional[ArrayLike],
680
+ idx: pd.Index
681
+ ) -> Optional[Union[pd.DataFrame, pd.Series]]:
682
+ """Subset array-like object by index."""
579
683
  if obj is None:
580
684
  return None
581
685
  if isinstance(obj, (pd.Series, pd.DataFrame)):
582
686
  return obj.loc[idx]
583
687
  else:
584
- return pd.DataFrame(obj).iloc[idx]
688
+ if isinstance(obj, np.ndarray) and obj.ndim == 1:
689
+ return pd.Series(obj, index=idx)
690
+ else:
691
+ return pd.DataFrame(obj, index=idx)
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.0
4
- Summary: Batch-effect harmonisation for machine learning frameworks.
3
+ Version: 0.1.2
4
+ Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
7
7
 
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
31
31
  Classifier: License :: OSI Approved :: MIT License
32
32
  Classifier: Operating System :: OS Independent
33
33
  Classifier: Programming Language :: Python :: 3
34
- Requires-Python: >=3.9
34
+ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
@@ -42,6 +42,12 @@ Dynamic: license-file
42
42
 
43
43
  # **combatlearn**
44
44
 
45
+ [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
46
+ [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
47
+ [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
48
+ [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
49
+ [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
50
+
45
51
  <div align="center">
46
52
  <p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
47
53
  </div>
@@ -50,7 +56,7 @@ Dynamic: license-file
50
56
 
51
57
  **Three methods**:
52
58
  - `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
53
- - `method="fortin"` - covariate-aware ComBat (Fortin _et al._, 2018)
59
+ - `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
54
60
  - `method="chen"` - CovBat (Chen _et al._, 2022)
55
61
 
56
62
  ## Installation
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
107
113
 
108
114
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
109
115
 
116
+ ## `ComBat` parameters
117
+
118
+ The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
119
+
120
+ ### Main Parameters
121
+
122
+ | Parameter | Type | Default | Description |
123
+ | --- | --- | --- | --- |
124
+ | `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
125
+ | `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
126
+ | `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
127
+
128
+ ### Algorithm Options
129
+
130
+ | Parameter | Type | Default | Description |
131
+ | --- | --- | --- | --- |
132
+ | `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
133
+ | `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
134
+ | `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
135
+ | `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
136
+ | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
+ | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
+
110
139
  ## Contributing
111
140
 
112
141
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
113
142
 
143
+ ## Author
144
+
145
+ [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
+
147
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
+
114
149
  ## Acknowledgements
115
150
 
116
151
  This project builds on the excellent work of the ComBat family of harmonisation methods.
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
2
+ combatlearn/combat.py,sha256=ghc83DTLC4ukLJN_xqpoWZTPPTxFa4DVtT6C5SVUjFA,25024
3
+ combatlearn-0.1.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-0.1.2.dist-info/METADATA,sha256=VxQpyJAwOSQqw8ypiSUxq4dmszCDRW3AsO_0XBQq6pk,8213
5
+ combatlearn-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-0.1.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-0.1.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
2
- combatlearn/combat.py,sha256=Ro9ap_bpWFYoxHAvHpRjdoMl63TOVBglNo7Er6digKg,20914
3
- combatlearn-0.1.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.1.0.dist-info/METADATA,sha256=2-z46YJJ4SeoGhqyMiueufOZ2xm8TINco2pyleEEWJ0,5376
5
- combatlearn-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.1.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.1.0.dist-info/RECORD,,