combatlearn 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +4 -3
- combatlearn/combat.py +40 -41
- {combatlearn-0.2.2.dist-info → combatlearn-1.0.0.dist-info}/METADATA +24 -3
- combatlearn-1.0.0.dist-info/RECORD +7 -0
- combatlearn-0.2.2.dist-info/RECORD +0 -7
- {combatlearn-0.2.2.dist-info → combatlearn-1.0.0.dist-info}/WHEEL +0 -0
- {combatlearn-0.2.2.dist-info → combatlearn-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-0.2.2.dist-info → combatlearn-1.0.0.dist-info}/top_level.txt +0 -0
combatlearn/__init__.py
CHANGED
combatlearn/combat.py
CHANGED
|
@@ -14,29 +14,17 @@ import numpy as np
|
|
|
14
14
|
import numpy.linalg as la
|
|
15
15
|
import pandas as pd
|
|
16
16
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
17
|
-
from sklearn.utils.validation import check_is_fitted
|
|
18
17
|
from sklearn.decomposition import PCA
|
|
19
18
|
from sklearn.manifold import TSNE
|
|
19
|
+
import matplotlib
|
|
20
20
|
import matplotlib.pyplot as plt
|
|
21
21
|
import matplotlib.colors as mcolors
|
|
22
|
-
from typing import Literal, Optional, Union, Dict, Tuple, Any
|
|
22
|
+
from typing import Literal, Optional, Union, Dict, Tuple, Any
|
|
23
23
|
import numpy.typing as npt
|
|
24
24
|
import warnings
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
UMAP_AVAILABLE = True
|
|
29
|
-
except ImportError:
|
|
30
|
-
UMAP_AVAILABLE = False
|
|
31
|
-
|
|
32
|
-
try:
|
|
33
|
-
import plotly.graph_objects as go
|
|
34
|
-
from plotly.subplots import make_subplots
|
|
35
|
-
PLOTLY_AVAILABLE = True
|
|
36
|
-
except ImportError:
|
|
37
|
-
PLOTLY_AVAILABLE = False
|
|
38
|
-
|
|
39
|
-
__author__ = "Ettore Rocchi"
|
|
25
|
+
import umap
|
|
26
|
+
import plotly.graph_objects as go
|
|
27
|
+
from plotly.subplots import make_subplots
|
|
40
28
|
|
|
41
29
|
ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
|
|
42
30
|
FloatArray = npt.NDArray[np.float64]
|
|
@@ -58,8 +46,9 @@ class ComBatModel:
|
|
|
58
46
|
ignoring the variance (`delta_star`).
|
|
59
47
|
reference_batch : str, optional
|
|
60
48
|
If specified, the batch level to use as reference.
|
|
61
|
-
covbat_cov_thresh : float, default=0.9
|
|
62
|
-
CovBat: cumulative
|
|
49
|
+
covbat_cov_thresh : float or int, default=0.9
|
|
50
|
+
CovBat: cumulative variance threshold (0, 1] to retain PCs, or
|
|
51
|
+
integer >= 1 specifying the number of components directly.
|
|
63
52
|
eps : float, default=1e-8
|
|
64
53
|
Numerical jitter to avoid division-by-zero.
|
|
65
54
|
"""
|
|
@@ -67,19 +56,19 @@ class ComBatModel:
|
|
|
67
56
|
def __init__(
|
|
68
57
|
self,
|
|
69
58
|
*,
|
|
70
|
-
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
59
|
+
method: Literal["johnson", "fortin", "chen"] = "johnson",
|
|
71
60
|
parametric: bool = True,
|
|
72
61
|
mean_only: bool = False,
|
|
73
62
|
reference_batch: Optional[str] = None,
|
|
74
63
|
eps: float = 1e-8,
|
|
75
|
-
covbat_cov_thresh: float = 0.9,
|
|
64
|
+
covbat_cov_thresh: Union[float, int] = 0.9,
|
|
76
65
|
) -> None:
|
|
77
66
|
self.method: str = method
|
|
78
67
|
self.parametric: bool = parametric
|
|
79
68
|
self.mean_only: bool = bool(mean_only)
|
|
80
69
|
self.reference_batch: Optional[str] = reference_batch
|
|
81
70
|
self.eps: float = float(eps)
|
|
82
|
-
self.covbat_cov_thresh: float =
|
|
71
|
+
self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
|
|
83
72
|
|
|
84
73
|
self._batch_levels: pd.Index
|
|
85
74
|
self._grand_mean: pd.Series
|
|
@@ -96,9 +85,16 @@ class ComBatModel:
|
|
|
96
85
|
self._batch_levels_pc: pd.Index
|
|
97
86
|
self._pc_gamma_star: FloatArray
|
|
98
87
|
self._pc_delta_star: FloatArray
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
88
|
+
|
|
89
|
+
# Validate covbat_cov_thresh
|
|
90
|
+
if isinstance(self.covbat_cov_thresh, float):
|
|
91
|
+
if not (0.0 < self.covbat_cov_thresh <= 1.0):
|
|
92
|
+
raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
|
|
93
|
+
elif isinstance(self.covbat_cov_thresh, int):
|
|
94
|
+
if self.covbat_cov_thresh < 1:
|
|
95
|
+
raise ValueError("covbat_cov_thresh must be >= 1 when int.")
|
|
96
|
+
else:
|
|
97
|
+
raise TypeError("covbat_cov_thresh must be float or int.")
|
|
102
98
|
|
|
103
99
|
@staticmethod
|
|
104
100
|
def _as_series(
|
|
@@ -336,8 +332,14 @@ class ComBatModel:
|
|
|
336
332
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
337
333
|
X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
|
|
338
334
|
pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
|
|
339
|
-
|
|
340
|
-
|
|
335
|
+
|
|
336
|
+
# Determine number of components based on threshold type
|
|
337
|
+
if isinstance(self.covbat_cov_thresh, int):
|
|
338
|
+
n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
|
|
339
|
+
else:
|
|
340
|
+
cumulative = np.cumsum(pca.explained_variance_ratio_)
|
|
341
|
+
n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
|
|
342
|
+
|
|
341
343
|
self._covbat_pca = pca
|
|
342
344
|
self._covbat_n_pc = n_pc
|
|
343
345
|
|
|
@@ -488,7 +490,8 @@ class ComBatModel:
|
|
|
488
490
|
continuous_covariates: Optional[ArrayLike] = None,
|
|
489
491
|
) -> pd.DataFrame:
|
|
490
492
|
"""Transform the data using fitted ComBat parameters."""
|
|
491
|
-
|
|
493
|
+
if not hasattr(self, "_gamma_star"):
|
|
494
|
+
raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
|
|
492
495
|
if not isinstance(X, pd.DataFrame):
|
|
493
496
|
X = pd.DataFrame(X)
|
|
494
497
|
idx = X.index
|
|
@@ -600,7 +603,7 @@ class ComBatModel:
|
|
|
600
603
|
"""Chen transform implementation."""
|
|
601
604
|
X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
|
|
602
605
|
X_centered = X_meanvar_adj - self._covbat_pca.mean_
|
|
603
|
-
scores = self._covbat_pca.transform(X_centered
|
|
606
|
+
scores = self._covbat_pca.transform(X_centered)
|
|
604
607
|
n_pc = self._covbat_n_pc
|
|
605
608
|
scores_adj = scores.copy()
|
|
606
609
|
|
|
@@ -639,7 +642,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
639
642
|
mean_only: bool = False,
|
|
640
643
|
reference_batch: Optional[str] = None,
|
|
641
644
|
eps: float = 1e-8,
|
|
642
|
-
covbat_cov_thresh: float = 0.9,
|
|
645
|
+
covbat_cov_thresh: Union[float, int] = 0.9,
|
|
643
646
|
) -> None:
|
|
644
647
|
self.batch = batch
|
|
645
648
|
self.discrete_covariates = discrete_covariates
|
|
@@ -759,7 +762,8 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
759
762
|
- `'original'`: embedding of original data
|
|
760
763
|
- `'transformed'`: embedding of ComBat-transformed data
|
|
761
764
|
"""
|
|
762
|
-
|
|
765
|
+
if not hasattr(self._model, "_gamma_star"):
|
|
766
|
+
raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
|
|
763
767
|
|
|
764
768
|
if n_components not in [2, 3]:
|
|
765
769
|
raise ValueError(f"n_components must be 2 or 3, got {n_components}")
|
|
@@ -768,11 +772,6 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
768
772
|
if plot_type not in ['static', 'interactive']:
|
|
769
773
|
raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
|
|
770
774
|
|
|
771
|
-
if reduction_method == 'umap' and not UMAP_AVAILABLE:
|
|
772
|
-
raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
|
|
773
|
-
if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
|
|
774
|
-
raise ImportError("Plotly is not installed. Install with: pip install plotly")
|
|
775
|
-
|
|
776
775
|
if not isinstance(X, pd.DataFrame):
|
|
777
776
|
X = pd.DataFrame(X)
|
|
778
777
|
|
|
@@ -797,8 +796,8 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
797
796
|
else:
|
|
798
797
|
umap_params = {'random_state': 42}
|
|
799
798
|
umap_params.update(reduction_kwargs)
|
|
800
|
-
reducer_orig = umap.UMAP(n_components=n_components, **
|
|
801
|
-
reducer_trans = umap.UMAP(n_components=n_components, **
|
|
799
|
+
reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
|
|
800
|
+
reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
|
|
802
801
|
|
|
803
802
|
X_embedded_orig = reducer_orig.fit_transform(X_np)
|
|
804
803
|
X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
|
|
@@ -845,9 +844,9 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
845
844
|
n_batches = len(unique_batches)
|
|
846
845
|
|
|
847
846
|
if n_batches <= 10:
|
|
848
|
-
colors =
|
|
847
|
+
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
|
|
849
848
|
else:
|
|
850
|
-
colors =
|
|
849
|
+
colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
|
|
851
850
|
|
|
852
851
|
if n_components == 2:
|
|
853
852
|
ax1 = plt.subplot(1, 2, 1)
|
|
@@ -956,7 +955,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
956
955
|
unique_batches = batch_labels.drop_duplicates()
|
|
957
956
|
|
|
958
957
|
n_batches = len(unique_batches)
|
|
959
|
-
cmap_func =
|
|
958
|
+
cmap_func = matplotlib.colormaps.get_cmap(cmap)
|
|
960
959
|
color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
|
|
961
960
|
|
|
962
961
|
batch_to_color = dict(zip(unique_batches, color_list))
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: combatlearn
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Batch-effect harmonization for machine learning frameworks.
|
|
5
5
|
Author-email: Ettore Rocchi <ettoreroc@gmail.com>
|
|
6
|
-
License
|
|
6
|
+
License: MIT
|
|
7
7
|
Keywords: machine-learning,harmonization,combat,preprocessing
|
|
8
8
|
Classifier: Development Status :: 3 - Alpha
|
|
9
9
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
|
|
|
19
19
|
Requires-Dist: plotly>=5.0
|
|
20
20
|
Requires-Dist: nbformat>=4.2
|
|
21
21
|
Requires-Dist: umap-learn>=0.5
|
|
22
|
-
|
|
22
|
+
Provides-Extra: dev
|
|
23
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
24
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
25
|
+
Requires-Dist: ruff>=0.1; extra == "dev"
|
|
26
|
+
Requires-Dist: mypy>=1.0; extra == "dev"
|
|
27
|
+
Provides-Extra: docs
|
|
28
|
+
Requires-Dist: mkdocs>=1.5.0; extra == "docs"
|
|
29
|
+
Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
|
|
30
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
|
|
31
|
+
Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
|
|
23
32
|
Dynamic: license-file
|
|
24
33
|
|
|
25
34
|
# **combatlearn**
|
|
26
35
|
|
|
27
36
|
[](https://www.python.org/)
|
|
28
37
|
[](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
|
|
38
|
+
[](https://combatlearn.readthedocs.io)
|
|
29
39
|
[](https://pepy.tech/projects/combatlearn)
|
|
30
40
|
[](https://pypi.org/project/combatlearn/)
|
|
31
41
|
[](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
|
|
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
|
|
|
95
105
|
|
|
96
106
|
For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
|
|
97
107
|
|
|
108
|
+
## Documentation
|
|
109
|
+
|
|
110
|
+
**Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
|
|
111
|
+
|
|
112
|
+
The documentation includes:
|
|
113
|
+
- [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
|
|
114
|
+
- [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
|
|
115
|
+
- [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
|
|
116
|
+
- [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
|
|
117
|
+
- [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
|
|
118
|
+
|
|
98
119
|
## `ComBat` parameters
|
|
99
120
|
|
|
100
121
|
The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
combatlearn/__init__.py,sha256=ck_EGW8iqLGUebg2wc-h794lwG3uAkHn9GaWjHgUIX4,99
|
|
2
|
+
combatlearn/combat.py,sha256=Hri1XwnfSXWLoC1KD2VkqtNLkZpixI5ax0UrT1HtjyU,38505
|
|
3
|
+
combatlearn-1.0.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
+
combatlearn-1.0.0.dist-info/METADATA,sha256=hJvZEiA_ekTq06wzfOf2p6M_4vwNXGOdoS-K5MvT4P0,8558
|
|
5
|
+
combatlearn-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
combatlearn-1.0.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
+
combatlearn-1.0.0.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
|
|
2
|
-
combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
|
|
3
|
-
combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
-
combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
|
|
5
|
-
combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
-
combatlearn-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|