combatlearn 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- combatlearn/__init__.py +1 -1
- combatlearn/combat.py +49 -21
- {combatlearn-0.2.1.dist-info → combatlearn-0.2.2.dist-info}/METADATA +1 -1
- combatlearn-0.2.2.dist-info/RECORD +7 -0
- combatlearn-0.2.1.dist-info/RECORD +0 -7
- {combatlearn-0.2.1.dist-info → combatlearn-0.2.2.dist-info}/WHEEL +0 -0
- {combatlearn-0.2.1.dist-info → combatlearn-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {combatlearn-0.2.1.dist-info → combatlearn-0.2.2.dist-info}/top_level.txt +0 -0
combatlearn/__init__.py
CHANGED
combatlearn/combat.py
CHANGED
|
@@ -18,6 +18,7 @@ from sklearn.utils.validation import check_is_fitted
|
|
|
18
18
|
from sklearn.decomposition import PCA
|
|
19
19
|
from sklearn.manifold import TSNE
|
|
20
20
|
import matplotlib.pyplot as plt
|
|
21
|
+
import matplotlib.colors as mcolors
|
|
21
22
|
from typing import Literal, Optional, Union, Dict, Tuple, Any, cast
|
|
22
23
|
import numpy.typing as npt
|
|
23
24
|
import warnings
|
|
@@ -811,7 +812,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
811
812
|
else:
|
|
812
813
|
fig = self._create_interactive_plot(
|
|
813
814
|
X_embedded_orig, X_embedded_trans, batch_vec,
|
|
814
|
-
reduction_method, n_components, title, show_legend
|
|
815
|
+
reduction_method, n_components, cmap, title, show_legend
|
|
815
816
|
)
|
|
816
817
|
|
|
817
818
|
if return_embeddings:
|
|
@@ -930,6 +931,7 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
930
931
|
batch_labels: pd.Series,
|
|
931
932
|
method: str,
|
|
932
933
|
n_components: int,
|
|
934
|
+
cmap: str,
|
|
933
935
|
title: Optional[str],
|
|
934
936
|
show_legend: bool) -> Any:
|
|
935
937
|
"""Create interactive plots using plotly."""
|
|
@@ -953,43 +955,69 @@ class ComBat(BaseEstimator, TransformerMixin):
|
|
|
953
955
|
|
|
954
956
|
unique_batches = batch_labels.drop_duplicates()
|
|
955
957
|
|
|
958
|
+
n_batches = len(unique_batches)
|
|
959
|
+
cmap_func = plt.cm.get_cmap(cmap)
|
|
960
|
+
color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
|
|
961
|
+
|
|
962
|
+
batch_to_color = dict(zip(unique_batches, color_list))
|
|
963
|
+
|
|
956
964
|
for batch in unique_batches:
|
|
957
965
|
mask = batch_labels == batch
|
|
958
966
|
|
|
959
967
|
if n_components == 2:
|
|
960
968
|
fig.add_trace(
|
|
961
|
-
go.Scatter(
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
969
|
+
go.Scatter(
|
|
970
|
+
x=X_orig[mask, 0], y=X_orig[mask, 1],
|
|
971
|
+
mode='markers',
|
|
972
|
+
name=f'Batch {batch}',
|
|
973
|
+
marker=dict(
|
|
974
|
+
size=8,
|
|
975
|
+
color=batch_to_color[batch],
|
|
976
|
+
line=dict(width=1, color='black')
|
|
977
|
+
),
|
|
978
|
+
showlegend=False),
|
|
966
979
|
row=1, col=1
|
|
967
980
|
)
|
|
968
981
|
|
|
969
982
|
fig.add_trace(
|
|
970
|
-
go.Scatter(
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
983
|
+
go.Scatter(
|
|
984
|
+
x=X_trans[mask, 0], y=X_trans[mask, 1],
|
|
985
|
+
mode='markers',
|
|
986
|
+
name=f'Batch {batch}',
|
|
987
|
+
marker=dict(
|
|
988
|
+
size=8,
|
|
989
|
+
color=batch_to_color[batch],
|
|
990
|
+
line=dict(width=1, color='black')
|
|
991
|
+
),
|
|
992
|
+
showlegend=show_legend),
|
|
975
993
|
row=1, col=2
|
|
976
994
|
)
|
|
977
995
|
else:
|
|
978
996
|
fig.add_trace(
|
|
979
|
-
go.Scatter3d(
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
997
|
+
go.Scatter3d(
|
|
998
|
+
x=X_orig[mask, 0], y=X_orig[mask, 1], z=X_orig[mask, 2],
|
|
999
|
+
mode='markers',
|
|
1000
|
+
name=f'Batch {batch}',
|
|
1001
|
+
marker=dict(
|
|
1002
|
+
size=5,
|
|
1003
|
+
color=batch_to_color[batch],
|
|
1004
|
+
line=dict(width=0.5, color='black')
|
|
1005
|
+
),
|
|
1006
|
+
showlegend=False),
|
|
984
1007
|
row=1, col=1
|
|
985
1008
|
)
|
|
986
1009
|
|
|
987
1010
|
fig.add_trace(
|
|
988
|
-
go.Scatter3d(
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
1011
|
+
go.Scatter3d(
|
|
1012
|
+
x=X_trans[mask, 0], y=X_trans[mask, 1], z=X_trans[mask, 2],
|
|
1013
|
+
mode='markers',
|
|
1014
|
+
name=f'Batch {batch}',
|
|
1015
|
+
marker=dict(
|
|
1016
|
+
size=5,
|
|
1017
|
+
color=batch_to_color[batch],
|
|
1018
|
+
line=dict(width=0.5, color='black')
|
|
1019
|
+
),
|
|
1020
|
+
showlegend=show_legend),
|
|
993
1021
|
row=1, col=2
|
|
994
1022
|
)
|
|
995
1023
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
combatlearn/__init__.py,sha256=qZK8xAUibzM9TQJ-xho1cjMYmTGkdWvpFRTXOokNvMY,98
|
|
2
|
+
combatlearn/combat.py,sha256=pVauFEgZ7wiYRimGZe7ZhBWZN7sGQ67A3o_SrBUtoJ8,38126
|
|
3
|
+
combatlearn-0.2.2.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
+
combatlearn-0.2.2.dist-info/METADATA,sha256=CNm0pbXPVVWORk4pI97WS1DohjWOu7fB88JS1JZ-3-A,7491
|
|
5
|
+
combatlearn-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
+
combatlearn-0.2.2.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
+
combatlearn-0.2.2.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
combatlearn/__init__.py,sha256=UzqGt-P5ZVBfK6SXGTi-OOgG5Ae5ZJO7ugZhFp3EHCM,98
|
|
2
|
-
combatlearn/combat.py,sha256=g6YnCVWq40j_fMU2OcXrJ1O0MCSyt2owCaZ4gfyF-Pw,37268
|
|
3
|
-
combatlearn-0.2.1.dist-info/licenses/LICENSE,sha256=O34CBRTmdL59PxDYOa6nq1N0-2A9xyXGkBXKbsL1NeY,1070
|
|
4
|
-
combatlearn-0.2.1.dist-info/METADATA,sha256=zYMV3IEi0vgrGuu6dwYwkLH-cCXxQTr9GekUjUGwTgc,7491
|
|
5
|
-
combatlearn-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
6
|
-
combatlearn-0.2.1.dist-info/top_level.txt,sha256=3cFQv4oj2sh_NKra45cPy8Go0v8W9x9-zkkUibqZCMk,12
|
|
7
|
-
combatlearn-0.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|