combatlearn 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .combat import ComBatModel, ComBat
2
2
 
3
3
  __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.1.0"
4
+ __version__ = "0.2.0"
combatlearn/combat.py CHANGED
@@ -16,10 +16,25 @@ import pandas as pd
16
16
  from sklearn.base import BaseEstimator, TransformerMixin
17
17
  from sklearn.utils.validation import check_is_fitted
18
18
  from sklearn.decomposition import PCA
19
+ from sklearn.manifold import TSNE
20
+ import matplotlib.pyplot as plt
19
21
  from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
20
22
  import numpy.typing as npt
21
23
  import warnings
22
24
 
25
+ try:
26
+ import umap
27
+ UMAP_AVAILABLE = True
28
+ except ImportError:
29
+ UMAP_AVAILABLE = False
30
+
31
+ try:
32
+ import plotly.graph_objects as go
33
+ from plotly.subplots import make_subplots
34
+ PLOTLY_AVAILABLE = True
35
+ except ImportError:
36
+ PLOTLY_AVAILABLE = False
37
+
23
38
  __author__ = "Ettore Rocchi"
24
39
 
25
40
  ArrayLike = Union[pd.DataFrame, pd.Series, npt.NDArray[Any]]
@@ -659,6 +674,7 @@ class ComBat(BaseEstimator, TransformerMixin):
659
674
  discrete_covariates=disc,
660
675
  continuous_covariates=cont,
661
676
  )
677
+ self._fitted_batch = batch_vec
662
678
  return self
663
679
 
664
680
  def transform(self, X: ArrayLike) -> pd.DataFrame:
@@ -689,3 +705,315 @@ class ComBat(BaseEstimator, TransformerMixin):
689
705
  return pd.Series(obj, index=idx)
690
706
  else:
691
707
  return pd.DataFrame(obj, index=idx)
