combatlearn 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {combatlearn-0.1.0/src/combatlearn.egg-info → combatlearn-0.1.2}/PKG-INFO +39 -4
- {combatlearn-0.1.0 → combatlearn-0.1.2}/README.md +36 -1
- {combatlearn-0.1.0/src → combatlearn-0.1.2}/combatlearn/combat.py +243 -136
- {combatlearn-0.1.0 → combatlearn-0.1.2/combatlearn.egg-info}/PKG-INFO +39 -4
- combatlearn-0.1.2/combatlearn.egg-info/SOURCES.txt +11 -0
- {combatlearn-0.1.0 → combatlearn-0.1.2}/pyproject.toml +4 -4
- combatlearn-0.1.0/src/combatlearn.egg-info/SOURCES.txt +0 -11
- {combatlearn-0.1.0 → combatlearn-0.1.2}/LICENSE +0 -0
- {combatlearn-0.1.0/src → combatlearn-0.1.2}/combatlearn/__init__.py +0 -0
- {combatlearn-0.1.0/src → combatlearn-0.1.2}/combatlearn.egg-info/dependency_links.txt +0 -0
- {combatlearn-0.1.0/src → combatlearn-0.1.2}/combatlearn.egg-info/requires.txt +0 -0
- {combatlearn-0.1.0/src → combatlearn-0.1.2}/combatlearn.egg-info/top_level.txt +0 -0
- {combatlearn-0.1.0 → combatlearn-0.1.2}/setup.cfg +0 -0
- {combatlearn-0.1.0 → combatlearn-0.1.2}/tests/test_combat.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: Batch-effect
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT License
|
|
7
7
|
|
|
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
31
31
|
Classifier: License :: OSI Approved :: MIT License
|
|
32
32
|
Classifier: Operating System :: OS Independent
|
|
33
33
|
Classifier: Programming Language :: Python :: 3
|
|
34
|
-
Requires-Python: >=3.
|
|
34
|
+
Requires-Python: >=3.10
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: pandas>=1.3
|
|
@@ -42,6 +42,12 @@ Dynamic: license-file
|
|
|
42
42
|
|
|
43
43
|
# **combatlearn**
|
|
44
44
|
|
|
45
|
+
[](https://www.python.org/)
|
|
46
|
+
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
47
|
+
[](https://pepy.tech/projects/combatlearn)
|
|
48
|
+
[](https://pypi.org/project/combatlearn/)
|
|
49
|
+
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
50
|
+
|
|
45
51
|
<div align="center">
|
|
46
52
|
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
47
53
|
</div>
|
|
@@ -50,7 +56,7 @@ Dynamic: license-file
|
|
|
50
56
|
|
|
51
57
|
**Three methods**:
|
|
52
58
|
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
53
|
-
- `method="fortin"` -
|
|
59
|
+
- `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
|
|
54
60
|
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
55
61
|
|
|
56
62
|
## Installation
|
|
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
107
113
|
|
|
108
114
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
109
115
|
|
|
116
|
+
## `ComBat` parameters
|
|
117
|
+
|
|
118
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
119
|
+
|
|
120
|
+
### Main Parameters
|
|
121
|
+
|
|
122
|
+
| Parameter | Type | Default | Description |
|
|
123
|
+
| --- | --- | --- | --- |
|
|
124
|
+
| `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
|
|
125
|
+
| `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
|
|
126
|
+
| `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
|
|
127
|
+
|
|
128
|
+
### Algorithm Options
|
|
129
|
+
|
|
130
|
+
| Parameter | Type | Default | Description |
|
|
131
|
+
| --- | --- | --- | --- |
|
|
132
|
+
| `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
|
|
133
|
+
| `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
|
|
134
|
+
| `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
|
|
135
|
+
| `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
|
|
136
|
+
| `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
|
|
137
|
+
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
138
|
+
|
|
110
139
|
## Contributing
|
|
111
140
|
|
|
112
141
|
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
113
142
|
|
|
143
|
+
## Author
|
|
144
|
+
|
|
145
|
+
[**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
|
|
146
|
+
|
|
147
|
+
[Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
|
|
148
|
+
|
|
114
149
|
## Acknowledgements
|
|
115
150
|
|
|
116
151
|
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
# **combatlearn**
|
|
2
2
|
|
|
3
|
+
[](https://www.python.org/)
|
|
4
|
+
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
5
|
+
[](https://pepy.tech/projects/combatlearn)
|
|
6
|
+
[](https://pypi.org/project/combatlearn/)
|
|
7
|
+
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
8
|
+
|
|
3
9
|
<div align="center">
|
|
4
10
|
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
5
11
|
</div>
|
|
@@ -8,7 +14,7 @@
|
|
|
8
14
|
|
|
9
15
|
**Three methods**:
|
|
10
16
|
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
11
|
-
- `method="fortin"` -
|
|
17
|
+
- `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
|
|
12
18
|
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
13
19
|
|
|
14
20
|
## Installation
|
|
@@ -65,10 +71,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
65
71
|
|
|
66
72
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
67
73
|
|
|
74
|
+
## `ComBat` parameters
|
|
75
|
+
|
|
76
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
77
|
+
|
|
78
|
+
### Main Parameters
|
|
79
|
+
|
|
80
|
+
| Parameter | Type | Default | Description |
|
|
81
|
+
| --- | --- | --- | --- |
|
|
82
|
+
| `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
|
|
83
|
+
| `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
|
|
84
|
+
| `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
|
|
85
|
+
|
|
86
|
+
### Algorithm Options
|
|
87
|
+
|
|
88
|
+
| Parameter | Type | Default | Description |
|
|
89
|
+
| --- | --- | --- | --- |
|
|
90
|
+
| `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
|
|
91
|
+
| `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
|
|
92
|
+
| `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
|
|
93
|
+
| `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
|
|
94
|
+
| `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
|
|
95
|
+
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
96
|
+
|
|
68
97
|
## Contributing
|
|
69
98
|
|
|
70
99
|
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
71
100
|
|
|
101
|
+
## Author
|
|
102
|
+
|
|
103
|
+
[**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
|
|
104
|
+
|
|
105
|
+
[Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
|
|
106
|
+
|
|
72
107
|
## Acknowledgements
|
|
73
108
|
|
|
74
109
|
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
@@ -1,15 +1,14 @@
|
|
|
1
|
-
__author__ = "Ettore Rocchi"
|
|
2
|
-
|
|
3
1
|
"""ComBat algorithm.
|
|
4
2
|
|
|
5
3
|
`ComBatModel` implements both:
|
|
6
|
-
* Johnson et
|
|
7
|
-
* Fortin et
|
|
8
|
-
* Chen et
|
|
4
|
+
* Johnson et al. (2007) vanilla ComBat (method="johnson")
|
|
5
|
+
* Fortin et al. (2018) extension with covariates (method="fortin")
|
|
6
|
+
* Chen et al. (2022) CovBat (method="chen")
|
|
9
7
|
|
|
10
8
|
`ComBat` makes the model compatible with scikit-learn by stashing
|
|
11
9
|
the batch (and optional covariates) at construction.
|
|
12
10
|
"""
|
|
11
|
+
from __future__ import annotations
|
|
13
12
|
|
|
14
13
|
import numpy as np
|
|
15
14
|
import numpy.linalg as la
|
|
@@ -17,9 +16,15 @@ import pandas as pd
|
|
|
17
16
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
18
17
|
from sklearn.utils.validation import check_is_fitted
|
|
19
18
|
from sklearn.decomposition import PCA
|
|
20
|
-
from typing import Literal
|
|
19
|
+
from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
|
|
20
|
+
import numpy.typing as npt
|
|
21
21
|
import warnings
|
|
22
22
|
|
|
23
|
+
__author__ = "Ettore Rocchi"
|
|
24
|
+
|
|
25
|
+
ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
|
|
26
|
+
FloatArray = npt.NDArray[np.float64]
|
|
27
|
+
|
|
23
28
|
|
|
24
29
|
class ComBatModel:
|
|
25
30
|
"""ComBat algorithm.
|
|
@@ -27,9 +32,9 @@ class ComBatModel:
|
|
|
27
32
|
Parameters
|
|
28
33
|
----------
|
|
29
34
|
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
30
|
-
* 'johnson'
|
|
31
|
-
* 'fortin'
|
|
32
|
-
* 'chen'
|
|
35
|
+
* 'johnson' - classic ComBat.
|
|
36
|
+
* 'fortin' - covariate-aware ComBat.
|
|
37
|
+
* 'chen' - CovBat, PCA-based ComBat.
|
|
33
38
|
parametric : bool, default=True
|
|
34
39
|
Use the parametric empirical Bayes variant.
|
|
35
40
|
mean_only : bool, default=False
|
|
@@ -40,7 +45,7 @@ class ComBatModel:
|
|
|
40
45
|
covbat_cov_thresh : float, default=0.9
|
|
41
46
|
CovBat: cumulative explained variance threshold for PCA.
|
|
42
47
|
eps : float, default=1e-8
|
|
43
|
-
Numerical jitter to avoid division
|
|
48
|
+
Numerical jitter to avoid division-by-zero.
|
|
44
49
|
"""
|
|
45
50
|
|
|
46
51
|
def __init__(
|
|
@@ -49,21 +54,43 @@ class ComBatModel:
|
|
|
49
54
|
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
50
55
|
parametric: bool = True,
|
|
51
56
|
mean_only: bool = False,
|
|
52
|
-
reference_batch=None,
|
|
57
|
+
reference_batch: Optional[str] = None,
|
|
53
58
|
eps: float = 1e-8,
|
|
54
59
|
covbat_cov_thresh: float = 0.9,
|
|
55
60
|
) -> None:
|
|
56
|
-
self.method = method
|
|
57
|
-
self.parametric = parametric
|
|
58
|
-
self.mean_only = bool(mean_only)
|
|
59
|
-
self.reference_batch = reference_batch
|
|
60
|
-
self.eps = float(eps)
|
|
61
|
-
self.covbat_cov_thresh = float(covbat_cov_thresh)
|
|
61
|
+
self.method: str = method
|
|
62
|
+
self.parametric: bool = parametric
|
|
63
|
+
self.mean_only: bool = bool(mean_only)
|
|
64
|
+
self.reference_batch: Optional[str] = reference_batch
|
|
65
|
+
self.eps: float = float(eps)
|
|
66
|
+
self.covbat_cov_thresh: float = float(covbat_cov_thresh)
|
|
67
|
+
|
|
68
|
+
self._batch_levels: pd.Index
|
|
69
|
+
self._grand_mean: pd.Series
|
|
70
|
+
self._pooled_var: pd.Series
|
|
71
|
+
self._gamma_star: FloatArray
|
|
72
|
+
self._delta_star: FloatArray
|
|
73
|
+
self._n_per_batch: Dict[str, int]
|
|
74
|
+
self._reference_batch_idx: Optional[int]
|
|
75
|
+
self._beta_hat_nonbatch: FloatArray
|
|
76
|
+
self._n_batch: int
|
|
77
|
+
self._p_design: int
|
|
78
|
+
self._covbat_pca: PCA
|
|
79
|
+
self._covbat_n_pc: int
|
|
80
|
+
self._batch_levels_pc: pd.Index
|
|
81
|
+
self._pc_gamma_star: FloatArray
|
|
82
|
+
self._pc_delta_star: FloatArray
|
|
83
|
+
|
|
62
84
|
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
63
85
|
raise ValueError("covbat_cov_thresh must be in (0, 1].")
|
|
64
86
|
|
|
65
87
|
@staticmethod
|
|
66
|
-
def _as_series(
|
|
88
|
+
def _as_series(
|
|
89
|
+
arr: ArrayLike,
|
|
90
|
+
index: pd.Index,
|
|
91
|
+
name: str
|
|
92
|
+
) -> pd.Series:
|
|
93
|
+
"""Convert array-like to categorical Series with validation."""
|
|
67
94
|
if isinstance(arr, pd.Series):
|
|
68
95
|
ser = arr.copy()
|
|
69
96
|
else:
|
|
@@ -73,7 +100,12 @@ class ComBatModel:
|
|
|
73
100
|
return ser.astype("category")
|
|
74
101
|
|
|
75
102
|
@staticmethod
|
|
76
|
-
def _to_df(
|
|
103
|
+
def _to_df(
|
|
104
|
+
arr: Optional[ArrayLike],
|
|
105
|
+
index: pd.Index,
|
|
106
|
+
name: str
|
|
107
|
+
) -> Optional[pd.DataFrame]:
|
|
108
|
+
"""Convert array-like to DataFrame."""
|
|
77
109
|
if arr is None:
|
|
78
110
|
return None
|
|
79
111
|
if isinstance(arr, pd.Series):
|
|
@@ -86,13 +118,14 @@ class ComBatModel:
|
|
|
86
118
|
|
|
87
119
|
def fit(
|
|
88
120
|
self,
|
|
89
|
-
X,
|
|
90
|
-
y=None,
|
|
121
|
+
X: ArrayLike,
|
|
122
|
+
y: Optional[ArrayLike] = None,
|
|
91
123
|
*,
|
|
92
|
-
batch,
|
|
93
|
-
discrete_covariates=None,
|
|
94
|
-
continuous_covariates=None,
|
|
95
|
-
):
|
|
124
|
+
batch: ArrayLike,
|
|
125
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
126
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
127
|
+
) -> ComBatModel:
|
|
128
|
+
"""Fit the ComBat model."""
|
|
96
129
|
method = self.method.lower()
|
|
97
130
|
if method not in {"johnson", "fortin", "chen"}:
|
|
98
131
|
raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
|
|
@@ -104,10 +137,9 @@ class ComBatModel:
|
|
|
104
137
|
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
105
138
|
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
106
139
|
|
|
107
|
-
|
|
108
140
|
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
109
141
|
raise ValueError(
|
|
110
|
-
f"reference_batch={self.reference_batch!r} not present in the data batches
|
|
142
|
+
f"reference_batch={self.reference_batch!r} not present in the data batches."
|
|
111
143
|
f"{list(batch.cat.categories)}"
|
|
112
144
|
)
|
|
113
145
|
|
|
@@ -127,38 +159,39 @@ class ComBatModel:
|
|
|
127
159
|
self,
|
|
128
160
|
X: pd.DataFrame,
|
|
129
161
|
batch: pd.Series
|
|
130
|
-
):
|
|
131
|
-
"""
|
|
132
|
-
Johnson et al. (2007) ComBat.
|
|
133
|
-
"""
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Johnson et al. (2007) ComBat."""
|
|
134
164
|
self._batch_levels = batch.cat.categories
|
|
135
165
|
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
136
166
|
grand_mean = X.mean(axis=0)
|
|
137
167
|
|
|
138
168
|
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
139
169
|
|
|
140
|
-
n_per_batch:
|
|
141
|
-
gamma_hat
|
|
170
|
+
n_per_batch: Dict[str, int] = {}
|
|
171
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
172
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
173
|
+
|
|
142
174
|
for lvl in self._batch_levels:
|
|
143
175
|
idx = batch == lvl
|
|
144
|
-
n_b = idx.sum()
|
|
176
|
+
n_b = int(idx.sum())
|
|
145
177
|
if n_b < 2:
|
|
146
178
|
raise ValueError(f"Batch '{lvl}' has <2 samples.")
|
|
147
|
-
n_per_batch[lvl] = n_b
|
|
179
|
+
n_per_batch[str(lvl)] = n_b
|
|
148
180
|
xb = Xs.loc[idx]
|
|
149
181
|
gamma_hat.append(xb.mean(axis=0).values)
|
|
150
182
|
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
151
|
-
|
|
152
|
-
|
|
183
|
+
|
|
184
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
185
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
153
186
|
|
|
154
187
|
if self.mean_only:
|
|
155
188
|
gamma_star = self._shrink_gamma(
|
|
156
|
-
|
|
189
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
157
190
|
)
|
|
158
|
-
delta_star = np.ones_like(
|
|
191
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
159
192
|
else:
|
|
160
193
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
161
|
-
|
|
194
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
162
195
|
)
|
|
163
196
|
|
|
164
197
|
if self.reference_batch is not None:
|
|
@@ -182,83 +215,107 @@ class ComBatModel:
|
|
|
182
215
|
self,
|
|
183
216
|
X: pd.DataFrame,
|
|
184
217
|
batch: pd.Series,
|
|
185
|
-
disc: pd.DataFrame
|
|
186
|
-
cont: pd.DataFrame
|
|
187
|
-
):
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
batch_levels = batch.cat.categories
|
|
192
|
-
n_batch = len(batch_levels)
|
|
218
|
+
disc: Optional[pd.DataFrame],
|
|
219
|
+
cont: Optional[pd.DataFrame],
|
|
220
|
+
) -> None:
|
|
221
|
+
"""Fortin et al. (2018) neuroComBat."""
|
|
222
|
+
self._batch_levels = batch.cat.categories
|
|
223
|
+
n_batch = len(self._batch_levels)
|
|
193
224
|
n_samples = len(X)
|
|
194
225
|
|
|
195
|
-
batch_dummies = pd.get_dummies(batch, drop_first=False)
|
|
196
|
-
|
|
226
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
|
|
227
|
+
if self.reference_batch is not None:
|
|
228
|
+
if self.reference_batch not in self._batch_levels:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"reference_batch={self.reference_batch!r} not present in batches."
|
|
231
|
+
f"{list(self._batch_levels)}"
|
|
232
|
+
)
|
|
233
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
234
|
+
|
|
235
|
+
parts: list[pd.DataFrame] = [batch_dummies]
|
|
197
236
|
if disc is not None:
|
|
198
|
-
parts.append(
|
|
237
|
+
parts.append(
|
|
238
|
+
pd.get_dummies(
|
|
239
|
+
disc.astype("category"), drop_first=True
|
|
240
|
+
).astype(float)
|
|
241
|
+
)
|
|
242
|
+
|
|
199
243
|
if cont is not None:
|
|
200
|
-
parts.append(cont)
|
|
201
|
-
|
|
244
|
+
parts.append(cont.astype(float))
|
|
245
|
+
|
|
246
|
+
design = pd.concat(parts, axis=1).values
|
|
202
247
|
p_design = design.shape[1]
|
|
203
248
|
|
|
204
249
|
X_np = X.values
|
|
205
|
-
beta_hat = la.
|
|
250
|
+
beta_hat = la.lstsq(design, X_np, rcond=None)[0]
|
|
206
251
|
|
|
207
|
-
|
|
252
|
+
beta_hat_batch = beta_hat[:n_batch]
|
|
208
253
|
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
209
254
|
|
|
210
|
-
n_per_batch = batch.value_counts().sort_index().values
|
|
211
|
-
self._n_per_batch = dict(zip(
|
|
255
|
+
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
256
|
+
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch))
|
|
212
257
|
|
|
213
|
-
|
|
214
|
-
|
|
258
|
+
if self.reference_batch is not None:
|
|
259
|
+
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
260
|
+
grand_mean = beta_hat_batch[ref_idx]
|
|
261
|
+
else:
|
|
262
|
+
grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
|
|
263
|
+
ref_idx = None
|
|
264
|
+
|
|
265
|
+
self._grand_mean = pd.Series(grand_mean, index=X.columns)
|
|
215
266
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
267
|
+
if self.reference_batch is not None:
|
|
268
|
+
ref_mask = (batch == self.reference_batch).values
|
|
269
|
+
resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
|
|
270
|
+
denom = int(ref_mask.sum())
|
|
271
|
+
else:
|
|
272
|
+
resid = X_np - design @ beta_hat
|
|
273
|
+
denom = n_samples
|
|
274
|
+
var_pooled = (resid ** 2).sum(axis=0) / denom + self.eps
|
|
275
|
+
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
219
276
|
|
|
220
277
|
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
221
278
|
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
222
279
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
280
|
+
gamma_hat = np.vstack(
|
|
281
|
+
[Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels]
|
|
282
|
+
)
|
|
283
|
+
delta_hat = np.vstack(
|
|
284
|
+
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps
|
|
285
|
+
for lvl in self._batch_levels]
|
|
286
|
+
)
|
|
227
287
|
|
|
228
288
|
if self.mean_only:
|
|
229
289
|
gamma_star = self._shrink_gamma(
|
|
230
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
290
|
+
gamma_hat, delta_hat, n_per_batch,
|
|
291
|
+
parametric = self.parametric
|
|
231
292
|
)
|
|
232
293
|
delta_star = np.ones_like(delta_hat)
|
|
233
294
|
else:
|
|
234
295
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
235
|
-
gamma_hat, delta_hat, n_per_batch,
|
|
296
|
+
gamma_hat, delta_hat, n_per_batch,
|
|
297
|
+
parametric = self.parametric
|
|
236
298
|
)
|
|
237
299
|
|
|
238
|
-
if
|
|
239
|
-
ref_idx =
|
|
240
|
-
gamma_ref = gamma_star[ref_idx]
|
|
241
|
-
delta_ref = delta_star[ref_idx]
|
|
242
|
-
gamma_star = gamma_star - gamma_ref
|
|
300
|
+
if ref_idx is not None:
|
|
301
|
+
gamma_star[ref_idx] = 0.0
|
|
243
302
|
if not self.mean_only:
|
|
244
|
-
delta_star =
|
|
245
|
-
|
|
246
|
-
else:
|
|
247
|
-
self._reference_batch_idx = None
|
|
303
|
+
delta_star[ref_idx] = 1.0
|
|
304
|
+
self._reference_batch_idx = ref_idx
|
|
248
305
|
|
|
249
|
-
self._batch_levels = batch_levels
|
|
250
306
|
self._gamma_star = gamma_star
|
|
251
307
|
self._delta_star = delta_star
|
|
252
|
-
self._n_batch
|
|
308
|
+
self._n_batch = n_batch
|
|
253
309
|
self._p_design = p_design
|
|
254
310
|
|
|
255
311
|
def _fit_chen(
|
|
256
312
|
self,
|
|
257
313
|
X: pd.DataFrame,
|
|
258
314
|
batch: pd.Series,
|
|
259
|
-
disc: pd.DataFrame
|
|
260
|
-
cont: pd.DataFrame
|
|
261
|
-
):
|
|
315
|
+
disc: Optional[pd.DataFrame],
|
|
316
|
+
cont: Optional[pd.DataFrame],
|
|
317
|
+
) -> None:
|
|
318
|
+
"""Chen et al. (2022) CovBat."""
|
|
262
319
|
self._fit_fortin(X, batch, disc, cont)
|
|
263
320
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
264
321
|
X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
|
|
@@ -273,23 +330,24 @@ class ComBatModel:
|
|
|
273
330
|
self._batch_levels_pc = self._batch_levels
|
|
274
331
|
n_per_batch = self._n_per_batch
|
|
275
332
|
|
|
276
|
-
gamma_hat
|
|
333
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
334
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
277
335
|
for lvl in self._batch_levels_pc:
|
|
278
336
|
idx = batch == lvl
|
|
279
337
|
xb = scores_df.loc[idx]
|
|
280
338
|
gamma_hat.append(xb.mean(axis=0).values)
|
|
281
339
|
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
282
|
-
|
|
283
|
-
|
|
340
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
341
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
284
342
|
|
|
285
343
|
if self.mean_only:
|
|
286
344
|
gamma_star = self._shrink_gamma(
|
|
287
|
-
|
|
345
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
288
346
|
)
|
|
289
|
-
delta_star = np.ones_like(
|
|
347
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
290
348
|
else:
|
|
291
349
|
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
292
|
-
|
|
350
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
293
351
|
)
|
|
294
352
|
|
|
295
353
|
if self.reference_batch is not None:
|
|
@@ -305,14 +363,15 @@ class ComBatModel:
|
|
|
305
363
|
|
|
306
364
|
def _shrink_gamma_delta(
|
|
307
365
|
self,
|
|
308
|
-
gamma_hat:
|
|
309
|
-
delta_hat:
|
|
310
|
-
n_per_batch:
|
|
366
|
+
gamma_hat: FloatArray,
|
|
367
|
+
delta_hat: FloatArray,
|
|
368
|
+
n_per_batch: Union[Dict[str, int], FloatArray],
|
|
311
369
|
*,
|
|
312
370
|
parametric: bool,
|
|
313
371
|
max_iter: int = 100,
|
|
314
372
|
tol: float = 1e-4,
|
|
315
|
-
):
|
|
373
|
+
) -> Tuple[FloatArray, FloatArray]:
|
|
374
|
+
"""Empirical Bayes shrinkage estimation."""
|
|
316
375
|
if parametric:
|
|
317
376
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
318
377
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
@@ -323,6 +382,7 @@ class ComBatModel:
|
|
|
323
382
|
gamma_star = np.empty_like(gamma_hat)
|
|
324
383
|
delta_star = np.empty_like(delta_hat)
|
|
325
384
|
n_vec = np.array(list(n_per_batch.values())) if isinstance(n_per_batch, dict) else n_per_batch
|
|
385
|
+
|
|
326
386
|
for i in range(B):
|
|
327
387
|
n_i = n_vec[i]
|
|
328
388
|
g, d = gamma_hat[i], delta_hat[i]
|
|
@@ -340,18 +400,29 @@ class ComBatModel:
|
|
|
340
400
|
gamma_bar = gamma_hat.mean(axis=0)
|
|
341
401
|
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
342
402
|
|
|
343
|
-
def postmean(
|
|
403
|
+
def postmean(
|
|
404
|
+
g_hat: FloatArray,
|
|
405
|
+
g_bar: FloatArray,
|
|
406
|
+
n: float,
|
|
407
|
+
d_star: FloatArray,
|
|
408
|
+
t2_: FloatArray
|
|
409
|
+
) -> FloatArray:
|
|
344
410
|
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
345
411
|
|
|
346
|
-
def postvar(
|
|
412
|
+
def postvar(
|
|
413
|
+
sum2: FloatArray,
|
|
414
|
+
n: float,
|
|
415
|
+
a: FloatArray,
|
|
416
|
+
b: FloatArray
|
|
417
|
+
) -> FloatArray:
|
|
347
418
|
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
348
419
|
|
|
349
|
-
def aprior(delta):
|
|
420
|
+
def aprior(delta: FloatArray) -> FloatArray:
|
|
350
421
|
m, s2 = delta.mean(), delta.var()
|
|
351
422
|
s2 = max(s2, self.eps)
|
|
352
423
|
return (2 * s2 + m ** 2) / s2
|
|
353
424
|
|
|
354
|
-
def bprior(delta):
|
|
425
|
+
def bprior(delta: FloatArray) -> FloatArray:
|
|
355
426
|
m, s2 = delta.mean(), delta.var()
|
|
356
427
|
s2 = max(s2, self.eps)
|
|
357
428
|
return (m * s2 + m ** 3) / s2
|
|
@@ -382,24 +453,25 @@ class ComBatModel:
|
|
|
382
453
|
|
|
383
454
|
def _shrink_gamma(
|
|
384
455
|
self,
|
|
385
|
-
gamma_hat:
|
|
386
|
-
delta_hat:
|
|
387
|
-
n_per_batch:
|
|
456
|
+
gamma_hat: FloatArray,
|
|
457
|
+
delta_hat: FloatArray,
|
|
458
|
+
n_per_batch: Union[Dict[str, int], FloatArray],
|
|
388
459
|
*,
|
|
389
460
|
parametric: bool,
|
|
390
|
-
) ->
|
|
391
|
-
"""Convenience wrapper that returns only γ⋆ (for *mean
|
|
461
|
+
) -> FloatArray:
|
|
462
|
+
"""Convenience wrapper that returns only γ⋆ (for *mean-only* mode)."""
|
|
392
463
|
gamma, _ = self._shrink_gamma_delta(gamma_hat, delta_hat, n_per_batch, parametric=parametric)
|
|
393
464
|
return gamma
|
|
394
465
|
|
|
395
466
|
def transform(
|
|
396
467
|
self,
|
|
397
|
-
X,
|
|
468
|
+
X: ArrayLike,
|
|
398
469
|
*,
|
|
399
|
-
batch,
|
|
400
|
-
discrete_covariates=None,
|
|
401
|
-
continuous_covariates=None,
|
|
402
|
-
):
|
|
470
|
+
batch: ArrayLike,
|
|
471
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
472
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
473
|
+
) -> pd.DataFrame:
|
|
474
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
403
475
|
check_is_fitted(self, ["_gamma_star"])
|
|
404
476
|
if not isinstance(X, pd.DataFrame):
|
|
405
477
|
X = pd.DataFrame(X)
|
|
@@ -407,7 +479,7 @@ class ComBatModel:
|
|
|
407
479
|
batch = self._as_series(batch, idx, "batch")
|
|
408
480
|
unseen = set(batch.cat.categories) - set(self._batch_levels)
|
|
409
481
|
if unseen:
|
|
410
|
-
raise ValueError(f"Unseen batch levels during transform: {unseen}")
|
|
482
|
+
raise ValueError(f"Unseen batch levels during transform: {unseen}.")
|
|
411
483
|
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
412
484
|
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
413
485
|
|
|
@@ -418,8 +490,15 @@ class ComBatModel:
|
|
|
418
490
|
return self._transform_fortin(X, batch, disc, cont)
|
|
419
491
|
elif method == "chen":
|
|
420
492
|
return self._transform_chen(X, batch, disc, cont)
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError(f"Unknown method: {method}.")
|
|
421
495
|
|
|
422
|
-
def _transform_johnson(
|
|
496
|
+
def _transform_johnson(
|
|
497
|
+
self,
|
|
498
|
+
X: pd.DataFrame,
|
|
499
|
+
batch: pd.Series
|
|
500
|
+
) -> pd.DataFrame:
|
|
501
|
+
"""Johnson transform implementation."""
|
|
423
502
|
pooled = self._pooled_var
|
|
424
503
|
grand = self._grand_mean
|
|
425
504
|
|
|
@@ -431,7 +510,7 @@ class ComBatModel:
|
|
|
431
510
|
if not idx.any():
|
|
432
511
|
continue
|
|
433
512
|
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
434
|
-
X_adj.loc[idx] = X.loc[idx].values
|
|
513
|
+
X_adj.loc[idx] = X.loc[idx].values
|
|
435
514
|
continue
|
|
436
515
|
|
|
437
516
|
g = self._gamma_star[i]
|
|
@@ -447,21 +526,32 @@ class ComBatModel:
|
|
|
447
526
|
self,
|
|
448
527
|
X: pd.DataFrame,
|
|
449
528
|
batch: pd.Series,
|
|
450
|
-
disc: pd.DataFrame
|
|
451
|
-
cont: pd.DataFrame
|
|
452
|
-
):
|
|
453
|
-
|
|
529
|
+
disc: Optional[pd.DataFrame],
|
|
530
|
+
cont: Optional[pd.DataFrame],
|
|
531
|
+
) -> pd.DataFrame:
|
|
532
|
+
"""Fortin transform implementation."""
|
|
533
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
534
|
+
if self.reference_batch is not None:
|
|
535
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
536
|
+
|
|
454
537
|
parts = [batch_dummies]
|
|
455
538
|
if disc is not None:
|
|
456
|
-
parts.append(
|
|
539
|
+
parts.append(
|
|
540
|
+
pd.get_dummies(
|
|
541
|
+
disc.astype("category"), drop_first=True
|
|
542
|
+
).astype(float)
|
|
543
|
+
)
|
|
457
544
|
if cont is not None:
|
|
458
|
-
parts.append(cont)
|
|
545
|
+
parts.append(cont.astype(float))
|
|
459
546
|
|
|
460
|
-
design = pd.concat(parts, axis=1).
|
|
547
|
+
design = pd.concat(parts, axis=1).values
|
|
461
548
|
|
|
462
549
|
X_np = X.values
|
|
463
|
-
|
|
464
|
-
|
|
550
|
+
stand_mu = (
|
|
551
|
+
self._grand_mean.values +
|
|
552
|
+
design[:, self._n_batch:] @ self._beta_hat_nonbatch
|
|
553
|
+
)
|
|
554
|
+
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
465
555
|
|
|
466
556
|
for i, lvl in enumerate(self._batch_levels):
|
|
467
557
|
idx = batch == lvl
|
|
@@ -478,19 +568,23 @@ class ComBatModel:
|
|
|
478
568
|
else:
|
|
479
569
|
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
480
570
|
|
|
481
|
-
X_adj =
|
|
482
|
-
|
|
571
|
+
X_adj = (
|
|
572
|
+
Xs * np.sqrt(self._pooled_var.values) +
|
|
573
|
+
stand_mu
|
|
574
|
+
)
|
|
575
|
+
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
483
576
|
|
|
484
577
|
def _transform_chen(
|
|
485
578
|
self,
|
|
486
579
|
X: pd.DataFrame,
|
|
487
580
|
batch: pd.Series,
|
|
488
|
-
disc: pd.DataFrame
|
|
489
|
-
cont: pd.DataFrame
|
|
490
|
-
):
|
|
581
|
+
disc: Optional[pd.DataFrame],
|
|
582
|
+
cont: Optional[pd.DataFrame],
|
|
583
|
+
) -> pd.DataFrame:
|
|
584
|
+
"""Chen transform implementation."""
|
|
491
585
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
492
586
|
X_centered = X_meanvar_adj - self._covbat_pca.mean_
|
|
493
|
-
scores = self._covbat_pca.transform(X_centered)
|
|
587
|
+
scores = self._covbat_pca.transform(X_centered.values)
|
|
494
588
|
n_pc = self._covbat_n_pc
|
|
495
589
|
scores_adj = scores.copy()
|
|
496
590
|
|
|
@@ -512,22 +606,22 @@ class ComBatModel:
|
|
|
512
606
|
|
|
513
607
|
|
|
514
608
|
class ComBat(BaseEstimator, TransformerMixin):
|
|
515
|
-
"""Pipeline
|
|
609
|
+
"""Pipeline-friendly wrapper around `ComBatModel`.
|
|
516
610
|
|
|
517
611
|
Stores batch (and optional covariates) passed at construction and
|
|
518
|
-
appropriately
|
|
612
|
+
appropriately uses them for separate `fit` and `transform`.
|
|
519
613
|
"""
|
|
520
614
|
|
|
521
615
|
def __init__(
|
|
522
616
|
self,
|
|
523
|
-
batch,
|
|
617
|
+
batch: ArrayLike,
|
|
524
618
|
*,
|
|
525
|
-
discrete_covariates=None,
|
|
526
|
-
continuous_covariates=None,
|
|
619
|
+
discrete_covariates: Optional[ArrayLike] = None,
|
|
620
|
+
continuous_covariates: Optional[ArrayLike] = None,
|
|
527
621
|
method: str = "johnson",
|
|
528
622
|
parametric: bool = True,
|
|
529
623
|
mean_only: bool = False,
|
|
530
|
-
reference_batch=None,
|
|
624
|
+
reference_batch: Optional[str] = None,
|
|
531
625
|
eps: float = 1e-8,
|
|
532
626
|
covbat_cov_thresh: float = 0.9,
|
|
533
627
|
) -> None:
|
|
@@ -549,8 +643,13 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
549
643
|
covbat_cov_thresh=covbat_cov_thresh,
|
|
550
644
|
)
|
|
551
645
|
|
|
552
|
-
def fit(
|
|
553
|
-
|
|
646
|
+
def fit(
|
|
647
|
+
self,
|
|
648
|
+
X: ArrayLike,
|
|
649
|
+
y: Optional[ArrayLike] = None
|
|
650
|
+
) -> "ComBat":
|
|
651
|
+
"""Fit the ComBat model."""
|
|
652
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
554
653
|
batch_vec = self._subset(self.batch, idx)
|
|
555
654
|
disc = self._subset(self.discrete_covariates, idx)
|
|
556
655
|
cont = self._subset(self.continuous_covariates, idx)
|
|
@@ -562,8 +661,9 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
562
661
|
)
|
|
563
662
|
return self
|
|
564
663
|
|
|
565
|
-
def transform(self, X):
|
|
566
|
-
|
|
664
|
+
def transform(self, X: ArrayLike) -> pd.DataFrame:
|
|
665
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
666
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
567
667
|
batch_vec = self._subset(self.batch, idx)
|
|
568
668
|
disc = self._subset(self.discrete_covariates, idx)
|
|
569
669
|
cont = self._subset(self.continuous_covariates, idx)
|
|
@@ -575,10 +675,17 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
575
675
|
)
|
|
576
676
|
|
|
577
677
|
@staticmethod
|
|
578
|
-
def _subset(
|
|
678
|
+
def _subset(
|
|
679
|
+
obj: Optional[ArrayLike],
|
|
680
|
+
idx: pd.Index
|
|
681
|
+
) -> Optional[Union[pd.DataFrame, pd.Series]]:
|
|
682
|
+
"""Subset array-like object by index."""
|
|
579
683
|
if obj is None:
|
|
580
684
|
return None
|
|
581
685
|
if isinstance(obj, (pd.Series, pd.DataFrame)):
|
|
582
686
|
return obj.loc[idx]
|
|
583
687
|
else:
|
|
584
|
-
|
|
688
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 1:
|
|
689
|
+
return pd.Series(obj, index=idx)
|
|
690
|
+
else:
|
|
691
|
+
return pd.DataFrame(obj, index=idx)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Summary: Batch-effect
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT License
|
|
7
7
|
|
|
@@ -31,7 +31,7 @@ Classifier: Intended Audience :: Science/Research
|
|
|
31
31
|
Classifier: License :: OSI Approved :: MIT License
|
|
32
32
|
Classifier: Operating System :: OS Independent
|
|
33
33
|
Classifier: Programming Language :: Python :: 3
|
|
34
|
-
Requires-Python: >=3.
|
|
34
|
+
Requires-Python: >=3.10
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: pandas>=1.3
|
|
@@ -42,6 +42,12 @@ Dynamic: license-file
|
|
|
42
42
|
|
|
43
43
|
# **combatlearn**
|
|
44
44
|
|
|
45
|
+
[](https://www.python.org/)
|
|
46
|
+
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
47
|
+
[](https://pepy.tech/projects/combatlearn)
|
|
48
|
+
[](https://pypi.org/project/combatlearn/)
|
|
49
|
+
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
50
|
+
|
|
45
51
|
<div align="center">
|
|
46
52
|
<p><img src="https://raw.githubusercontent.com/EttoreRocchi/combatlearn/main/docs/logo.png" alt="combatlearn logo" width="350" /></p>
|
|
47
53
|
</div>
|
|
@@ -50,7 +56,7 @@ Dynamic: license-file
|
|
|
50
56
|
|
|
51
57
|
**Three methods**:
|
|
52
58
|
- `method="johnson"` - classic ComBat (Johnson _et al._, 2007)
|
|
53
|
-
- `method="fortin"` -
|
|
59
|
+
- `method="fortin"` - neuroComBat (Fortin _et al._, 2018)
|
|
54
60
|
- `method="chen"` - CovBat (Chen _et al._, 2022)
|
|
55
61
|
|
|
56
62
|
## Installation
|
|
@@ -107,10 +113,39 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
107
113
|
|
|
108
114
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
|
|
109
115
|
|
|
116
|
+
## `ComBat` parameters
|
|
117
|
+
|
|
118
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
119
|
+
|
|
120
|
+
### Main Parameters
|
|
121
|
+
|
|
122
|
+
| Parameter | Type | Default | Description |
|
|
123
|
+
| --- | --- | --- | --- |
|
|
124
|
+
| `batch` | array-like or pd.Series | **required** | Vector indicating batch assignment for each sample. This is used to estimate and remove batch effects. |
|
|
125
|
+
| `discrete_covariates` | array-like, pd.Series, or pd.DataFrame | `None` | Optional categorical covariates (e.g., sex, site). Only used in `"fortin"` and `"chen"` methods. |
|
|
126
|
+
| `continuous_covariates` | array-like, pd.Series or pd.DataFrame | `None` | Optional continuous covariates (e.g., age). Only used in `"fortin"` and `"chen"` methods. |
|
|
127
|
+
|
|
128
|
+
### Algorithm Options
|
|
129
|
+
|
|
130
|
+
| Parameter | Type | Default | Description |
|
|
131
|
+
| --- | --- | --- | --- |
|
|
132
|
+
| `method` | str | `"johnson"` | ComBat method to use: <ul><li>`"johnson"` - Classical ComBat (_Johnson et al. 2007_)</li><li>`"fortin"` - ComBat with covariates (_Fortin et al. 2018_)</li><li>`"chen"` - CovBat, PCA-based correction (_Chen et al. 2022_)</li></ul> |
|
|
133
|
+
| `parametric` | bool | `True` | Whether to use the **parametric empirical Bayes** formulation. If `False`, a non-parametric iterative scheme is used. |
|
|
134
|
+
| `mean_only` | bool | `False` | If `True`, only the **mean** is corrected, while variances are left unchanged. Useful for preserving variance structure in the data. |
|
|
135
|
+
| `reference_batch` | str or `None` | `None` | If specified, acts as a reference batch - other batches will be corrected to match this one. |
|
|
136
|
+
| `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
|
|
137
|
+
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
138
|
+
|
|
110
139
|
## Contributing
|
|
111
140
|
|
|
112
141
|
Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
|
|
113
142
|
|
|
143
|
+
## Author
|
|
144
|
+
|
|
145
|
+
[**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
|
|
146
|
+
|
|
147
|
+
[Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
|
|
148
|
+
|
|
114
149
|
## Acknowledgements
|
|
115
150
|
|
|
116
151
|
This project builds on the excellent work of the ComBat family of harmonisation methods.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
combatlearn/__init__.py
|
|
5
|
+
combatlearn/combat.py
|
|
6
|
+
combatlearn.egg-info/PKG-INFO
|
|
7
|
+
combatlearn.egg-info/SOURCES.txt
|
|
8
|
+
combatlearn.egg-info/dependency_links.txt
|
|
9
|
+
combatlearn.egg-info/requires.txt
|
|
10
|
+
combatlearn.egg-info/top_level.txt
|
|
11
|
+
tests/test_combat.py
|
|
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "combatlearn"
|
|
7
|
-
version = "0.1.
|
|
8
|
-
description = "Batch-effect
|
|
7
|
+
version = "0.1.2"
|
|
8
|
+
description = "Batch-effect harmonization for machine learning frameworks."
|
|
9
9
|
authors = [{name="Ettore Rocchi", email="ettoreroc@gmail.com"}]
|
|
10
|
-
requires-python = ">=3.
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
11
|
dependencies = [
|
|
12
12
|
"pandas>=1.3",
|
|
13
13
|
"numpy>=1.21",
|
|
@@ -31,5 +31,5 @@ classifiers = [
|
|
|
31
31
|
]
|
|
32
32
|
|
|
33
33
|
[tool.setuptools.packages.find]
|
|
34
|
-
where = ["
|
|
34
|
+
where = ["."]
|
|
35
35
|
include = ["combatlearn*"]
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
LICENSE
|
|
2
|
-
README.md
|
|
3
|
-
pyproject.toml
|
|
4
|
-
src/combatlearn/__init__.py
|
|
5
|
-
src/combatlearn/combat.py
|
|
6
|
-
src/combatlearn.egg-info/PKG-INFO
|
|
7
|
-
src/combatlearn.egg-info/SOURCES.txt
|
|
8
|
-
src/combatlearn.egg-info/dependency_links.txt
|
|
9
|
-
src/combatlearn.egg-info/requires.txt
|
|
10
|
-
src/combatlearn.egg-info/top_level.txt
|
|
11
|
-
tests/test_combat.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|