combatlearn 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from .combat import ComBatModel, ComBat
1
+ from .combat import ComBat
2
2
 
3
- __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.2.2"
3
+ __all__ = ["ComBat"]
4
+ __version__ = "1.0.0"
5
+ __author__ = "Ettore Rocchi"
combatlearn/combat.py CHANGED
@@ -14,29 +14,17 @@ import numpy as np
14
14
  import numpy.linalg as la
15
15
  import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
- from sklearn.utils.validation import check_is_fitted
18
17
  from sklearn.decomposition import PCA
19
18
  from sklearn.manifold import TSNE
19
+ import matplotlib
20
20
  import matplotlib.pyplot as plt
21
21
  import matplotlib.colors as mcolors
22
- from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
22
+ from typing import Literal, Optional, Union, Dict, Tuple, Any
23
23
  import numpy.typing as npt
24
24
  import warnings
25
-
26
- try:
27
- import umap
28
- UMAP_AVAILABLE = True
29
- except ImportError:
30
- UMAP_AVAILABLE = False
31
-
32
- try:
33
- import plotly.graph_objects as go
34
- from plotly.subplots import make_subplots
35
- PLOTLY_AVAILABLE = True
36
- except ImportError:
37
- PLOTLY_AVAILABLE = False
38
-
39
- __author__ = "Ettore Rocchi"
25
+ import umap
26
+ import plotly.graph_objects as go
27
+ from plotly.subplots import make_subplots
40
28
 
41
29
  ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
42
30
  FloatArray = npt.NDArray[np.float64]
@@ -58,8 +46,9 @@ class ComBatModel:
58
46
  ignoring the variance (`delta_star`).
59
47
  reference_batch : str, optional
60
48
  If specified, the batch level to use as reference.
61
- covbat_cov_thresh : float, default=0.9
62
- CovBat: cumulative explained variance threshold for PCA.
49
+ covbat_cov_thresh : float or int, default=0.9
50
+ CovBat: cumulative variance threshold (0, 1] to retain PCs, or
51
+ integer >= 1 specifying the number of components directly.
63
52
  eps : float, default=1e-8
64
53
  Numerical jitter to avoid division-by-zero.