708
+
709
+ def plot_transformation(
710
+ self,
711
+ X: ArrayLike, *,
712
+ reduction_method: Literal['pca', 'tsne', 'umap'] = 'pca',
713
+ n_components: Literal[2, 3] = 2,
714
+ plot_type: Literal['static', 'interactive'] = 'static',
715
+ figsize: Tuple[int, int] = (12, 5),
716
+ alpha: float = 0.7,
717
+ point_size: int = 50,
718
+ cmap: str = 'Set1',
719
+ title: Optional[str] = None,
720
+ show_legend: bool = True,
721
+ return_embeddings: bool = False,
722
+ **reduction_kwargs) -> Union[Any, Tuple[Any, Dict[str, FloatArray]]]:
723
+ """
724
+ Visualize the ComBat transformation effect using dimensionality reduction.
725
+
726
+ It shows a before/after comparison of data transformed by `ComBat` using
727
+ PCA, t-SNE, or UMAP to reduce dimensions for visualization.
728
+
729
+ Parameters
730
+ ----------
731
+ X : array-like of shape (n_samples, n_features)
732
+ Input data to transform and visualize.
733
+
734
+ reduction_method : {`'pca'`, `'tsne'`, `'umap'`}, default=`'pca'`
735
+ Dimensionality reduction method.
736
+
737
+ n_components : {2, 3}, default=2
738
+ Number of components for dimensionality reduction.
739
+
740
+ plot_type : {`'static'`, `'interactive'`}, default=`'static'`
741
+ Visualization type:
742
+ - `'static'`: matplotlib plots (can be saved as images)
743
+ - `'interactive'`: plotly plots (explorable, requires plotly)
744
+
745
+ return_embeddings : bool, default=False
746
+ If `True`, return embeddings along with the plot.
747
+
748
+ **reduction_kwargs : dict
749
+ Additional parameters for reduction methods.
750
+
751
+ Returns
752
+ -------
753
+ fig : matplotlib.figure.Figure or plotly.graph_objects.Figure
754
+ The figure object containing the plots.
755
+
756
+ embeddings : dict, optional
757
+ If `return_embeddings=True`, dictionary with:
758
+ - `'original'`: embedding of original data
759
+ - `'transformed'`: embedding of ComBat-transformed data
760
+ """
761
+ check_is_fitted(self._model, ["_gamma_star"])
762
+
763
+ if n_components not in [2, 3]:
764
+ raise ValueError(f"n_components must be 2 or 3, got {n_components}")
765
+ if reduction_method not in ['pca', 'tsne', 'umap']:
766
+ raise ValueError(f"reduction_method must be 'pca', 'tsne', or 'umap', got '{reduction_method}'")
767
+ if plot_type not in ['static', 'interactive']:
768
+ raise ValueError(f"plot_type must be 'static' or 'interactive', got '{plot_type}'")
769
+
770
+ if reduction_method == 'umap' and not UMAP_AVAILABLE:
771
+ raise ImportError("UMAP is not installed. Install with: pip install umap-learn")
772
+ if plot_type == 'interactive' and not PLOTLY_AVAILABLE:
773
+ raise ImportError("Plotly is not installed. Install with: pip install plotly")
774
+
775
+ if not isinstance(X, pd.DataFrame):
776
+ X = pd.DataFrame(X)
777
+
778
+ idx = X.index
779
+ batch_vec = self._subset(self.batch, idx)
780
+ if batch_vec is None:
781
+ raise ValueError("Batch information is required for visualization")
782
+
783
+ X_transformed = self.transform(X)
784
+
785
+ X_np = X.values
786
+ X_trans_np = X_transformed.values
787
+
788
+ if reduction_method == 'pca':
789
+ reducer_orig = PCA(n_components=n_components, **reduction_kwargs)
790
+ reducer_trans = PCA(n_components=n_components, **reduction_kwargs)
791
+ elif reduction_method == 'tsne':
792
+ tsne_params = {'perplexity': 30, 'max_iter': 1000, 'random_state': 42}
793
+ tsne_params.update(reduction_kwargs)
794
+ reducer_orig = TSNE(n_components=n_components, **tsne_params)
795
+ reducer_trans = TSNE(n_components=n_components, **tsne_params)
796
+ else:
797
+ umap_params = {'random_state': 42}
798
+ umap_params.update(reduction_kwargs)
799
+ reducer_orig = umap.UMAP(n_components=n_components, **reduction_kwargs)
800
+ reducer_trans = umap.UMAP(n_components=n_components, **reduction_kwargs)
801
+
802
+ X_embedded_orig = reducer_orig.fit_transform(X_np)
803
+ X_embedded_trans = reducer_trans.fit_transform(X_trans_np)
804
+
805
+ if plot_type == 'static':
806
+ fig = self._create_static_plot(
807
+ X_embedded_orig, X_embedded_trans, batch_vec,
808
+ reduction_method, n_components, figsize, alpha,
809
+ point_size, cmap, title, show_legend
810
+ )
811
+ else:
812
+ fig = self._create_interactive_plot(
813
+ X_embedded_orig, X_embedded_trans, batch_vec,
814
+ reduction_method, n_components, title, show_legend
815
+ )
816
+
817
+ if return_embeddings:
818
+ embeddings = {
819
+ 'original': X_embedded_orig,
820
+ 'transformed': X_embedded_trans
821
+ }
822
+ return fig, embeddings
823
+ else:
824
+ return fig
825
+
826
+ def _create_static_plot(
827
+ self,
828
+ X_orig: FloatArray,
829
+ X_trans: FloatArray,
830
+ batch_labels: pd.Series,
831
+ method: str,
832
+ n_components: int,
833
+ figsize: Tuple[int, int],
834
+ alpha: float,
835
+ point_size: int,
836
+ cmap: str,
837
+ title: Optional[str],
838
+ show_legend: bool) -> Any:
839
+ """Create static plots using matplotlib."""
840
+
841
+ fig = plt.figure(figsize=figsize)
842
+
843
+ unique_batches = batch_labels.drop_duplicates()
844
+ n_batches = len(unique_batches)
845
+
846
+ if n_batches <= 10:
847
+ colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, n_batches))
848
+ else:
849
+ colors = plt.cm.get_cmap('tab20')(np.linspace(0, 1, n_batches))
850
+
851
+ if n_components == 2:
852
+ ax1 = plt.subplot(1, 2, 1)
853
+ ax2 = plt.subplot(1, 2, 2)
854
+ else:
855
+ ax1 = fig.add_subplot(121, projection='3d')
856
+ ax2 = fig.add_subplot(122, projection='3d')
857
+
858
+ for i, batch in enumerate(unique_batches):
859
+ mask = batch_labels == batch
860
+ if n_components == 2:
861
+ ax1.scatter(
862
+ X_orig[mask, 0], X_orig[mask, 1],
863
+ c=[colors[i]],
864
+ s=point_size,
865
+ alpha=alpha,
866
+ label=f'Batch {batch}',
867
+ edgecolors='black',
868
+ linewidth=0.5
869
+ )
870
+ else:
871
+ ax1.scatter(
872
+ X_orig[mask, 0], X_orig[mask, 1], X_orig[mask, 2],
873
+ c=[colors[i]],
874
+ s=point_size,
875
+ alpha=alpha,
876
+ label=f'Batch {batch}',
877
+ edgecolors='black',
878
+ linewidth=0.5
879
+ )
880
+
881
+ ax1.set_title(f'Before ComBat correction\n({method.upper()})')
882
+ ax1.set_xlabel(f'{method.upper()}1')
883
+ ax1.set_ylabel(f'{method.upper()}2')
884
+ if n_components == 3:
885
+ ax1.set_zlabel(f'{method.upper()}3')
886
+
887
+ for i, batch in enumerate(unique_batches):
888
+ mask = batch_labels == batch
889
+ if n_components == 2:
890
+ ax2.scatter(
891
+ X_trans[mask, 0], X_trans[mask, 1],
892
+ c=[colors[i]],
893
+ s=point_size,
894
+ alpha=alpha,
895
+ label=f'Batch {batch}',
896
+ edgecolors='black',
897
+ linewidth=0.5
898
+ )
899
+ else:
900
+ ax2.scatter(
901
+ X_trans[mask, 0], X_trans[mask, 1], X_trans[mask, 2],
902
+ c=[colors[i]],
903
+ s=point_size,
904
+ alpha=alpha,
905
+ label=f'Batch {batch}',
906
+ edgecolors='black',
907
+ linewidth=0.5
908
+ )
909
+
910
+ ax2.set_title(f'After ComBat correction\n({method.upper()})')
911
+ ax2.set_xlabel(f'{method.upper()}1')
912
+ ax2.set_ylabel(f'{method.upper()}2')
913
+ if n_components == 3:
914
+ ax2.set_zlabel(f'{method.upper()}3')
915
+
916
+ if show_legend and n_batches <= 20:
917
+ ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
918
+
919
+ if title is None:
920
+ title = f'ComBat correction effect visualized with {method.upper()}'
921
+ fig.suptitle(title, fontsize=14, fontweight='bold')
922
+
923
+ plt.tight_layout()
924
+ return fig
925
+
926
+ def _create_interactive_plot(
927
+ self,
928
+ X_orig: FloatArray,
929
+ X_trans: FloatArray,
930
+ batch_labels: pd.Series,
931
+ method: str,
932
+ n_components: int,
933
+ title: Optional[str],
934
+ show_legend: bool) -> Any:
935
+ """Create interactive plots using plotly."""
936
+ if n_components == 2:
937
+ fig = make_subplots(
938
+ rows=1, cols=2,
939
+ subplot_titles=(
940
+ f'Before ComBat correction ({method.upper()})',
941
+ f'After ComBat correction ({method.upper()})'
942
+ )
943
+ )
944
+ else:
945
+ fig = make_subplots(
946
+ rows=1, cols=2,
947
+ specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
948
+ subplot_titles=(
949
+ f'Before ComBat correction ({method.upper()})',
950
+ f'After ComBat correction ({method.upper()})'
951
+ )
952
+ )
953
+
954
+ unique_batches = batch_labels.drop_duplicates()
955
+
956
+ for batch in unique_batches:
957
+ mask = batch_labels == batch
958
+
959
+ if n_components == 2:
960
+ fig.add_trace(
961
+ go.Scatter(x=X_orig[mask, 0], y=X_orig[mask, 1],
962
+ mode='markers',
963
+ name=f'Batch {batch}',
964
+ marker=dict(size=8, line=dict(width=1, color='black')),
965
+ showlegend=False),
966
+ row=1, col=1
967
+ )
968
+
969
+ fig.add_trace(
970
+ go.Scatter(x=X_trans[mask, 0], y=X_trans[mask, 1],
971
+ mode='markers',
972
+ name=f'Batch {batch}',
973
+ marker=dict(size=8, line=dict(width=1, color='black')),
974
+ showlegend=show_legend),
975
+ row=1, col=2
976
+ )
977
+ else:
978
+ fig.add_trace(
979
+ go.Scatter3d(x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
980
+ mode='markers',
981
+ name=f'Batch {batch}',
982
+ marker=dict(size=5, line=dict(width=0.5, color='black')),
983
+ showlegend=False),
984
+ row=1, col=1
985
+ )
986
+
987
+ fig.add_trace(
988
+ go.Scatter3d(x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
989
+ mode='markers',
990
+ name=f'Batch {batch}',
991
+ marker=dict(size=5, line=dict(width=0.5, color='black')),
992
+ showlegend=show_legend),
993
+ row=1, col=2
994
+ )
995
+
996
+ if title is None:
997
+ title = f'ComBat correction effect visualized with {method.upper()}'
998
+
999
+ fig.update_layout(
1000
+ title=title,
1001
+ title_font_size=16,
1002
+ height=600,
1003
+ showlegend=show_legend,
1004
+ hovermode='closest'
1005
+ )
1006
+
1007
+ axis_labels = [f'{method.upper()}{i+1}' for i in range(n_components)]
1008
+
1009
+ if n_components == 2:
1010
+ fig.update_xaxes(title_text=axis_labels[0])
1011
+ fig.update_yaxes(title_text=axis_labels[1])
1012
+ else:
1013
+ fig.update_scenes(
1014
+ xaxis_title=axis_labels[0],
1015
+ yaxis_title=axis_labels[1],
1016
+ zaxis_title=axis_labels[2]
1017
+ )
1018
+
1019
+ return fig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License: MIT License
@@ -37,6 +37,9 @@ License-File: LICENSE
37
37
  Requires-Dist: pandas>=1.3
