combatlearn 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
combatlearn/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from .combat import ComBatModel, ComBat
2
2
 
3
3
  __all__ = ["ComBatModel", "ComBat"]
4
- __version__ = "0.2.1"
4
+ __version__ = "0.2.2"
combatlearn/combat.py CHANGED
@@ -18,6 +18,7 @@ from sklearn.utils.validation import check_is_fitted
18
18
  from sklearn.decomposition import PCA
19
19
  from sklearn.manifold import TSNE
20
20
  import matplotlib.pyplot as plt
21
+ import matplotlib.colors as mcolors
21
22
  from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
22
23
  import numpy.typing as npt
23
24
  import warnings
@@ -811,7 +812,7 @@ class ComBat(BaseEstimator, TransformerMixin):
811
812
  else:
812
813
  fig = self._create_interactive_plot(
813
814
  X_embedded_orig, X_embedded_trans, batch_vec,
814
- reduction_method, n_components, title, show_legend
815
+ reduction_method, n_components, cmap, title, show_legend
815
816
  )
816
817
 
817
818
  if return_embeddings:
@@ -930,6 +931,7 @@ class ComBat(BaseEstimator, TransformerMixin):
930
931
  batch_labels: pd.Series,
931
932
  method: str,
932
933
  n_components: int,
934
+ cmap: str,
933
935
  title: Optional[str],
934
936
  show_legend: bool) -> Any:
935
937
  """Create interactive plots using plotly."""
@@ -953,43 +955,69 @@ class ComBat(BaseEstimator, TransformerMixin):
953
955
 
954
956
  unique_batches = batch_labels.drop_duplicates()
955
957
 
958
+ n_batches = len(unique_batches)
959
+ cmap_func = plt.cm.get_cmap(cmap)
960
+ color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
961
+
962
+ batch_to_color = dict(zip(unique_batches, color_list))
963
+
956
964
  for batch in unique_batches:
957
965
  mask = batch_labels == batch
958
966
 
959
967
  if n_components == 2:
960
968
  fig.add_trace(
961
- go.Scatter(x=X_orig[mask, 0], y=X_orig[mask, 1],
962
- mode='markers',
963
- name=f'Batch {batch}',
964
- marker=dict(size=8, line=dict(width=1, color='black')),
965
- showlegend=False),
969
+ go.Scatter(
970
+ x=X_orig[mask, 0], y=X_orig[mask, 1],
971
+ mode='markers',
972
+ name=f'Batch {batch}',
973
+ marker=dict(
974
+ size=8,
975
+ color=batch_to_color[batch],
976
+ line=dict(width=1, color='black')
977
+ ),
978
+ showlegend=False),
966
979
  row=1, col=1
967
980
  )
968
981
 
969
982
  fig.add_trace(
970
- go.Scatter(x=X_trans[mask, 0], y=X_trans[mask, 1],
971
- mode='markers',
972
- name=f'Batch {batch}',
973
- marker=dict(size=8, line=dict(width=1, color='black')),
974
- showlegend=show_legend),
983
+ go.Scatter(
984
+ x=X_trans[mask, 0], y=X_trans[mask, 1],
985
+ mode='markers',
986
+ name=f'Batch {batch}',
987
+ marker=dict(
988
+ size=8,
989
+ color=batch_to_color[batch],
990
+ line=dict(width=1, color='black')
991
+ ),
992
+ showlegend=show_legend),
975
993
  row=1, col=2
976
994
  )
977
995
  else:
978
996
  fig.add_trace(
979
- go.Scatter3d(x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
980
- mode='markers',
981
- name=f'Batch {batch}',
982
- marker=dict(size=5, line=dict(width=0.5, color='black')),
983
- showlegend=False),
997
+ go.Scatter3d(
998
+ x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
999
+ mode='markers',
1000
+ name=f'Batch {batch}',
1001
+ marker=dict(
1002
+ size=5,
1003
+ color=batch_to_color[batch],
1004
+ line=dict(width=0.5, color='black')
1005
+ ),
1006
+ showlegend=False),
984
1007
  row=1, col=1
985
1008
  )
986
1009
 
987
1010
  fig.add_trace(
988
- go.Scatter3d(x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
989
- mode='markers',
990
- name=f'Batch {batch}',
991
- marker=dict(size=5, line=dict(width=0.5, color='black')),
992
- showlegend=show_legend),
1011
+ go.Scatter3d(
1012
+ x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
1013
+ mode='markers',
1014
+ name=f'Batch {batch}',
1015
+ marker=dict(
1016
+ size=5,
1017
+ color=batch_to_color[batch],
1018
+ line=dict(width=0.5, color='black')
1019
+ ),
1020
+ showlegend=show_legend),
993
1021
  row=1, col=2
994
1022
  )
995
1023
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: combatlearn
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: Batch-effect harmonization for machine learning frameworks.
5
5
  Author-email: Ettore Rocchi <ettoreroc@gmail.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,7 @@
1
+ combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
2
+ combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
3
+ combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
+ combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
5
+ combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
+ combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
+ combatlearn-0.2.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- combatlearn/__init__.py,sha256=UzqGt-P5ZVBfK6SXGTi-OOgG5Ae5ZJO7ugZhFp3EHCM,98
2
- combatlearn/combat.py,sha256=g6YnCVWq40j_fMU2OcXrJ1O0MCSyt2owCaZ4gfyF-Pw,37268
3
- combatlearn-0.2.1.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
4
- combatlearn-0.2.1.dist-info/METADATA,sha256=zYMV3IEi0vgrGuu6dwYwkLH-cCXxQTr9GekUjUGwTgc,7491
5
- combatlearn-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
6
- combatlearn-0.2.1.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
7
- combatlearn-0.2.1.dist-info/RECORD,,