combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/core.py +578 -0
- combatlearn/metrics.py +788 -0
- combatlearn/sklearn_api.py +143 -0
- combatlearn/visualization.py +533 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/METADATA +24 -14
- combatlearn-1.2.0.dist-info/RECORD +10 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/WHEEL +1 -1
- combatlearn/combat.py +0 -1770
- combatlearn-1.1.2.dist-info/RECORD +0 -7
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Scikit-learn compatible ComBat wrapper."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
8
|
+
|
|
9
|
+
from .core import ArrayLike, ComBatModel
|
|
10
|
+
from .metrics import ComBatMetricsMixin
|
|
11
|
+
from .visualization import ComBatVisualizationMixin
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ComBat(ComBatMetricsMixin, ComBatVisualizationMixin, BaseEstimator, TransformerMixin):
|
|
15
|
+
"""Pipeline-friendly wrapper around `ComBatModel`.
|
|
16
|
+
|
|
17
|
+
Stores batch (and optional covariates) passed at construction and
|
|
18
|
+
appropriately uses them for separate `fit` and `transform`.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : array-like of shape (n_samples,)
|
|
23
|
+
Batch labels for each sample.
|
|
24
|
+
discrete_covariates : array-like, optional
|
|
25
|
+
Categorical covariates to protect (Fortin/Chen only).
|
|
26
|
+
continuous_covariates : array-like, optional
|
|
27
|
+
Continuous covariates to protect (Fortin/Chen only).
|
|
28
|
+
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
29
|
+
ComBat variant to use.
|
|
30
|
+
parametric : bool, default=True
|
|
31
|
+
Use parametric empirical Bayes.
|
|
32
|
+
mean_only : bool, default=False
|
|
33
|
+
Adjust only the mean (ignore variance).
|
|
34
|
+
reference_batch : str, optional
|
|
35
|
+
Batch level to leave unchanged.
|
|
36
|
+
eps : float, default=1e-8
|
|
37
|
+
Numerical jitter for stability.
|
|
38
|
+
covbat_cov_thresh : float or int, default=0.9
|
|
39
|
+
CovBat variance threshold for PCs.
|
|
40
|
+
compute_metrics : bool, default=False
|
|
41
|
+
If True, ``fit_transform`` caches batch metrics in ``metrics_``.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
batch: ArrayLike,
|
|
47
|
+
*,
|
|
48
|
+
discrete_covariates: ArrayLike | None = None,
|
|
49
|
+
continuous_covariates: ArrayLike | None = None,
|
|
50
|
+
method: str = "johnson",
|
|
51
|
+
parametric: bool = True,
|
|
52
|
+
mean_only: bool = False,
|
|
53
|
+
reference_batch: str | None = None,
|
|
54
|
+
eps: float = 1e-8,
|
|
55
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
56
|
+
compute_metrics: bool = False,
|
|
57
|
+
) -> None:
|
|
58
|
+
self.batch = batch
|
|
59
|
+
self.discrete_covariates = discrete_covariates
|
|
60
|
+
self.continuous_covariates = continuous_covariates
|
|
61
|
+
self.method = method
|
|
62
|
+
self.parametric = parametric
|
|
63
|
+
self.mean_only = mean_only
|
|
64
|
+
self.reference_batch = reference_batch
|
|
65
|
+
self.eps = eps
|
|
66
|
+
self.covbat_cov_thresh = covbat_cov_thresh
|
|
67
|
+
self.compute_metrics = compute_metrics
|
|
68
|
+
self._model = ComBatModel(
|
|
69
|
+
method=method,
|
|
70
|
+
parametric=parametric,
|
|
71
|
+
mean_only=mean_only,
|
|
72
|
+
reference_batch=reference_batch,
|
|
73
|
+
eps=eps,
|
|
74
|
+
covbat_cov_thresh=covbat_cov_thresh,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
|
|
78
|
+
"""Fit the ComBat model."""
|
|
79
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
80
|
+
batch_vec = self._subset(self.batch, idx)
|
|
81
|
+
disc = self._subset(self.discrete_covariates, idx)
|
|
82
|
+
cont = self._subset(self.continuous_covariates, idx)
|
|
83
|
+
self._model.fit(
|
|
84
|
+
X,
|
|
85
|
+
batch=batch_vec,
|
|
86
|
+
discrete_covariates=disc,
|
|
87
|
+
continuous_covariates=cont,
|
|
88
|
+
)
|
|
89
|
+
self._fitted_batch = batch_vec
|
|
90
|
+
return self
|
|
91
|
+
|
|
92
|
+
def transform(self, X: ArrayLike) -> pd.DataFrame:
|
|
93
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
94
|
+
idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
|
|
95
|
+
batch_vec = self._subset(self.batch, idx)
|
|
96
|
+
disc = self._subset(self.discrete_covariates, idx)
|
|
97
|
+
cont = self._subset(self.continuous_covariates, idx)
|
|
98
|
+
return self._model.transform(
|
|
99
|
+
X,
|
|
100
|
+
batch=batch_vec,
|
|
101
|
+
discrete_covariates=disc,
|
|
102
|
+
continuous_covariates=cont,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
|
|
107
|
+
"""Subset array-like object by index."""
|
|
108
|
+
if obj is None:
|
|
109
|
+
return None
|
|
110
|
+
if isinstance(obj, (pd.Series, pd.DataFrame)):
|
|
111
|
+
return obj.loc[idx]
|
|
112
|
+
else:
|
|
113
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 1:
|
|
114
|
+
return pd.Series(obj, index=idx)
|
|
115
|
+
else:
|
|
116
|
+
return pd.DataFrame(obj, index=idx)
|
|
117
|
+
|
|
118
|
+
def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
|
|
119
|
+
"""
|
|
120
|
+
Fit and transform the data, optionally computing metrics.
|
|
121
|
+
|
|
122
|
+
If ``compute_metrics=True`` was set at construction, batch effect
|
|
123
|
+
metrics are computed and cached in the ``metrics_`` property.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
X : array-like of shape (n_samples, n_features)
|
|
128
|
+
Input data to fit and transform.
|
|
129
|
+
y : None
|
|
130
|
+
Ignored. Present for API compatibility.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
X_transformed : pd.DataFrame
|
|
135
|
+
Batch-corrected data.
|
|
136
|
+
"""
|
|
137
|
+
self.fit(X, y)
|
|
138
|
+
X_transformed = self.transform(X)
|
|
139
|
+
|
|
140
|
+
if self.compute_metrics:
|
|
141
|
+
self._metrics_cache = self.compute_batch_metrics(X)
|
|
142
|
+
|
|
143
|
+
return X_transformed
|
|
@@ -0,0 +1,533 @@
|
|
|
1
|
+
"""Visualization utilities for ComBat batch correction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import matplotlib
|
|
8
|
+
import matplotlib.colors as mcolors
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
import umap
|
|
14
|
+
from plotly.subplots import make_subplots
|
|
15
|
+
from sklearn.decomposition import PCA
|
|
16
|
+
from sklearn.manifold import TSNE
|
|
17
|
+
|
|
18
|
+
from .core import ArrayLike, FloatArray
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ComBatVisualizationMixin:
|
|
22
|
+
"""Mixin providing visualization methods for the ComBat wrapper."""
|
|
23
|
+
|
|
24
|
+
def plot_transformation(
|
|
25
|
+
self,
|
|
26
|
+
X: ArrayLike,
|
|
27
|
+
*,
|
|
28
|
+
reduction_method: Literal["pca", "tsne", "umap"] = "pca",
|
|
29
|
+
n_components: Literal[2, 3] = 2,
|
|
30
|
+
plot_type: Literal["static", "interactive"] = "static",
|
|
31
|
+
figsize: tuple[int, int] = (12, 5),
|
|
32
|
+
alpha: float = 0.7,
|
|
33
|
+
point_size: int = 50,
|
|
34
|
+
cmap: str = "Set1",
|
|
35
|
+
title: str | None = None,
|
|
36
|
+
show_legend: bool = True,
|
|
37
|
+
return_embeddings: bool = False,
|
|
38
|
+
**reduction_kwargs,
|
|
39
|
+
) -> Any | tuple[Any, dict[str, FloatArray]]:
|
|
40
|
+
"""
|
|
41
|
+
Visualize the ComBat transformation effect using dimensionality reduction.
|
|
42
|
+
|
|
43
|
+
It shows a before/after comparison of data transformed by `ComBat` using
|
|
44
|
+
PCA, t-SNE, or UMAP to reduce dimensions for visualization.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
X : array-like of shape (n_samples, n_features)
|
|
49
|
+
Input data to transform and visualize.
|
|
50
|
+
|
|
51
|
+
reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
|
|
52
|
+
Dimensionality reduction method.
|
|
53
|
+
|
|
54
|
+
n_components : {2, 3}, default=2
|
|
55
|
+
Number of components for dimensionality reduction.
|
|
56
|
+
|
|
57
|
+
plot_type : {`'static'`, `'interactive'`}, default=`'static'`
|
|
58
|
+
Visualization type:
|
|
59
|
+
- `'static'`: matplotlib plots (can be saved as images)
|
|
60
|
+
- `'interactive'`: plotly plots (explorable, requires plotly)
|
|
61
|
+
|
|
62
|
+
return_embeddings : bool, default=False
|
|
63
|
+
If `True`, return embeddings along with the plot.
|
|
64
|
+
|
|
65
|
+
**reduction_kwargs : dict
|
|
66
|
+
Additional parameters for reduction methods.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
|
|
71
|
+
The figure object containing the plots.
|
|
72
|
+
|
|
73
|
+
embeddings : dict, optional
|
|
74
|
+
If `return_embeddings=True`, dictionary with:
|
|
75
|
+
- `'original'`: embedding of original data
|
|
76
|
+
- `'transformed'`: embedding of ComBat-transformed data
|
|
77
|
+
"""
|
|
78
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if n_components not in [2, 3]:
|
|
84
|
+
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
85
|
+
if reduction_method not in ["pca", "tsne", "umap"]:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
|
|
88
|
+
)
|
|
89
|
+
if plot_type not in ["static", "interactive"]:
|
|
90
|
+
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
91
|
+
|
|
92
|
+
if not isinstance(X, pd.DataFrame):
|
|
93
|
+
X = pd.DataFrame(X)
|
|
94
|
+
|
|
95
|
+
idx = X.index
|
|
96
|
+
batch_vec = self._subset(self.batch, idx)
|
|
97
|
+
if batch_vec is None:
|
|
98
|
+
raise ValueError("Batch information is required for visualization")
|
|
99
|
+
|
|
100
|
+
X_transformed = self.transform(X)
|
|
101
|
+
|
|
102
|
+
X_np = X.values
|
|
103
|
+
X_trans_np = X_transformed.values
|
|
104
|
+
|
|
105
|
+
if reduction_method == "pca":
|
|
106
|
+
reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
|
|
107
|
+
reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
|
|
108
|
+
elif reduction_method == "tsne":
|
|
109
|
+
tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
|
|
110
|
+
tsne_params.update(reduction_kwargs)
|
|
111
|
+
reducer_orig = TSNE(n_components=n_components, **tsne_params)
|
|
112
|
+
reducer_trans = TSNE(n_components=n_components, **tsne_params)
|
|
113
|
+
else:
|
|
114
|
+
umap_params = {"random_state": 42}
|
|
115
|
+
umap_params.update(reduction_kwargs)
|
|
116
|
+
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
117
|
+
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
118
|
+
|
|
119
|
+
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
120
|
+
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
121
|
+
|
|
122
|
+
if plot_type == "static":
|
|
123
|
+
fig = self._create_static_plot(
|
|
124
|
+
X_embedded_orig,
|
|
125
|
+
X_embedded_trans,
|
|
126
|
+
batch_vec,
|
|
127
|
+
reduction_method,
|
|
128
|
+
n_components,
|
|
129
|
+
figsize,
|
|
130
|
+
alpha,
|
|
131
|
+
point_size,
|
|
132
|
+
cmap,
|
|
133
|
+
title,
|
|
134
|
+
show_legend,
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
fig = self._create_interactive_plot(
|
|
138
|
+
X_embedded_orig,
|
|
139
|
+
X_embedded_trans,
|
|
140
|
+
batch_vec,
|
|
141
|
+
reduction_method,
|
|
142
|
+
n_components,
|
|
143
|
+
cmap,
|
|
144
|
+
title,
|
|
145
|
+
show_legend,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if return_embeddings:
|
|
149
|
+
embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
|
|
150
|
+
return fig, embeddings
|
|
151
|
+
else:
|
|
152
|
+
return fig
|
|
153
|
+
|
|
154
|
+
def _create_static_plot(
|
|
155
|
+
self,
|
|
156
|
+
X_orig: FloatArray,
|
|
157
|
+
X_trans: FloatArray,
|
|
158
|
+
batch_labels: pd.Series,
|
|
159
|
+
method: str,
|
|
160
|
+
n_components: int,
|
|
161
|
+
figsize: tuple[int, int],
|
|
162
|
+
alpha: float,
|
|
163
|
+
point_size: int,
|
|
164
|
+
cmap: str,
|
|
165
|
+
title: str | None,
|
|
166
|
+
show_legend: bool,
|
|
167
|
+
) -> Any:
|
|
168
|
+
"""Create static plots using matplotlib."""
|
|
169
|
+
|
|
170
|
+
fig = plt.figure(figsize=figsize)
|
|
171
|
+
|
|
172
|
+
unique_batches = batch_labels.drop_duplicates()
|
|
173
|
+
n_batches = len(unique_batches)
|
|
174
|
+
|
|
175
|
+
if n_batches <= 10:
|
|
176
|
+
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
177
|
+
else:
|
|
178
|
+
colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
|
|
179
|
+
|
|
180
|
+
if n_components == 2:
|
|
181
|
+
ax1 = plt.subplot(1, 2, 1)
|
|
182
|
+
ax2 = plt.subplot(1, 2, 2)
|
|
183
|
+
else:
|
|
184
|
+
ax1 = fig.add_subplot(121, projection="3d")
|
|
185
|
+
ax2 = fig.add_subplot(122, projection="3d")
|
|
186
|
+
|
|
187
|
+
for i, batch in enumerate(unique_batches):
|
|
188
|
+
mask = batch_labels == batch
|
|
189
|
+
if n_components == 2:
|
|
190
|
+
ax1.scatter(
|
|
191
|
+
X_orig[mask, 0],
|
|
192
|
+
X_orig[mask, 1],
|
|
193
|
+
c=[colors[i]],
|
|
194
|
+
s=point_size,
|
|
195
|
+
alpha=alpha,
|
|
196
|
+
label=f"Batch {batch}",
|
|
197
|
+
edgecolors="black",
|
|
198
|
+
linewidth=0.5,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
ax1.scatter(
|
|
202
|
+
X_orig[mask, 0],
|
|
203
|
+
X_orig[mask, 1],
|
|
204
|
+
X_orig[mask, 2],
|
|
205
|
+
c=[colors[i]],
|
|
206
|
+
s=point_size,
|
|
207
|
+
alpha=alpha,
|
|
208
|
+
label=f"Batch {batch}",
|
|
209
|
+
edgecolors="black",
|
|
210
|
+
linewidth=0.5,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
ax1.set_title(f"Before ComBat correction\n({method.upper()})")
|
|
214
|
+
ax1.set_xlabel(f"{method.upper()}1")
|
|
215
|
+
ax1.set_ylabel(f"{method.upper()}2")
|
|
216
|
+
if n_components == 3:
|
|
217
|
+
ax1.set_zlabel(f"{method.upper()}3")
|
|
218
|
+
|
|
219
|
+
for i, batch in enumerate(unique_batches):
|
|
220
|
+
mask = batch_labels == batch
|
|
221
|
+
if n_components == 2:
|
|
222
|
+
ax2.scatter(
|
|
223
|
+
X_trans[mask, 0],
|
|
224
|
+
X_trans[mask, 1],
|
|
225
|
+
c=[colors[i]],
|
|
226
|
+
s=point_size,
|
|
227
|
+
alpha=alpha,
|
|
228
|
+
label=f"Batch {batch}",
|
|
229
|
+
edgecolors="black",
|
|
230
|
+
linewidth=0.5,
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
ax2.scatter(
|
|
234
|
+
X_trans[mask, 0],
|
|
235
|
+
X_trans[mask, 1],
|
|
236
|
+
X_trans[mask, 2],
|
|
237
|
+
c=[colors[i]],
|
|
238
|
+
s=point_size,
|
|
239
|
+
alpha=alpha,
|
|
240
|
+
label=f"Batch {batch}",
|
|
241
|
+
edgecolors="black",
|
|
242
|
+
linewidth=0.5,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
ax2.set_title(f"After ComBat correction\n({method.upper()})")
|
|
246
|
+
ax2.set_xlabel(f"{method.upper()}1")
|
|
247
|
+
ax2.set_ylabel(f"{method.upper()}2")
|
|
248
|
+
if n_components == 3:
|
|
249
|
+
ax2.set_zlabel(f"{method.upper()}3")
|
|
250
|
+
|
|
251
|
+
if show_legend and n_batches <= 20:
|
|
252
|
+
ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
|
253
|
+
|
|
254
|
+
if title is None:
|
|
255
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
256
|
+
fig.suptitle(title, fontsize=14, fontweight="bold")
|
|
257
|
+
|
|
258
|
+
plt.tight_layout()
|
|
259
|
+
return fig
|
|
260
|
+
|
|
261
|
+
def _create_interactive_plot(
|
|
262
|
+
self,
|
|
263
|
+
X_orig: FloatArray,
|
|
264
|
+
X_trans: FloatArray,
|
|
265
|
+
batch_labels: pd.Series,
|
|
266
|
+
method: str,
|
|
267
|
+
n_components: int,
|
|
268
|
+
cmap: str,
|
|
269
|
+
title: str | None,
|
|
270
|
+
show_legend: bool,
|
|
271
|
+
) -> Any:
|
|
272
|
+
"""Create interactive plots using plotly."""
|
|
273
|
+
if n_components == 2:
|
|
274
|
+
fig = make_subplots(
|
|
275
|
+
rows=1,
|
|
276
|
+
cols=2,
|
|
277
|
+
subplot_titles=(
|
|
278
|
+
f"Before ComBat correction ({method.upper()})",
|
|
279
|
+
f"After ComBat correction ({method.upper()})",
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
fig = make_subplots(
|
|
284
|
+
rows=1,
|
|
285
|
+
cols=2,
|
|
286
|
+
specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
|
|
287
|
+
subplot_titles=(
|
|
288
|
+
f"Before ComBat correction ({method.upper()})",
|
|
289
|
+
f"After ComBat correction ({method.upper()})",
|
|
290
|
+
),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
unique_batches = batch_labels.drop_duplicates()
|
|
294
|
+
|
|
295
|
+
n_batches = len(unique_batches)
|
|
296
|
+
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
297
|
+
color_list = [
|
|
298
|
+
mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
batch_to_color = dict(zip(unique_batches, color_list, strict=True))
|
|
302
|
+
|
|
303
|
+
for batch in unique_batches:
|
|
304
|
+
mask = batch_labels == batch
|
|
305
|
+
|
|
306
|
+
if n_components == 2:
|
|
307
|
+
fig.add_trace(
|
|
308
|
+
go.Scatter(
|
|
309
|
+
x=X_orig[mask, 0],
|
|
310
|
+
y=X_orig[mask, 1],
|
|
311
|
+
mode="markers",
|
|
312
|
+
name=f"Batch {batch}",
|
|
313
|
+
marker={
|
|
314
|
+
"size": 8,
|
|
315
|
+
"color": batch_to_color[batch],
|
|
316
|
+
"line": {"width": 1, "color": "black"},
|
|
317
|
+
},
|
|
318
|
+
showlegend=False,
|
|
319
|
+
),
|
|
320
|
+
row=1,
|
|
321
|
+
col=1,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
fig.add_trace(
|
|
325
|
+
go.Scatter(
|
|
326
|
+
x=X_trans[mask, 0],
|
|
327
|
+
y=X_trans[mask, 1],
|
|
328
|
+
mode="markers",
|
|
329
|
+
name=f"Batch {batch}",
|
|
330
|
+
marker={
|
|
331
|
+
"size": 8,
|
|
332
|
+
"color": batch_to_color[batch],
|
|
333
|
+
"line": {"width": 1, "color": "black"},
|
|
334
|
+
},
|
|
335
|
+
showlegend=show_legend,
|
|
336
|
+
),
|
|
337
|
+
row=1,
|
|
338
|
+
col=2,
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
fig.add_trace(
|
|
342
|
+
go.Scatter3d(
|
|
343
|
+
x=X_orig[mask, 0],
|
|
344
|
+
y=X_orig[mask, 1],
|
|
345
|
+
z=X_orig[mask, 2],
|
|
346
|
+
mode="markers",
|
|
347
|
+
name=f"Batch {batch}",
|
|
348
|
+
marker={
|
|
349
|
+
"size": 5,
|
|
350
|
+
"color": batch_to_color[batch],
|
|
351
|
+
"line": {"width": 0.5, "color": "black"},
|
|
352
|
+
},
|
|
353
|
+
showlegend=False,
|
|
354
|
+
),
|
|
355
|
+
row=1,
|
|
356
|
+
col=1,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
fig.add_trace(
|
|
360
|
+
go.Scatter3d(
|
|
361
|
+
x=X_trans[mask, 0],
|
|
362
|
+
y=X_trans[mask, 1],
|
|
363
|
+
z=X_trans[mask, 2],
|
|
364
|
+
mode="markers",
|
|
365
|
+
name=f"Batch {batch}",
|
|
366
|
+
marker={
|
|
367
|
+
"size": 5,
|
|
368
|
+
"color": batch_to_color[batch],
|
|
369
|
+
"line": {"width": 0.5, "color": "black"},
|
|
370
|
+
},
|
|
371
|
+
showlegend=show_legend,
|
|
372
|
+
),
|
|
373
|
+
row=1,
|
|
374
|
+
col=2,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if title is None:
|
|
378
|
+
title = f"ComBat correction effect visualized with {method.upper()}"
|
|
379
|
+
|
|
380
|
+
fig.update_layout(
|
|
381
|
+
title=title,
|
|
382
|
+
title_font_size=16,
|
|
383
|
+
height=600,
|
|
384
|
+
showlegend=show_legend,
|
|
385
|
+
hovermode="closest",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
|
|
389
|
+
|
|
390
|
+
if n_components == 2:
|
|
391
|
+
fig.update_xaxes(title_text=axis_labels[0])
|
|
392
|
+
fig.update_yaxes(title_text=axis_labels[1])
|
|
393
|
+
else:
|
|
394
|
+
fig.update_scenes(
|
|
395
|
+
xaxis_title=axis_labels[0],
|
|
396
|
+
yaxis_title=axis_labels[1],
|
|
397
|
+
zaxis_title=axis_labels[2],
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
return fig
|
|
401
|
+
|
|
402
|
+
def plot_feature_importance(
|
|
403
|
+
self,
|
|
404
|
+
top_n: int = 20,
|
|
405
|
+
kind: Literal["location", "scale", "combined"] = "combined",
|
|
406
|
+
mode: Literal["magnitude", "distribution"] = "magnitude",
|
|
407
|
+
figsize: tuple[int, int] = (8, 10),
|
|
408
|
+
) -> Any:
|
|
409
|
+
"""Plot top features affected by batch effects.
|
|
410
|
+
|
|
411
|
+
Parameters
|
|
412
|
+
----------
|
|
413
|
+
top_n : int, default=20
|
|
414
|
+
Number of top features to display.
|
|
415
|
+
kind : {'location', 'scale', 'combined'}, default='combined'
|
|
416
|
+
- 'location': bar plot of location (mean shift) contribution only
|
|
417
|
+
- 'scale': bar plot of scale (variance) contribution only
|
|
418
|
+
- 'combined': grouped bar plot showing location and scale
|
|
419
|
+
side-by-side for each feature (sorted by Euclidean magnitude).
|
|
420
|
+
In magnitude mode: bars reflect Euclidean decomposition
|
|
421
|
+
(combined**2 = location**2 + scale**2).
|
|
422
|
+
In distribution mode: bars reflect independent normalized
|
|
423
|
+
contributions (each sums to 1 separately).
|
|
424
|
+
mode : {'magnitude', 'distribution'}, default='magnitude'
|
|
425
|
+
- 'magnitude': y-axis shows absolute batch effect magnitude
|
|
426
|
+
- 'distribution': y-axis shows relative contribution (proportion), includes
|
|
427
|
+
annotation showing cumulative contribution of top_n features
|
|
428
|
+
(e.g., "Top 20 features explain 75% of total batch effect")
|
|
429
|
+
figsize : tuple, default=(8,10)
|
|
430
|
+
Figure size (width, height) in inches.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
matplotlib.figure.Figure
|
|
435
|
+
The figure object containing the plot.
|
|
436
|
+
|
|
437
|
+
Raises
|
|
438
|
+
------
|
|
439
|
+
ValueError
|
|
440
|
+
If the model is not fitted, or if kind/mode is invalid.
|
|
441
|
+
"""
|
|
442
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
443
|
+
raise ValueError(
|
|
444
|
+
"This ComBat instance is not fitted yet. "
|
|
445
|
+
"Call 'fit' before 'plot_feature_importance'."
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
if kind not in ["location", "scale", "combined"]:
|
|
449
|
+
raise ValueError(f"kind must be 'location', 'scale', or 'combined', got '{kind}'")
|
|
450
|
+
|
|
451
|
+
if mode not in ["magnitude", "distribution"]:
|
|
452
|
+
raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
|
|
453
|
+
|
|
454
|
+
importance_df = self.feature_batch_importance(mode=mode)
|
|
455
|
+
top_features = importance_df.head(top_n)
|
|
456
|
+
|
|
457
|
+
# Reverse so highest values are at the top of the horizontal bar plot
|
|
458
|
+
top_features = top_features.iloc[::-1]
|
|
459
|
+
|
|
460
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
461
|
+
|
|
462
|
+
if kind == "combined":
|
|
463
|
+
# Grouped horizontal bar plot showing location and scale side-by-side
|
|
464
|
+
y = np.arange(len(top_features))
|
|
465
|
+
height = 0.35
|
|
466
|
+
|
|
467
|
+
ax.barh(
|
|
468
|
+
y + height / 2,
|
|
469
|
+
top_features["location"],
|
|
470
|
+
height,
|
|
471
|
+
label="Location",
|
|
472
|
+
color="steelblue",
|
|
473
|
+
edgecolor="black",
|
|
474
|
+
linewidth=0.5,
|
|
475
|
+
)
|
|
476
|
+
ax.barh(
|
|
477
|
+
y - height / 2,
|
|
478
|
+
top_features["scale"],
|
|
479
|
+
height,
|
|
480
|
+
label="Scale",
|
|
481
|
+
color="coral",
|
|
482
|
+
edgecolor="black",
|
|
483
|
+
linewidth=0.5,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
ax.set_yticks(y)
|
|
487
|
+
ax.set_yticklabels(top_features.index)
|
|
488
|
+
ax.legend()
|
|
489
|
+
else:
|
|
490
|
+
# Single horizontal bar plot for location or scale
|
|
491
|
+
color = "steelblue" if kind == "location" else "coral"
|
|
492
|
+
ax.barh(
|
|
493
|
+
range(len(top_features)),
|
|
494
|
+
top_features[kind],
|
|
495
|
+
color=color,
|
|
496
|
+
edgecolor="black",
|
|
497
|
+
linewidth=0.5,
|
|
498
|
+
)
|
|
499
|
+
ax.set_yticks(range(len(top_features)))
|
|
500
|
+
ax.set_yticklabels(top_features.index)
|
|
501
|
+
|
|
502
|
+
# Set labels and title
|
|
503
|
+
ax.set_ylabel("Feature")
|
|
504
|
+
if mode == "magnitude":
|
|
505
|
+
ax.set_xlabel("Batch Effect Magnitude (RMS)")
|
|
506
|
+
title = f"Top {top_n} Features by Batch Effect"
|
|
507
|
+
else:
|
|
508
|
+
ax.set_xlabel("Relative Contribution")
|
|
509
|
+
title = f"Top {top_n} Features by Batch Effect (Distribution)"
|
|
510
|
+
|
|
511
|
+
if kind == "combined":
|
|
512
|
+
title += " (Location & Scale)"
|
|
513
|
+
else:
|
|
514
|
+
title += f" ({kind.capitalize()})"
|
|
515
|
+
|
|
516
|
+
ax.set_title(title)
|
|
517
|
+
plt.tight_layout()
|
|
518
|
+
|
|
519
|
+
# For distribution mode, print cumulative contribution
|
|
520
|
+
if mode == "distribution":
|
|
521
|
+
if kind == "combined":
|
|
522
|
+
cumulative_pct = top_features["combined"].sum() * 100
|
|
523
|
+
effect_label = "batch effect"
|
|
524
|
+
elif kind == "location":
|
|
525
|
+
cumulative_pct = top_features["location"].sum() * 100
|
|
526
|
+
effect_label = "location effect"
|
|
527
|
+
else: # scale
|
|
528
|
+
cumulative_pct = top_features["scale"].sum() * 100
|
|
529
|
+
effect_label = "scale effect"
|
|
530
|
+
|
|
531
|
+
print(f"Top {top_n} features explain {cumulative_pct:.1f}% of total {effect_label}")
|
|
532
|
+
|
|
533
|
+
return fig
|