38
38
  Requires-Dist: numpy>=1.21
39
39
  Requires-Dist: scikit-learn>=1.2
40
+ Requires-Dist: plotly>=5.0
41
+ Requires-Dist: nbformat>=4.2
42
+ Requires-Dist: umap-learn>=0.5
40
43
  Requires-Dist: pytest>=7
41
44
  Dynamic: license-file
42
45
 
@@ -111,7 +114,7 @@ print("Best parameters:", grid.best_params_)
111
114
  print(f"Best CV AUROC: {grid.best_score_:.3f}")
112
115
  ```
113
116
 
114
- For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/demo/combatlearn_demo.ipynb)
117
+ For a full example of how to use **combatlearn** see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb)
115
118
 
116
119
  ## `ComBat` parameters
117
120
 
@@ -136,6 +139,13 @@ The following section provides a detailed explanation of all parameters availabl
136
139
  | `covbat_cov_thresh` | float, int | `0.9` | For `"chen"` method only: Cumulative variance threshold $]0,1[$ to retain PCs in PCA space (e.g., 0.9 = retain 90% explained variance). If an integer is provided, it represents the number of principal components to use. |
137
140
  | `eps` | float | `1e-8` | Small jitter value added to variances to prevent divide-by-zero errors during standardization. |
138
141
 
142
+
143
+ ### Batch Effect Correction Visualization
144
+
145
+ The `plot_transformation` method allows to visualize the **ComBat** transformation effect using dimensionality reduction, showing the before/after comparison of data transformed by `ComBat` using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
146
+
147
+ For further details see the [notebook demo](https://github.com/EttoreRocchi/combatlearn/blob/main/docs/demo/combatlearn_demo.ipynb).
148
+
139
149
  ## Contributing
140
150
 
141
151
  Pull requests, bug reports, and feature ideas are welcome: feel free to open a PR!
@@ -144,7 +154,7 @@ Pull requests, bug reports, and feature ideas are welcome: feel free to open a P
144
154
 
145
155
  [**Ettore Rocchi**](https://github.com/ettorerocchi) @ University of Bologna
146
156
 
147
- [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) $\cdot$ [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
157
+ [Google Scholar](https://scholar.google.com/citations?user=MKHoGnQAAAAJ) | [Scopus](https://www.scopus.com/authid/detail.uri?authorId=57220152522)
148
158
 
149
159
  ## Acknowledgements
150
160
 
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=wJ5E-Nrz6s7KLCHDY_p1kpUwMws-Q6Xd_1cK3JksNxU,98
2
+ combatlearn/combat.py,sha256=g6YnCVWq40j_fMU2OcXrJ1O0MCSyt2owCaZ4gfyF-Pw,37268
3
+ combatlearn-0.2.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-0.2.0.dist-info/METADATA,sha256=oEL_LK1fJUUeacf0k09I5HlEOVejeUWEGu3i-QJhL3Y,8735
5
+ combatlearn-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-0.2.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-0.2.0.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=PHezKTkdkd2fnyqihhayxRN8hducHCXug7iQ5-UsfSc,98
2
- combatlearn/combat.py,sha256=ghc83DTLC4ukLJN_xqpoWZTPPTxFa4DVtT6C5SVUjFA,25024
3
- combatlearn-0.1.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.1.2.dist-info/METADATA,sha256=VxQpyJAwOSQqw8ypiSUxq4dmszCDRW3AsO_0XBQq6pk,8213
5
- combatlearn-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.1.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.1.2.dist-info/RECORD,,