py2ls 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.DS_Store +0 -0
- py2ls/data/sns_info.json +74 -0
- py2ls/data/usages_pd.json +56 -0
- py2ls/data/usages_sns.json +25 -0
- py2ls/ips.py +1253 -517
- py2ls/plot.py +746 -30
- py2ls/stats.py +18 -9
- py2ls/update2usage.py +126 -0
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/METADATA +1 -1
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/RECORD +11 -7
- {py2ls-0.2.1.dist-info → py2ls-0.2.3.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -9,24 +9,32 @@ import matplotlib.ticker as tck
|
|
9
9
|
from cycler import cycler
|
10
10
|
import logging
|
11
11
|
import os
|
12
|
+
import re
|
12
13
|
|
13
|
-
from .ips import fsave, fload, mkdir, listdir, figsave
|
14
|
+
from .ips import fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit
|
14
15
|
from .stats import *
|
16
|
+
from .netfinder import get_soup, fetch
|
17
|
+
|
15
18
|
|
16
19
|
# Suppress INFO messages from fontTools
|
17
20
|
logging.getLogger("fontTools").setLevel(logging.WARNING)
|
18
21
|
|
19
22
|
|
20
|
-
def
|
21
|
-
|
22
|
-
|
23
|
+
def heatmap(
|
24
|
+
data,
|
25
|
+
ax=None,
|
26
|
+
kind="corr", #'corr','direct','pivot'
|
27
|
+
columns="all", # pivot, default: coll numeric columns
|
28
|
+
index=None, # pivot
|
29
|
+
values=None, # pivot
|
23
30
|
tri="u",
|
24
31
|
mask=True,
|
25
32
|
k=1,
|
26
33
|
annot=True,
|
27
34
|
cmap="coolwarm",
|
28
35
|
fmt=".2f",
|
29
|
-
cluster=False,
|
36
|
+
cluster=False,
|
37
|
+
inplace=False,
|
30
38
|
figsize=(10, 8),
|
31
39
|
row_cluster=True, # Perform clustering on rows
|
32
40
|
col_cluster=True, # Perform clustering on columns
|
@@ -36,24 +44,115 @@ def df_corr(
|
|
36
44
|
yticklabels=True, # Show row labels
|
37
45
|
**kwargs,
|
38
46
|
):
|
47
|
+
if ax is None and not cluster:
|
48
|
+
ax = plt.gca()
|
39
49
|
# Select numeric columns or specific subset of columns
|
40
50
|
if columns == "all":
|
41
|
-
df_numeric =
|
51
|
+
df_numeric = data.select_dtypes(include=[float, int])
|
42
52
|
else:
|
43
|
-
df_numeric =
|
44
|
-
|
45
|
-
|
46
|
-
|
53
|
+
df_numeric = data[columns]
|
54
|
+
|
55
|
+
kinds = ["corr", "direct", "pivot"]
|
56
|
+
kind = strcmp(kind, kinds)[0]
|
57
|
+
if kind == "corr":
|
58
|
+
# Compute the correlation matrix
|
59
|
+
data4heatmap = df_numeric.corr()
|
60
|
+
# Generate mask for the upper triangle if mask is True
|
61
|
+
if mask:
|
62
|
+
if "u" in tri.lower(): # upper => np.tril
|
63
|
+
mask_array = np.tril(np.ones_like(data4heatmap, dtype=bool), k=k)
|
64
|
+
else: # lower => np.triu
|
65
|
+
mask_array = np.triu(np.ones_like(data4heatmap, dtype=bool), k=k)
|
66
|
+
else:
|
67
|
+
mask_array = None
|
68
|
+
|
69
|
+
# Remove conflicting kwargs
|
70
|
+
kwargs.pop("mask", None)
|
71
|
+
kwargs.pop("annot", None)
|
72
|
+
kwargs.pop("cmap", None)
|
73
|
+
kwargs.pop("fmt", None)
|
74
|
+
|
75
|
+
kwargs.pop("clustermap", None)
|
76
|
+
kwargs.pop("row_cluster", None)
|
77
|
+
kwargs.pop("col_cluster", None)
|
78
|
+
kwargs.pop("dendrogram_ratio", None)
|
79
|
+
kwargs.pop("cbar_pos", None)
|
80
|
+
kwargs.pop("xticklabels", None)
|
81
|
+
kwargs.pop("col_cluster", None)
|
82
|
+
|
83
|
+
# Plot the heatmap or clustermap
|
84
|
+
if cluster:
|
85
|
+
# Create a clustermap
|
86
|
+
cluster_obj = sns.clustermap(
|
87
|
+
data4heatmap,
|
88
|
+
# ax=ax,
|
89
|
+
mask=mask_array,
|
90
|
+
annot=annot,
|
91
|
+
cmap=cmap,
|
92
|
+
fmt=fmt,
|
93
|
+
figsize=figsize, # Figure size, adjusted for professional display
|
94
|
+
row_cluster=row_cluster, # Perform clustering on rows
|
95
|
+
col_cluster=col_cluster, # Perform clustering on columns
|
96
|
+
dendrogram_ratio=dendrogram_ratio, # Adjust size of dendrograms
|
97
|
+
cbar_pos=cbar_pos, # Adjust colorbar position
|
98
|
+
xticklabels=xticklabels, # Show column labels
|
99
|
+
yticklabels=yticklabels, # Show row labels
|
100
|
+
**kwargs, # Pass any additional arguments to sns.clustermap
|
101
|
+
)
|
102
|
+
df_row_cluster = pd.DataFrame()
|
103
|
+
df_col_cluster = pd.DataFrame()
|
104
|
+
if row_cluster:
|
105
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
106
|
+
from scipy.spatial.distance import pdist
|
107
|
+
|
108
|
+
# Compute pairwise distances
|
109
|
+
distances = pdist(data, metric="euclidean")
|
110
|
+
# Perform hierarchical clustering
|
111
|
+
linkage_matrix = linkage(distances, method="average")
|
112
|
+
# Get cluster assignments based on the distance threshold
|
113
|
+
row_clusters_value = fcluster(
|
114
|
+
linkage_matrix, t=1.5, criterion="distance"
|
115
|
+
)
|
116
|
+
df_row_cluster["row_cluster"] = row_clusters_value
|
117
|
+
if col_cluster:
|
118
|
+
col_distances = pdist(
|
119
|
+
data4heatmap.T, metric="euclidean"
|
120
|
+
) # Transpose for column clustering
|
121
|
+
col_linkage_matrix = linkage(col_distances, method="average")
|
122
|
+
col_clusters_value = fcluster(
|
123
|
+
col_linkage_matrix, t=1.5, criterion="distance"
|
124
|
+
)
|
125
|
+
df_col_cluster = pd.DataFrame(
|
126
|
+
{"Cluster": col_clusters_value}, index=data4heatmap.columns
|
127
|
+
)
|
47
128
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
129
|
+
return (
|
130
|
+
cluster_obj.ax_row_dendrogram,
|
131
|
+
cluster_obj.ax_col_dendrogram,
|
132
|
+
cluster_obj.ax_heatmap,
|
133
|
+
df_row_cluster,
|
134
|
+
df_col_cluster,
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
# Create a standard heatmap
|
138
|
+
ax = sns.heatmap(
|
139
|
+
data4heatmap,
|
140
|
+
ax=ax,
|
141
|
+
mask=mask_array,
|
142
|
+
annot=annot,
|
143
|
+
cmap=cmap,
|
144
|
+
fmt=fmt,
|
145
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
146
|
+
)
|
147
|
+
# Return the Axes object for further customization if needed
|
148
|
+
return ax
|
149
|
+
elif kind == "direct":
|
150
|
+
data4heatmap = df_numeric
|
151
|
+
elif kind == "pivot":
|
152
|
+
print('need 3 param: e.g., index="Task", columns="Model", values="Score"')
|
153
|
+
data4heatmap = data.pivot(index=index, columns=columns, values=values)
|
54
154
|
else:
|
55
|
-
|
56
|
-
|
155
|
+
print(f'"{kind}" is not supported')
|
57
156
|
# Remove conflicting kwargs
|
58
157
|
kwargs.pop("mask", None)
|
59
158
|
kwargs.pop("annot", None)
|
@@ -72,8 +171,9 @@ def df_corr(
|
|
72
171
|
if cluster:
|
73
172
|
# Create a clustermap
|
74
173
|
cluster_obj = sns.clustermap(
|
75
|
-
|
76
|
-
|
174
|
+
data4heatmap,
|
175
|
+
# ax=ax,
|
176
|
+
# mask=mask_array,
|
77
177
|
annot=annot,
|
78
178
|
cmap=cmap,
|
79
179
|
fmt=fmt,
|
@@ -86,18 +186,43 @@ def df_corr(
|
|
86
186
|
yticklabels=yticklabels, # Show row labels
|
87
187
|
**kwargs, # Pass any additional arguments to sns.clustermap
|
88
188
|
)
|
189
|
+
df_row_cluster = pd.DataFrame()
|
190
|
+
df_col_cluster = pd.DataFrame()
|
191
|
+
if row_cluster:
|
192
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
193
|
+
from scipy.spatial.distance import pdist
|
194
|
+
|
195
|
+
# Compute pairwise distances
|
196
|
+
distances = pdist(data, metric="euclidean")
|
197
|
+
# Perform hierarchical clustering
|
198
|
+
linkage_matrix = linkage(distances, method="average")
|
199
|
+
# Get cluster assignments based on the distance threshold
|
200
|
+
row_clusters_value = fcluster(linkage_matrix, t=1.5, criterion="distance")
|
201
|
+
df_row_cluster["row_cluster"] = row_clusters_value
|
202
|
+
if col_cluster:
|
203
|
+
col_distances = pdist(
|
204
|
+
data4heatmap.T, metric="euclidean"
|
205
|
+
) # Transpose for column clustering
|
206
|
+
col_linkage_matrix = linkage(col_distances, method="average")
|
207
|
+
col_clusters_value = fcluster(
|
208
|
+
col_linkage_matrix, t=1.5, criterion="distance"
|
209
|
+
)
|
210
|
+
df_col_cluster = pd.DataFrame(
|
211
|
+
{"Cluster": col_clusters_value}, index=data4heatmap.columns
|
212
|
+
)
|
89
213
|
|
90
214
|
return (
|
91
215
|
cluster_obj.ax_row_dendrogram,
|
92
216
|
cluster_obj.ax_col_dendrogram,
|
93
217
|
cluster_obj.ax_heatmap,
|
218
|
+
df_row_cluster,
|
219
|
+
df_col_cluster,
|
94
220
|
)
|
95
221
|
else:
|
96
222
|
# Create a standard heatmap
|
97
|
-
plt.figure(figsize=figsize)
|
98
223
|
ax = sns.heatmap(
|
99
|
-
|
100
|
-
|
224
|
+
data4heatmap,
|
225
|
+
ax=ax,
|
101
226
|
annot=annot,
|
102
227
|
cmap=cmap,
|
103
228
|
fmt=fmt,
|
@@ -107,6 +232,60 @@ def df_corr(
|
|
107
232
|
return ax
|
108
233
|
|
109
234
|
|
235
|
+
# !usage: py2ls.plot.heatmap()
|
236
|
+
# penguins_clean = penguins.replace([np.inf, -np.inf], np.nan).dropna()
|
237
|
+
# from py2ls import plot
|
238
|
+
|
239
|
+
# _, axs = plt.subplots(2, 2, figsize=(10, 10))
|
240
|
+
# # kind='pivot'
|
241
|
+
# plot.heatmap(
|
242
|
+
# ax=axs[0][0],
|
243
|
+
# data=sns.load_dataset("glue"),
|
244
|
+
# kind="pi",
|
245
|
+
# index="Model",
|
246
|
+
# columns="Task",
|
247
|
+
# values="Score",
|
248
|
+
# fmt=".1f",
|
249
|
+
# cbar_kws=dict(shrink=1),
|
250
|
+
# annot_kws=dict(size=7),
|
251
|
+
# )
|
252
|
+
# # kind='direct'
|
253
|
+
# plot.heatmap(
|
254
|
+
# ax=axs[0][1],
|
255
|
+
# data=sns.load_dataset("penguins").iloc[:10, 2:6],
|
256
|
+
# kind="direct",
|
257
|
+
# tri="lower",
|
258
|
+
# fmt=".1f",
|
259
|
+
# k=1,
|
260
|
+
# cbar_kws=dict(shrink=1),
|
261
|
+
# annot_kws=dict(size=7),
|
262
|
+
# )
|
263
|
+
|
264
|
+
# # kind='corr'
|
265
|
+
# plot.heatmap(
|
266
|
+
# ax=axs[1][0],
|
267
|
+
# data=sns.load_dataset("penguins"),
|
268
|
+
# kind="corr",
|
269
|
+
# fmt=".1f",
|
270
|
+
# k=-1,
|
271
|
+
# cbar_kws=dict(shrink=1),
|
272
|
+
# annot_kws=dict(size=7),
|
273
|
+
# )
|
274
|
+
# # kind='corr'
|
275
|
+
# plot.heatmap(
|
276
|
+
# ax=axs[1][1],
|
277
|
+
# data=penguins_clean.iloc[:15, :10],
|
278
|
+
# kind="direct",
|
279
|
+
# tri="lower",
|
280
|
+
# fmt=".1f",
|
281
|
+
# k=1,
|
282
|
+
# annot=False,
|
283
|
+
# cluster=True,
|
284
|
+
# cbar_kws=dict(shrink=1),
|
285
|
+
# annot_kws=dict(size=7),
|
286
|
+
# )
|
287
|
+
|
288
|
+
|
110
289
|
def catplot(data, *args, **kwargs):
|
111
290
|
"""
|
112
291
|
catplot(data, opt=None, ax=None)
|
@@ -1524,6 +1703,10 @@ def figsets(*args, **kwargs):
|
|
1524
1703
|
alignment='left')
|
1525
1704
|
)
|
1526
1705
|
"""
|
1706
|
+
import matplotlib
|
1707
|
+
|
1708
|
+
matplotlib.rc("text", usetex=False)
|
1709
|
+
|
1527
1710
|
fig = plt.gcf()
|
1528
1711
|
fontsize = 11
|
1529
1712
|
fontname = "Arial"
|
@@ -1615,6 +1798,16 @@ def figsets(*args, **kwargs):
|
|
1615
1798
|
if isinstance(value, list):
|
1616
1799
|
loc = []
|
1617
1800
|
for i in value:
|
1801
|
+
ax.tick_params(
|
1802
|
+
axis="both",
|
1803
|
+
which="both",
|
1804
|
+
bottom=False,
|
1805
|
+
top=False,
|
1806
|
+
left=False,
|
1807
|
+
right=False,
|
1808
|
+
labelbottom=False,
|
1809
|
+
labelleft=False,
|
1810
|
+
)
|
1618
1811
|
if ("l" in i.lower()) and ("a" not in i.lower()):
|
1619
1812
|
ax.yaxis.set_ticks_position("left")
|
1620
1813
|
if "r" in i.lower():
|
@@ -1624,12 +1817,38 @@ def figsets(*args, **kwargs):
|
|
1624
1817
|
if "b" in i.lower():
|
1625
1818
|
ax.xaxis.set_ticks_position("bottom")
|
1626
1819
|
if i.lower() in ["a", "both", "all", "al", ":"]:
|
1627
|
-
ax.
|
1628
|
-
|
1820
|
+
ax.tick_params(
|
1821
|
+
axis="both", # Apply to both axes
|
1822
|
+
which="both", # Apply to both major and minor ticks
|
1823
|
+
bottom=True, # Show ticks at the bottom
|
1824
|
+
top=True, # Show ticks at the top
|
1825
|
+
left=True, # Show ticks on the left
|
1826
|
+
right=True, # Show ticks on the right
|
1827
|
+
labelbottom=True, # Show labels at the bottom
|
1828
|
+
labelleft=True, # Show labels on the left
|
1829
|
+
)
|
1629
1830
|
if i.lower() in ["xnone", "xoff", "none"]:
|
1630
|
-
ax.
|
1831
|
+
ax.tick_params(
|
1832
|
+
axis="x",
|
1833
|
+
which="both",
|
1834
|
+
bottom=False,
|
1835
|
+
top=False,
|
1836
|
+
left=False,
|
1837
|
+
right=False,
|
1838
|
+
labelbottom=False,
|
1839
|
+
labelleft=False,
|
1840
|
+
)
|
1631
1841
|
if i.lower() in ["ynone", "yoff", "none"]:
|
1632
|
-
ax.
|
1842
|
+
ax.tick_params(
|
1843
|
+
axis="y",
|
1844
|
+
which="both",
|
1845
|
+
bottom=False,
|
1846
|
+
top=False,
|
1847
|
+
left=False,
|
1848
|
+
right=False,
|
1849
|
+
labelbottom=False,
|
1850
|
+
labelleft=False,
|
1851
|
+
)
|
1633
1852
|
# ticks / labels
|
1634
1853
|
elif "x" in key.lower():
|
1635
1854
|
if value is None:
|
@@ -1674,6 +1893,10 @@ def figsets(*args, **kwargs):
|
|
1674
1893
|
|
1675
1894
|
if "bo" in key in key: # box setting, and ("p" in key or "l" in key):
|
1676
1895
|
if isinstance(value, (str, list)):
|
1896
|
+
# locations = ["left", "right", "top", "bottom"]
|
1897
|
+
# for loc, spi in ax.spines.items():
|
1898
|
+
# if loc in locations:
|
1899
|
+
# spi.set_color("none") # no spine
|
1677
1900
|
locations = []
|
1678
1901
|
for i in value:
|
1679
1902
|
if "l" in i.lower() and not "t" in i.lower():
|
@@ -1689,12 +1912,12 @@ def figsets(*args, **kwargs):
|
|
1689
1912
|
locations.append(x)
|
1690
1913
|
for x in ["left", "right", "top", "bottom"]
|
1691
1914
|
]
|
1692
|
-
|
1693
|
-
|
1694
|
-
locations = []
|
1915
|
+
if "none" in value:
|
1916
|
+
locations = [] # hide all
|
1695
1917
|
# check spines
|
1696
1918
|
for loc, spi in ax.spines.items():
|
1697
1919
|
if loc in locations:
|
1920
|
+
# spi.set_color("k")
|
1698
1921
|
spi.set_position(("outward", 0))
|
1699
1922
|
else:
|
1700
1923
|
spi.set_color("none") # no spine
|
@@ -2527,3 +2750,496 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, usage=F
|
|
2527
2750
|
plt.tight_layout()
|
2528
2751
|
if show:
|
2529
2752
|
plt.show()
|
2753
|
+
|
2754
|
+
|
2755
|
+
def get_params_from_func_usage(function_signature):
|
2756
|
+
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2757
|
+
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2758
|
+
# Find all matches
|
2759
|
+
matches = re.findall(keys_pattern, function_signature)
|
2760
|
+
return matches
|
2761
|
+
|
2762
|
+
|
2763
|
+
def plot_xy(
|
2764
|
+
data: pd.DataFrame = None,
|
2765
|
+
x=None,
|
2766
|
+
y=None,
|
2767
|
+
ax=None,
|
2768
|
+
kind: str = None, # Specify the kind of plot
|
2769
|
+
usage=False,
|
2770
|
+
# kws_figsets:dict=None,
|
2771
|
+
**kwargs,
|
2772
|
+
):
|
2773
|
+
"""
|
2774
|
+
e.g., plot_xy(data=data_log, x="Component_1", y="Component_2", hue="Cluster",kind='scater)
|
2775
|
+
Create a variety of plots based on the kind parameter.
|
2776
|
+
|
2777
|
+
Parameters:
|
2778
|
+
data (pd.DataFrame): DataFrame containing the data.
|
2779
|
+
x (str): Column name for the x-axis.
|
2780
|
+
y (str): Column name for the y-axis.
|
2781
|
+
hue (str): Column name for the hue (color) grouping.
|
2782
|
+
ax: Matplotlib axes object for the plot.
|
2783
|
+
kind (str): Type of plot ('scatter', 'line', 'displot', 'kdeplot', etc.).
|
2784
|
+
usage (bool): If True, print default settings instead of plotting.
|
2785
|
+
**kwargs: Additional keyword arguments for the plot functions.
|
2786
|
+
|
2787
|
+
Returns:
|
2788
|
+
ax or FacetGrid: Matplotlib axes object or FacetGrid for displot.
|
2789
|
+
"""
|
2790
|
+
# Check for valid plot kind
|
2791
|
+
# Default arguments for various plot types
|
2792
|
+
default_settings = fload(
|
2793
|
+
"/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_sns.json"
|
2794
|
+
)
|
2795
|
+
sns_info = pd.DataFrame(
|
2796
|
+
fload(
|
2797
|
+
"/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/sns_info.json",
|
2798
|
+
)
|
2799
|
+
)
|
2800
|
+
valid_kinds = list(default_settings.keys())
|
2801
|
+
print(valid_kinds)
|
2802
|
+
if kind is not None:
|
2803
|
+
if isinstance(kind, str):
|
2804
|
+
kind = [kind]
|
2805
|
+
kind = [strcmp(i, valid_kinds)[0] for i in kind]
|
2806
|
+
else:
|
2807
|
+
usage = True
|
2808
|
+
if usage:
|
2809
|
+
if kind is not None:
|
2810
|
+
for k in kind:
|
2811
|
+
if k in valid_kinds:
|
2812
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
2813
|
+
print(
|
2814
|
+
sns_info[sns_info["Functions"].str.contains(k)]
|
2815
|
+
.iloc[:, -1]
|
2816
|
+
.tolist()[0]
|
2817
|
+
)
|
2818
|
+
print()
|
2819
|
+
usage_str = """plot_xy(data=ranked_genes,
|
2820
|
+
x="log2(fold_change)",
|
2821
|
+
y="-log10(p-value)",
|
2822
|
+
palette=get_color(3, cmap="coolwarm"),
|
2823
|
+
kind=["scatter","rug"],
|
2824
|
+
kws_rug=dict(height=0.2),
|
2825
|
+
kws_scatter=dict(s=20, color=get_color(3)[2]),
|
2826
|
+
usage=0)
|
2827
|
+
"""
|
2828
|
+
print(f"currently support to plot:\n{valid_kinds}\n\nusage:\n{usage_str}")
|
2829
|
+
return # Do not plot, just print the usage
|
2830
|
+
|
2831
|
+
kws_figsets = {}
|
2832
|
+
for k_arg, v_arg in kwargs.items():
|
2833
|
+
if "figset" in k_arg:
|
2834
|
+
kws_figsets = v_arg
|
2835
|
+
kwargs.pop(k_arg, None)
|
2836
|
+
break
|
2837
|
+
|
2838
|
+
for k in kind:
|
2839
|
+
# indicate 'col' features
|
2840
|
+
col = kwargs.get("col", None)
|
2841
|
+
sns_with_col = [
|
2842
|
+
"catplot",
|
2843
|
+
"histplot",
|
2844
|
+
"relplot",
|
2845
|
+
"lmplot",
|
2846
|
+
"pairplot",
|
2847
|
+
"displot",
|
2848
|
+
"kdeplot",
|
2849
|
+
]
|
2850
|
+
if col is not None:
|
2851
|
+
if not k in sns_with_col:
|
2852
|
+
print(
|
2853
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
2854
|
+
)
|
2855
|
+
# (1) return FcetGrid
|
2856
|
+
if k == "jointplot":
|
2857
|
+
kws_joint = kwargs.pop("kws_joint", kwargs)
|
2858
|
+
g = sns.jointplot(data=data, x=x, y=y, **kws_joint)
|
2859
|
+
elif k == "lmplot":
|
2860
|
+
kws_lm = kwargs.pop("kws_lm", kwargs)
|
2861
|
+
g = sns.lmplot(data=data, x=x, y=y, **kws_lm)
|
2862
|
+
elif k == "catplot_sns":
|
2863
|
+
kws_cat = kwargs.pop("kws_cat", kwargs)
|
2864
|
+
g = sns.catplot(data=data, x=x, y=y, **kws_cat)
|
2865
|
+
elif k == "displot":
|
2866
|
+
kws_dis = kwargs.pop("kws_dis", kwargs)
|
2867
|
+
# displot creates a new figure and returns a FacetGrid
|
2868
|
+
g = sns.displot(data=data, x=x, **kws_dis)
|
2869
|
+
|
2870
|
+
# (2) return axis
|
2871
|
+
if ax is None:
|
2872
|
+
ax = plt.gca()
|
2873
|
+
|
2874
|
+
if k == "catplot":
|
2875
|
+
kws_cat = kwargs.pop("kws_cat", kwargs)
|
2876
|
+
g = catplot(data=data, x=x, y=y, ax=ax, **kws_cat)
|
2877
|
+
elif k == "stdshade":
|
2878
|
+
kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
|
2879
|
+
ax = stdshade(ax=ax, **kwargs)
|
2880
|
+
elif k == "scatterplot":
|
2881
|
+
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
2882
|
+
palette = kws_scatter.pop(
|
2883
|
+
"palette",
|
2884
|
+
(
|
2885
|
+
sns.color_palette("tab10", data[hue].nunique())
|
2886
|
+
if hue is not None
|
2887
|
+
else sns.color_palette("tab10")
|
2888
|
+
),
|
2889
|
+
)
|
2890
|
+
s = kws_scatter.pop("s", 10)
|
2891
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
2892
|
+
ax = sns.scatterplot(
|
2893
|
+
ax=ax,
|
2894
|
+
data=data,
|
2895
|
+
x=x,
|
2896
|
+
y=y,
|
2897
|
+
hue=hue,
|
2898
|
+
palette=palette,
|
2899
|
+
s=s,
|
2900
|
+
alpha=alpha,
|
2901
|
+
**kws_scatter,
|
2902
|
+
)
|
2903
|
+
elif k == "histplot":
|
2904
|
+
kws_hist = kwargs.pop("kws_hist", kwargs)
|
2905
|
+
ax = sns.histplot(data=data, x=x, ax=ax, **kws_hist)
|
2906
|
+
elif k == "kdeplot":
|
2907
|
+
kws_kde = kwargs.pop("kws_kde", kwargs)
|
2908
|
+
ax = sns.kdeplot(data=data, x=x, ax=ax, **kws_kde)
|
2909
|
+
elif k == "ecdfplot":
|
2910
|
+
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
2911
|
+
ax = sns.ecdfplot(data=data, x=x, ax=ax, **kws_ecdf)
|
2912
|
+
elif k == "rugplot":
|
2913
|
+
kws_rug = kwargs.pop("kws_rug", kwargs)
|
2914
|
+
print(kws_rug)
|
2915
|
+
ax = sns.rugplot(data=data, x=x, ax=ax, **kws_rug)
|
2916
|
+
elif k == "stripplot":
|
2917
|
+
kws_strip = kwargs.pop("kws_strip", kwargs)
|
2918
|
+
ax = sns.stripplot(data=data, x=x, y=y, ax=ax, **kws_strip)
|
2919
|
+
elif k == "swarmplot":
|
2920
|
+
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
2921
|
+
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, **kws_swarm)
|
2922
|
+
elif k == "boxplot":
|
2923
|
+
kws_box = kwargs.pop("kws_box", kwargs)
|
2924
|
+
ax = sns.boxplot(data=data, x=x, y=y, ax=ax, **kws_box)
|
2925
|
+
elif k == "violinplot":
|
2926
|
+
kws_violin = kwargs.pop("kws_violin", kwargs)
|
2927
|
+
ax = sns.violinplot(data=data, x=x, y=y, ax=ax, **kws_violin)
|
2928
|
+
elif k == "boxenplot":
|
2929
|
+
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
2930
|
+
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, **kws_boxen)
|
2931
|
+
elif k == "pointplot":
|
2932
|
+
kws_point = kwargs.pop("kws_point", kwargs)
|
2933
|
+
ax = sns.pointplot(data=data, x=x, y=y, ax=ax, **kws_point)
|
2934
|
+
elif k == "barplot":
|
2935
|
+
kws_bar = kwargs.pop("kws_bar", kwargs)
|
2936
|
+
ax = sns.barplot(data=data, x=x, y=y, ax=ax, **kws_bar)
|
2937
|
+
elif k == "countplot":
|
2938
|
+
kws_count = kwargs.pop("kws_count", kwargs)
|
2939
|
+
ax = sns.countplot(data=data, x=x, ax=ax, **kws_count)
|
2940
|
+
elif k == "regplot":
|
2941
|
+
kws_reg = kwargs.pop("kws_reg", kwargs)
|
2942
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
|
2943
|
+
elif k == "residplot":
|
2944
|
+
kws_resid = kwargs.pop("kws_resid", kwargs)
|
2945
|
+
ax = sns.residplot(data=data, x=x, y=y, lowess=True, ax=ax, **kws_resid)
|
2946
|
+
elif k == "lineplot":
|
2947
|
+
kws_line = kwargs.pop("kws_line", kwargs)
|
2948
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, **kws_line)
|
2949
|
+
|
2950
|
+
figsets(**kws_figsets)
|
2951
|
+
print(k, " ⤵ ")
|
2952
|
+
print(default_settings[k])
|
2953
|
+
print(
|
2954
|
+
"=>\t",
|
2955
|
+
sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],
|
2956
|
+
)
|
2957
|
+
print()
|
2958
|
+
if "g" in locals():
|
2959
|
+
if ax is not None:
|
2960
|
+
return g, ax
|
2961
|
+
return ax
|
2962
|
+
|
2963
|
+
|
2964
|
+
def volcano(
|
2965
|
+
data,
|
2966
|
+
x,
|
2967
|
+
y,
|
2968
|
+
gene_col=None,
|
2969
|
+
top_genes=5,
|
2970
|
+
thr_x=np.log2(1.5),
|
2971
|
+
thr_y=-np.log10(0.05),
|
2972
|
+
colors=("#e70b0b", "#0d26e3", "#b8bbbe"),
|
2973
|
+
s=20,
|
2974
|
+
fill=True, # plot filled scatter
|
2975
|
+
facecolor="none",
|
2976
|
+
edgecolor="none",
|
2977
|
+
edgelinewidth=0.5,
|
2978
|
+
alpha=0.8,
|
2979
|
+
legend=False,
|
2980
|
+
ax=None,
|
2981
|
+
usage=False,
|
2982
|
+
kws_arrow=None,
|
2983
|
+
kws_text=None,
|
2984
|
+
**kwargs,
|
2985
|
+
):
|
2986
|
+
"""
|
2987
|
+
Generates a customizable scatter plot (e.g., volcano plot).
|
2988
|
+
|
2989
|
+
Parameters:
|
2990
|
+
-----------
|
2991
|
+
data : pd.DataFrame
|
2992
|
+
The DataFrame containing the data to plot.
|
2993
|
+
x : str
|
2994
|
+
Column name for x-axis values (e.g., log2FoldChange).
|
2995
|
+
y : str
|
2996
|
+
Column name for y-axis values (e.g., -log10(FDR)).
|
2997
|
+
gene_col : str, optional
|
2998
|
+
Column name for gene names. If provided, gene names will be displayed. Default is None.
|
2999
|
+
top_genes : int, optional
|
3000
|
+
Number of top genes to label based on y-axis values. Default is 5.
|
3001
|
+
thr_x : float, optional
|
3002
|
+
Threshold for x-axis values. Default is 0.585.
|
3003
|
+
thr_y : float, optional
|
3004
|
+
Threshold for y-axis values (e.g., significance threshold). Default is -np.log10(0.05).
|
3005
|
+
colors : tuple, optional
|
3006
|
+
Colors for points above/below thresholds and neutral points. Default is ("red", "blue", "gray").
|
3007
|
+
figsize : tuple, optional
|
3008
|
+
Figure size. Default is (6, 4).
|
3009
|
+
s : int, optional
|
3010
|
+
Size of points in the plot. Default is 20.
|
3011
|
+
fontsize : int, optional
|
3012
|
+
Font size for gene labels. Default is 10.
|
3013
|
+
alpha : float, optional
|
3014
|
+
Transparency of the points. Default is 0.8.
|
3015
|
+
legend : bool, optional
|
3016
|
+
Whether to show a legend. Default is False.
|
3017
|
+
"""
|
3018
|
+
usage_str = """
|
3019
|
+
_, axs = plt.subplots(1, 1, figsize=(4, 5))
|
3020
|
+
volcano(
|
3021
|
+
ax=axs,
|
3022
|
+
data=ranked_genes,
|
3023
|
+
x="log2(fold_change)",
|
3024
|
+
y="-log10(p-value)",
|
3025
|
+
gene_col="ID_REF",
|
3026
|
+
top_genes=6,
|
3027
|
+
thr_x=np.log2(1.2),
|
3028
|
+
# thr_y=-np.log10(0.05),
|
3029
|
+
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3030
|
+
fill=0,
|
3031
|
+
alpha=1,
|
3032
|
+
facecolor="none",
|
3033
|
+
s=20,
|
3034
|
+
edgelinewidth=0.5,
|
3035
|
+
edgecolor="0.5",
|
3036
|
+
kws_text=dict(fontsize=10, color="k"),
|
3037
|
+
kws_arrow=dict(style="-", color="k", lw=0.5),
|
3038
|
+
# usage=True,
|
3039
|
+
figsets=dict(ylim=[0, 10], title="df"),
|
3040
|
+
)
|
3041
|
+
"""
|
3042
|
+
if usage:
|
3043
|
+
print(usage_str)
|
3044
|
+
return
|
3045
|
+
from adjustText import adjust_text
|
3046
|
+
|
3047
|
+
kws_figsets = {}
|
3048
|
+
for k_arg, v_arg in kwargs.items():
|
3049
|
+
if "figset" in k_arg:
|
3050
|
+
kws_figsets = v_arg
|
3051
|
+
kwargs.pop(k_arg, None)
|
3052
|
+
break
|
3053
|
+
# Color-coding based on thresholds using np.where
|
3054
|
+
data["color"] = np.where(
|
3055
|
+
(data[x] > thr_x) & (data[y] > thr_y),
|
3056
|
+
colors[2],
|
3057
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3058
|
+
)
|
3059
|
+
|
3060
|
+
# Selecting top significant points for labeling
|
3061
|
+
sele_gene = (
|
3062
|
+
data.query("color != @colors[2]") # Exclude gray points
|
3063
|
+
.groupby("color", axis=0)
|
3064
|
+
.apply(lambda x: x.sort_values(y, ascending=False).head(top_genes))
|
3065
|
+
.droplevel(level=0)
|
3066
|
+
)
|
3067
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3068
|
+
# Plot setup
|
3069
|
+
if ax is None:
|
3070
|
+
ax = plt.gca()
|
3071
|
+
|
3072
|
+
# Handle fill parameter
|
3073
|
+
if fill:
|
3074
|
+
facecolors = data["color"] # Fill with colors
|
3075
|
+
edgecolors = edgecolor # Set edgecolor
|
3076
|
+
else:
|
3077
|
+
facecolors = facecolor # No fill, use edge color as the face color
|
3078
|
+
edgecolors = data["color"]
|
3079
|
+
|
3080
|
+
ax = sns.scatterplot(
|
3081
|
+
ax=ax,
|
3082
|
+
data=data,
|
3083
|
+
x=x,
|
3084
|
+
y=y,
|
3085
|
+
# hue="color",
|
3086
|
+
palette=palette,
|
3087
|
+
s=s,
|
3088
|
+
linewidths=edgelinewidth,
|
3089
|
+
color=facecolors,
|
3090
|
+
edgecolor=edgecolors,
|
3091
|
+
alpha=alpha,
|
3092
|
+
legend=legend,
|
3093
|
+
**kwargs,
|
3094
|
+
)
|
3095
|
+
|
3096
|
+
# Add threshold lines for x and y axes
|
3097
|
+
plt.axhline(y=thr_y, color="black", linestyle="--")
|
3098
|
+
plt.axvline(x=-thr_x, color="black", linestyle="--")
|
3099
|
+
plt.axvline(x=thr_x, color="black", linestyle="--")
|
3100
|
+
|
3101
|
+
# Add gene labels for selected significant points
|
3102
|
+
if gene_col:
|
3103
|
+
texts = []
|
3104
|
+
if kws_text:
|
3105
|
+
fontname = kws_text.pop("fontname", "Arial")
|
3106
|
+
textcolor = kws_text.pop("color", "k")
|
3107
|
+
fontsize = kws_text.pop("fontsize", 10)
|
3108
|
+
for i in range(sele_gene.shape[0]):
|
3109
|
+
if isinstance(textcolor, list): # be consistant with dots's color
|
3110
|
+
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
3111
|
+
texts.append(
|
3112
|
+
plt.text(
|
3113
|
+
x=sele_gene[x].iloc[i],
|
3114
|
+
y=sele_gene[y].iloc[i],
|
3115
|
+
s=sele_gene[gene_col].iloc[i],
|
3116
|
+
fontdict={
|
3117
|
+
"fontsize": fontsize,
|
3118
|
+
"color": textcolor,
|
3119
|
+
"fontname": fontname,
|
3120
|
+
},
|
3121
|
+
)
|
3122
|
+
)
|
3123
|
+
|
3124
|
+
arrowstyles = [
|
3125
|
+
"-",
|
3126
|
+
"->",
|
3127
|
+
"-[",
|
3128
|
+
"|->",
|
3129
|
+
"<-",
|
3130
|
+
"<->",
|
3131
|
+
"<|-",
|
3132
|
+
"<|-|>",
|
3133
|
+
"-|>",
|
3134
|
+
"-[ ",
|
3135
|
+
"fancy",
|
3136
|
+
"simple",
|
3137
|
+
"wedge",
|
3138
|
+
]
|
3139
|
+
arrowstyle = kws_arrow.pop("style", "-")
|
3140
|
+
arrowcolor = kws_arrow.pop("color", "0.5")
|
3141
|
+
arrowlinewidth = kws_arrow.pop("lw", 0.5)
|
3142
|
+
shrinkA = kws_arrow.pop("shrinkA", 5)
|
3143
|
+
shrinkB = kws_arrow.pop("shrinkB", 5)
|
3144
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles)[0]
|
3145
|
+
adjust_text(
|
3146
|
+
texts,
|
3147
|
+
expand_text=(1.05, 1.2),
|
3148
|
+
arrowprops=dict(
|
3149
|
+
arrowstyle=arrowstyle,
|
3150
|
+
color=arrowcolor,
|
3151
|
+
lw=arrowlinewidth,
|
3152
|
+
shrinkA=shrinkA,
|
3153
|
+
shrinkB=shrinkB,
|
3154
|
+
**kws_arrow,
|
3155
|
+
),
|
3156
|
+
)
|
3157
|
+
|
3158
|
+
figsets(**kws_figsets)
|
3159
|
+
|
3160
|
+
|
3161
|
+
def sns_func_info(dir_save=None):
|
3162
|
+
sns_info = {
|
3163
|
+
"Functions": [
|
3164
|
+
"relplot",
|
3165
|
+
"scatterplot",
|
3166
|
+
"lineplot",
|
3167
|
+
"lmplot",
|
3168
|
+
"catplot",
|
3169
|
+
"stripplot",
|
3170
|
+
"boxplot",
|
3171
|
+
"violinplot",
|
3172
|
+
"boxenplot",
|
3173
|
+
"pointplot",
|
3174
|
+
"barplot",
|
3175
|
+
"countplot",
|
3176
|
+
"displot",
|
3177
|
+
"histplot",
|
3178
|
+
"kdeplot",
|
3179
|
+
"ecdfplot",
|
3180
|
+
"rugplot",
|
3181
|
+
"regplot",
|
3182
|
+
"residplot",
|
3183
|
+
"pairplot",
|
3184
|
+
"jointplot",
|
3185
|
+
"plotting_context",
|
3186
|
+
],
|
3187
|
+
"Category": [
|
3188
|
+
"relational",
|
3189
|
+
"relational",
|
3190
|
+
"relational",
|
3191
|
+
"relational",
|
3192
|
+
"categorical",
|
3193
|
+
"categorical",
|
3194
|
+
"categorical",
|
3195
|
+
"categorical",
|
3196
|
+
"categorical",
|
3197
|
+
"categorical",
|
3198
|
+
"categorical",
|
3199
|
+
"categorical",
|
3200
|
+
"distribution",
|
3201
|
+
"distribution",
|
3202
|
+
"distribution",
|
3203
|
+
"distribution",
|
3204
|
+
"distribution",
|
3205
|
+
"regression",
|
3206
|
+
"regression",
|
3207
|
+
"grid-based(fig)",
|
3208
|
+
"grid-based(fig)",
|
3209
|
+
"context",
|
3210
|
+
],
|
3211
|
+
"Detail": [
|
3212
|
+
"A figure-level function for creating scatter plots and line plots. It combines the functionality of scatterplot and lineplot.",
|
3213
|
+
"A function for creating scatter plots, useful for visualizing the relationship between two continuous variables.",
|
3214
|
+
"A function for drawing line plots, often used to visualize trends over time or ordered categories.",
|
3215
|
+
"A figure-level function for creating linear model plots, combining regression lines with scatter plots.",
|
3216
|
+
"A figure-level function for creating categorical plots, which can display various types of plots like box plots, violin plots, and bar plots in one function.",
|
3217
|
+
"A function for creating a scatter plot where one of the variables is categorical, helping visualize distribution along a categorical axis.",
|
3218
|
+
"A function for creating box plots, which summarize the distribution of a continuous variable based on a categorical variable.",
|
3219
|
+
"A function for creating violin plots, which combine box plots and KDEs to visualize the distribution of data.",
|
3220
|
+
"A function for creating boxen plots, an enhanced version of box plots that better represent data distributions with more quantiles.",
|
3221
|
+
"A function for creating point plots, which show the mean (or another estimator) of a variable for each level of a categorical variable.",
|
3222
|
+
"A function for creating bar plots, which represent the mean (or other estimators) of a variable with bars, typically used with categorical data.",
|
3223
|
+
"A function for creating count plots, which show the counts of observations in each categorical bin.",
|
3224
|
+
"A figure-level function that creates distribution plots. It can visualize histograms, KDEs, and ECDFs, making it versatile for analyzing the distribution of data.",
|
3225
|
+
"A function for creating histograms, useful for showing the frequency distribution of a continuous variable.",
|
3226
|
+
"A function for creating kernel density estimate (KDE) plots, which visualize the probability density function of a continuous variable.",
|
3227
|
+
"A function for creating empirical cumulative distribution function (ECDF) plots, which show the proportion of observations below a certain value.",
|
3228
|
+
"A function that adds a rug plot to the axes, representing individual data points along an axis.",
|
3229
|
+
"A function for creating regression plots, which fit and visualize a regression model on scatter data.",
|
3230
|
+
"A function for creating residual plots, useful for diagnosing the fit of a regression model.",
|
3231
|
+
"A figure-level function that creates a grid of scatter plots for each pair of variables in a dataset, often used for exploratory data analysis.",
|
3232
|
+
"A figure-level function that combines scatter plots and histograms (or KDEs) to visualize the relationship between two variables and their distributions.",
|
3233
|
+
"Not a plot itself, but a function that allows you to change the context (style and scaling) of your plots to fit different publication requirements or visual preferences.",
|
3234
|
+
],
|
3235
|
+
}
|
3236
|
+
if dir_save is None:
|
3237
|
+
if "mac" in get_os():
|
3238
|
+
dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
|
3239
|
+
else:
|
3240
|
+
dir_save = "Z:\\Jianfeng\\temp\\"
|
3241
|
+
dir_save += "/" if not dir_save.endswith("/") else ""
|
3242
|
+
fsave(
|
3243
|
+
dir_save + "sns_info.json",
|
3244
|
+
sns_info,
|
3245
|
+
)
|