combatlearn 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .combat import ComBatModel, ComBat
2
2
 
3
3
  __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.2.0"
4
+ __version__ = "0.2.2"
combatlearn/combat.py CHANGED
@@ -18,6 +18,7 @@ from sklearn.utils.validation import check_is_fitted
18
18
  from sklearn.decomposition import PCA
19
19
  from sklearn.manifold import TSNE
20
20
  import matplotlib.pyplot as plt
21
+ import matplotlib.colors as mcolors
21
22
  from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
22
23
  import numpy.typing as npt
23
24
  import warnings
@@ -811,7 +812,7 @@ class ComBat(BaseEstimator, TransformerMixin):
811
812
  else:
812
813
  fig = self._create_interactive_plot(
813
814
  X_embedded_orig, X_embedded_trans, batch_vec,
814
- reduction_method, n_components, title, show_legend
815
+ reduction_method, n_components, cmap, title, show_legend
815
816
  )
816
817
 
817
818
  if return_embeddings:
@@ -930,6 +931,7 @@ class ComBat(BaseEstimator, TransformerMixin):
930
931
  batch_labels: pd.Series,
931
932
  method: str,
932
933
  n_components: int,
934
+ cmap: str,
933
935
  title: Optional[str],
934
936
  show_legend: bool) -> Any:
935
937
  """Create interactive plots using plotly."""
@@ -953,43 +955,69 @@ class ComBat(BaseEstimator, TransformerMixin):
953
955
 
954
956
  unique_batches = batch_labels.drop_duplicates()
955
957
 
958
+ n_batches = len(unique_batches)
959
+ cmap_func = plt.cm.get_cmap(cmap)
960
+ color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
961
+
962
+ batch_to_color = dict(zip(unique_batches, color_list))
963
+
956
964
  for batch in unique_batches:
957
965
  mask = batch_labels == batch
958
966
 
959
967
  if n_components == 2:
960
968
  fig.add_trace(
961
- go.Scatter(x=X_orig[mask, 0], y=X_orig[mask, 1],
962
- mode='markers',
963
- name=f'Batch {batch}',
964
- marker=dict(size=8, line=dict(width=1, color='black')),
965
- showlegend=False),
969
+ go.Scatter(
970
+ x=X_orig[mask, 0], y=X_orig[mask, 1],
971
+ mode='markers',
972
+ name=f'Batch {batch}',
973
+ marker=dict(
974
+ size=8,
975
+ color=batch_to_color[batch],
976
+ line=dict(width=1, color='black')
977
+ ),
978
+ showlegend=False),
966
979
  row=1, col=1
967
980
  )
968
981
 
969
982
  fig.add_trace(
970
- go.Scatter(x=X_trans[mask, 0], y=X_trans[mask, 1],
971
- mode='markers',
972
- name=f'Batch {batch}',
973
- marker=dict(size=8, line=dict(width=1, color='black')),
974
- showlegend=show_legend),
983
+ go.Scatter(
984
+ x=X_trans[mask, 0], y=X_trans[mask, 1],
985
+ mode='markers',
986
+ name=f'Batch {batch}',
987
+ marker=dict(
988
+ size=8,
989
+ color=batch_to_color[batch],
990
+ line=dict(width=1, color='black')
991
+ ),
992
+ showlegend=show_legend),
975
993
  row=1, col=2
976
994
  )
977
995
  else:
978
996
  fig.add_trace(
979
- go.Scatter3d(x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
980
- mode='markers',
981
- name=f'Batch {batch}',
982
- marker=dict(size=5, line=dict(width=0.5, color='black')),
983
- showlegend=False),
997
+ go.Scatter3d(
998
+ x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
999
+ mode='markers',
1000
+ name=f'Batch {batch}',
1001
+ marker=dict(
1002
+ size=5,
1003
+ color=batch_to_color[batch],
1004
+ line=dict(width=0.5, color='black')
1005
+ ),
1006
+ showlegend=False),
984
1007
  row=1, col=1
985
1008
  )
986
1009
 
987
1010
  fig.add_trace(
988
- go.Scatter3d(x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
989
- mode='markers',
990
- name=f'Batch {batch}',
991
- marker=dict(size=5, line=dict(width=0.5, color='black')),
992
- showlegend=show_legend),
1011
+ go.Scatter3d(
1012
+ x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
1013
+ mode='markers',
1014
+ name=f'Batch {batch}',
1015
+ marker=dict(
1016
+ size=5,
1017
+ color=batch_to_color[batch],
1018
+ line=dict(width=0.5, color='black')
1019
+ ),
1020
+ showlegend=show_legend),
993
1021
  row=1, col=2
994
1022
  )
995
1023
 
@@ -1,34 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
- License: MIT License
7
-
8
- Copyright (c) 2025 Ettore Rocchi
9
-
10
- Permission is hereby granted, free of charge, to any person obtaining a copy
11
- of this software and associated documentation files (the "Software"), to deal
12
- in the Software without restriction, including without limitation the rights
13
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
- copies of the Software, and to permit persons to whom the Software is
15
- furnished to do so, subject to the following conditions:
16
-
17
- The above copyright notice and this permission notice shall be included in all
18
- copies or substantial portions of the Software.
19
-
20
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
- SOFTWARE.
27
-
6
+ License-Expression: MIT
28
7
  Keywords: machine-learning,harmonization,combat,preprocessing
29
8
  Classifier: Development Status :: 3 - Alpha
30
9
  Classifier: Intended Audience :: Science/Research
31
- Classifier: License :: OSI Approved :: MIT License
32
10
  Classifier: Operating System :: OS Independent
33
11
  Classifier: Programming Language :: Python :: 3
34
12
  Requires-Python: >=3.10
@@ -37,6 +15,7 @@ License-File: LICENSE
37
15
  Requires-Dist: pandas>=1.3
38
16
  Requires-Dist: numpy>=1.21
39
17
  Requires-Dist: scikit-learn>=1.2
18
+ Requires-Dist: matplotlib>=3.4
40
19
  Requires-Dist: plotly>=5.0
41
20
  Requires-Dist: nbformat>=4.2
42
21
  Requires-Dist: umap-learn>=0.5
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
2
+ combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
3
+ combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
5
+ combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-0.2.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=wJ5E-Nrz6s7KLCHDY_p1kpUwMws-Q6Xd_1cK3JksNxU,98
2
- combatlearn/combat.py,sha256=g6YnCVWq40j_fMU2OcXrJ1O0MCSyt2owCaZ4gfyF-Pw,37268
3
- combatlearn-0.2.0.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.2.0.dist-info/METADATA,sha256=oEL_LK1fJUUeacf0k09I5HlEOVejeUWEGu3i-QJhL3Y,8735
5
- combatlearn-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.2.0.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.2.0.dist-info/RECORD,,