combatlearn 1.0.0__tar.gz → 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {combatlearn-1.0.0 → combatlearn-1.1.0}/PKG-INFO +24 -17
- {combatlearn-1.0.0 → combatlearn-1.1.0}/README.md +23 -16
- combatlearn-1.1.0/combatlearn/__init__.py +5 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn/combat.py +742 -1
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn.egg-info/PKG-INFO +24 -17
- {combatlearn-1.0.0 → combatlearn-1.1.0}/tests/test_combat.py +111 -1
- combatlearn-1.0.0/combatlearn/__init__.py +0 -5
- {combatlearn-1.0.0 → combatlearn-1.1.0}/LICENSE +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn.egg-info/SOURCES.txt +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn.egg-info/dependency_links.txt +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn.egg-info/requires.txt +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/combatlearn.egg-info/top_level.txt +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/pyproject.toml +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/requirements.txt +0 -0
- {combatlearn-1.0.0 → combatlearn-1.1.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -57,8 +57,21 @@ Dynamic: license-file
|
|
|
57
57
|
pip install combatlearn
|
|
58
58
|
```
|
|
59
59
|
|
|
60
|
+
## Documentation
|
|
61
|
+
|
|
62
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
63
|
+
|
|
64
|
+
The documentation includes:
|
|
65
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
66
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
67
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
68
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
69
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
70
|
+
|
|
60
71
|
## Quick start
|
|
61
72
|
|
|
73
|
+
For more details, see the [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/).
|
|
74
|
+
|
|
62
75
|
```python
|
|
63
76
|
import pandas as pd
|
|
64
77
|
from sklearn.pipeline import Pipeline
|
|
@@ -105,20 +118,9 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
105
118
|
|
|
106
119
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
107
120
|
|
|
108
|
-
## Documentation
|
|
109
|
-
|
|
110
|
-
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
111
|
-
|
|
112
|
-
The documentation includes:
|
|
113
|
-
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
114
|
-
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
115
|
-
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
116
|
-
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
117
|
-
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
118
|
-
|
|
119
121
|
## `ComBat` parameters
|
|
120
122
|
|
|
121
|
-
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
123
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class. For complete API documentation, see the [API Reference](https://combatlearn.readthedocs.io/en/latest/api/).
|
|
122
124
|
|
|
123
125
|
### Main Parameters
|
|
124
126
|
|
|
@@ -140,11 +142,17 @@ The following section provides a detailed explanation of all parameters availabl
|
|
|
140
142
|
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
141
143
|
|
|
142
144
|
|
|
143
|
-
### Batch Effect Correction Visualization
|
|
145
|
+
### Batch Effect Correction Visualization
|
|
144
146
|
|
|
145
147
|
The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
|
|
146
148
|
|
|
147
|
-
For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
149
|
+
For further details see the [Visualization Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/visualization/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
150
|
+
|
|
151
|
+
### Batch Effect Metrics
|
|
152
|
+
|
|
153
|
+
The `compute_batch_metrics` method provides quantitative assessment of batch correction quality. It computes metrics including Silhouette coefficient, Davies-Bouldin index, kBET, LISI, and variance ratio for batch effect quantification, as well as k-NN preservation and distance correlation for structure preservation.
|
|
154
|
+
|
|
155
|
+
For further details see the [Metrics Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/metrics/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
148
156
|
|
|
149
157
|
## Contributing
|
|
150
158
|
|
|
@@ -167,8 +175,7 @@ We gratefully acknowledge:
|
|
|
167
175
|
|
|
168
176
|
## Citation
|
|
169
177
|
|
|
170
|
-
If **combatlearn** is useful in your research, please cite the original
|
|
171
|
-
papers:
|
|
178
|
+
If **combatlearn** is useful in your research, please cite the original papers:
|
|
172
179
|
|
|
173
180
|
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
174
181
|
|
|
@@ -24,8 +24,21 @@
|
|
|
24
24
|
pip install combatlearn
|
|
25
25
|
```
|
|
26
26
|
|
|
27
|
+
## Documentation
|
|
28
|
+
|
|
29
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
30
|
+
|
|
31
|
+
The documentation includes:
|
|
32
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
33
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
34
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
35
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
36
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
37
|
+
|
|
27
38
|
## Quick start
|
|
28
39
|
|
|
40
|
+
For more details, see the [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/).
|
|
41
|
+
|
|
29
42
|
```python
|
|
30
43
|
import pandas as pd
|
|
31
44
|
from sklearn.pipeline import Pipeline
|
|
@@ -72,20 +85,9 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
72
85
|
|
|
73
86
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
74
87
|
|
|
75
|
-
## Documentation
|
|
76
|
-
|
|
77
|
-
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
78
|
-
|
|
79
|
-
The documentation includes:
|
|
80
|
-
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
81
|
-
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
82
|
-
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
83
|
-
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
84
|
-
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
85
|
-
|
|
86
88
|
## `ComBat` parameters
|
|
87
89
|
|
|
88
|
-
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
90
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class. For complete API documentation, see the [API Reference](https://combatlearn.readthedocs.io/en/latest/api/).
|
|
89
91
|
|
|
90
92
|
### Main Parameters
|
|
91
93
|
|
|
@@ -107,11 +109,17 @@ The following section provides a detailed explanation of all parameters availabl
|
|
|
107
109
|
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
108
110
|
|
|
109
111
|
|
|
110
|
-
### Batch Effect Correction Visualization
|
|
112
|
+
### Batch Effect Correction Visualization
|
|
111
113
|
|
|
112
114
|
The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
|
|
113
115
|
|
|
114
|
-
For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
116
|
+
For further details see the [Visualization Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/visualization/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
117
|
+
|
|
118
|
+
### Batch Effect Metrics
|
|
119
|
+
|
|
120
|
+
The `compute_batch_metrics` method provides quantitative assessment of batch correction quality. It computes metrics including Silhouette coefficient, Davies-Bouldin index, kBET, LISI, and variance ratio for batch effect quantification, as well as k-NN preservation and distance correlation for structure preservation.
|
|
121
|
+
|
|
122
|
+
For further details see the [Metrics Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/metrics/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
115
123
|
|
|
116
124
|
## Contributing
|
|
117
125
|
|
|
@@ -134,8 +142,7 @@ We gratefully acknowledge:
|
|
|
134
142
|
|
|
135
143
|
## Citation
|
|
136
144
|
|
|
137
|
-
If **combatlearn** is useful in your research, please cite the original
|
|
138
|
-
papers:
|
|
145
|
+
If **combatlearn** is useful in your research, please cite the original papers:
|
|
139
146
|
|
|
140
147
|
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
141
148
|
|
|
@@ -16,10 +16,14 @@ import pandas as pd
|
|
|
16
16
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
17
17
|
from sklearn.decomposition import PCA
|
|
18
18
|
from sklearn.manifold import TSNE
|
|
19
|
+
from sklearn.neighbors import NearestNeighbors
|
|
20
|
+
from sklearn.metrics import silhouette_score, davies_bouldin_score
|
|
21
|
+
from scipy.stats import levene, spearmanr, chi2
|
|
22
|
+
from scipy.spatial.distance import pdist
|
|
19
23
|
import matplotlib
|
|
20
24
|
import matplotlib.pyplot as plt
|
|
21
25
|
import matplotlib.colors as mcolors
|
|
22
|
-
from typing import Literal, Optional, Union, Dict, Tuple, Any
|
|
26
|
+
from typing import Literal, Optional, Union, Dict, Tuple, Any, List
|
|
23
27
|
import numpy.typing as npt
|
|
24
28
|
import warnings
|
|
25
29
|
import umap
|
|
@@ -29,6 +33,526 @@ from plotly.subplots import make_subplots
|
|
|
29
33
|
ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
|
|
30
34
|
FloatArray = npt.NDArray[np.float64]
|
|
31
35
|
|
|
36
|
+
def _compute_pca_embedding(
|
|
37
|
+
X_before: np.ndarray,
|
|
38
|
+
X_after: np.ndarray,
|
|
39
|
+
n_components: int,
|
|
40
|
+
) -> Tuple[np.ndarray, np.ndarray, PCA]:
|
|
41
|
+
"""
|
|
42
|
+
Compute PCA embeddings for both datasets.
|
|
43
|
+
|
|
44
|
+
Fits PCA on X_before and applies to both datasets.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
X_before : np.ndarray
|
|
49
|
+
Original data before correction.
|
|
50
|
+
X_after : np.ndarray
|
|
51
|
+
Corrected data.
|
|
52
|
+
n_components : int
|
|
53
|
+
Number of PCA components.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
X_before_pca : np.ndarray
|
|
58
|
+
PCA-transformed original data.
|
|
59
|
+
X_after_pca : np.ndarray
|
|
60
|
+
PCA-transformed corrected data.
|
|
61
|
+
pca : PCA
|
|
62
|
+
Fitted PCA model.
|
|
63
|
+
"""
|
|
64
|
+
n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
|
|
65
|
+
pca = PCA(n_components=n_components, random_state=42)
|
|
66
|
+
X_before_pca = pca.fit_transform(X_before)
|
|
67
|
+
X_after_pca = pca.transform(X_after)
|
|
68
|
+
return X_before_pca, X_after_pca, pca
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _silhouette_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
72
|
+
"""
|
|
73
|
+
Compute silhouette coefficient using batch as cluster labels.
|
|
74
|
+
|
|
75
|
+
Lower values after correction indicate better batch mixing.
|
|
76
|
+
Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
X : np.ndarray
|
|
81
|
+
Data matrix.
|
|
82
|
+
batch_labels : np.ndarray
|
|
83
|
+
Batch labels for each sample.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
float
|
|
88
|
+
Silhouette coefficient.
|
|
89
|
+
"""
|
|
90
|
+
unique_batches = np.unique(batch_labels)
|
|
91
|
+
if len(unique_batches) < 2:
|
|
92
|
+
return 0.0
|
|
93
|
+
try:
|
|
94
|
+
return silhouette_score(X, batch_labels, metric='euclidean')
|
|
95
|
+
except Exception:
|
|
96
|
+
return 0.0
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _davies_bouldin_batch(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
100
|
+
"""
|
|
101
|
+
Compute Davies-Bouldin index using batch labels.
|
|
102
|
+
|
|
103
|
+
Lower values indicate better batch mixing.
|
|
104
|
+
Range: [0, inf), 0 = perfect batch overlap.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
X : np.ndarray
|
|
109
|
+
Data matrix.
|
|
110
|
+
batch_labels : np.ndarray
|
|
111
|
+
Batch labels for each sample.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
float
|
|
116
|
+
Davies-Bouldin index.
|
|
117
|
+
"""
|
|
118
|
+
unique_batches = np.unique(batch_labels)
|
|
119
|
+
if len(unique_batches) < 2:
|
|
120
|
+
return 0.0
|
|
121
|
+
try:
|
|
122
|
+
return davies_bouldin_score(X, batch_labels)
|
|
123
|
+
except Exception:
|
|
124
|
+
return 0.0
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _kbet_score(
|
|
128
|
+
X: np.ndarray,
|
|
129
|
+
batch_labels: np.ndarray,
|
|
130
|
+
k0: int,
|
|
131
|
+
alpha: float = 0.05,
|
|
132
|
+
) -> Tuple[float, float]:
|
|
133
|
+
"""
|
|
134
|
+
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
|
|
135
|
+
|
|
136
|
+
Tests if local batch proportions match global batch proportions.
|
|
137
|
+
Higher acceptance rate = better batch mixing.
|
|
138
|
+
|
|
139
|
+
Reference: Buttner et al. (2019) Nature Methods
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
X : np.ndarray
|
|
144
|
+
Data matrix.
|
|
145
|
+
batch_labels : np.ndarray
|
|
146
|
+
Batch labels for each sample.
|
|
147
|
+
k0 : int
|
|
148
|
+
Neighborhood size.
|
|
149
|
+
alpha : float
|
|
150
|
+
Significance level for chi-squared test.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
acceptance_rate : float
|
|
155
|
+
Fraction of samples where H0 (uniform mixing) is accepted.
|
|
156
|
+
mean_stat : float
|
|
157
|
+
Mean chi-squared statistic across samples.
|
|
158
|
+
"""
|
|
159
|
+
n_samples = X.shape[0]
|
|
160
|
+
unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
|
|
161
|
+
n_batches = len(unique_batches)
|
|
162
|
+
|
|
163
|
+
if n_batches < 2:
|
|
164
|
+
return 1.0, 0.0
|
|
165
|
+
|
|
166
|
+
global_freq = batch_counts / n_samples
|
|
167
|
+
k0 = min(k0, n_samples - 1)
|
|
168
|
+
|
|
169
|
+
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm='auto')
|
|
170
|
+
nn.fit(X)
|
|
171
|
+
_, indices = nn.kneighbors(X)
|
|
172
|
+
|
|
173
|
+
chi2_stats = []
|
|
174
|
+
p_values = []
|
|
175
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
176
|
+
|
|
177
|
+
for i in range(n_samples):
|
|
178
|
+
neighbors = indices[i, 1:k0+1]
|
|
179
|
+
neighbor_batches = batch_labels[neighbors]
|
|
180
|
+
|
|
181
|
+
observed = np.zeros(n_batches)
|
|
182
|
+
for nb in neighbor_batches:
|
|
183
|
+
observed[batch_to_idx[nb]] += 1
|
|
184
|
+
|
|
185
|
+
expected = global_freq * k0
|
|
186
|
+
|
|
187
|
+
mask = expected > 0
|
|
188
|
+
if mask.sum() < 2:
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
stat = np.sum((observed[mask] - expected[mask])**2 / expected[mask])
|
|
192
|
+
df = max(1, mask.sum() - 1)
|
|
193
|
+
p_val = 1 - chi2.cdf(stat, df)
|
|
194
|
+
|
|
195
|
+
chi2_stats.append(stat)
|
|
196
|
+
p_values.append(p_val)
|
|
197
|
+
|
|
198
|
+
if len(p_values) == 0:
|
|
199
|
+
return 1.0, 0.0
|
|
200
|
+
|
|
201
|
+
acceptance_rate = np.mean(np.array(p_values) > alpha)
|
|
202
|
+
mean_stat = np.mean(chi2_stats)
|
|
203
|
+
|
|
204
|
+
return acceptance_rate, mean_stat
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _find_sigma(distances: np.ndarray, target_perplexity: float, tol: float = 1e-5) -> float:
|
|
208
|
+
"""
|
|
209
|
+
Binary search for sigma to achieve target perplexity.
|
|
210
|
+
|
|
211
|
+
Used in LISI computation.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
distances : np.ndarray
|
|
216
|
+
Distances to neighbors.
|
|
217
|
+
target_perplexity : float
|
|
218
|
+
Target perplexity value.
|
|
219
|
+
tol : float
|
|
220
|
+
Tolerance for convergence.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
float
|
|
225
|
+
Sigma value.
|
|
226
|
+
"""
|
|
227
|
+
target_H = np.log2(target_perplexity + 1e-10)
|
|
228
|
+
|
|
229
|
+
sigma_min, sigma_max = 1e-10, 1e10
|
|
230
|
+
sigma = 1.0
|
|
231
|
+
|
|
232
|
+
for _ in range(50):
|
|
233
|
+
P = np.exp(-distances**2 / (2 * sigma**2 + 1e-10))
|
|
234
|
+
P_sum = P.sum()
|
|
235
|
+
if P_sum < 1e-10:
|
|
236
|
+
sigma = (sigma + sigma_max) / 2
|
|
237
|
+
continue
|
|
238
|
+
P = P / P_sum
|
|
239
|
+
P = np.clip(P, 1e-10, 1.0)
|
|
240
|
+
H = -np.sum(P * np.log2(P))
|
|
241
|
+
|
|
242
|
+
if abs(H - target_H) < tol:
|
|
243
|
+
break
|
|
244
|
+
elif H < target_H:
|
|
245
|
+
sigma_min = sigma
|
|
246
|
+
else:
|
|
247
|
+
sigma_max = sigma
|
|
248
|
+
sigma = (sigma_min + sigma_max) / 2
|
|
249
|
+
|
|
250
|
+
return sigma
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _lisi_score(
|
|
254
|
+
X: np.ndarray,
|
|
255
|
+
batch_labels: np.ndarray,
|
|
256
|
+
perplexity: int = 30,
|
|
257
|
+
) -> float:
|
|
258
|
+
"""
|
|
259
|
+
Compute mean Local Inverse Simpson's Index (LISI).
|
|
260
|
+
|
|
261
|
+
Range: [1, n_batches], where n_batches = perfect mixing.
|
|
262
|
+
Higher = better batch mixing.
|
|
263
|
+
|
|
264
|
+
Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
X : np.ndarray
|
|
269
|
+
Data matrix.
|
|
270
|
+
batch_labels : np.ndarray
|
|
271
|
+
Batch labels for each sample.
|
|
272
|
+
perplexity : int
|
|
273
|
+
Perplexity for Gaussian kernel.
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
float
|
|
278
|
+
Mean LISI score.
|
|
279
|
+
"""
|
|
280
|
+
n_samples = X.shape[0]
|
|
281
|
+
unique_batches = np.unique(batch_labels)
|
|
282
|
+
n_batches = len(unique_batches)
|
|
283
|
+
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
|
|
284
|
+
|
|
285
|
+
if n_batches < 2:
|
|
286
|
+
return 1.0
|
|
287
|
+
|
|
288
|
+
k = min(3 * perplexity, n_samples - 1)
|
|
289
|
+
|
|
290
|
+
nn = NearestNeighbors(n_neighbors=k + 1, algorithm='auto')
|
|
291
|
+
nn.fit(X)
|
|
292
|
+
distances, indices = nn.kneighbors(X)
|
|
293
|
+
|
|
294
|
+
distances = distances[:, 1:]
|
|
295
|
+
indices = indices[:, 1:]
|
|
296
|
+
|
|
297
|
+
lisi_values = []
|
|
298
|
+
|
|
299
|
+
for i in range(n_samples):
|
|
300
|
+
sigma = _find_sigma(distances[i], perplexity)
|
|
301
|
+
|
|
302
|
+
P = np.exp(-distances[i]**2 / (2 * sigma**2 + 1e-10))
|
|
303
|
+
P_sum = P.sum()
|
|
304
|
+
if P_sum < 1e-10:
|
|
305
|
+
lisi_values.append(1.0)
|
|
306
|
+
continue
|
|
307
|
+
P = P / P_sum
|
|
308
|
+
|
|
309
|
+
neighbor_batches = batch_labels[indices[i]]
|
|
310
|
+
batch_probs = np.zeros(n_batches)
|
|
311
|
+
for j, nb in enumerate(neighbor_batches):
|
|
312
|
+
batch_probs[batch_to_idx[nb]] += P[j]
|
|
313
|
+
|
|
314
|
+
simpson = np.sum(batch_probs**2)
|
|
315
|
+
if simpson < 1e-10:
|
|
316
|
+
lisi = n_batches
|
|
317
|
+
else:
|
|
318
|
+
lisi = 1.0 / simpson
|
|
319
|
+
lisi_values.append(lisi)
|
|
320
|
+
|
|
321
|
+
return np.mean(lisi_values)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _variance_ratio(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
325
|
+
"""
|
|
326
|
+
Compute between-batch to within-batch variance ratio.
|
|
327
|
+
|
|
328
|
+
Similar to F-statistic in one-way ANOVA.
|
|
329
|
+
Lower ratio after correction = better batch effect removal.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
X : np.ndarray
|
|
334
|
+
Data matrix.
|
|
335
|
+
batch_labels : np.ndarray
|
|
336
|
+
Batch labels for each sample.
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
float
|
|
341
|
+
Variance ratio (between/within).
|
|
342
|
+
"""
|
|
343
|
+
unique_batches = np.unique(batch_labels)
|
|
344
|
+
n_batches = len(unique_batches)
|
|
345
|
+
n_samples = X.shape[0]
|
|
346
|
+
|
|
347
|
+
if n_batches < 2:
|
|
348
|
+
return 0.0
|
|
349
|
+
|
|
350
|
+
grand_mean = np.mean(X, axis=0)
|
|
351
|
+
|
|
352
|
+
between_var = 0.0
|
|
353
|
+
within_var = 0.0
|
|
354
|
+
|
|
355
|
+
for batch in unique_batches:
|
|
356
|
+
mask = batch_labels == batch
|
|
357
|
+
n_b = np.sum(mask)
|
|
358
|
+
X_batch = X[mask]
|
|
359
|
+
batch_mean = np.mean(X_batch, axis=0)
|
|
360
|
+
|
|
361
|
+
between_var += n_b * np.sum((batch_mean - grand_mean)**2)
|
|
362
|
+
within_var += np.sum((X_batch - batch_mean)**2)
|
|
363
|
+
|
|
364
|
+
between_var /= (n_batches - 1)
|
|
365
|
+
within_var /= (n_samples - n_batches)
|
|
366
|
+
|
|
367
|
+
if within_var < 1e-10:
|
|
368
|
+
return 0.0
|
|
369
|
+
|
|
370
|
+
return between_var / within_var
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _knn_preservation(
|
|
374
|
+
X_before: np.ndarray,
|
|
375
|
+
X_after: np.ndarray,
|
|
376
|
+
k_values: List[int],
|
|
377
|
+
n_jobs: int = 1,
|
|
378
|
+
) -> Dict[int, float]:
|
|
379
|
+
"""
|
|
380
|
+
Compute fraction of k-nearest neighbors preserved after correction.
|
|
381
|
+
|
|
382
|
+
Range: [0, 1], where 1 = perfect preservation.
|
|
383
|
+
Higher = better biological structure preservation.
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
X_before : np.ndarray
|
|
388
|
+
Original data.
|
|
389
|
+
X_after : np.ndarray
|
|
390
|
+
Corrected data.
|
|
391
|
+
k_values : list of int
|
|
392
|
+
Values of k for k-NN.
|
|
393
|
+
n_jobs : int
|
|
394
|
+
Number of parallel jobs.
|
|
395
|
+
|
|
396
|
+
Returns
|
|
397
|
+
-------
|
|
398
|
+
dict
|
|
399
|
+
Mapping from k to preservation fraction.
|
|
400
|
+
"""
|
|
401
|
+
results = {}
|
|
402
|
+
max_k = max(k_values)
|
|
403
|
+
max_k = min(max_k, X_before.shape[0] - 1)
|
|
404
|
+
|
|
405
|
+
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
|
|
406
|
+
nn_before.fit(X_before)
|
|
407
|
+
_, indices_before = nn_before.kneighbors(X_before)
|
|
408
|
+
|
|
409
|
+
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm='auto', n_jobs=n_jobs)
|
|
410
|
+
nn_after.fit(X_after)
|
|
411
|
+
_, indices_after = nn_after.kneighbors(X_after)
|
|
412
|
+
|
|
413
|
+
for k in k_values:
|
|
414
|
+
if k > max_k:
|
|
415
|
+
results[k] = 0.0
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
overlaps = []
|
|
419
|
+
for i in range(X_before.shape[0]):
|
|
420
|
+
neighbors_before = set(indices_before[i, 1:k+1])
|
|
421
|
+
neighbors_after = set(indices_after[i, 1:k+1])
|
|
422
|
+
overlap = len(neighbors_before & neighbors_after) / k
|
|
423
|
+
overlaps.append(overlap)
|
|
424
|
+
|
|
425
|
+
results[k] = np.mean(overlaps)
|
|
426
|
+
|
|
427
|
+
return results
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _pairwise_distance_correlation(
|
|
431
|
+
X_before: np.ndarray,
|
|
432
|
+
X_after: np.ndarray,
|
|
433
|
+
subsample: int = 1000,
|
|
434
|
+
random_state: int = 42,
|
|
435
|
+
) -> float:
|
|
436
|
+
"""
|
|
437
|
+
Compute Spearman correlation of pairwise distances.
|
|
438
|
+
|
|
439
|
+
Range: [-1, 1], where 1 = perfect rank preservation.
|
|
440
|
+
Higher = better relative relationship preservation.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
X_before : np.ndarray
|
|
445
|
+
Original data.
|
|
446
|
+
X_after : np.ndarray
|
|
447
|
+
Corrected data.
|
|
448
|
+
subsample : int
|
|
449
|
+
Maximum samples to use (for efficiency).
|
|
450
|
+
random_state : int
|
|
451
|
+
Random seed for subsampling.
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
float
|
|
456
|
+
Spearman correlation coefficient.
|
|
457
|
+
"""
|
|
458
|
+
n_samples = X_before.shape[0]
|
|
459
|
+
|
|
460
|
+
if n_samples > subsample:
|
|
461
|
+
rng = np.random.default_rng(random_state)
|
|
462
|
+
idx = rng.choice(n_samples, subsample, replace=False)
|
|
463
|
+
X_before = X_before[idx]
|
|
464
|
+
X_after = X_after[idx]
|
|
465
|
+
|
|
466
|
+
dist_before = pdist(X_before, metric='euclidean')
|
|
467
|
+
dist_after = pdist(X_after, metric='euclidean')
|
|
468
|
+
|
|
469
|
+
if len(dist_before) == 0:
|
|
470
|
+
return 1.0
|
|
471
|
+
|
|
472
|
+
corr, _ = spearmanr(dist_before, dist_after)
|
|
473
|
+
|
|
474
|
+
if np.isnan(corr):
|
|
475
|
+
return 1.0
|
|
476
|
+
|
|
477
|
+
return corr
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _mean_centroid_distance(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
481
|
+
"""
|
|
482
|
+
Compute mean pairwise Euclidean distance between batch centroids.
|
|
483
|
+
|
|
484
|
+
Lower after correction = better batch alignment.
|
|
485
|
+
|
|
486
|
+
Parameters
|
|
487
|
+
----------
|
|
488
|
+
X : np.ndarray
|
|
489
|
+
Data matrix.
|
|
490
|
+
batch_labels : np.ndarray
|
|
491
|
+
Batch labels for each sample.
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
float
|
|
496
|
+
Mean pairwise distance between centroids.
|
|
497
|
+
"""
|
|
498
|
+
unique_batches = np.unique(batch_labels)
|
|
499
|
+
n_batches = len(unique_batches)
|
|
500
|
+
|
|
501
|
+
if n_batches < 2:
|
|
502
|
+
return 0.0
|
|
503
|
+
|
|
504
|
+
centroids = []
|
|
505
|
+
for batch in unique_batches:
|
|
506
|
+
mask = batch_labels == batch
|
|
507
|
+
centroid = np.mean(X[mask], axis=0)
|
|
508
|
+
centroids.append(centroid)
|
|
509
|
+
|
|
510
|
+
centroids = np.array(centroids)
|
|
511
|
+
distances = pdist(centroids, metric='euclidean')
|
|
512
|
+
|
|
513
|
+
return np.mean(distances)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _levene_median_statistic(X: np.ndarray, batch_labels: np.ndarray) -> float:
|
|
517
|
+
"""
|
|
518
|
+
Compute median Levene test statistic across features.
|
|
519
|
+
|
|
520
|
+
Lower statistic = more homogeneous variances across batches.
|
|
521
|
+
|
|
522
|
+
Parameters
|
|
523
|
+
----------
|
|
524
|
+
X : np.ndarray
|
|
525
|
+
Data matrix.
|
|
526
|
+
batch_labels : np.ndarray
|
|
527
|
+
Batch labels for each sample.
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
float
|
|
532
|
+
Median Levene test statistic.
|
|
533
|
+
"""
|
|
534
|
+
unique_batches = np.unique(batch_labels)
|
|
535
|
+
if len(unique_batches) < 2:
|
|
536
|
+
return 0.0
|
|
537
|
+
|
|
538
|
+
levene_stats = []
|
|
539
|
+
for j in range(X.shape[1]):
|
|
540
|
+
groups = [X[batch_labels == b, j] for b in unique_batches]
|
|
541
|
+
groups = [g for g in groups if len(g) > 0]
|
|
542
|
+
if len(groups) < 2:
|
|
543
|
+
continue
|
|
544
|
+
try:
|
|
545
|
+
stat, _ = levene(*groups, center='median')
|
|
546
|
+
if not np.isnan(stat):
|
|
547
|
+
levene_stats.append(stat)
|
|
548
|
+
except Exception:
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
if len(levene_stats) == 0:
|
|
552
|
+
return 0.0
|
|
553
|
+
|
|
554
|
+
return np.median(levene_stats)
|
|
555
|
+
|
|
32
556
|
|
|
33
557
|
class ComBatModel:
|
|
34
558
|
"""ComBat algorithm.
|
|
@@ -643,6 +1167,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
643
1167
|
reference_batch: Optional[str] = None,
|
|
644
1168
|
eps: float = 1e-8,
|
|
645
1169
|
covbat_cov_thresh: Union[float, int] = 0.9,
|
|
1170
|
+
compute_metrics: bool = False,
|
|
646
1171
|
) -> None:
|
|
647
1172
|
self.batch = batch
|
|
648
1173
|
self.discrete_covariates = discrete_covariates
|
|
@@ -653,6 +1178,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
653
1178
|
self.reference_batch = reference_batch
|
|
654
1179
|
self.eps = eps
|
|
655
1180
|
self.covbat_cov_thresh = covbat_cov_thresh
|
|
1181
|
+
self.compute_metrics = compute_metrics
|
|
656
1182
|
self._model = ComBatModel(
|
|
657
1183
|
method=method,
|
|
658
1184
|
parametric=parametric,
|
|
@@ -710,6 +1236,221 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
710
1236
|
else:
|
|
711
1237
|
return pd.DataFrame(obj, index=idx)
|
|
712
1238
|
|
|
1239
|
+
@property
|
|
1240
|
+
def metrics_(self) -> Optional[Dict[str, Any]]:
|
|
1241
|
+
"""Return cached metrics from last fit_transform with compute_metrics=True.
|
|
1242
|
+
|
|
1243
|
+
Returns
|
|
1244
|
+
-------
|
|
1245
|
+
dict or None
|
|
1246
|
+
Cached metrics dictionary, or None if no metrics have been computed.
|
|
1247
|
+
"""
|
|
1248
|
+
return getattr(self, '_metrics_cache', None)
|
|
1249
|
+
|
|
1250
|
+
def compute_batch_metrics(
|
|
1251
|
+
self,
|
|
1252
|
+
X: ArrayLike,
|
|
1253
|
+
batch: Optional[ArrayLike] = None,
|
|
1254
|
+
*,
|
|
1255
|
+
pca_components: Optional[int] = None,
|
|
1256
|
+
k_neighbors: List[int] = [5, 10, 50],
|
|
1257
|
+
kbet_k0: Optional[int] = None,
|
|
1258
|
+
lisi_perplexity: int = 30,
|
|
1259
|
+
n_jobs: int = 1,
|
|
1260
|
+
) -> Dict[str, Any]:
|
|
1261
|
+
"""
|
|
1262
|
+
Compute batch effect metrics before and after ComBat correction.
|
|
1263
|
+
|
|
1264
|
+
Parameters
|
|
1265
|
+
----------
|
|
1266
|
+
X : array-like of shape (n_samples, n_features)
|
|
1267
|
+
Input data to evaluate.
|
|
1268
|
+
batch : array-like of shape (n_samples,), optional
|
|
1269
|
+
Batch labels. If None, uses the batch stored at construction.
|
|
1270
|
+
pca_components : int, optional
|
|
1271
|
+
Number of PCA components for dimensionality reduction before
|
|
1272
|
+
computing metrics. If None (default), metrics are computed in
|
|
1273
|
+
the original feature space. Must be less than min(n_samples, n_features).
|
|
1274
|
+
k_neighbors : list of int, default=[5, 10, 50]
|
|
1275
|
+
Values of k for k-NN preservation metric.
|
|
1276
|
+
kbet_k0 : int, optional
|
|
1277
|
+
Neighborhood size for kBET. Default is 10% of samples.
|
|
1278
|
+
lisi_perplexity : int, default=30
|
|
1279
|
+
Perplexity for LISI computation.
|
|
1280
|
+
n_jobs : int, default=1
|
|
1281
|
+
Number of parallel jobs for neighbor computations.
|
|
1282
|
+
|
|
1283
|
+
Returns
|
|
1284
|
+
-------
|
|
1285
|
+
metrics : dict
|
|
1286
|
+
Dictionary with structure:
|
|
1287
|
+
{
|
|
1288
|
+
'batch_effect': {
|
|
1289
|
+
'silhouette': {'before': float, 'after': float},
|
|
1290
|
+
'davies_bouldin': {...},
|
|
1291
|
+
'kbet': {...},
|
|
1292
|
+
'lisi': {..., 'max_value': n_batches},
|
|
1293
|
+
'variance_ratio': {...},
|
|
1294
|
+
},
|
|
1295
|
+
'preservation': {
|
|
1296
|
+
'knn': {k: fraction for k in k_neighbors},
|
|
1297
|
+
'distance_correlation': float,
|
|
1298
|
+
},
|
|
1299
|
+
'alignment': {
|
|
1300
|
+
'centroid_distance': {...},
|
|
1301
|
+
'levene_statistic': {...},
|
|
1302
|
+
},
|
|
1303
|
+
}
|
|
1304
|
+
|
|
1305
|
+
Raises
|
|
1306
|
+
------
|
|
1307
|
+
ValueError
|
|
1308
|
+
If the model is not fitted or if pca_components is invalid.
|
|
1309
|
+
"""
|
|
1310
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
1311
|
+
raise ValueError(
|
|
1312
|
+
"This ComBat instance is not fitted yet. "
|
|
1313
|
+
"Call 'fit' before 'compute_batch_metrics'."
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
if not isinstance(X, pd.DataFrame):
|
|
1317
|
+
X = pd.DataFrame(X)
|
|
1318
|
+
|
|
1319
|
+
idx = X.index
|
|
1320
|
+
|
|
1321
|
+
if batch is None:
|
|
1322
|
+
batch_vec = self._subset(self.batch, idx)
|
|
1323
|
+
else:
|
|
1324
|
+
if isinstance(batch, (pd.Series, pd.DataFrame)):
|
|
1325
|
+
batch_vec = batch.loc[idx] if hasattr(batch, 'loc') else batch
|
|
1326
|
+
elif isinstance(batch, np.ndarray):
|
|
1327
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
1328
|
+
else:
|
|
1329
|
+
batch_vec = pd.Series(batch, index=idx)
|
|
1330
|
+
|
|
1331
|
+
batch_labels = np.array(batch_vec)
|
|
1332
|
+
|
|
1333
|
+
X_before = X.values
|
|
1334
|
+
X_after = self.transform(X).values
|
|
1335
|
+
|
|
1336
|
+
n_samples, n_features = X_before.shape
|
|
1337
|
+
if kbet_k0 is None:
|
|
1338
|
+
kbet_k0 = max(10, int(0.10 * n_samples))
|
|
1339
|
+
|
|
1340
|
+
# Validate and apply PCA if requested
|
|
1341
|
+
if pca_components is not None:
|
|
1342
|
+
max_components = min(n_samples, n_features)
|
|
1343
|
+
if pca_components >= max_components:
|
|
1344
|
+
raise ValueError(
|
|
1345
|
+
f"pca_components={pca_components} must be less than "
|
|
1346
|
+
f"min(n_samples, n_features)={max_components}."
|
|
1347
|
+
)
|
|
1348
|
+
X_before_pca, X_after_pca, _ = _compute_pca_embedding(
|
|
1349
|
+
X_before, X_after, pca_components
|
|
1350
|
+
)
|
|
1351
|
+
else:
|
|
1352
|
+
X_before_pca = X_before
|
|
1353
|
+
X_after_pca = X_after
|
|
1354
|
+
|
|
1355
|
+
silhouette_before = _silhouette_batch(X_before_pca, batch_labels)
|
|
1356
|
+
silhouette_after = _silhouette_batch(X_after_pca, batch_labels)
|
|
1357
|
+
|
|
1358
|
+
db_before = _davies_bouldin_batch(X_before_pca, batch_labels)
|
|
1359
|
+
db_after = _davies_bouldin_batch(X_after_pca, batch_labels)
|
|
1360
|
+
|
|
1361
|
+
kbet_before, _ = _kbet_score(X_before_pca, batch_labels, kbet_k0)
|
|
1362
|
+
kbet_after, _ = _kbet_score(X_after_pca, batch_labels, kbet_k0)
|
|
1363
|
+
|
|
1364
|
+
lisi_before = _lisi_score(X_before_pca, batch_labels, lisi_perplexity)
|
|
1365
|
+
lisi_after = _lisi_score(X_after_pca, batch_labels, lisi_perplexity)
|
|
1366
|
+
|
|
1367
|
+
var_ratio_before = _variance_ratio(X_before_pca, batch_labels)
|
|
1368
|
+
var_ratio_after = _variance_ratio(X_after_pca, batch_labels)
|
|
1369
|
+
|
|
1370
|
+
knn_results = _knn_preservation(X_before_pca, X_after_pca, k_neighbors, n_jobs)
|
|
1371
|
+
dist_corr = _pairwise_distance_correlation(X_before_pca, X_after_pca)
|
|
1372
|
+
|
|
1373
|
+
centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
|
|
1374
|
+
centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
|
|
1375
|
+
|
|
1376
|
+
levene_before = _levene_median_statistic(X_before, batch_labels)
|
|
1377
|
+
levene_after = _levene_median_statistic(X_after, batch_labels)
|
|
1378
|
+
|
|
1379
|
+
n_batches = len(np.unique(batch_labels))
|
|
1380
|
+
|
|
1381
|
+
metrics = {
|
|
1382
|
+
'batch_effect': {
|
|
1383
|
+
'silhouette': {
|
|
1384
|
+
'before': silhouette_before,
|
|
1385
|
+
'after': silhouette_after,
|
|
1386
|
+
},
|
|
1387
|
+
'davies_bouldin': {
|
|
1388
|
+
'before': db_before,
|
|
1389
|
+
'after': db_after,
|
|
1390
|
+
},
|
|
1391
|
+
'kbet': {
|
|
1392
|
+
'before': kbet_before,
|
|
1393
|
+
'after': kbet_after,
|
|
1394
|
+
},
|
|
1395
|
+
'lisi': {
|
|
1396
|
+
'before': lisi_before,
|
|
1397
|
+
'after': lisi_after,
|
|
1398
|
+
'max_value': n_batches,
|
|
1399
|
+
},
|
|
1400
|
+
'variance_ratio': {
|
|
1401
|
+
'before': var_ratio_before,
|
|
1402
|
+
'after': var_ratio_after,
|
|
1403
|
+
},
|
|
1404
|
+
},
|
|
1405
|
+
'preservation': {
|
|
1406
|
+
'knn': knn_results,
|
|
1407
|
+
'distance_correlation': dist_corr,
|
|
1408
|
+
},
|
|
1409
|
+
'alignment': {
|
|
1410
|
+
'centroid_distance': {
|
|
1411
|
+
'before': centroid_before,
|
|
1412
|
+
'after': centroid_after,
|
|
1413
|
+
},
|
|
1414
|
+
'levene_statistic': {
|
|
1415
|
+
'before': levene_before,
|
|
1416
|
+
'after': levene_after,
|
|
1417
|
+
},
|
|
1418
|
+
},
|
|
1419
|
+
}
|
|
1420
|
+
|
|
1421
|
+
return metrics
|
|
1422
|
+
|
|
1423
|
+
def fit_transform(
|
|
1424
|
+
self,
|
|
1425
|
+
X: ArrayLike,
|
|
1426
|
+
y: Optional[ArrayLike] = None
|
|
1427
|
+
) -> pd.DataFrame:
|
|
1428
|
+
"""
|
|
1429
|
+
Fit and transform the data, optionally computing metrics.
|
|
1430
|
+
|
|
1431
|
+
If compute_metrics=True was set at construction, batch effect
|
|
1432
|
+
metrics are computed and cached in metrics_ property.
|
|
1433
|
+
|
|
1434
|
+
Parameters
|
|
1435
|
+
----------
|
|
1436
|
+
X : array-like of shape (n_samples, n_features)
|
|
1437
|
+
Input data to fit and transform.
|
|
1438
|
+
y : None
|
|
1439
|
+
Ignored. Present for API compatibility.
|
|
1440
|
+
|
|
1441
|
+
Returns
|
|
1442
|
+
-------
|
|
1443
|
+
X_transformed : pd.DataFrame
|
|
1444
|
+
Batch-corrected data.
|
|
1445
|
+
"""
|
|
1446
|
+
self.fit(X, y)
|
|
1447
|
+
X_transformed = self.transform(X)
|
|
1448
|
+
|
|
1449
|
+
if self.compute_metrics:
|
|
1450
|
+
self._metrics_cache = self.compute_batch_metrics(X)
|
|
1451
|
+
|
|
1452
|
+
return X_transformed
|
|
1453
|
+
|
|
713
1454
|
def plot_transformation(
|
|
714
1455
|
self,
|
|
715
1456
|
X: ArrayLike, *,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -57,8 +57,21 @@ Dynamic: license-file
|
|
|
57
57
|
pip install combatlearn
|
|
58
58
|
```
|
|
59
59
|
|
|
60
|
+
## Documentation
|
|
61
|
+
|
|
62
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
63
|
+
|
|
64
|
+
The documentation includes:
|
|
65
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
66
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
67
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
68
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
69
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
70
|
+
|
|
60
71
|
## Quick start
|
|
61
72
|
|
|
73
|
+
For more details, see the [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/).
|
|
74
|
+
|
|
62
75
|
```python
|
|
63
76
|
import pandas as pd
|
|
64
77
|
from sklearn.pipeline import Pipeline
|
|
@@ -105,20 +118,9 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
105
118
|
|
|
106
119
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
107
120
|
|
|
108
|
-
## Documentation
|
|
109
|
-
|
|
110
|
-
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
111
|
-
|
|
112
|
-
The documentation includes:
|
|
113
|
-
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
114
|
-
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
115
|
-
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
116
|
-
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
117
|
-
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
118
|
-
|
|
119
121
|
## `ComBat` parameters
|
|
120
122
|
|
|
121
|
-
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
123
|
+
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class. For complete API documentation, see the [API Reference](https://combatlearn.readthedocs.io/en/latest/api/).
|
|
122
124
|
|
|
123
125
|
### Main Parameters
|
|
124
126
|
|
|
@@ -140,11 +142,17 @@ The following section provides a detailed explanation of all parameters availabl
|
|
|
140
142
|
| `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
|
|
141
143
|
|
|
142
144
|
|
|
143
|
-
### Batch Effect Correction Visualization
|
|
145
|
+
### Batch Effect Correction Visualization
|
|
144
146
|
|
|
145
147
|
The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
|
|
146
148
|
|
|
147
|
-
For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
149
|
+
For further details see the [Visualization Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/visualization/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
150
|
+
|
|
151
|
+
### Batch Effect Metrics
|
|
152
|
+
|
|
153
|
+
The `compute_batch_metrics` method provides quantitative assessment of batch correction quality. It computes metrics including Silhouette coefficient, Davies-Bouldin index, kBET, LISI, and variance ratio for batch effect quantification, as well as k-NN preservation and distance correlation for structure preservation.
|
|
154
|
+
|
|
155
|
+
For further details see the [Metrics Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/metrics/) and the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
|
|
148
156
|
|
|
149
157
|
## Contributing
|
|
150
158
|
|
|
@@ -167,8 +175,7 @@ We gratefully acknowledge:
|
|
|
167
175
|
|
|
168
176
|
## Citation
|
|
169
177
|
|
|
170
|
-
If **combatlearn** is useful in your research, please cite the original
|
|
171
|
-
papers:
|
|
178
|
+
If **combatlearn** is useful in your research, please cite the original papers:
|
|
172
179
|
|
|
173
180
|
- Johnson WE, Li C, Rabinovic A. Adjusting batch effects in microarray expression data using empirical Bayes methods. _Biostatistics_. 2007 Jan;8(1):118-27. doi: [10.1093/biostatistics/kxj037](https://doi.org/10.1093/biostatistics/kxj037)
|
|
174
181
|
|
|
@@ -376,4 +376,114 @@ def test_invalid_method_raises():
|
|
|
376
376
|
"""
|
|
377
377
|
X, batch = simulate_data()
|
|
378
378
|
with pytest.raises(ValueError, match="method must be"):
|
|
379
|
-
ComBatModel(method="invalid").fit(X, batch=batch)
|
|
379
|
+
ComBatModel(method="invalid").fit(X, batch=batch)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def test_compute_metrics_caches_in_metrics_property():
|
|
383
|
+
"""
|
|
384
|
+
Test that compute_metrics=True caches metrics in metrics_ property.
|
|
385
|
+
"""
|
|
386
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
387
|
+
|
|
388
|
+
combat = ComBat(batch=batch, method="johnson", compute_metrics=True)
|
|
389
|
+
X_corr = combat.fit_transform(X)
|
|
390
|
+
|
|
391
|
+
assert combat.metrics_ is not None
|
|
392
|
+
assert 'batch_effect' in combat.metrics_
|
|
393
|
+
assert 'preservation' in combat.metrics_
|
|
394
|
+
assert 'alignment' in combat.metrics_
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def test_compute_metrics_false_returns_none():
|
|
398
|
+
"""
|
|
399
|
+
Test that metrics_ is None when compute_metrics=False.
|
|
400
|
+
"""
|
|
401
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
402
|
+
|
|
403
|
+
combat = ComBat(batch=batch, method="johnson", compute_metrics=False)
|
|
404
|
+
X_corr = combat.fit_transform(X)
|
|
405
|
+
|
|
406
|
+
assert combat.metrics_ is None
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def test_compute_batch_metrics_returns_correct_structure():
|
|
410
|
+
"""
|
|
411
|
+
Test that compute_batch_metrics returns the expected structure.
|
|
412
|
+
"""
|
|
413
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
414
|
+
|
|
415
|
+
combat = ComBat(batch=batch, method="johnson")
|
|
416
|
+
combat.fit(X)
|
|
417
|
+
|
|
418
|
+
metrics = combat.compute_batch_metrics(X, k_neighbors=[5, 10])
|
|
419
|
+
|
|
420
|
+
assert 'batch_effect' in metrics
|
|
421
|
+
assert 'silhouette' in metrics['batch_effect']
|
|
422
|
+
assert 'davies_bouldin' in metrics['batch_effect']
|
|
423
|
+
assert 'kbet' in metrics['batch_effect']
|
|
424
|
+
assert 'lisi' in metrics['batch_effect']
|
|
425
|
+
assert 'variance_ratio' in metrics['batch_effect']
|
|
426
|
+
|
|
427
|
+
for metric_name in ['silhouette', 'davies_bouldin', 'kbet', 'lisi', 'variance_ratio']:
|
|
428
|
+
metric_vals = metrics['batch_effect'][metric_name]
|
|
429
|
+
assert 'before' in metric_vals
|
|
430
|
+
assert 'after' in metric_vals
|
|
431
|
+
|
|
432
|
+
assert 'preservation' in metrics
|
|
433
|
+
assert 'knn' in metrics['preservation']
|
|
434
|
+
assert 5 in metrics['preservation']['knn']
|
|
435
|
+
assert 10 in metrics['preservation']['knn']
|
|
436
|
+
assert 'distance_correlation' in metrics['preservation']
|
|
437
|
+
|
|
438
|
+
assert 'alignment' in metrics
|
|
439
|
+
assert 'centroid_distance' in metrics['alignment']
|
|
440
|
+
assert 'levene_statistic' in metrics['alignment']
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def test_compute_batch_metrics_not_fitted_raises():
|
|
444
|
+
"""
|
|
445
|
+
Test that compute_batch_metrics raises ValueError if not fitted.
|
|
446
|
+
"""
|
|
447
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
448
|
+
|
|
449
|
+
combat = ComBat(batch=batch, method="johnson")
|
|
450
|
+
|
|
451
|
+
with pytest.raises(ValueError, match="not fitted"):
|
|
452
|
+
combat.compute_batch_metrics(X)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def test_compute_batch_metrics_pca_components_validation():
|
|
456
|
+
"""
|
|
457
|
+
Test that pca_components must be less than min(n_samples, n_features).
|
|
458
|
+
"""
|
|
459
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
460
|
+
|
|
461
|
+
combat = ComBat(batch=batch, method="johnson")
|
|
462
|
+
combat.fit(X)
|
|
463
|
+
|
|
464
|
+
# pca_components >= n_features should raise
|
|
465
|
+
with pytest.raises(ValueError, match="pca_components.*must be less than"):
|
|
466
|
+
combat.compute_batch_metrics(X, pca_components=20)
|
|
467
|
+
|
|
468
|
+
# pca_components > n_features should raise
|
|
469
|
+
with pytest.raises(ValueError, match="pca_components.*must be less than"):
|
|
470
|
+
combat.compute_batch_metrics(X, pca_components=50)
|
|
471
|
+
|
|
472
|
+
# Valid pca_components should work
|
|
473
|
+
metrics = combat.compute_batch_metrics(X, pca_components=10)
|
|
474
|
+
assert metrics is not None
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def test_compute_batch_metrics_no_pca_default():
|
|
478
|
+
"""
|
|
479
|
+
Test that metrics are computed in original feature space by default (no PCA).
|
|
480
|
+
"""
|
|
481
|
+
X, batch = simulate_data(n_samples=100, n_features=20)
|
|
482
|
+
|
|
483
|
+
combat = ComBat(batch=batch, method="johnson")
|
|
484
|
+
combat.fit(X)
|
|
485
|
+
|
|
486
|
+
# Default (pca_components=None) should work
|
|
487
|
+
metrics = combat.compute_batch_metrics(X)
|
|
488
|
+
assert metrics is not None
|
|
489
|
+
assert 'batch_effect' in metrics
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|