combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +2 -2
- combatlearn/core.py +578 -0
- combatlearn/metrics.py +788 -0
- combatlearn/sklearn_api.py +143 -0
- combatlearn/visualization.py +533 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/METADATA +24 -14
- combatlearn-1.2.0.dist-info/RECORD +10 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/WHEEL +1 -1
- combatlearn/combat.py +0 -1770
- combatlearn-1.1.2.dist-info/RECORD +0 -7
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-1.1.2.dist-info → combatlearn-1.2.0.dist-info}/top_level.txt +0 -0
combatlearn/__init__.py
CHANGED
combatlearn/core.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
"""ComBat algorithm core.
|
|
2
|
+
|
|
3
|
+
`ComBatModel` implements three variants of the ComBat algorithm:
|
|
4
|
+
* Johnson et al. (2007) vanilla ComBat (method="johnson")
|
|
5
|
+
* Fortin et al. (2018) extension with covariates (method="fortin")
|
|
6
|
+
* Chen et al. (2022) CovBat (method="chen")
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any, Literal
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.linalg as la
|
|
16
|
+
import numpy.typing as npt
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from sklearn.decomposition import PCA
|
|
19
|
+
|
|
20
|
+
ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
|
|
21
|
+
FloatArray = npt.NDArray[np.float64]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ComBatModel:
|
|
25
|
+
"""ComBat algorithm.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
method : {'johnson', 'fortin', 'chen'}, default='johnson'
|
|
30
|
+
* 'johnson' - classic ComBat.
|
|
31
|
+
* 'fortin' - covariate-aware ComBat.
|
|
32
|
+
* 'chen' - CovBat, PCA-based ComBat.
|
|
33
|
+
parametric : bool, default=True
|
|
34
|
+
Use the parametric empirical Bayes variant.
|
|
35
|
+
mean_only : bool, default=False
|
|
36
|
+
If True, only the mean is adjusted (`gamma_star`),
|
|
37
|
+
ignoring the variance (`delta_star`).
|
|
38
|
+
reference_batch : str, optional
|
|
39
|
+
If specified, the batch level to use as reference.
|
|
40
|
+
covbat_cov_thresh : float or int, default=0.9
|
|
41
|
+
CovBat: cumulative variance threshold (0, 1] to retain PCs, or
|
|
42
|
+
integer >= 1 specifying the number of components directly.
|
|
43
|
+
eps : float, default=1e-8
|
|
44
|
+
Numerical jitter to avoid division-by-zero.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
51
|
+
parametric: bool = True,
|
|
52
|
+
mean_only: bool = False,
|
|
53
|
+
reference_batch: str | None = None,
|
|
54
|
+
eps: float = 1e-8,
|
|
55
|
+
covbat_cov_thresh: float | int = 0.9,
|
|
56
|
+
) -> None:
|
|
57
|
+
self.method: str = method
|
|
58
|
+
self.parametric: bool = parametric
|
|
59
|
+
self.mean_only: bool = bool(mean_only)
|
|
60
|
+
self.reference_batch: str | None = reference_batch
|
|
61
|
+
self.eps: float = float(eps)
|
|
62
|
+
self.covbat_cov_thresh: float | int = covbat_cov_thresh
|
|
63
|
+
|
|
64
|
+
self._batch_levels: pd.Index
|
|
65
|
+
self._grand_mean: pd.Series
|
|
66
|
+
self._pooled_var: pd.Series
|
|
67
|
+
self._gamma_star: FloatArray
|
|
68
|
+
self._delta_star: FloatArray
|
|
69
|
+
self._n_per_batch: dict[str, int]
|
|
70
|
+
self._reference_batch_idx: int | None
|
|
71
|
+
self._beta_hat_nonbatch: FloatArray
|
|
72
|
+
self._n_batch: int
|
|
73
|
+
self._p_design: int
|
|
74
|
+
self._covbat_pca: PCA
|
|
75
|
+
self._covbat_n_pc: int
|
|
76
|
+
self._batch_levels_pc: pd.Index
|
|
77
|
+
self._pc_gamma_star: FloatArray
|
|
78
|
+
self._pc_delta_star: FloatArray
|
|
79
|
+
|
|
80
|
+
# Validate covbat_cov_thresh
|
|
81
|
+
if isinstance(self.covbat_cov_thresh, float):
|
|
82
|
+
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
83
|
+
raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
|
|
84
|
+
elif isinstance(self.covbat_cov_thresh, int):
|
|
85
|
+
if self.covbat_cov_thresh < 1:
|
|
86
|
+
raise ValueError("covbat_cov_thresh must be >= 1 when int.")
|
|
87
|
+
else:
|
|
88
|
+
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
|
|
92
|
+
"""Convert array-like to categorical Series with validation."""
|
|
93
|
+
ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
|
|
94
|
+
if not ser.index.equals(index):
|
|
95
|
+
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
96
|
+
return ser.astype("category")
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
|
|
100
|
+
"""Convert array-like to DataFrame."""
|
|
101
|
+
if arr is None:
|
|
102
|
+
return None
|
|
103
|
+
if isinstance(arr, pd.Series):
|
|
104
|
+
arr = arr.to_frame()
|
|
105
|
+
if not isinstance(arr, pd.DataFrame):
|
|
106
|
+
arr = pd.DataFrame(arr, index=index)
|
|
107
|
+
if not arr.index.equals(index):
|
|
108
|
+
raise ValueError(f"`{name}` index mismatch with `X`.")
|
|
109
|
+
return arr
|
|
110
|
+
|
|
111
|
+
def fit(
|
|
112
|
+
self,
|
|
113
|
+
X: ArrayLike,
|
|
114
|
+
y: ArrayLike | None = None,
|
|
115
|
+
*,
|
|
116
|
+
batch: ArrayLike,
|
|
117
|
+
discrete_covariates: ArrayLike | None = None,
|
|
118
|
+
continuous_covariates: ArrayLike | None = None,
|
|
119
|
+
) -> ComBatModel:
|
|
120
|
+
"""Fit the ComBat model."""
|
|
121
|
+
method = self.method.lower()
|
|
122
|
+
if method not in {"johnson", "fortin", "chen"}:
|
|
123
|
+
raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
|
|
124
|
+
if not isinstance(X, pd.DataFrame):
|
|
125
|
+
X = pd.DataFrame(X)
|
|
126
|
+
idx = X.index
|
|
127
|
+
batch = self._as_series(batch, idx, "batch")
|
|
128
|
+
|
|
129
|
+
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
130
|
+
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
131
|
+
|
|
132
|
+
if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"reference_batch={self.reference_batch!r} not present in the data batches."
|
|
135
|
+
f"{list(batch.cat.categories)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if method == "johnson":
|
|
139
|
+
if disc is not None or cont is not None:
|
|
140
|
+
warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
|
|
141
|
+
self._fit_johnson(X, batch)
|
|
142
|
+
elif method == "fortin":
|
|
143
|
+
self._fit_fortin(X, batch, disc, cont)
|
|
144
|
+
elif method == "chen":
|
|
145
|
+
self._fit_chen(X, batch, disc, cont)
|
|
146
|
+
return self
|
|
147
|
+
|
|
148
|
+
def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
|
|
149
|
+
"""Johnson et al. (2007) ComBat."""
|
|
150
|
+
self._batch_levels = batch.cat.categories
|
|
151
|
+
pooled_var = X.var(axis=0, ddof=1) + self.eps
|
|
152
|
+
grand_mean = X.mean(axis=0)
|
|
153
|
+
|
|
154
|
+
Xs = (X - grand_mean) / np.sqrt(pooled_var)
|
|
155
|
+
|
|
156
|
+
n_per_batch: dict[str, int] = {}
|
|
157
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
158
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
159
|
+
|
|
160
|
+
for lvl in self._batch_levels:
|
|
161
|
+
idx = batch == lvl
|
|
162
|
+
n_b = int(idx.sum())
|
|
163
|
+
if n_b < 2:
|
|
164
|
+
raise ValueError(f"Batch '{lvl}' has <2 samples.")
|
|
165
|
+
n_per_batch[str(lvl)] = n_b
|
|
166
|
+
xb = Xs.loc[idx]
|
|
167
|
+
gamma_hat.append(xb.mean(axis=0).values)
|
|
168
|
+
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
169
|
+
|
|
170
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
171
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
172
|
+
|
|
173
|
+
if self.mean_only:
|
|
174
|
+
gamma_star = self._shrink_gamma(
|
|
175
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
176
|
+
)
|
|
177
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
178
|
+
else:
|
|
179
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
180
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if self.reference_batch is not None:
|
|
184
|
+
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
185
|
+
gamma_ref = gamma_star[ref_idx]
|
|
186
|
+
delta_ref = delta_star[ref_idx]
|
|
187
|
+
gamma_star = gamma_star - gamma_ref
|
|
188
|
+
if not self.mean_only:
|
|
189
|
+
delta_star = delta_star / delta_ref
|
|
190
|
+
self._reference_batch_idx = ref_idx
|
|
191
|
+
else:
|
|
192
|
+
self._reference_batch_idx = None
|
|
193
|
+
|
|
194
|
+
self._grand_mean = grand_mean
|
|
195
|
+
self._pooled_var = pooled_var
|
|
196
|
+
self._gamma_star = gamma_star
|
|
197
|
+
self._delta_star = delta_star
|
|
198
|
+
self._n_per_batch = n_per_batch
|
|
199
|
+
|
|
200
|
+
def _fit_fortin(
|
|
201
|
+
self,
|
|
202
|
+
X: pd.DataFrame,
|
|
203
|
+
batch: pd.Series,
|
|
204
|
+
disc: pd.DataFrame | None,
|
|
205
|
+
cont: pd.DataFrame | None,
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Fortin et al. (2018) neuroComBat."""
|
|
208
|
+
self._batch_levels = batch.cat.categories
|
|
209
|
+
n_batch = len(self._batch_levels)
|
|
210
|
+
n_samples = len(X)
|
|
211
|
+
|
|
212
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
|
|
213
|
+
if self.reference_batch is not None:
|
|
214
|
+
if self.reference_batch not in self._batch_levels:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
f"reference_batch={self.reference_batch!r} not present in batches."
|
|
217
|
+
f"{list(self._batch_levels)}"
|
|
218
|
+
)
|
|
219
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
220
|
+
|
|
221
|
+
parts: list[pd.DataFrame] = [batch_dummies]
|
|
222
|
+
if disc is not None:
|
|
223
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
224
|
+
|
|
225
|
+
if cont is not None:
|
|
226
|
+
parts.append(cont.astype(float))
|
|
227
|
+
|
|
228
|
+
design = pd.concat(parts, axis=1).values
|
|
229
|
+
p_design = design.shape[1]
|
|
230
|
+
|
|
231
|
+
X_np = X.values
|
|
232
|
+
beta_hat = la.lstsq(design, X_np, rcond=None)[0]
|
|
233
|
+
|
|
234
|
+
beta_hat_batch = beta_hat[:n_batch]
|
|
235
|
+
self._beta_hat_nonbatch = beta_hat[n_batch:]
|
|
236
|
+
|
|
237
|
+
n_per_batch = batch.value_counts().sort_index().astype(int).values
|
|
238
|
+
self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
|
|
239
|
+
|
|
240
|
+
if self.reference_batch is not None:
|
|
241
|
+
ref_idx = list(self._batch_levels).index(self.reference_batch)
|
|
242
|
+
grand_mean = beta_hat_batch[ref_idx]
|
|
243
|
+
else:
|
|
244
|
+
grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
|
|
245
|
+
ref_idx = None
|
|
246
|
+
|
|
247
|
+
self._grand_mean = pd.Series(grand_mean, index=X.columns)
|
|
248
|
+
|
|
249
|
+
if self.reference_batch is not None:
|
|
250
|
+
ref_mask = (batch == self.reference_batch).values
|
|
251
|
+
resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
|
|
252
|
+
denom = int(ref_mask.sum())
|
|
253
|
+
else:
|
|
254
|
+
resid = X_np - design @ beta_hat
|
|
255
|
+
denom = n_samples
|
|
256
|
+
var_pooled = (resid**2).sum(axis=0) / denom + self.eps
|
|
257
|
+
self._pooled_var = pd.Series(var_pooled, index=X.columns)
|
|
258
|
+
|
|
259
|
+
stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
|
|
260
|
+
Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
|
|
261
|
+
|
|
262
|
+
gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
|
|
263
|
+
delta_hat = np.vstack(
|
|
264
|
+
[Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if self.mean_only:
|
|
268
|
+
gamma_star = self._shrink_gamma(
|
|
269
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
270
|
+
)
|
|
271
|
+
delta_star = np.ones_like(delta_hat)
|
|
272
|
+
else:
|
|
273
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
274
|
+
gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
if ref_idx is not None:
|
|
278
|
+
gamma_star[ref_idx] = 0.0
|
|
279
|
+
if not self.mean_only:
|
|
280
|
+
delta_star[ref_idx] = 1.0
|
|
281
|
+
self._reference_batch_idx = ref_idx
|
|
282
|
+
|
|
283
|
+
self._gamma_star = gamma_star
|
|
284
|
+
self._delta_star = delta_star
|
|
285
|
+
self._n_batch = n_batch
|
|
286
|
+
self._p_design = p_design
|
|
287
|
+
|
|
288
|
+
def _fit_chen(
|
|
289
|
+
self,
|
|
290
|
+
X: pd.DataFrame,
|
|
291
|
+
batch: pd.Series,
|
|
292
|
+
disc: pd.DataFrame | None,
|
|
293
|
+
cont: pd.DataFrame | None,
|
|
294
|
+
) -> None:
|
|
295
|
+
"""Chen et al. (2022) CovBat."""
|
|
296
|
+
self._fit_fortin(X, batch, disc, cont)
|
|
297
|
+
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
298
|
+
pca = PCA(svd_solver="full", whiten=False).fit(X_meanvar_adj)
|
|
299
|
+
|
|
300
|
+
# Determine number of components based on threshold type
|
|
301
|
+
if isinstance(self.covbat_cov_thresh, int):
|
|
302
|
+
n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
|
|
303
|
+
else:
|
|
304
|
+
cumulative = np.cumsum(pca.explained_variance_ratio_)
|
|
305
|
+
n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
|
|
306
|
+
|
|
307
|
+
self._covbat_pca = pca
|
|
308
|
+
self._covbat_n_pc = n_pc
|
|
309
|
+
|
|
310
|
+
scores = pca.transform(X_meanvar_adj)[:, :n_pc]
|
|
311
|
+
scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
|
|
312
|
+
self._batch_levels_pc = self._batch_levels
|
|
313
|
+
n_per_batch = self._n_per_batch
|
|
314
|
+
|
|
315
|
+
gamma_hat: list[npt.NDArray[np.float64]] = []
|
|
316
|
+
delta_hat: list[npt.NDArray[np.float64]] = []
|
|
317
|
+
for lvl in self._batch_levels_pc:
|
|
318
|
+
idx = batch == lvl
|
|
319
|
+
xb = scores_df.loc[idx]
|
|
320
|
+
gamma_hat.append(xb.mean(axis=0).values)
|
|
321
|
+
delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
|
|
322
|
+
gamma_hat_arr = np.vstack(gamma_hat)
|
|
323
|
+
delta_hat_arr = np.vstack(delta_hat)
|
|
324
|
+
|
|
325
|
+
if self.mean_only:
|
|
326
|
+
gamma_star = self._shrink_gamma(
|
|
327
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
328
|
+
)
|
|
329
|
+
delta_star = np.ones_like(delta_hat_arr)
|
|
330
|
+
else:
|
|
331
|
+
gamma_star, delta_star = self._shrink_gamma_delta(
|
|
332
|
+
gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if self.reference_batch is not None:
|
|
336
|
+
ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
|
|
337
|
+
gamma_ref = gamma_star[ref_idx]
|
|
338
|
+
delta_ref = delta_star[ref_idx]
|
|
339
|
+
gamma_star = gamma_star - gamma_ref
|
|
340
|
+
if not self.mean_only:
|
|
341
|
+
delta_star = delta_star / delta_ref
|
|
342
|
+
|
|
343
|
+
self._pc_gamma_star = gamma_star
|
|
344
|
+
self._pc_delta_star = delta_star
|
|
345
|
+
|
|
346
|
+
def _shrink_gamma_delta(
|
|
347
|
+
self,
|
|
348
|
+
gamma_hat: FloatArray,
|
|
349
|
+
delta_hat: FloatArray,
|
|
350
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
351
|
+
*,
|
|
352
|
+
parametric: bool,
|
|
353
|
+
max_iter: int = 100,
|
|
354
|
+
tol: float = 1e-4,
|
|
355
|
+
) -> tuple[FloatArray, FloatArray]:
|
|
356
|
+
"""Empirical Bayes shrinkage estimation."""
|
|
357
|
+
if parametric:
|
|
358
|
+
gamma_bar = gamma_hat.mean(axis=0)
|
|
359
|
+
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
360
|
+
a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
|
|
361
|
+
b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
|
|
362
|
+
|
|
363
|
+
B, _p = gamma_hat.shape
|
|
364
|
+
gamma_star = np.empty_like(gamma_hat)
|
|
365
|
+
delta_star = np.empty_like(delta_hat)
|
|
366
|
+
n_vec = (
|
|
367
|
+
np.array(list(n_per_batch.values()))
|
|
368
|
+
if isinstance(n_per_batch, dict)
|
|
369
|
+
else n_per_batch
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
for i in range(B):
|
|
373
|
+
n_i = n_vec[i]
|
|
374
|
+
g, d = gamma_hat[i], delta_hat[i]
|
|
375
|
+
gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
|
|
376
|
+
gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
|
|
377
|
+
|
|
378
|
+
a_post = a_prior + n_i / 2.0
|
|
379
|
+
b_post = b_prior + 0.5 * n_i * d
|
|
380
|
+
delta_star[i] = b_post / (a_post - 1)
|
|
381
|
+
return gamma_star, delta_star
|
|
382
|
+
|
|
383
|
+
else:
|
|
384
|
+
B, _p = gamma_hat.shape
|
|
385
|
+
n_vec = (
|
|
386
|
+
np.array(list(n_per_batch.values()))
|
|
387
|
+
if isinstance(n_per_batch, dict)
|
|
388
|
+
else n_per_batch
|
|
389
|
+
)
|
|
390
|
+
gamma_bar = gamma_hat.mean(axis=0)
|
|
391
|
+
t2 = gamma_hat.var(axis=0, ddof=1)
|
|
392
|
+
|
|
393
|
+
def postmean(
|
|
394
|
+
g_hat: FloatArray,
|
|
395
|
+
g_bar: FloatArray,
|
|
396
|
+
n: float,
|
|
397
|
+
d_star: FloatArray,
|
|
398
|
+
t2_: FloatArray,
|
|
399
|
+
) -> FloatArray:
|
|
400
|
+
return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
|
|
401
|
+
|
|
402
|
+
def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
|
|
403
|
+
return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
|
|
404
|
+
|
|
405
|
+
def aprior(delta: FloatArray) -> FloatArray:
|
|
406
|
+
m, s2 = delta.mean(), delta.var()
|
|
407
|
+
s2 = max(s2, self.eps)
|
|
408
|
+
return (2 * s2 + m**2) / s2
|
|
409
|
+
|
|
410
|
+
def bprior(delta: FloatArray) -> FloatArray:
|
|
411
|
+
m, s2 = delta.mean(), delta.var()
|
|
412
|
+
s2 = max(s2, self.eps)
|
|
413
|
+
return (m * s2 + m**3) / s2
|
|
414
|
+
|
|
415
|
+
gamma_star = np.empty_like(gamma_hat)
|
|
416
|
+
delta_star = np.empty_like(delta_hat)
|
|
417
|
+
|
|
418
|
+
for i in range(B):
|
|
419
|
+
n_i = n_vec[i]
|
|
420
|
+
g_hat_i = gamma_hat[i]
|
|
421
|
+
d_hat_i = delta_hat[i]
|
|
422
|
+
a_i = aprior(d_hat_i)
|
|
423
|
+
b_i = bprior(d_hat_i)
|
|
424
|
+
|
|
425
|
+
g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
|
|
426
|
+
for _ in range(max_iter):
|
|
427
|
+
g_prev, d_prev = g_new, d_new
|
|
428
|
+
g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
|
|
429
|
+
sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
|
|
430
|
+
d_new = postvar(sum2, n_i, a_i, b_i)
|
|
431
|
+
if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
|
|
432
|
+
self.mean_only
|
|
433
|
+
or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
|
|
434
|
+
):
|
|
435
|
+
break
|
|
436
|
+
gamma_star[i] = g_new
|
|
437
|
+
delta_star[i] = 1.0 if self.mean_only else d_new
|
|
438
|
+
return gamma_star, delta_star
|
|
439
|
+
|
|
440
|
+
def _shrink_gamma(
|
|
441
|
+
self,
|
|
442
|
+
gamma_hat: FloatArray,
|
|
443
|
+
delta_hat: FloatArray,
|
|
444
|
+
n_per_batch: dict[str, int] | FloatArray,
|
|
445
|
+
*,
|
|
446
|
+
parametric: bool,
|
|
447
|
+
) -> FloatArray:
|
|
448
|
+
"""Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
|
|
449
|
+
gamma, _ = self._shrink_gamma_delta(
|
|
450
|
+
gamma_hat, delta_hat, n_per_batch, parametric=parametric
|
|
451
|
+
)
|
|
452
|
+
return gamma
|
|
453
|
+
|
|
454
|
+
def transform(
|
|
455
|
+
self,
|
|
456
|
+
X: ArrayLike,
|
|
457
|
+
*,
|
|
458
|
+
batch: ArrayLike,
|
|
459
|
+
discrete_covariates: ArrayLike | None = None,
|
|
460
|
+
continuous_covariates: ArrayLike | None = None,
|
|
461
|
+
) -> pd.DataFrame:
|
|
462
|
+
"""Transform the data using fitted ComBat parameters."""
|
|
463
|
+
if not hasattr(self, "_gamma_star"):
|
|
464
|
+
raise ValueError(
|
|
465
|
+
"This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
|
|
466
|
+
)
|
|
467
|
+
if not isinstance(X, pd.DataFrame):
|
|
468
|
+
X = pd.DataFrame(X)
|
|
469
|
+
idx = X.index
|
|
470
|
+
batch = self._as_series(batch, idx, "batch")
|
|
471
|
+
unseen = set(batch.cat.categories) - set(self._batch_levels)
|
|
472
|
+
if unseen:
|
|
473
|
+
raise ValueError(f"Unseen batch levels during transform: {unseen}.")
|
|
474
|
+
disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
|
|
475
|
+
cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
|
|
476
|
+
|
|
477
|
+
method = self.method.lower()
|
|
478
|
+
if method == "johnson":
|
|
479
|
+
return self._transform_johnson(X, batch)
|
|
480
|
+
elif method == "fortin":
|
|
481
|
+
return self._transform_fortin(X, batch, disc, cont)
|
|
482
|
+
elif method == "chen":
|
|
483
|
+
return self._transform_chen(X, batch, disc, cont)
|
|
484
|
+
else:
|
|
485
|
+
raise ValueError(f"Unknown method: {method}.")
|
|
486
|
+
|
|
487
|
+
def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
|
|
488
|
+
"""Johnson transform implementation."""
|
|
489
|
+
pooled = self._pooled_var
|
|
490
|
+
grand = self._grand_mean
|
|
491
|
+
|
|
492
|
+
Xs = (X - grand) / np.sqrt(pooled)
|
|
493
|
+
X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
|
|
494
|
+
|
|
495
|
+
for i, lvl in enumerate(self._batch_levels):
|
|
496
|
+
idx = batch == lvl
|
|
497
|
+
if not idx.any():
|
|
498
|
+
continue
|
|
499
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
500
|
+
X_adj.loc[idx] = X.loc[idx].values
|
|
501
|
+
continue
|
|
502
|
+
|
|
503
|
+
g = self._gamma_star[i]
|
|
504
|
+
d = self._delta_star[i]
|
|
505
|
+
Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
|
|
506
|
+
X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
|
|
507
|
+
return X_adj
|
|
508
|
+
|
|
509
|
+
def _transform_fortin(
|
|
510
|
+
self,
|
|
511
|
+
X: pd.DataFrame,
|
|
512
|
+
batch: pd.Series,
|
|
513
|
+
disc: pd.DataFrame | None,
|
|
514
|
+
cont: pd.DataFrame | None,
|
|
515
|
+
) -> pd.DataFrame:
|
|
516
|
+
"""Fortin transform implementation."""
|
|
517
|
+
batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
|
|
518
|
+
if self.reference_batch is not None:
|
|
519
|
+
batch_dummies.loc[:, self.reference_batch] = 1.0
|
|
520
|
+
|
|
521
|
+
parts = [batch_dummies]
|
|
522
|
+
if disc is not None:
|
|
523
|
+
parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
|
|
524
|
+
if cont is not None:
|
|
525
|
+
parts.append(cont.astype(float))
|
|
526
|
+
|
|
527
|
+
design = pd.concat(parts, axis=1).values
|
|
528
|
+
|
|
529
|
+
X_np = X.values
|
|
530
|
+
stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
|
|
531
|
+
Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
|
|
532
|
+
|
|
533
|
+
for i, lvl in enumerate(self._batch_levels):
|
|
534
|
+
idx = batch == lvl
|
|
535
|
+
if not idx.any():
|
|
536
|
+
continue
|
|
537
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
538
|
+
# leave reference samples unchanged
|
|
539
|
+
continue
|
|
540
|
+
|
|
541
|
+
g = self._gamma_star[i]
|
|
542
|
+
d = self._delta_star[i]
|
|
543
|
+
if self.mean_only:
|
|
544
|
+
Xs[idx] = Xs[idx] - g
|
|
545
|
+
else:
|
|
546
|
+
Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
|
|
547
|
+
|
|
548
|
+
X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
|
|
549
|
+
return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
|
|
550
|
+
|
|
551
|
+
def _transform_chen(
|
|
552
|
+
self,
|
|
553
|
+
X: pd.DataFrame,
|
|
554
|
+
batch: pd.Series,
|
|
555
|
+
disc: pd.DataFrame | None,
|
|
556
|
+
cont: pd.DataFrame | None,
|
|
557
|
+
) -> pd.DataFrame:
|
|
558
|
+
"""Chen transform implementation."""
|
|
559
|
+
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
560
|
+
scores = self._covbat_pca.transform(X_meanvar_adj)
|
|
561
|
+
n_pc = self._covbat_n_pc
|
|
562
|
+
scores_adj = scores.copy()
|
|
563
|
+
|
|
564
|
+
for i, lvl in enumerate(self._batch_levels_pc):
|
|
565
|
+
idx = batch == lvl
|
|
566
|
+
if not idx.any():
|
|
567
|
+
continue
|
|
568
|
+
if self.reference_batch is not None and lvl == self.reference_batch:
|
|
569
|
+
continue
|
|
570
|
+
g = self._pc_gamma_star[i]
|
|
571
|
+
d = self._pc_delta_star[i]
|
|
572
|
+
if self.mean_only:
|
|
573
|
+
scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
|
|
574
|
+
else:
|
|
575
|
+
scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
|
|
576
|
+
|
|
577
|
+
X_recon = self._covbat_pca.inverse_transform(scores_adj)
|
|
578
|
+
return pd.DataFrame(X_recon, index=X.index, columns=X.columns)
|