combatlearn 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/combat.py +83 -45
- {combatlearn-0.1.1.dist-info → combatlearn-0.1.2.dist-info}/METADATA +4 -4
- combatlearn-0.1.2.dist-info/RECORD +7 -0
- combatlearn-0.1.1.dist-info/RECORD +0 -7
- {combatlearn-0.1.1.dist-info → combatlearn-0.1.2.dist-info}/WHEEL +0 -0
- {combatlearn-0.1.1.dist-info → combatlearn-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-0.1.1.dist-info → combatlearn-0.1.2.dist-info}/top_level.txt +0 -0
combatlearn/combat.py
CHANGED
|
@@ -139,7 +139,7 @@ class ComBatModel:
|
|
|
139
139
|
|
|
140
140
|
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
141
141
|
raise ValueError(
|
|
142
|
-
f"reference_batch={self.reference_batch!r} not present in the data batches
|
|
142
|
+
f"reference_batch={self.reference_batch!r} not present in the data batches."
|
|
143
143
|
f"{list(batch.cat.categories)}"
|
|
144
144
|
)
|
|
145
145
|
|
|
@@ -218,69 +218,94 @@ class ComBatModel:
|
|
|
218
218
|
disc: Optional[pd.DataFrame],
|
|
219
219
|
cont: Optional[pd.DataFrame],
|
|
220
220
|
) -> None:
|
|
221
|
-
"""Fortin et al. (2018)
|
|
222
|
-
|
|
223
|
-
n_batch = len(
|
|
221
|
+
"""Fortin et al. (2018) neuroComBat."""
|
|
222
|
+
self._batch_levels = batch.cat.categories
|
|
223
|
+
n_batch = len(self._batch_levels)
|
|
224
224
|
n_samples = len(X)
|
|
225
225
|
|
|
226
|
-
batch_dummies = pd.get_dummies(batch, drop_first=False)
|
|
226
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
|
|
227
|
+
if self.reference_batch is not None:
|
|
228
|
+
if self.reference_batch not in self._batch_levels:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"reference_batch={self.reference_batch!r} not present in batches."
|
|
231
|
+
f"{list(self._batch_levels)}"
|
|
232
|
+
)
|
|
233
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
234
|
+
|
|
227
235
|
parts: list[pd.DataFrame] = [batch_dummies]
|
|
228
236
|
if disc is not None:
|
|
229
|
-
parts.append(
|
|
237
|
+
parts.append(
|
|
238
|
+
pd.get_dummies(
|
|
239
|
+
disc.astype("category"), drop_first=True
|
|
240
|
+
).astype(float)
|
|
241
|
+
)
|
|
242
|
+
|
|
230
243
|
if cont is not None:
|
|
231
|
-
parts.append(cont)
|
|
232
|
-
|
|
244
|
+
parts.append(cont.astype(float))
|
|
245
|
+
|
|
246
|
+
design = pd.concat(parts, axis=1).values
|
|
233
247
|
p_design = design.shape[1]
|
|
234
248
|
|
|
235
249
|
X_np = X.values
|
|
236
250
|
beta_hat = la.lstsq(design, X_np, rcond=None)[0]
|
|
237
251
|
|
|
238
|
-
|
|
252
|
+
beta_hat_batch = beta_hat[:n_batch]
|
|
239
253
|
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
240
254
|
|
|
241
|
-
|
|
242
|
-
self._n_per_batch = dict(zip(
|
|
255
|
+
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
256
|
+
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
|
|
257
|
+
|
|
258
|
+
if self.reference_batch is not None:
|
|
259
|
+
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
260
|
+
grand_mean = beta_hat_batch[ref_idx]
|
|
261
|
+
else:
|
|
262
|
+
grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
|
|
263
|
+
ref_idx = None
|
|
243
264
|
|
|
244
|
-
grand_mean = (n_per_batch_arr / n_samples) @ gamma_hat
|
|
245
265
|
self._grand_mean = pd.Series(grand_mean, index=X.columns)
|
|
246
266
|
|
|
247
|
-
|
|
248
|
-
|
|
267
|
+
if self.reference_batch is not None:
|
|
268
|
+
ref_mask = (batch == self.reference_batch).values
|
|
269
|
+
resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
|
|
270
|
+
denom = int(ref_mask.sum())
|
|
271
|
+
else:
|
|
272
|
+
resid = X_np - design @ beta_hat
|
|
273
|
+
denom = n_samples
|
|
274
|
+
var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
|
|
249
275
|
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
250
276
|
|
|
251
277
|
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
252
278
|
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
253
279
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
280
|
+
gamma_hat = np.vstack(
|
|
281
|
+
[Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
|
|
282
|
+
)
|
|
283
|
+
delta_hat = np.vstack(
|
|
284
|
+
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
|
|
285
|
+
for lvl in self._batch_levels]
|
|
286
|
+
)
|
|
258
287
|
|
|
259
288
|
if self.mean_only:
|
|
260
289
|
gamma_star = self._shrink_gamma(
|
|
261
|
-
gamma_hat, delta_hat,
|
|
290
|
+
gamma_hat, delta_hat, n_per_batch,
|
|
291
|
+
parametric = self.parametric
|
|
262
292
|
)
|
|
263
293
|
delta_star = np.ones_like(delta_hat)
|
|
264
294
|
else:
|
|
265
295
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
266
|
-
gamma_hat, delta_hat,
|
|
296
|
+
gamma_hat, delta_hat, n_per_batch,
|
|
297
|
+
parametric = self.parametric
|
|
267
298
|
)
|
|
268
299
|
|
|
269
|
-
if
|
|
270
|
-
ref_idx =
|
|
271
|
-
gamma_ref = gamma_star[ref_idx]
|
|
272
|
-
delta_ref = delta_star[ref_idx]
|
|
273
|
-
gamma_star = gamma_star - gamma_ref
|
|
300
|
+
if ref_idx is not None:
|
|
301
|
+
gamma_star[ref_idx] = 0.0
|
|
274
302
|
if not self.mean_only:
|
|
275
|
-
delta_star =
|
|
276
|
-
|
|
277
|
-
else:
|
|
278
|
-
self._reference_batch_idx = None
|
|
303
|
+
delta_star[ref_idx] = 1.0
|
|
304
|
+
self._reference_batch_idx = ref_idx
|
|
279
305
|
|
|
280
|
-
self._batch_levels = batch_levels
|
|
281
306
|
self._gamma_star = gamma_star
|
|
282
307
|
self._delta_star = delta_star
|
|
283
|
-
self._n_batch
|
|
308
|
+
self._n_batch = n_batch
|
|
284
309
|
self._p_design = p_design
|
|
285
310
|
|
|
286
311
|
def _fit_chen(
|
|
@@ -434,7 +459,7 @@ class ComBatModel:
|
|
|
434
459
|
*,
|
|
435
460
|
parametric: bool,
|
|
436
461
|
) -> FloatArray:
|
|
437
|
-
"""Convenience wrapper that returns only γ⋆ (for *mean
|
|
462
|
+
"""Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
|
|
438
463
|
gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
|
|
439
464
|
return gamma
|
|
440
465
|
|
|
@@ -454,7 +479,7 @@ class ComBatModel:
|
|
|
454
479
|
batch = self._as_series(batch, idx, "batch")
|
|
455
480
|
unseen = set(batch.cat.categories) - set(self._batch_levels)
|
|
456
481
|
if unseen:
|
|
457
|
-
raise ValueError(f"Unseen batch levels during transform: {unseen}")
|
|
482
|
+
raise ValueError(f"Unseen batch levels during transform: {unseen}.")
|
|
458
483
|
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
459
484
|
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
460
485
|
|
|
@@ -466,7 +491,7 @@ class ComBatModel:
|
|
|
466
491
|
elif method == "chen":
|
|
467
492
|
return self._transform_chen(X, batch, disc, cont)
|
|
468
493
|
else:
|
|
469
|
-
raise ValueError(f"Unknown method: {method}")
|
|
494
|
+
raise ValueError(f"Unknown method: {method}.")
|
|
470
495
|
|
|
471
496
|
def _transform_johnson(
|
|
472
497
|
self,
|
|
@@ -485,7 +510,7 @@ class ComBatModel:
|
|
|
485
510
|
if not idx.any():
|
|
486
511
|
continue
|
|
487
512
|
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
488
|
-
X_adj.loc[idx] = X.loc[idx].values
|
|
513
|
+
X_adj.loc[idx] = X.loc[idx].values
|
|
489
514
|
continue
|
|
490
515
|
|
|
491
516
|
g = self._gamma_star[i]
|
|
@@ -505,18 +530,28 @@ class ComBatModel:
|
|
|
505
530
|
cont: Optional[pd.DataFrame],
|
|
506
531
|
) -> pd.DataFrame:
|
|
507
532
|
"""Fortin transform implementation."""
|
|
508
|
-
batch_dummies = pd.get_dummies(batch, drop_first=False)[self._batch_levels]
|
|
509
|
-
|
|
533
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
534
|
+
if self.reference_batch is not None:
|
|
535
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
536
|
+
|
|
537
|
+
parts = [batch_dummies]
|
|
510
538
|
if disc is not None:
|
|
511
|
-
parts.append(
|
|
539
|
+
parts.append(
|
|
540
|
+
pd.get_dummies(
|
|
541
|
+
disc.astype("category"), drop_first=True
|
|
542
|
+
).astype(float)
|
|
543
|
+
)
|
|
512
544
|
if cont is not None:
|
|
513
|
-
parts.append(cont)
|
|
545
|
+
parts.append(cont.astype(float))
|
|
514
546
|
|
|
515
|
-
design = pd.concat(parts, axis=1).
|
|
547
|
+
design = pd.concat(parts, axis=1).values
|
|
516
548
|
|
|
517
549
|
X_np = X.values
|
|
518
|
-
|
|
519
|
-
|
|
550
|
+
stand_mu = (
|
|
551
|
+
self._grand_mean.values +
|
|
552
|
+
design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
553
|
+
)
|
|
554
|
+
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
520
555
|
|
|
521
556
|
for i, lvl in enumerate(self._batch_levels):
|
|
522
557
|
idx = batch == lvl
|
|
@@ -533,8 +568,11 @@ class ComBatModel:
|
|
|
533
568
|
else:
|
|
534
569
|
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
535
570
|
|
|
536
|
-
X_adj =
|
|
537
|
-
|
|
571
|
+
X_adj = (
|
|
572
|
+
Xs * np.sqrt(self._pooled_var.values) +
|
|
573
|
+
stand_mu
|
|
574
|
+
)
|
|
575
|
+
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
538
576
|
|
|
539
577
|
def _transform_chen(
|
|
540
578
|
self,
|
|
@@ -568,7 +606,7 @@ class ComBatModel:
|
|
|
568
606
|
|
|
569
607
|
|
|
570
608
|
class ComBat(BaseEstimator, TransformerMixin):
|
|
571
|
-
"""Pipeline
|
|
609
|
+
"""Pipeline-friendly wrapper around `ComBatModel`.
|
|
572
610
|
|
|
573
611
|
Stores batch (and optional covariates) passed at construction and
|
|
574
612
|
appropriately uses them for separate `fit` and `transform`.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -42,10 +42,10 @@ Dynamic: license-file
|
|
|
42
42
|
|
|
43
43
|
# **combatlearn**
|
|
44
44
|
|
|
45
|
-
[](https://www.python.org/)
|
|
46
46
|
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
47
47
|
[](https://pepy.tech/projects/combatlearn)
|
|
48
|
-
[](https://pypi.org/project/combatlearn/)
|
|
49
49
|
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
50
50
|
|
|
51
51
|
<div align="center">
|
|
@@ -56,7 +56,7 @@ Dynamic: license-file
|
|
|
56
56
|
|
|
57
57
|
**Three methods**:
|
|
58
58
|
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
59
|
-
- `method="fortin"` -
|
|
59
|
+
- `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
|
|
60
60
|
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
61
61
|
|
|
62
62
|
## Installation
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
|
|
2
|
+
combatlearn/combat.py,sha256=ghc83DTLC4ukLJN_xqpoWZTPPTxFa4DVtT6C5SVUjFA,25024
|
|
3
|
+
combatlearn-0.1.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
+
combatlearn-0.1.2.dist-info/METADATA,sha256=VxQpyJAwOSQqw8ypiSUxq4dmszCDRW3AsO_0XBQq6pk,8213
|
|
5
|
+
combatlearn-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
combatlearn-0.1.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
+
combatlearn-0.1.2.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
|
|
2
|
-
combatlearn/combat.py,sha256=3tWZDCtXcJtrv8QECD0OFGTpdo-zNRW1YflMXi7rU0c,24022
|
|
3
|
-
combatlearn-0.1.1.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
-
combatlearn-0.1.1.dist-info/METADATA,sha256=82rLKgrrISPGM0LPer9tlyfRR6dRyrAFTwNzdkKQDa8,8286
|
|
5
|
-
combatlearn-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
combatlearn-0.1.1.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
-
combatlearn-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|