65
54
  """
@@ -67,19 +56,19 @@ class ComBatModel:
67
56
  def __init__(
68
57
  self,
69
58
  *,
70
- method: Literal["johnson", "fortin", "chen"] = "johnson",
59
+ method: Literal["johnson", "fortin", "chen"] = "johnson",
71
60
  parametric: bool = True,
72
61
  mean_only: bool = False,
73
62
  reference_batch: Optional[str] = None,
74
63
  eps: float = 1e-8,
75
- covbat_cov_thresh: float = 0.9,
64
+ covbat_cov_thresh: Union[float, int] = 0.9,
76
65
  ) -> None:
77
66
  self.method: str = method
78
67
  self.parametric: bool = parametric
79
68
  self.mean_only: bool = bool(mean_only)
80
69
  self.reference_batch: Optional[str] = reference_batch
81
70
  self.eps: float = float(eps)
82
- self.covbat_cov_thresh: float = float(covbat_cov_thresh)
71
+ self.covbat_cov_thresh: Union[float, int] = covbat_cov_thresh
83
72
 
84
73
  self._batch_levels: pd.Index
85
74
  self._grand_mean: pd.Series
@@ -96,9 +85,16 @@ class ComBatModel:
96
85
  self._batch_levels_pc: pd.Index
97
86
  self._pc_gamma_star: FloatArray
98
87
  self._pc_delta_star: FloatArray
99
-
100
- if not (0.0 < self.covbat_cov_thresh <= 1.0):
101
- raise ValueError("covbat_cov_thresh must be in (0, 1].")
88
+
89
+ # Validate covbat_cov_thresh
90
+ if isinstance(self.covbat_cov_thresh, float):
91
+ if not (0.0 < self.covbat_cov_thresh <= 1.0):
92
+ raise ValueError("covbat_cov_thresh must be in (0, 1] when float.")
93
+ elif isinstance(self.covbat_cov_thresh, int):
94
+ if self.covbat_cov_thresh < 1:
95
+ raise ValueError("covbat_cov_thresh must be >= 1 when int.")
96
+ else:
97
+ raise TypeError("covbat_cov_thresh must be float or int.")
102
98
 
103
99
  @staticmethod
104
100
  def _as_series(
@@ -336,8 +332,14 @@ class ComBatModel:
336
332
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
337
333
  X_centered = X_meanvar_adj - X_meanvar_adj.mean(axis=0)
338
334
  pca = PCA(svd_solver="full", whiten=False).fit(X_centered)
339
- cumulative = np.cumsum(pca.explained_variance_ratio_)
340
- n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
335
+
336
+ # Determine number of components based on threshold type
337
+ if isinstance(self.covbat_cov_thresh, int):
338
+ n_pc = min(self.covbat_cov_thresh, len(pca.explained_variance_ratio_))
339
+ else:
340
+ cumulative = np.cumsum(pca.explained_variance_ratio_)
341
+ n_pc = int(np.searchsorted(cumulative, self.covbat_cov_thresh) + 1)
342
+
341
343
  self._covbat_pca = pca
342
344
  self._covbat_n_pc = n_pc
343
345
 
@@ -488,7 +490,8 @@ class ComBatModel:
488
490
  continuous_covariates: Optional[ArrayLike] = None,
489
491
  ) -> pd.DataFrame:
490
492
  """Transform the data using fitted ComBat parameters."""
491
- check_is_fitted(self, ["_gamma_star"])
493
+ if not hasattr(self, "_gamma_star"):
494
+ raise ValueError("This ComBatModel instance is not fitted yet. Call 'fit' before 'transform'.")
492
495
  if not isinstance(X, pd.DataFrame):
493
496
  X = pd.DataFrame(X)
494
497
  idx = X.index
@@ -600,7 +603,7 @@ class ComBatModel:
600
603
  """Chen transform implementation."""
601
604
  X_meanvar_adj = self._transform_fortin(X, batch, disc, cont)
602
605
  X_centered = X_meanvar_adj - self._covbat_pca.mean_
603
- scores = self._covbat_pca.transform(X_centered.values)
606
+ scores = self._covbat_pca.transform(X_centered)
604
607
  n_pc = self._covbat_n_pc
605
608
  scores_adj = scores.copy()
606
609
 
@@ -639,7 +642,7 @@ class ComBat(BaseEstimator, TransformerMixin):
639
642
  mean_only: bool = False,
640
643
  reference_batch: Optional[str] = None,
641
644
  eps: float = 1e-8,
642
- covbat_cov_thresh: float = 0.9,
645
+ covbat_cov_thresh: Union[float, int] = 0.9,
643
646
  ) -> None:
644
647
  self.batch = batch
645
648
  self.discrete_covariates = discrete_covariates
@@ -759,7 +762,8 @@ class ComBat(BaseEstimator, TransformerMixin):
759
762
  - `'original'`: embedding of original data
760
763
  - `'transformed'`: embedding of ComBat-transformed data
761
764
  """
762
- check_is_fitted(self._model, ["_gamma_star"])
765
+ if not hasattr(self._model, "_gamma_star"):
766
+ raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'plot_transformation'.")
763
767
 
764
768
  if n_components not in [2, 3]:
765
769
  raise ValueError(f"n_components must be 2 or 3, got {n_components}")
@@ -768,11 +772,6 @@ class ComBat(BaseEstimator, TransformerMixin):
768
772
  if plot_type not in ['static', 'interactive']:
769
773
  raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
770
774
 
771
- if reduction_method == 'umap' and not UMAP_AVAILABLE:
772
- raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
773
- if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
774
- raise ImportError("Plotly is not installed. Install with: pip install plotly")
775
-
776
775
  if not isinstance(X, pd.DataFrame):
777
776
  X = pd.DataFrame(X)
778
777
 
@@ -797,8 +796,8 @@ class ComBat(BaseEstimator, TransformerMixin):
797
796
  else:
798
797
  umap_params = {'random_state': 42}
799
798
  umap_params.update(reduction_kwargs)
800
- reducer_orig = umap.UMAP(n_components=n_components, **reduction_kwargs)
801
- reducer_trans = umap.UMAP(n_components=n_components, **reduction_kwargs)
799
+ reducer_orig = umap.UMAP(n_components=n_components, **umap_params)
800
+ reducer_trans = umap.UMAP(n_components=n_components, **umap_params)
802
801
 
803
802
  X_embedded_orig = reducer_orig.fit_transform(X_np)
804
803
  X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
@@ -845,9 +844,9 @@ class ComBat(BaseEstimator, TransformerMixin):
845
844
  n_batches = len(unique_batches)
846
845
 
847
846
  if n_batches <= 10:
848
- colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_batches))
847
+ colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
849
848
  else:
