combatlearn 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from .combat import ComBat
1
+ from .sklearn_api import ComBat
2
2
 
3
3
  __all__ = ["ComBat"]
4
- __version__ = "1.1.1"
4
+ __version__ = "1.2.0"
5
5
  __author__ = "Ettore Rocchi"
combatlearn/core.py ADDED
@@ -0,0 +1,578 @@
1
+ """ComBat algorithm core.
2
+
3
+ `ComBatModel` implements three variants of the ComBat algorithm:
4
+ * Johnson et al. (2007) vanilla ComBat (method="johnson")
5
+ * Fortin et al. (2018) extension with covariates (method="fortin")
6
+ * Chen et al. (2022) CovBat (method="chen")
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import warnings
12
+ from typing import Any, Literal
13
+
14
+ import numpy as np
15
+ import numpy.linalg as la
16
+ import numpy.typing as npt
17
+ import pandas as pd
18
+ from sklearn.decomposition import PCA
19
+
20
+ ArrayLike = pd.DataFrame | pd.Series | npt.NDArray[Any]
21
+ FloatArray = npt.NDArray[np.float64]
22
+
23
+
24
+ class ComBatModel:
25
+ """ComBat algorithm.
26
+
27
+ Parameters
28
+ ----------
29
+ method : {'johnson', 'fortin', 'chen'}, default='johnson'
30
+ * 'johnson' - classic ComBat.
31
+ * 'fortin' - covariate-aware ComBat.
32
+ * 'chen' - CovBat, PCA-based ComBat.
33
+ parametric : bool, default=True
34
+ Use the parametric empirical Bayes variant.
35
+ mean_only : bool, default=False
36
+ If True, only the mean is adjusted (`gamma_star`),
37
+ ignoring the variance (`delta_star`).
38
+ reference_batch : str, optional
39
+ If specified, the batch level to use as reference.
40
+ covbat_cov_thresh : float or int, default=0.9
41
+ CovBat: cumulative variance threshold (0, 1] to retain PCs, or
42
+ integer >= 1 specifying the number of components directly.
43
+ eps : float, default=1e-8
44
+ Numerical jitter to avoid division-by-zero.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *,
50
+ method: Literal["johnson", "fortin", "chen"] = "johnson",
51
+ parametric: bool = True,
52
+ mean_only: bool = False,
53
+ reference_batch: str | None = None,
54
+ eps: float = 1e-8,
55
+ covbat_cov_thresh: float | int = 0.9,
56
+ ) -> None:
57
+ self.method: str = method
58
+ self.parametric: bool = parametric
59
+ self.mean_only: bool = bool(mean_only)
60
+ self.reference_batch: str | None = reference_batch
61
+ self.eps: float = float(eps)
62
+ self.covbat_cov_thresh: float | int = covbat_cov_thresh
63
+
64
+ self._batch_levels: pd.Index
65
+ self._grand_mean: pd.Series
66
+ self._pooled_var: pd.Series
67
+ self._gamma_star: FloatArray
68
+ self._delta_star: FloatArray
69
+ self._n_per_batch: dict[str, int]
70
+ self._reference_batch_idx: int | None
71
+ self._beta_hat_nonbatch: FloatArray
72
+ self._n_batch: int
73
+ self._p_design: int
74
+ self._covbat_pca: PCA
75
+ self._covbat_n_pc: int
76
+ self._batch_levels_pc: pd.Index
77
+ self._pc_gamma_star: FloatArray
78
+ self._pc_delta_star: FloatArray
79
+
80
+ # Validate covbat_cov_thresh
81
+ if isinstance(self.covbat_cov_thresh, float):
82
+ if not (0.0 < self.covbat_cov_thresh <= 1.0):
83
+ raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
84
+ elif isinstance(self.covbat_cov_thresh, int):
85
+ if self.covbat_cov_thresh < 1:
86
+ raise ValueError("covbat_cov_thresh must be >= 1 when int.")
87
+ else:
88
+ raise TypeError("covbat_cov_thresh must be float or int.")
89
+
90
+ @staticmethod
91
+ def _as_series(arr: ArrayLike, index: pd.Index, name: str) -> pd.Series:
92
+ """Convert array-like to categorical Series with validation."""
93
+ ser = arr.copy() if isinstance(arr, pd.Series) else pd.Series(arr, index=index, name=name)
94
+ if not ser.index.equals(index):
95
+ raise ValueError(f"`{name}` index mismatch with `X`.")
96
+ return ser.astype("category")
97
+
98
+ @staticmethod
99
+ def _to_df(arr: ArrayLike | None, index: pd.Index, name: str) -> pd.DataFrame | None:
100
+ """Convert array-like to DataFrame."""
101
+ if arr is None:
102
+ return None
103
+ if isinstance(arr, pd.Series):
104
+ arr = arr.to_frame()
105
+ if not isinstance(arr, pd.DataFrame):
106
+ arr = pd.DataFrame(arr, index=index)
107
+ if not arr.index.equals(index):
108
+ raise ValueError(f"`{name}` index mismatch with `X`.")
109
+ return arr
110
+
111
+ def fit(
112
+ self,
113
+ X: ArrayLike,
114
+ y: ArrayLike | None = None,
115
+ *,
116
+ batch: ArrayLike,
117
+ discrete_covariates: ArrayLike | None = None,
118
+ continuous_covariates: ArrayLike | None = None,
119
+ ) -> ComBatModel:
120
+ """Fit the ComBat model."""
121
+ method = self.method.lower()
122
+ if method not in {"johnson", "fortin", "chen"}:
123
+ raise ValueError("method must be 'johnson', 'fortin', or 'chen'.")
124
+ if not isinstance(X, pd.DataFrame):
125
+ X = pd.DataFrame(X)
126
+ idx = X.index
127
+ batch = self._as_series(batch, idx, "batch")
128
+
129
+ disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
130
+ cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
131
+
132
+ if self.reference_batch is not None and self.reference_batch not in batch.cat.categories:
133
+ raise ValueError(
134
+ f"reference_batch={self.reference_batch!r} not present in the data batches."
135
+ f"{list(batch.cat.categories)}"
136
+ )
137
+
138
+ if method == "johnson":
139
+ if disc is not None or cont is not None:
140
+ warnings.warn("Covariates are ignored when using method='johnson'.", stacklevel=2)
141
+ self._fit_johnson(X, batch)
142
+ elif method == "fortin":
143
+ self._fit_fortin(X, batch, disc, cont)
144
+ elif method == "chen":
145
+ self._fit_chen(X, batch, disc, cont)
146
+ return self
147
+
148
+ def _fit_johnson(self, X: pd.DataFrame, batch: pd.Series) -> None:
149
+ """Johnson et al. (2007) ComBat."""
150
+ self._batch_levels = batch.cat.categories
151
+ pooled_var = X.var(axis=0, ddof=1) + self.eps
152
+ grand_mean = X.mean(axis=0)
153
+
154
+ Xs = (X - grand_mean) / np.sqrt(pooled_var)
155
+
156
+ n_per_batch: dict[str, int] = {}
157
+ gamma_hat: list[npt.NDArray[np.float64]] = []
158
+ delta_hat: list[npt.NDArray[np.float64]] = []
159
+
160
+ for lvl in self._batch_levels:
161
+ idx = batch == lvl
162
+ n_b = int(idx.sum())
163
+ if n_b < 2:
164
+ raise ValueError(f"Batch '{lvl}' has <2 samples.")
165
+ n_per_batch[str(lvl)] = n_b
166
+ xb = Xs.loc[idx]
167
+ gamma_hat.append(xb.mean(axis=0).values)
168
+ delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
169
+
170
+ gamma_hat_arr = np.vstack(gamma_hat)
171
+ delta_hat_arr = np.vstack(delta_hat)
172
+
173
+ if self.mean_only:
174
+ gamma_star = self._shrink_gamma(
175
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
176
+ )
177
+ delta_star = np.ones_like(delta_hat_arr)
178
+ else:
179
+ gamma_star, delta_star = self._shrink_gamma_delta(
180
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
181
+ )
182
+
183
+ if self.reference_batch is not None:
184
+ ref_idx = list(self._batch_levels).index(self.reference_batch)
185
+ gamma_ref = gamma_star[ref_idx]
186
+ delta_ref = delta_star[ref_idx]
187
+ gamma_star = gamma_star - gamma_ref
188
+ if not self.mean_only:
189
+ delta_star = delta_star / delta_ref
190
+ self._reference_batch_idx = ref_idx
191
+ else:
192
+ self._reference_batch_idx = None
193
+
194
+ self._grand_mean = grand_mean
195
+ self._pooled_var = pooled_var
196
+ self._gamma_star = gamma_star
197
+ self._delta_star = delta_star
198
+ self._n_per_batch = n_per_batch
199
+
200
+ def _fit_fortin(
201
+ self,
202
+ X: pd.DataFrame,
203
+ batch: pd.Series,
204
+ disc: pd.DataFrame | None,
205
+ cont: pd.DataFrame | None,
206
+ ) -> None:
207
+ """Fortin et al. (2018) neuroComBat."""
208
+ self._batch_levels = batch.cat.categories
209
+ n_batch = len(self._batch_levels)
210
+ n_samples = len(X)
211
+
212
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)
213
+ if self.reference_batch is not None:
214
+ if self.reference_batch not in self._batch_levels:
215
+ raise ValueError(
216
+ f"reference_batch={self.reference_batch!r} not present in batches."
217
+ f"{list(self._batch_levels)}"
218
+ )
219
+ batch_dummies.loc[:, self.reference_batch] = 1.0
220
+
221
+ parts: list[pd.DataFrame] = [batch_dummies]
222
+ if disc is not None:
223
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
224
+
225
+ if cont is not None:
226
+ parts.append(cont.astype(float))
227
+
228
+ design = pd.concat(parts, axis=1).values
229
+ p_design = design.shape[1]
230
+
231
+ X_np = X.values
232
+ beta_hat = la.lstsq(design, X_np, rcond=None)[0]
233
+
234
+ beta_hat_batch = beta_hat[:n_batch]
235
+ self._beta_hat_nonbatch = beta_hat[n_batch:]
236
+
237
+ n_per_batch = batch.value_counts().sort_index().astype(int).values
238
+ self._n_per_batch = dict(zip(self._batch_levels, n_per_batch, strict=True))
239
+
240
+ if self.reference_batch is not None:
241
+ ref_idx = list(self._batch_levels).index(self.reference_batch)
242
+ grand_mean = beta_hat_batch[ref_idx]
243
+ else:
244
+ grand_mean = (n_per_batch / n_samples) @ beta_hat_batch
245
+ ref_idx = None
246
+
247
+ self._grand_mean = pd.Series(grand_mean, index=X.columns)
248
+
249
+ if self.reference_batch is not None:
250
+ ref_mask = (batch == self.reference_batch).values
251
+ resid = X_np[ref_mask] - design[ref_mask] @ beta_hat
252
+ denom = int(ref_mask.sum())
253
+ else:
254
+ resid = X_np - design @ beta_hat
255
+ denom = n_samples
256
+ var_pooled = (resid**2).sum(axis=0) / denom + self.eps
257
+ self._pooled_var = pd.Series(var_pooled, index=X.columns)
258
+
259
+ stand_mean = grand_mean + design[:, n_batch:] @ self._beta_hat_nonbatch
260
+ Xs = (X_np - stand_mean) / np.sqrt(var_pooled)
261
+
262
+ gamma_hat = np.vstack([Xs[batch == lvl].mean(axis=0) for lvl in self._batch_levels])
263
+ delta_hat = np.vstack(
264
+ [Xs[batch == lvl].var(axis=0, ddof=1) + self.eps for lvl in self._batch_levels]
265
+ )
266
+
267
+ if self.mean_only:
268
+ gamma_star = self._shrink_gamma(
269
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
270
+ )
271
+ delta_star = np.ones_like(delta_hat)
272
+ else:
273
+ gamma_star, delta_star = self._shrink_gamma_delta(
274
+ gamma_hat, delta_hat, n_per_batch, parametric=self.parametric
275
+ )
276
+
277
+ if ref_idx is not None:
278
+ gamma_star[ref_idx] = 0.0
279
+ if not self.mean_only:
280
+ delta_star[ref_idx] = 1.0
281
+ self._reference_batch_idx = ref_idx
282
+
283
+ self._gamma_star = gamma_star
284
+ self._delta_star = delta_star
285
+ self._n_batch = n_batch
286
+ self._p_design = p_design
287
+
288
+ def _fit_chen(
289
+ self,
290
+ X: pd.DataFrame,
291
+ batch: pd.Series,
292
+ disc: pd.DataFrame | None,
293
+ cont: pd.DataFrame | None,
294
+ ) -> None:
295
+ """Chen et al. (2022) CovBat."""
296
+ self._fit_fortin(X, batch, disc, cont)
297
+ X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
298
+ pca = PCA(svd_solver="full", whiten=False).fit(X_meanvar_adj)
299
+
300
+ # Determine number of components based on threshold type
301
+ if isinstance(self.covbat_cov_thresh, int):
302
+ n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
303
+ else:
304
+ cumulative = np.cumsum(pca.explained_variance_ratio_)
305
+ n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
306
+
307
+ self._covbat_pca = pca
308
+ self._covbat_n_pc = n_pc
309
+
310
+ scores = pca.transform(X_meanvar_adj)[:, :n_pc]
311
+ scores_df = pd.DataFrame(scores, index=X.index, columns=[f"PC{i + 1}" for i in range(n_pc)])
312
+ self._batch_levels_pc = self._batch_levels
313
+ n_per_batch = self._n_per_batch
314
+
315
+ gamma_hat: list[npt.NDArray[np.float64]] = []
316
+ delta_hat: list[npt.NDArray[np.float64]] = []
317
+ for lvl in self._batch_levels_pc:
318
+ idx = batch == lvl
319
+ xb = scores_df.loc[idx]
320
+ gamma_hat.append(xb.mean(axis=0).values)
321
+ delta_hat.append(xb.var(axis=0, ddof=1).values + self.eps)
322
+ gamma_hat_arr = np.vstack(gamma_hat)
323
+ delta_hat_arr = np.vstack(delta_hat)
324
+
325
+ if self.mean_only:
326
+ gamma_star = self._shrink_gamma(
327
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
328
+ )
329
+ delta_star = np.ones_like(delta_hat_arr)
330
+ else:
331
+ gamma_star, delta_star = self._shrink_gamma_delta(
332
+ gamma_hat_arr, delta_hat_arr, n_per_batch, parametric=self.parametric
333
+ )
334
+
335
+ if self.reference_batch is not None:
336
+ ref_idx = list(self._batch_levels_pc).index(self.reference_batch)
337
+ gamma_ref = gamma_star[ref_idx]
338
+ delta_ref = delta_star[ref_idx]
339
+ gamma_star = gamma_star - gamma_ref
340
+ if not self.mean_only:
341
+ delta_star = delta_star / delta_ref
342
+
343
+ self._pc_gamma_star = gamma_star
344
+ self._pc_delta_star = delta_star
345
+
346
+ def _shrink_gamma_delta(
347
+ self,
348
+ gamma_hat: FloatArray,
349
+ delta_hat: FloatArray,
350
+ n_per_batch: dict[str, int] | FloatArray,
351
+ *,
352
+ parametric: bool,
353
+ max_iter: int = 100,
354
+ tol: float = 1e-4,
355
+ ) -> tuple[FloatArray, FloatArray]:
356
+ """Empirical Bayes shrinkage estimation."""
357
+ if parametric:
358
+ gamma_bar = gamma_hat.mean(axis=0)
359
+ t2 = gamma_hat.var(axis=0, ddof=1)
360
+ a_prior = (delta_hat.mean(axis=0) ** 2) / delta_hat.var(axis=0, ddof=1) + 2
361
+ b_prior = delta_hat.mean(axis=0) * (a_prior - 1)
362
+
363
+ B, _p = gamma_hat.shape
364
+ gamma_star = np.empty_like(gamma_hat)
365
+ delta_star = np.empty_like(delta_hat)
366
+ n_vec = (
367
+ np.array(list(n_per_batch.values()))
368
+ if isinstance(n_per_batch, dict)
369
+ else n_per_batch
370
+ )
371
+
372
+ for i in range(B):
373
+ n_i = n_vec[i]
374
+ g, d = gamma_hat[i], delta_hat[i]
375
+ gamma_post_var = 1.0 / (n_i / d + 1.0 / t2)
376
+ gamma_star[i] = gamma_post_var * (n_i * g / d + gamma_bar / t2)
377
+
378
+ a_post = a_prior + n_i / 2.0
379
+ b_post = b_prior + 0.5 * n_i * d
380
+ delta_star[i] = b_post / (a_post - 1)
381
+ return gamma_star, delta_star
382
+
383
+ else:
384
+ B, _p = gamma_hat.shape
385
+ n_vec = (
386
+ np.array(list(n_per_batch.values()))
387
+ if isinstance(n_per_batch, dict)
388
+ else n_per_batch
389
+ )
390
+ gamma_bar = gamma_hat.mean(axis=0)
391
+ t2 = gamma_hat.var(axis=0, ddof=1)
392
+
393
+ def postmean(
394
+ g_hat: FloatArray,
395
+ g_bar: FloatArray,
396
+ n: float,
397
+ d_star: FloatArray,
398
+ t2_: FloatArray,
399
+ ) -> FloatArray:
400
+ return (t2_ * n * g_hat + d_star * g_bar) / (t2_ * n + d_star)
401
+
402
+ def postvar(sum2: FloatArray, n: float, a: FloatArray, b: FloatArray) -> FloatArray:
403
+ return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
404
+
405
+ def aprior(delta: FloatArray) -> FloatArray:
406
+ m, s2 = delta.mean(), delta.var()
407
+ s2 = max(s2, self.eps)
408
+ return (2 * s2 + m**2) / s2
409
+
410
+ def bprior(delta: FloatArray) -> FloatArray:
411
+ m, s2 = delta.mean(), delta.var()
412
+ s2 = max(s2, self.eps)
413
+ return (m * s2 + m**3) / s2
414
+
415
+ gamma_star = np.empty_like(gamma_hat)
416
+ delta_star = np.empty_like(delta_hat)
417
+
418
+ for i in range(B):
419
+ n_i = n_vec[i]
420
+ g_hat_i = gamma_hat[i]
421
+ d_hat_i = delta_hat[i]
422
+ a_i = aprior(d_hat_i)
423
+ b_i = bprior(d_hat_i)
424
+
425
+ g_new, d_new = g_hat_i.copy(), d_hat_i.copy()
426
+ for _ in range(max_iter):
427
+ g_prev, d_prev = g_new, d_new
428
+ g_new = postmean(g_hat_i, gamma_bar, n_i, d_prev, t2)
429
+ sum2 = (n_i - 1) * d_hat_i + n_i * (g_hat_i - g_new) ** 2
430
+ d_new = postvar(sum2, n_i, a_i, b_i)
431
+ if np.max(np.abs(g_new - g_prev) / (np.abs(g_prev) + self.eps)) < tol and (
432
+ self.mean_only
433
+ or np.max(np.abs(d_new - d_prev) / (np.abs(d_prev) + self.eps)) < tol
434
+ ):
435
+ break
436
+ gamma_star[i] = g_new
437
+ delta_star[i] = 1.0 if self.mean_only else d_new
438
+ return gamma_star, delta_star
439
+
440
+ def _shrink_gamma(
441
+ self,
442
+ gamma_hat: FloatArray,
443
+ delta_hat: FloatArray,
444
+ n_per_batch: dict[str, int] | FloatArray,
445
+ *,
446
+ parametric: bool,
447
+ ) -> FloatArray:
448
+ """Convenience wrapper that returns only gamma* (for *mean-only* mode)."""
449
+ gamma, _ = self._shrink_gamma_delta(
450
+ gamma_hat, delta_hat, n_per_batch, parametric=parametric
451
+ )
452
+ return gamma
453
+
454
+ def transform(
455
+ self,
456
+ X: ArrayLike,
457
+ *,
458
+ batch: ArrayLike,
459
+ discrete_covariates: ArrayLike | None = None,
460
+ continuous_covariates: ArrayLike | None = None,
461
+ ) -> pd.DataFrame:
462
+ """Transform the data using fitted ComBat parameters."""
463
+ if not hasattr(self, "_gamma_star"):
464
+ raise ValueError(
465
+ "This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'."
466
+ )
467
+ if not isinstance(X, pd.DataFrame):
468
+ X = pd.DataFrame(X)
469
+ idx = X.index
470
+ batch = self._as_series(batch, idx, "batch")
471
+ unseen = set(batch.cat.categories) - set(self._batch_levels)
472
+ if unseen:
473
+ raise ValueError(f"Unseen batch levels during transform: {unseen}.")
474
+ disc = self._to_df(discrete_covariates, idx, "discrete_covariates")
475
+ cont = self._to_df(continuous_covariates, idx, "continuous_covariates")
476
+
477
+ method = self.method.lower()
478
+ if method == "johnson":
479
+ return self._transform_johnson(X, batch)
480
+ elif method == "fortin":
481
+ return self._transform_fortin(X, batch, disc, cont)
482
+ elif method == "chen":
483
+ return self._transform_chen(X, batch, disc, cont)
484
+ else:
485
+ raise ValueError(f"Unknown method: {method}.")
486
+
487
+ def _transform_johnson(self, X: pd.DataFrame, batch: pd.Series) -> pd.DataFrame:
488
+ """Johnson transform implementation."""
489
+ pooled = self._pooled_var
490
+ grand = self._grand_mean
491
+
492
+ Xs = (X - grand) / np.sqrt(pooled)
493
+ X_adj = pd.DataFrame(index=X.index, columns=X.columns, dtype=float)
494
+
495
+ for i, lvl in enumerate(self._batch_levels):
496
+ idx = batch == lvl
497
+ if not idx.any():
498
+ continue
499
+ if self.reference_batch is not None and lvl == self.reference_batch:
500
+ X_adj.loc[idx] = X.loc[idx].values
501
+ continue
502
+
503
+ g = self._gamma_star[i]
504
+ d = self._delta_star[i]
505
+ Xb = Xs.loc[idx] - g if self.mean_only else (Xs.loc[idx] - g) / np.sqrt(d)
506
+ X_adj.loc[idx] = (Xb * np.sqrt(pooled) + grand).values
507
+ return X_adj
508
+
509
+ def _transform_fortin(
510
+ self,
511
+ X: pd.DataFrame,
512
+ batch: pd.Series,
513
+ disc: pd.DataFrame | None,
514
+ cont: pd.DataFrame | None,
515
+ ) -> pd.DataFrame:
516
+ """Fortin transform implementation."""
517
+ batch_dummies = pd.get_dummies(batch, drop_first=False).astype(float)[self._batch_levels]
518
+ if self.reference_batch is not None:
519
+ batch_dummies.loc[:, self.reference_batch] = 1.0
520
+
521
+ parts = [batch_dummies]
522
+ if disc is not None:
523
+ parts.append(pd.get_dummies(disc.astype("category"), drop_first=True).astype(float))
524
+ if cont is not None:
525
+ parts.append(cont.astype(float))
526
+
527
+ design = pd.concat(parts, axis=1).values
528
+
529
+ X_np = X.values
530
+ stand_mu = self._grand_mean.values + design[:, self._n_batch :] @ self._beta_hat_nonbatch
531
+ Xs = (X_np - stand_mu) / np.sqrt(self._pooled_var.values)
532
+
533
+ for i, lvl in enumerate(self._batch_levels):
534
+ idx = batch == lvl
535
+ if not idx.any():
536
+ continue
537
+ if self.reference_batch is not None and lvl == self.reference_batch:
538
+ # leave reference samples unchanged
539
+ continue
540
+
541
+ g = self._gamma_star[i]
542
+ d = self._delta_star[i]
543
+ if self.mean_only:
544
+ Xs[idx] = Xs[idx] - g
545
+ else:
546
+ Xs[idx] = (Xs[idx] - g) / np.sqrt(d)
547
+
548
+ X_adj = Xs * np.sqrt(self._pooled_var.values) + stand_mu
549
+ return pd.DataFrame(X_adj, index=X.index, columns=X.columns, dtype=float)
550
+
551
+ def _transform_chen(
552
+ self,
553
+ X: pd.DataFrame,
554
+ batch: pd.Series,
555
+ disc: pd.DataFrame | None,
556
+ cont: pd.DataFrame | None,
557
+ ) -> pd.DataFrame:
558
+ """Chen transform implementation."""
559
+ X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
560
+ scores = self._covbat_pca.transform(X_meanvar_adj)
561
+ n_pc = self._covbat_n_pc
562
+ scores_adj = scores.copy()
563
+
564
+ for i, lvl in enumerate(self._batch_levels_pc):
565
+ idx = batch == lvl
566
+ if not idx.any():
567
+ continue
568
+ if self.reference_batch is not None and lvl == self.reference_batch:
569
+ continue
570
+ g = self._pc_gamma_star[i]
571
+ d = self._pc_delta_star[i]
572
+ if self.mean_only:
573
+ scores_adj[idx, :n_pc] = scores_adj[idx, :n_pc] - g
574
+ else:
575
+ scores_adj[idx, :n_pc] = (scores_adj[idx, :n_pc] - g) / np.sqrt(d)
576
+
577
+ X_recon = self._covbat_pca.inverse_transform(scores_adj)
578
+ return pd.DataFrame(X_recon, index=X.index, columns=X.columns)