combatlearn 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ """Scikit-learn compatible ComBat wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+
9
+ from .core import ArrayLike, ComBatModel
10
+ from .metrics import ComBatMetricsMixin
11
+ from .visualization import ComBatVisualizationMixin
12
+
13
+
14
+ class ComBat(ComBatMetricsMixin, ComBatVisualizationMixin, BaseEstimator, TransformerMixin):
15
+ """Pipeline-friendly wrapper around `ComBatModel`.
16
+
17
+ Stores batch (and optional covariates) passed at construction and
18
+ appropriately uses them for separate `fit` and `transform`.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : array-like of shape (n_samples,)
23
+ Batch labels for each sample.
24
+ discrete_covariates : array-like, optional
25
+ Categorical covariates to protect (Fortin/Chen only).
26
+ continuous_covariates : array-like, optional
27
+ Continuous covariates to protect (Fortin/Chen only).
28
+ method : {'johnson', 'fortin', 'chen'}, default='johnson'
29
+ ComBat variant to use.
30
+ parametric : bool, default=True
31
+ Use parametric empirical Bayes.
32
+ mean_only : bool, default=False
33
+ Adjust only the mean (ignore variance).
34
+ reference_batch : str, optional
35
+ Batch level to leave unchanged.
36
+ eps : float, default=1e-8
37
+ Numerical jitter for stability.
38
+ covbat_cov_thresh : float or int, default=0.9
39
+ CovBat variance threshold for PCs.
40
+ compute_metrics : bool, default=False
41
+ If True, ``fit_transform`` caches batch metrics in ``metrics_``.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ batch: ArrayLike,
47
+ *,
48
+ discrete_covariates: ArrayLike | None = None,
49
+ continuous_covariates: ArrayLike | None = None,
50
+ method: str = "johnson",
51
+ parametric: bool = True,
52
+ mean_only: bool = False,
53
+ reference_batch: str | None = None,
54
+ eps: float = 1e-8,
55
+ covbat_cov_thresh: float | int = 0.9,
56
+ compute_metrics: bool = False,
57
+ ) -> None:
58
+ self.batch = batch
59
+ self.discrete_covariates = discrete_covariates
60
+ self.continuous_covariates = continuous_covariates
61
+ self.method = method
62
+ self.parametric = parametric
63
+ self.mean_only = mean_only
64
+ self.reference_batch = reference_batch
65
+ self.eps = eps
66
+ self.covbat_cov_thresh = covbat_cov_thresh
67
+ self.compute_metrics = compute_metrics
68
+ self._model = ComBatModel(
69
+ method=method,
70
+ parametric=parametric,
71
+ mean_only=mean_only,
72
+ reference_batch=reference_batch,
73
+ eps=eps,
74
+ covbat_cov_thresh=covbat_cov_thresh,
75
+ )
76
+
77
+ def fit(self, X: ArrayLike, y: ArrayLike | None = None) -> ComBat:
78
+ """Fit the ComBat model."""
79
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
80
+ batch_vec = self._subset(self.batch, idx)
81
+ disc = self._subset(self.discrete_covariates, idx)
82
+ cont = self._subset(self.continuous_covariates, idx)
83
+ self._model.fit(
84
+ X,
85
+ batch=batch_vec,
86
+ discrete_covariates=disc,
87
+ continuous_covariates=cont,
88
+ )
89
+ self._fitted_batch = batch_vec
90
+ return self
91
+
92
+ def transform(self, X: ArrayLike) -> pd.DataFrame:
93
+ """Transform the data using fitted ComBat parameters."""
94
+ idx = X.index if isinstance(X, pd.DataFrame) else pd.RangeIndex(len(X))
95
+ batch_vec = self._subset(self.batch, idx)
96
+ disc = self._subset(self.discrete_covariates, idx)
97
+ cont = self._subset(self.continuous_covariates, idx)
98
+ return self._model.transform(
99
+ X,
100
+ batch=batch_vec,
101
+ discrete_covariates=disc,
102
+ continuous_covariates=cont,
103
+ )
104
+
105
+ @staticmethod
106
+ def _subset(obj: ArrayLike | None, idx: pd.Index) -> pd.DataFrame | pd.Series | None:
107
+ """Subset array-like object by index."""
108
+ if obj is None:
109
+ return None
110
+ if isinstance(obj, (pd.Series, pd.DataFrame)):
111
+ return obj.loc[idx]
112
+ else:
113
+ if isinstance(obj, np.ndarray) and obj.ndim == 1:
114
+ return pd.Series(obj, index=idx)
115
+ else:
116
+ return pd.DataFrame(obj, index=idx)
117
+
118
+ def fit_transform(self, X: ArrayLike, y: ArrayLike | None = None) -> pd.DataFrame:
119
+ """
120
+ Fit and transform the data, optionally computing metrics.
121
+
122
+ If ``compute_metrics=True`` was set at construction, batch effect
123
+ metrics are computed and cached in the ``metrics_`` property.
124
+
125
+ Parameters
126
+ ----------
127
+ X : array-like of shape (n_samples, n_features)
128
+ Input data to fit and transform.
129
+ y : None
130
+ Ignored. Present for API compatibility.
131
+
132
+ Returns
133
+ -------
134
+ X_transformed : pd.DataFrame
135
+ Batch-corrected data.
136
+ """
137
+ self.fit(X, y)
138
+ X_transformed = self.transform(X)
139
+
140
+ if self.compute_metrics:
141
+ self._metrics_cache = self.compute_batch_metrics(X)
142
+
143
+ return X_transformed
@@ -0,0 +1,533 @@
1
+ """Visualization utilities for ComBat batch correction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal
6
+
7
+ import matplotlib
8
+ import matplotlib.colors as mcolors
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ import plotly.graph_objects as go
13
+ import umap
14
+ from plotly.subplots import make_subplots
15
+ from sklearn.decomposition import PCA
16
+ from sklearn.manifold import TSNE
17
+
18
+ from .core import ArrayLike, FloatArray
19
+
20
+
21
+ class ComBatVisualizationMixin:
22
+ """Mixin providing visualization methods for the ComBat wrapper."""
23
+
24
+ def plot_transformation(
25
+ self,
26
+ X: ArrayLike,
27
+ *,
28
+ reduction_method: Literal["pca", "tsne", "umap"] = "pca",
29
+ n_components: Literal[2, 3] = 2,
30
+ plot_type: Literal["static", "interactive"] = "static",
31
+ figsize: tuple[int, int] = (12, 5),
32
+ alpha: float = 0.7,
33
+ point_size: int = 50,
34
+ cmap: str = "Set1",
35
+ title: str | None = None,
36
+ show_legend: bool = True,
37
+ return_embeddings: bool = False,
38
+ **reduction_kwargs,
39
+ ) -> Any | tuple[Any, dict[str, FloatArray]]:
40
+ """
41
+ Visualize the ComBat transformation effect using dimensionality reduction.
42
+
43
+ It shows a before/after comparison of data transformed by `ComBat` using
44
+ PCA, t-SNE, or UMAP to reduce dimensions for visualization.
45
+
46
+ Parameters
47
+ ----------
48
+ X : array-like of shape (n_samples, n_features)
49
+ Input data to transform and visualize.
50
+
51
+ reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
52
+ Dimensionality reduction method.
53
+
54
+ n_components : {2, 3}, default=2
55
+ Number of components for dimensionality reduction.
56
+
57
+ plot_type : {`'static'`, `'interactive'`}, default=`'static'`
58
+ Visualization type:
59
+ - `'static'`: matplotlib plots (can be saved as images)
60
+ - `'interactive'`: plotly plots (explorable, requires plotly)
61
+
62
+ return_embeddings : bool, default=False
63
+ If `True`, return embeddings along with the plot.
64
+
65
+ **reduction_kwargs : dict
66
+ Additional parameters for reduction methods.
67
+
68
+ Returns
69
+ -------
70
+ fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
71
+ The figure object containing the plots.
72
+
73
+ embeddings : dict, optional
74
+ If `return_embeddings=True`, dictionary with:
75
+ - `'original'`: embedding of original data
76
+ - `'transformed'`: embedding of ComBat-transformed data
77
+ """
78
+ if not hasattr(self._model, "_gamma_star"):
79
+ raise ValueError(
80
+ "This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'."
81
+ )
82
+
83
+ if n_components not in [2, 3]:
84
+ raise ValueError(f"n_components must be 2 or 3, got {n_components}")
85
+ if reduction_method not in ["pca", "tsne", "umap"]:
86
+ raise ValueError(
87
+ f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'"
88
+ )
89
+ if plot_type not in ["static", "interactive"]:
90
+ raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
91
+
92
+ if not isinstance(X, pd.DataFrame):
93
+ X = pd.DataFrame(X)
94
+
95
+ idx = X.index
96
+ batch_vec = self._subset(self.batch, idx)
97
+ if batch_vec is None:
98
+ raise ValueError("Batch information is required for visualization")
99
+
100
+ X_transformed = self.transform(X)
101
+
102
+ X_np = X.values
103
+ X_trans_np = X_transformed.values
104
+
105
+ if reduction_method == "pca":
106
+ reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
107
+ reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
108
+ elif reduction_method == "tsne":
109
+ tsne_params = {"perplexity": 30, "max_iter": 1000, "random_state": 42}
110
+ tsne_params.update(reduction_kwargs)
111
+ reducer_orig = TSNE(n_components=n_components, **tsne_params)
112
+ reducer_trans = TSNE(n_components=n_components, **tsne_params)
113
+ else:
114
+ umap_params = {"random_state": 42}
115
+ umap_params.update(reduction_kwargs)
116
+ reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
117
+ reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
118
+
119
+ X_embedded_orig = reducer_orig.fit_transform(X_np)
120
+ X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
121
+
122
+ if plot_type == "static":
123
+ fig = self._create_static_plot(
124
+ X_embedded_orig,
125
+ X_embedded_trans,
126
+ batch_vec,
127
+ reduction_method,
128
+ n_components,
129
+ figsize,
130
+ alpha,
131
+ point_size,
132
+ cmap,
133
+ title,
134
+ show_legend,
135
+ )
136
+ else:
137
+ fig = self._create_interactive_plot(
138
+ X_embedded_orig,
139
+ X_embedded_trans,
140
+ batch_vec,
141
+ reduction_method,
142
+ n_components,
143
+ cmap,
144
+ title,
145
+ show_legend,
146
+ )
147
+
148
+ if return_embeddings:
149
+ embeddings = {"original": X_embedded_orig, "transformed": X_embedded_trans}
150
+ return fig, embeddings
151
+ else:
152
+ return fig
153
+
154
+ def _create_static_plot(
155
+ self,
156
+ X_orig: FloatArray,
157
+ X_trans: FloatArray,
158
+ batch_labels: pd.Series,
159
+ method: str,
160
+ n_components: int,
161
+ figsize: tuple[int, int],
162
+ alpha: float,
163
+ point_size: int,
164
+ cmap: str,
165
+ title: str | None,
166
+ show_legend: bool,
167
+ ) -> Any:
168
+ """Create static plots using matplotlib."""
169
+
170
+ fig = plt.figure(figsize=figsize)
171
+
172
+ unique_batches = batch_labels.drop_duplicates()
173
+ n_batches = len(unique_batches)
174
+
175
+ if n_batches <= 10:
176
+ colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
177
+ else:
178
+ colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
179
+
180
+ if n_components == 2:
181
+ ax1 = plt.subplot(1, 2, 1)
182
+ ax2 = plt.subplot(1, 2, 2)
183
+ else:
184
+ ax1 = fig.add_subplot(121, projection="3d")
185
+ ax2 = fig.add_subplot(122, projection="3d")
186
+
187
+ for i, batch in enumerate(unique_batches):
188
+ mask = batch_labels == batch
189
+ if n_components == 2:
190
+ ax1.scatter(
191
+ X_orig[mask, 0],
192
+ X_orig[mask, 1],
193
+ c=[colors[i]],
194
+ s=point_size,
195
+ alpha=alpha,
196
+ label=f"Batch {batch}",
197
+ edgecolors="black",
198
+ linewidth=0.5,
199
+ )
200
+ else:
201
+ ax1.scatter(
202
+ X_orig[mask, 0],
203
+ X_orig[mask, 1],
204
+ X_orig[mask, 2],
205
+ c=[colors[i]],
206
+ s=point_size,
207
+ alpha=alpha,
208
+ label=f"Batch {batch}",
209
+ edgecolors="black",
210
+ linewidth=0.5,
211
+ )
212
+
213
+ ax1.set_title(f"Before ComBat correction\n({method.upper()})")
214
+ ax1.set_xlabel(f"{method.upper()}1")
215
+ ax1.set_ylabel(f"{method.upper()}2")
216
+ if n_components == 3:
217
+ ax1.set_zlabel(f"{method.upper()}3")
218
+
219
+ for i, batch in enumerate(unique_batches):
220
+ mask = batch_labels == batch
221
+ if n_components == 2:
222
+ ax2.scatter(
223
+ X_trans[mask, 0],
224
+ X_trans[mask, 1],
225
+ c=[colors[i]],
226
+ s=point_size,
227
+ alpha=alpha,
228
+ label=f"Batch {batch}",
229
+ edgecolors="black",
230
+ linewidth=0.5,
231
+ )
232
+ else:
233
+ ax2.scatter(
234
+ X_trans[mask, 0],
235
+ X_trans[mask, 1],
236
+ X_trans[mask, 2],
237
+ c=[colors[i]],
238
+ s=point_size,
239
+ alpha=alpha,
240
+ label=f"Batch {batch}",
241
+ edgecolors="black",
242
+ linewidth=0.5,
243
+ )
244
+
245
+ ax2.set_title(f"After ComBat correction\n({method.upper()})")
246
+ ax2.set_xlabel(f"{method.upper()}1")
247
+ ax2.set_ylabel(f"{method.upper()}2")
248
+ if n_components == 3:
249
+ ax2.set_zlabel(f"{method.upper()}3")
250
+
251
+ if show_legend and n_batches <= 20:
252
+ ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
253
+
254
+ if title is None:
255
+ title = f"ComBat correction effect visualized with {method.upper()}"
256
+ fig.suptitle(title, fontsize=14, fontweight="bold")
257
+
258
+ plt.tight_layout()
259
+ return fig
260
+
261
+ def _create_interactive_plot(
262
+ self,
263
+ X_orig: FloatArray,
264
+ X_trans: FloatArray,
265
+ batch_labels: pd.Series,
266
+ method: str,
267
+ n_components: int,
268
+ cmap: str,
269
+ title: str | None,
270
+ show_legend: bool,
271
+ ) -> Any:
272
+ """Create interactive plots using plotly."""
273
+ if n_components == 2:
274
+ fig = make_subplots(
275
+ rows=1,
276
+ cols=2,
277
+ subplot_titles=(
278
+ f"Before ComBat correction ({method.upper()})",
279
+ f"After ComBat correction ({method.upper()})",
280
+ ),
281
+ )
282
+ else:
283
+ fig = make_subplots(
284
+ rows=1,
285
+ cols=2,
286
+ specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
287
+ subplot_titles=(
288
+ f"Before ComBat correction ({method.upper()})",
289
+ f"After ComBat correction ({method.upper()})",
290
+ ),
291
+ )
292
+
293
+ unique_batches = batch_labels.drop_duplicates()
294
+
295
+ n_batches = len(unique_batches)
296
+ cmap_func = matplotlib.colormaps.get_cmap(cmap)
297
+ color_list = [
298
+ mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)
299
+ ]
300
+
301
+ batch_to_color = dict(zip(unique_batches, color_list, strict=True))
302
+
303
+ for batch in unique_batches:
304
+ mask = batch_labels == batch
305
+
306
+ if n_components == 2:
307
+ fig.add_trace(
308
+ go.Scatter(
309
+ x=X_orig[mask, 0],
310
+ y=X_orig[mask, 1],
311
+ mode="markers",
312
+ name=f"Batch {batch}",
313
+ marker={
314
+ "size": 8,
315
+ "color": batch_to_color[batch],
316
+ "line": {"width": 1, "color": "black"},
317
+ },
318
+ showlegend=False,
319
+ ),
320
+ row=1,
321
+ col=1,
322
+ )
323
+
324
+ fig.add_trace(
325
+ go.Scatter(
326
+ x=X_trans[mask, 0],
327
+ y=X_trans[mask, 1],
328
+ mode="markers",
329
+ name=f"Batch {batch}",
330
+ marker={
331
+ "size": 8,
332
+ "color": batch_to_color[batch],
333
+ "line": {"width": 1, "color": "black"},
334
+ },
335
+ showlegend=show_legend,
336
+ ),
337
+ row=1,
338
+ col=2,
339
+ )
340
+ else:
341
+ fig.add_trace(
342
+ go.Scatter3d(
343
+ x=X_orig[mask, 0],
344
+ y=X_orig[mask, 1],
345
+ z=X_orig[mask, 2],
346
+ mode="markers",
347
+ name=f"Batch {batch}",
348
+ marker={
349
+ "size": 5,
350
+ "color": batch_to_color[batch],
351
+ "line": {"width": 0.5, "color": "black"},
352
+ },
353
+ showlegend=False,
354
+ ),
355
+ row=1,
356
+ col=1,
357
+ )
358
+
359
+ fig.add_trace(
360
+ go.Scatter3d(
361
+ x=X_trans[mask, 0],
362
+ y=X_trans[mask, 1],
363
+ z=X_trans[mask, 2],
364
+ mode="markers",
365
+ name=f"Batch {batch}",
366
+ marker={
367
+ "size": 5,
368
+ "color": batch_to_color[batch],
369
+ "line": {"width": 0.5, "color": "black"},
370
+ },
371
+ showlegend=show_legend,
372
+ ),
373
+ row=1,
374
+ col=2,
375
+ )
376
+
377
+ if title is None:
378
+ title = f"ComBat correction effect visualized with {method.upper()}"
379
+
380
+ fig.update_layout(
381
+ title=title,
382
+ title_font_size=16,
383
+ height=600,
384
+ showlegend=show_legend,
385
+ hovermode="closest",
386
+ )
387
+
388
+ axis_labels = [f"{method.upper()}{i + 1}" for i in range(n_components)]
389
+
390
+ if n_components == 2:
391
+ fig.update_xaxes(title_text=axis_labels[0])
392
+ fig.update_yaxes(title_text=axis_labels[1])
393
+ else:
394
+ fig.update_scenes(
395
+ xaxis_title=axis_labels[0],
396
+ yaxis_title=axis_labels[1],
397
+ zaxis_title=axis_labels[2],
398
+ )
399
+
400
+ return fig
401
+
402
+ def plot_feature_importance(
403
+ self,
404
+ top_n: int = 20,
405
+ kind: Literal["location", "scale", "combined"] = "combined",
406
+ mode: Literal["magnitude", "distribution"] = "magnitude",
407
+ figsize: tuple[int, int] = (8, 10),
408
+ ) -> Any:
409
+ """Plot top features affected by batch effects.
410
+
411
+ Parameters
412
+ ----------
413
+ top_n : int, default=20
414
+ Number of top features to display.
415
+ kind : {'location', 'scale', 'combined'}, default='combined'
416
+ - 'location': bar plot of location (mean shift) contribution only
417
+ - 'scale': bar plot of scale (variance) contribution only
418
+ - 'combined': grouped bar plot showing location and scale
419
+ side-by-side for each feature (sorted by Euclidean magnitude).
420
+ In magnitude mode: bars reflect Euclidean decomposition
421
+ (combined**2 = location**2 + scale**2).
422
+ In distribution mode: bars reflect independent normalized
423
+ contributions (each sums to 1 separately).
424
+ mode : {'magnitude', 'distribution'}, default='magnitude'
425
+ - 'magnitude': y-axis shows absolute batch effect magnitude
426
+ - 'distribution': y-axis shows relative contribution (proportion), includes
427
+ annotation showing cumulative contribution of top_n features
428
+ (e.g., "Top 20 features explain 75% of total batch effect")
429
+ figsize : tuple, default=(8,10)
430
+ Figure size (width, height) in inches.
431
+
432
+ Returns
433
+ -------
434
+ matplotlib.figure.Figure
435
+ The figure object containing the plot.
436
+
437
+ Raises
438
+ ------
439
+ ValueError
440
+ If the model is not fitted, or if kind/mode is invalid.
441
+ """
442
+ if not hasattr(self._model, "_gamma_star"):
443
+ raise ValueError(
444
+ "This ComBat instance is not fitted yet. "
445
+ "Call 'fit' before 'plot_feature_importance'."
446
+ )
447
+
448
+ if kind not in ["location", "scale", "combined"]:
449
+ raise ValueError(f"kind must be 'location', 'scale', or 'combined', got '{kind}'")
450
+
451
+ if mode not in ["magnitude", "distribution"]:
452
+ raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
453
+
454
+ importance_df = self.feature_batch_importance(mode=mode)
455
+ top_features = importance_df.head(top_n)
456
+
457
+ # Reverse so highest values are at the top of the horizontal bar plot
458
+ top_features = top_features.iloc[::-1]
459
+
460
+ fig, ax = plt.subplots(figsize=figsize)
461
+
462
+ if kind == "combined":
463
+ # Grouped horizontal bar plot showing location and scale side-by-side
464
+ y = np.arange(len(top_features))
465
+ height = 0.35
466
+
467
+ ax.barh(
468
+ y + height / 2,
469
+ top_features["location"],
470
+ height,
471
+ label="Location",
472
+ color="steelblue",
473
+ edgecolor="black",
474
+ linewidth=0.5,
475
+ )
476
+ ax.barh(
477
+ y - height / 2,
478
+ top_features["scale"],
479
+ height,
480
+ label="Scale",
481
+ color="coral",
482
+ edgecolor="black",
483
+ linewidth=0.5,
484
+ )
485
+
486
+ ax.set_yticks(y)
487
+ ax.set_yticklabels(top_features.index)
488
+ ax.legend()
489
+ else:
490
+ # Single horizontal bar plot for location or scale
491
+ color = "steelblue" if kind == "location" else "coral"
492
+ ax.barh(
493
+ range(len(top_features)),
494
+ top_features[kind],
495
+ color=color,
496
+ edgecolor="black",
497
+ linewidth=0.5,
498
+ )
499
+ ax.set_yticks(range(len(top_features)))
500
+ ax.set_yticklabels(top_features.index)
501
+
502
+ # Set labels and title
503
+ ax.set_ylabel("Feature")
504
+ if mode == "magnitude":
505
+ ax.set_xlabel("Batch Effect Magnitude (RMS)")
506
+ title = f"Top {top_n} Features by Batch Effect"
507
+ else:
508
+ ax.set_xlabel("Relative Contribution")
509
+ title = f"Top {top_n} Features by Batch Effect (Distribution)"
510
+
511
+ if kind == "combined":
512
+ title += " (Location & Scale)"
513
+ else:
514
+ title += f" ({kind.capitalize()})"
515
+
516
+ ax.set_title(title)
517
+ plt.tight_layout()
518
+
519
+ # For distribution mode, print cumulative contribution
520
+ if mode == "distribution":
521
+ if kind == "combined":
522
+ cumulative_pct = top_features["combined"].sum() * 100
523
+ effect_label = "batch effect"
524
+ elif kind == "location":
525
+ cumulative_pct = top_features["location"].sum() * 100
526
+ effect_label = "location effect"
527
+ else: # scale
528
+ cumulative_pct = top_features["scale"].sum() * 100
529
+ effect_label = "scale effect"
530
+
531
+ print(f"Top {top_n} features explain {cumulative_pct:.1f}% of total {effect_label}")
532
+
533
+ return fig