850
- colors = plt.cm.get_cmap('tab20')(np.linspace(0, 1, n_batches))
849
+ colors = matplotlib.colormaps.get_cmap('tab20')(np.linspace(0, 1, n_batches))
851
850
 
852
851
  if n_components == 2:
853
852
  ax1 = plt.subplot(1, 2, 1)
@@ -956,7 +955,7 @@ class ComBat(BaseEstimator, TransformerMixin):
956
955
  unique_batches = batch_labels.drop_duplicates()
957
956
 
958
957
  n_batches = len(unique_batches)
959
- cmap_func = plt.cm.get_cmap(cmap)
958
+ cmap_func = matplotlib.colormaps.get_cmap(cmap)
960
959
  color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
961
960
 
962
961
  batch_to_color = dict(zip(unique_batches, color_list))
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.2
3
+ Version: 1.0.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
- License-Expression: MIT
6
+ License: MIT
7
7
  Keywords: machine-learning,harmonization,combat,preprocessing
8
8
  Classifier: Development Status :: 3 - Alpha
9
9
  Classifier: Intended Audience :: Science/Research
@@ -19,13 +19,23 @@ Requires-Dist: matplotlib>=3.4
19
19
  Requires-Dist: plotly>=5.0
20
20
  Requires-Dist: nbformat>=4.2
21
21
  Requires-Dist: umap-learn>=0.5
22
- Requires-Dist: pytest>=7
22
+ Provides-Extra: dev
23
+ Requires-Dist: pytest>=7; extra == "dev"
24
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
25
+ Requires-Dist: ruff>=0.1; extra == "dev"
26
+ Requires-Dist: mypy>=1.0; extra == "dev"
27
+ Provides-Extra: docs
28
+ Requires-Dist: mkdocs>=1.5.0; extra == "docs"
29
+ Requires-Dist: mkdocs-material>=9.0.0; extra == "docs"
30
+ Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
31
+ Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
23
32
  Dynamic: license-file
24
33
 
25
34
  # **combatlearn**
26
35
 
27
36
  [![Python versions](https://img.shields.io/badge/python-%3E%3D3.10-blue?logo=python)](https://www.python.org/)
28
37
  [![Test](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml/badge.svg)](https://github.com/EttoreRocchi/combatlearn/actions/workflows/test.yaml)
38
+ [![Documentation](https://readthedocs.org/projects/combatlearn/badge/?version=latest)](https://combatlearn.readthedocs.io)
29
39
  [![PyPI Downloads](https://static.pepy.tech/badge/combatlearn)](https://pepy.tech/projects/combatlearn)
30
40
  [![PyPI Version](https://img.shields.io/pypi/v/combatlearn?cacheSeconds=300)](https://pypi.org/project/combatlearn/)
31
41
  [![License](https://img.shields.io/github/license/EttoreRocchi/combatlearn)](https://github.com/EttoreRocchi/combatlearn/blob/main/LICENSE)
@@ -95,6 +105,17 @@ print(f"Best CV AUROC: {grid.best_score_:.3f}")
95
105
 
96
106
  For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
97
107
 
108
+ ## Documentation
109
+
110
+ **Full documentation is available at [combatlearn.readthedocs.io](https://combatlearn.readthedocs.io)**
111
+
112
+ The documentation includes:
113
+ - [Installation Guide](https://combatlearn.readthedocs.io/en/latest/installation/)
114
+ - [Quick Start Tutorial](https://combatlearn.readthedocs.io/en/latest/quickstart/)
115
+ - [User Guide](https://combatlearn.readthedocs.io/en/latest/user-guide/overview/)
116
+ - [API Reference](https://combatlearn.readthedocs.io/en/latest/api/)
117
+ - [Examples](https://combatlearn.readthedocs.io/en/latest/examples/basic-usage/)
118
+
98
119
  ## `ComBat` parameters
99
120
 
100
121
  The following section provides a detailed explanation of all parameters available in the scikit-learn-compatible `ComBat` class.
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=ck_EGW8iqLGUebg2wc-h794lwG3uAkHn9GaWjHgUIX4,99
2
+ combatlearn/combat.py,sha256=Hri1XwnfSXWLoC1KD2VkqtNLkZpixI5ax0UrT1HtjyU,38505
3
+ combatlearn-1.0.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-1.0.0.dist-info/METADATA,sha256=hJvZEiA_ekTq06wzfOf2p6M_4vwNXGOdoS-K5MvT4P0,8558
5
+ combatlearn-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-1.0.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-1.0.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
2
- combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
3
- combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
5
- combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.2.2.dist-info/RECORD,,