py2ls 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/data/sns_info.json +74 -0
- py2ls/data/usages_sns.json +25 -0
- py2ls/ips.py +1204 -505
- py2ls/plot.py +808 -30
- py2ls/stats.py +18 -9
- {py2ls-0.2.1.dist-info → py2ls-0.2.2.dist-info}/METADATA +1 -1
- {py2ls-0.2.1.dist-info → py2ls-0.2.2.dist-info}/RECORD +8 -6
- {py2ls-0.2.1.dist-info → py2ls-0.2.2.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -9,24 +9,97 @@ import matplotlib.ticker as tck
|
|
9
9
|
from cycler import cycler
|
10
10
|
import logging
|
11
11
|
import os
|
12
|
+
import re
|
12
13
|
|
13
|
-
from .ips import fsave, fload, mkdir, listdir, figsave
|
14
|
+
from .ips import fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit
|
14
15
|
from .stats import *
|
16
|
+
from .netfinder import get_soup, fetch
|
17
|
+
|
15
18
|
|
16
19
|
# Suppress INFO messages from fontTools
|
17
20
|
logging.getLogger("fontTools").setLevel(logging.WARNING)
|
18
21
|
|
19
22
|
|
20
|
-
def
|
21
|
-
|
22
|
-
|
23
|
+
def update_sns_usages(
|
24
|
+
url="https://seaborn.pydata.org/generated/seaborn.swarmplot.html",
|
25
|
+
dir_save=None,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Fetches usage examples of various Seaborn plotting functions from the Seaborn documentation website.
|
29
|
+
It filters the relevant plot-related links, extracts usage examples, and saves them in a JSON file.
|
30
|
+
|
31
|
+
Parameters:
|
32
|
+
- url (str): URL of the Seaborn page to start extracting plot usages (default is swarmplot page).
|
33
|
+
- dir_save (str): Directory where the JSON file containing usages will be saved (default is a local path).
|
34
|
+
|
35
|
+
Saves:
|
36
|
+
- A JSON file named 'usages_sns.json' containing plotting function names and their usage descriptions.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
- None
|
40
|
+
"""
|
41
|
+
|
42
|
+
# extract each usage from its url
|
43
|
+
def get_usage(url):
|
44
|
+
sp = get_soup(url, driver="se")
|
45
|
+
# preview(sp)
|
46
|
+
return fetch(sp, where="dt")[0]
|
47
|
+
|
48
|
+
if dir_save is None:
|
49
|
+
if "mac" in get_os():
|
50
|
+
dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
|
51
|
+
else:
|
52
|
+
dir_save = "Z:\\Jianfeng\\temp\\"
|
53
|
+
sp = get_soup(url, driver="se")
|
54
|
+
links_all = fetch(sp, where="a", get="href", class_="reference internal")
|
55
|
+
filtered_links = unique(
|
56
|
+
[
|
57
|
+
i
|
58
|
+
for i in links_all
|
59
|
+
if not any(
|
60
|
+
[
|
61
|
+
i.startswith(cond)
|
62
|
+
for cond in [
|
63
|
+
"seaborn.JointGrid",
|
64
|
+
"seaborn.PairGrid",
|
65
|
+
"seaborn.objects",
|
66
|
+
]
|
67
|
+
]
|
68
|
+
+ ["plot" not in i]
|
69
|
+
)
|
70
|
+
]
|
71
|
+
)
|
72
|
+
links = ["https://seaborn.pydata.org/generated/" + i for i in filtered_links]
|
73
|
+
|
74
|
+
usages = [get_usage(i) for i in links]
|
75
|
+
dict_usage = {}
|
76
|
+
for usage in usages:
|
77
|
+
dict_usage.update(
|
78
|
+
{ssplit(usage, by="(")[0].replace("seaborn.", ""): usage[:-1]}
|
79
|
+
)
|
80
|
+
# save to local
|
81
|
+
dir_save += "/" if not dir_save.endswith("/") else ""
|
82
|
+
fsave(
|
83
|
+
dir_save + "usages_sns.json",
|
84
|
+
dict_usage,
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
def heatmap(
|
89
|
+
data,
|
90
|
+
ax=None,
|
91
|
+
kind="corr", #'corr','direct','pivot'
|
92
|
+
columns="all", # pivot, default: coll numeric columns
|
93
|
+
index=None, # pivot
|
94
|
+
values=None, # pivot
|
23
95
|
tri="u",
|
24
96
|
mask=True,
|
25
97
|
k=1,
|
26
98
|
annot=True,
|
27
99
|
cmap="coolwarm",
|
28
100
|
fmt=".2f",
|
29
|
-
cluster=False,
|
101
|
+
cluster=False,
|
102
|
+
inplace=False,
|
30
103
|
figsize=(10, 8),
|
31
104
|
row_cluster=True, # Perform clustering on rows
|
32
105
|
col_cluster=True, # Perform clustering on columns
|
@@ -36,24 +109,115 @@ def df_corr(
|
|
36
109
|
yticklabels=True, # Show row labels
|
37
110
|
**kwargs,
|
38
111
|
):
|
112
|
+
if ax is None and not cluster:
|
113
|
+
ax = plt.gca()
|
39
114
|
# Select numeric columns or specific subset of columns
|
40
115
|
if columns == "all":
|
41
|
-
df_numeric =
|
116
|
+
df_numeric = data.select_dtypes(include=[float, int])
|
42
117
|
else:
|
43
|
-
df_numeric =
|
44
|
-
|
45
|
-
|
46
|
-
|
118
|
+
df_numeric = data[columns]
|
119
|
+
|
120
|
+
kinds = ["corr", "direct", "pivot"]
|
121
|
+
kind = strcmp(kind, kinds)[0]
|
122
|
+
if kind == "corr":
|
123
|
+
# Compute the correlation matrix
|
124
|
+
data4heatmap = df_numeric.corr()
|
125
|
+
# Generate mask for the upper triangle if mask is True
|
126
|
+
if mask:
|
127
|
+
if "u" in tri.lower(): # upper => np.tril
|
128
|
+
mask_array = np.tril(np.ones_like(data4heatmap, dtype=bool), k=k)
|
129
|
+
else: # lower => np.triu
|
130
|
+
mask_array = np.triu(np.ones_like(data4heatmap, dtype=bool), k=k)
|
131
|
+
else:
|
132
|
+
mask_array = None
|
133
|
+
|
134
|
+
# Remove conflicting kwargs
|
135
|
+
kwargs.pop("mask", None)
|
136
|
+
kwargs.pop("annot", None)
|
137
|
+
kwargs.pop("cmap", None)
|
138
|
+
kwargs.pop("fmt", None)
|
139
|
+
|
140
|
+
kwargs.pop("clustermap", None)
|
141
|
+
kwargs.pop("row_cluster", None)
|
142
|
+
kwargs.pop("col_cluster", None)
|
143
|
+
kwargs.pop("dendrogram_ratio", None)
|
144
|
+
kwargs.pop("cbar_pos", None)
|
145
|
+
kwargs.pop("xticklabels", None)
|
146
|
+
kwargs.pop("col_cluster", None)
|
147
|
+
|
148
|
+
# Plot the heatmap or clustermap
|
149
|
+
if cluster:
|
150
|
+
# Create a clustermap
|
151
|
+
cluster_obj = sns.clustermap(
|
152
|
+
data4heatmap,
|
153
|
+
# ax=ax,
|
154
|
+
mask=mask_array,
|
155
|
+
annot=annot,
|
156
|
+
cmap=cmap,
|
157
|
+
fmt=fmt,
|
158
|
+
figsize=figsize, # Figure size, adjusted for professional display
|
159
|
+
row_cluster=row_cluster, # Perform clustering on rows
|
160
|
+
col_cluster=col_cluster, # Perform clustering on columns
|
161
|
+
dendrogram_ratio=dendrogram_ratio, # Adjust size of dendrograms
|
162
|
+
cbar_pos=cbar_pos, # Adjust colorbar position
|
163
|
+
xticklabels=xticklabels, # Show column labels
|
164
|
+
yticklabels=yticklabels, # Show row labels
|
165
|
+
**kwargs, # Pass any additional arguments to sns.clustermap
|
166
|
+
)
|
167
|
+
df_row_cluster = pd.DataFrame()
|
168
|
+
df_col_cluster = pd.DataFrame()
|
169
|
+
if row_cluster:
|
170
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
171
|
+
from scipy.spatial.distance import pdist
|
172
|
+
|
173
|
+
# Compute pairwise distances
|
174
|
+
distances = pdist(data, metric="euclidean")
|
175
|
+
# Perform hierarchical clustering
|
176
|
+
linkage_matrix = linkage(distances, method="average")
|
177
|
+
# Get cluster assignments based on the distance threshold
|
178
|
+
row_clusters_value = fcluster(
|
179
|
+
linkage_matrix, t=1.5, criterion="distance"
|
180
|
+
)
|
181
|
+
df_row_cluster["row_cluster"] = row_clusters_value
|
182
|
+
if col_cluster:
|
183
|
+
col_distances = pdist(
|
184
|
+
data4heatmap.T, metric="euclidean"
|
185
|
+
) # Transpose for column clustering
|
186
|
+
col_linkage_matrix = linkage(col_distances, method="average")
|
187
|
+
col_clusters_value = fcluster(
|
188
|
+
col_linkage_matrix, t=1.5, criterion="distance"
|
189
|
+
)
|
190
|
+
df_col_cluster = pd.DataFrame(
|
191
|
+
{"Cluster": col_clusters_value}, index=data4heatmap.columns
|
192
|
+
)
|
47
193
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
194
|
+
return (
|
195
|
+
cluster_obj.ax_row_dendrogram,
|
196
|
+
cluster_obj.ax_col_dendrogram,
|
197
|
+
cluster_obj.ax_heatmap,
|
198
|
+
df_row_cluster,
|
199
|
+
df_col_cluster,
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
# Create a standard heatmap
|
203
|
+
ax = sns.heatmap(
|
204
|
+
data4heatmap,
|
205
|
+
ax=ax,
|
206
|
+
mask=mask_array,
|
207
|
+
annot=annot,
|
208
|
+
cmap=cmap,
|
209
|
+
fmt=fmt,
|
210
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
211
|
+
)
|
212
|
+
# Return the Axes object for further customization if needed
|
213
|
+
return ax
|
214
|
+
elif kind == "direct":
|
215
|
+
data4heatmap = df_numeric
|
216
|
+
elif kind == "pivot":
|
217
|
+
print('need 3 param: e.g., index="Task", columns="Model", values="Score"')
|
218
|
+
data4heatmap = data.pivot(index=index, columns=columns, values=values)
|
54
219
|
else:
|
55
|
-
|
56
|
-
|
220
|
+
print(f'"{kind}" is not supported')
|
57
221
|
# Remove conflicting kwargs
|
58
222
|
kwargs.pop("mask", None)
|
59
223
|
kwargs.pop("annot", None)
|
@@ -72,8 +236,9 @@ def df_corr(
|
|
72
236
|
if cluster:
|
73
237
|
# Create a clustermap
|
74
238
|
cluster_obj = sns.clustermap(
|
75
|
-
|
76
|
-
|
239
|
+
data4heatmap,
|
240
|
+
# ax=ax,
|
241
|
+
# mask=mask_array,
|
77
242
|
annot=annot,
|
78
243
|
cmap=cmap,
|
79
244
|
fmt=fmt,
|
@@ -86,18 +251,43 @@ def df_corr(
|
|
86
251
|
yticklabels=yticklabels, # Show row labels
|
87
252
|
**kwargs, # Pass any additional arguments to sns.clustermap
|
88
253
|
)
|
254
|
+
df_row_cluster = pd.DataFrame()
|
255
|
+
df_col_cluster = pd.DataFrame()
|
256
|
+
if row_cluster:
|
257
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
258
|
+
from scipy.spatial.distance import pdist
|
259
|
+
|
260
|
+
# Compute pairwise distances
|
261
|
+
distances = pdist(data, metric="euclidean")
|
262
|
+
# Perform hierarchical clustering
|
263
|
+
linkage_matrix = linkage(distances, method="average")
|
264
|
+
# Get cluster assignments based on the distance threshold
|
265
|
+
row_clusters_value = fcluster(linkage_matrix, t=1.5, criterion="distance")
|
266
|
+
df_row_cluster["row_cluster"] = row_clusters_value
|
267
|
+
if col_cluster:
|
268
|
+
col_distances = pdist(
|
269
|
+
data4heatmap.T, metric="euclidean"
|
270
|
+
) # Transpose for column clustering
|
271
|
+
col_linkage_matrix = linkage(col_distances, method="average")
|
272
|
+
col_clusters_value = fcluster(
|
273
|
+
col_linkage_matrix, t=1.5, criterion="distance"
|
274
|
+
)
|
275
|
+
df_col_cluster = pd.DataFrame(
|
276
|
+
{"Cluster": col_clusters_value}, index=data4heatmap.columns
|
277
|
+
)
|
89
278
|
|
90
279
|
return (
|
91
280
|
cluster_obj.ax_row_dendrogram,
|
92
281
|
cluster_obj.ax_col_dendrogram,
|
93
282
|
cluster_obj.ax_heatmap,
|
283
|
+
df_row_cluster,
|
284
|
+
df_col_cluster,
|
94
285
|
)
|
95
286
|
else:
|
96
287
|
# Create a standard heatmap
|
97
|
-
plt.figure(figsize=figsize)
|
98
288
|
ax = sns.heatmap(
|
99
|
-
|
100
|
-
|
289
|
+
data4heatmap,
|
290
|
+
ax=ax,
|
101
291
|
annot=annot,
|
102
292
|
cmap=cmap,
|
103
293
|
fmt=fmt,
|
@@ -107,6 +297,60 @@ def df_corr(
|
|
107
297
|
return ax
|
108
298
|
|
109
299
|
|
300
|
+
# !usage: py2ls.plot.heatmap()
|
301
|
+
# penguins_clean = penguins.replace([np.inf, -np.inf], np.nan).dropna()
|
302
|
+
# from py2ls import plot
|
303
|
+
|
304
|
+
# _, axs = plt.subplots(2, 2, figsize=(10, 10))
|
305
|
+
# # kind='pivot'
|
306
|
+
# plot.heatmap(
|
307
|
+
# ax=axs[0][0],
|
308
|
+
# data=sns.load_dataset("glue"),
|
309
|
+
# kind="pi",
|
310
|
+
# index="Model",
|
311
|
+
# columns="Task",
|
312
|
+
# values="Score",
|
313
|
+
# fmt=".1f",
|
314
|
+
# cbar_kws=dict(shrink=1),
|
315
|
+
# annot_kws=dict(size=7),
|
316
|
+
# )
|
317
|
+
# # kind='direct'
|
318
|
+
# plot.heatmap(
|
319
|
+
# ax=axs[0][1],
|
320
|
+
# data=sns.load_dataset("penguins").iloc[:10, 2:6],
|
321
|
+
# kind="direct",
|
322
|
+
# tri="lower",
|
323
|
+
# fmt=".1f",
|
324
|
+
# k=1,
|
325
|
+
# cbar_kws=dict(shrink=1),
|
326
|
+
# annot_kws=dict(size=7),
|
327
|
+
# )
|
328
|
+
|
329
|
+
# # kind='corr'
|
330
|
+
# plot.heatmap(
|
331
|
+
# ax=axs[1][0],
|
332
|
+
# data=sns.load_dataset("penguins"),
|
333
|
+
# kind="corr",
|
334
|
+
# fmt=".1f",
|
335
|
+
# k=-1,
|
336
|
+
# cbar_kws=dict(shrink=1),
|
337
|
+
# annot_kws=dict(size=7),
|
338
|
+
# )
|
339
|
+
# # kind='corr'
|
340
|
+
# plot.heatmap(
|
341
|
+
# ax=axs[1][1],
|
342
|
+
# data=penguins_clean.iloc[:15, :10],
|
343
|
+
# kind="direct",
|
344
|
+
# tri="lower",
|
345
|
+
# fmt=".1f",
|
346
|
+
# k=1,
|
347
|
+
# annot=False,
|
348
|
+
# cluster=True,
|
349
|
+
# cbar_kws=dict(shrink=1),
|
350
|
+
# annot_kws=dict(size=7),
|
351
|
+
# )
|
352
|
+
|
353
|
+
|
110
354
|
def catplot(data, *args, **kwargs):
|
111
355
|
"""
|
112
356
|
catplot(data, opt=None, ax=None)
|
@@ -1524,6 +1768,10 @@ def figsets(*args, **kwargs):
|
|
1524
1768
|
alignment='left')
|
1525
1769
|
)
|
1526
1770
|
"""
|
1771
|
+
import matplotlib
|
1772
|
+
|
1773
|
+
matplotlib.rc("text", usetex=False)
|
1774
|
+
|
1527
1775
|
fig = plt.gcf()
|
1528
1776
|
fontsize = 11
|
1529
1777
|
fontname = "Arial"
|
@@ -1615,6 +1863,16 @@ def figsets(*args, **kwargs):
|
|
1615
1863
|
if isinstance(value, list):
|
1616
1864
|
loc = []
|
1617
1865
|
for i in value:
|
1866
|
+
ax.tick_params(
|
1867
|
+
axis="both",
|
1868
|
+
which="both",
|
1869
|
+
bottom=False,
|
1870
|
+
top=False,
|
1871
|
+
left=False,
|
1872
|
+
right=False,
|
1873
|
+
labelbottom=False,
|
1874
|
+
labelleft=False,
|
1875
|
+
)
|
1618
1876
|
if ("l" in i.lower()) and ("a" not in i.lower()):
|
1619
1877
|
ax.yaxis.set_ticks_position("left")
|
1620
1878
|
if "r" in i.lower():
|
@@ -1624,12 +1882,38 @@ def figsets(*args, **kwargs):
|
|
1624
1882
|
if "b" in i.lower():
|
1625
1883
|
ax.xaxis.set_ticks_position("bottom")
|
1626
1884
|
if i.lower() in ["a", "both", "all", "al", ":"]:
|
1627
|
-
ax.
|
1628
|
-
|
1885
|
+
ax.tick_params(
|
1886
|
+
axis="both", # Apply to both axes
|
1887
|
+
which="both", # Apply to both major and minor ticks
|
1888
|
+
bottom=True, # Show ticks at the bottom
|
1889
|
+
top=True, # Show ticks at the top
|
1890
|
+
left=True, # Show ticks on the left
|
1891
|
+
right=True, # Show ticks on the right
|
1892
|
+
labelbottom=True, # Show labels at the bottom
|
1893
|
+
labelleft=True, # Show labels on the left
|
1894
|
+
)
|
1629
1895
|
if i.lower() in ["xnone", "xoff", "none"]:
|
1630
|
-
ax.
|
1896
|
+
ax.tick_params(
|
1897
|
+
axis="x",
|
1898
|
+
which="both",
|
1899
|
+
bottom=False,
|
1900
|
+
top=False,
|
1901
|
+
left=False,
|
1902
|
+
right=False,
|
1903
|
+
labelbottom=False,
|
1904
|
+
labelleft=False,
|
1905
|
+
)
|
1631
1906
|
if i.lower() in ["ynone", "yoff", "none"]:
|
1632
|
-
ax.
|
1907
|
+
ax.tick_params(
|
1908
|
+
axis="y",
|
1909
|
+
which="both",
|
1910
|
+
bottom=False,
|
1911
|
+
top=False,
|
1912
|
+
left=False,
|
1913
|
+
right=False,
|
1914
|
+
labelbottom=False,
|
1915
|
+
labelleft=False,
|
1916
|
+
)
|
1633
1917
|
# ticks / labels
|
1634
1918
|
elif "x" in key.lower():
|
1635
1919
|
if value is None:
|
@@ -1674,6 +1958,10 @@ def figsets(*args, **kwargs):
|
|
1674
1958
|
|
1675
1959
|
if "bo" in key in key: # box setting, and ("p" in key or "l" in key):
|
1676
1960
|
if isinstance(value, (str, list)):
|
1961
|
+
# locations = ["left", "right", "top", "bottom"]
|
1962
|
+
# for loc, spi in ax.spines.items():
|
1963
|
+
# if loc in locations:
|
1964
|
+
# spi.set_color("none") # no spine
|
1677
1965
|
locations = []
|
1678
1966
|
for i in value:
|
1679
1967
|
if "l" in i.lower() and not "t" in i.lower():
|
@@ -1689,12 +1977,12 @@ def figsets(*args, **kwargs):
|
|
1689
1977
|
locations.append(x)
|
1690
1978
|
for x in ["left", "right", "top", "bottom"]
|
1691
1979
|
]
|
1692
|
-
|
1693
|
-
|
1694
|
-
locations = []
|
1980
|
+
if "none" in value:
|
1981
|
+
locations = [] # hide all
|
1695
1982
|
# check spines
|
1696
1983
|
for loc, spi in ax.spines.items():
|
1697
1984
|
if loc in locations:
|
1985
|
+
# spi.set_color("k")
|
1698
1986
|
spi.set_position(("outward", 0))
|
1699
1987
|
else:
|
1700
1988
|
spi.set_color("none") # no spine
|
@@ -2527,3 +2815,493 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, usage=F
|
|
2527
2815
|
plt.tight_layout()
|
2528
2816
|
if show:
|
2529
2817
|
plt.show()
|
2818
|
+
|
2819
|
+
|
2820
|
+
def get_params_from_func_usage(function_signature):
|
2821
|
+
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2822
|
+
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2823
|
+
# Find all matches
|
2824
|
+
matches = re.findall(keys_pattern, function_signature)
|
2825
|
+
return matches
|
2826
|
+
|
2827
|
+
|
2828
|
+
def plot_xy(
|
2829
|
+
data: pd.DataFrame = None,
|
2830
|
+
x=None,
|
2831
|
+
y=None,
|
2832
|
+
ax=None,
|
2833
|
+
kind: str = None, # Specify the kind of plot
|
2834
|
+
usage=False,
|
2835
|
+
# kws_figsets:dict=None,
|
2836
|
+
**kwargs,
|
2837
|
+
):
|
2838
|
+
"""
|
2839
|
+
e.g., plot_xy(data=data_log, x="Component_1", y="Component_2", hue="Cluster",kind='scater)
|
2840
|
+
Create a variety of plots based on the kind parameter.
|
2841
|
+
|
2842
|
+
Parameters:
|
2843
|
+
data (pd.DataFrame): DataFrame containing the data.
|
2844
|
+
x (str): Column name for the x-axis.
|
2845
|
+
y (str): Column name for the y-axis.
|
2846
|
+
hue (str): Column name for the hue (color) grouping.
|
2847
|
+
ax: Matplotlib axes object for the plot.
|
2848
|
+
kind (str): Type of plot ('scatter', 'line', 'displot', 'kdeplot', etc.).
|
2849
|
+
usage (bool): If True, print default settings instead of plotting.
|
2850
|
+
**kwargs: Additional keyword arguments for the plot functions.
|
2851
|
+
|
2852
|
+
Returns:
|
2853
|
+
ax or FacetGrid: Matplotlib axes object or FacetGrid for displot.
|
2854
|
+
"""
|
2855
|
+
# Check for valid plot kind
|
2856
|
+
# Default arguments for various plot types
|
2857
|
+
default_settings = fload(
|
2858
|
+
"/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_sns.json"
|
2859
|
+
)
|
2860
|
+
sns_info = pd.DataFrame(
|
2861
|
+
fload(
|
2862
|
+
"/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/sns_info.json",
|
2863
|
+
)
|
2864
|
+
)
|
2865
|
+
valid_kinds = list(default_settings.keys())
|
2866
|
+
print(valid_kinds)
|
2867
|
+
if kind is not None:
|
2868
|
+
if isinstance(kind, str):
|
2869
|
+
kind = [kind]
|
2870
|
+
kind = [strcmp(i, valid_kinds)[0] for i in kind]
|
2871
|
+
else:
|
2872
|
+
usage = True
|
2873
|
+
if usage:
|
2874
|
+
if kind is not None:
|
2875
|
+
for k in kind:
|
2876
|
+
if k in valid_kinds:
|
2877
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
2878
|
+
print(
|
2879
|
+
sns_info[sns_info["Functions"].str.contains(k)]
|
2880
|
+
.iloc[:, -1]
|
2881
|
+
.tolist()[0]
|
2882
|
+
)
|
2883
|
+
print()
|
2884
|
+
usage_str = """plot_xy(data=ranked_genes,
|
2885
|
+
x="log2(fold_change)",
|
2886
|
+
y="-log10(p-value)",
|
2887
|
+
palette=get_color(3, cmap="coolwarm"),
|
2888
|
+
kind=["scatter","rug"],
|
2889
|
+
kws_rug=dict(height=0.2),
|
2890
|
+
kws_scatter=dict(s=20, color=get_color(3)[2]),
|
2891
|
+
usage=0)
|
2892
|
+
"""
|
2893
|
+
print(f"currently support to plot:\n{valid_kinds}\n\nusage:\n{usage_str}")
|
2894
|
+
return # Do not plot, just print the usage
|
2895
|
+
|
2896
|
+
kws_figsets = {}
|
2897
|
+
for k_arg, v_arg in kwargs.items():
|
2898
|
+
if "figset" in k_arg:
|
2899
|
+
kws_figsets = v_arg
|
2900
|
+
kwargs.pop(k_arg, None)
|
2901
|
+
break
|
2902
|
+
|
2903
|
+
for k in kind:
|
2904
|
+
# indicate 'col' features
|
2905
|
+
col = kwargs.get("col", None)
|
2906
|
+
sns_with_col = [
|
2907
|
+
"catplot",
|
2908
|
+
"histplot",
|
2909
|
+
"relplot",
|
2910
|
+
"lmplot",
|
2911
|
+
"pairplot",
|
2912
|
+
"displot",
|
2913
|
+
"kdeplot",
|
2914
|
+
]
|
2915
|
+
if col is not None:
|
2916
|
+
if not k in sns_with_col:
|
2917
|
+
print(
|
2918
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
2919
|
+
)
|
2920
|
+
# (1) return FcetGrid
|
2921
|
+
if k == "jointplot":
|
2922
|
+
kws_joint = kwargs.pop("kws_joint", kwargs)
|
2923
|
+
g = sns.jointplot(data=data, x=x, y=y, hue=hue, **kws_joint)
|
2924
|
+
elif k == "lmplot":
|
2925
|
+
kws_lm = kwargs.pop("kws_lm", kwargs)
|
2926
|
+
g = sns.lmplot(data=data, x=x, y=y, hue=hue, **kws_lm)
|
2927
|
+
elif k == "catplot_sns":
|
2928
|
+
kws_cat = kwargs.pop("kws_cat", kwargs)
|
2929
|
+
g = sns.catplot(data=data, x=x, y=y, hue=hue, **kws_cat)
|
2930
|
+
elif k == "displot":
|
2931
|
+
kws_dis = kwargs.pop("kws_dis", kwargs)
|
2932
|
+
# displot creates a new figure and returns a FacetGrid
|
2933
|
+
g = sns.displot(data=data, x=x, hue=hue, **kws_dis)
|
2934
|
+
|
2935
|
+
# (2) return axis
|
2936
|
+
if ax is None:
|
2937
|
+
ax = plt.gca()
|
2938
|
+
|
2939
|
+
if k == "catplot":
|
2940
|
+
kws_cat = kwargs.pop("kws_cat", kwargs)
|
2941
|
+
g = catplot(data=data, x=x, y=y, ax=ax, **kws_cat)
|
2942
|
+
elif k == "scatterplot":
|
2943
|
+
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
2944
|
+
palette = kws_scatter.pop(
|
2945
|
+
"palette",
|
2946
|
+
(
|
2947
|
+
sns.color_palette("tab10", data[hue].nunique())
|
2948
|
+
if hue is not None
|
2949
|
+
else sns.color_palette("tab10")
|
2950
|
+
),
|
2951
|
+
)
|
2952
|
+
s = kws_scatter.pop("s", 10)
|
2953
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
2954
|
+
ax = sns.scatterplot(
|
2955
|
+
ax=ax,
|
2956
|
+
data=data,
|
2957
|
+
x=x,
|
2958
|
+
y=y,
|
2959
|
+
hue=hue,
|
2960
|
+
palette=palette,
|
2961
|
+
s=s,
|
2962
|
+
alpha=alpha,
|
2963
|
+
**kws_scatter,
|
2964
|
+
)
|
2965
|
+
elif k == "histplot":
|
2966
|
+
kws_hist = kwargs.pop("kws_hist", kwargs)
|
2967
|
+
ax = sns.histplot(data=data, x=x, hue=hue, ax=ax, **kws_hist)
|
2968
|
+
elif k == "kdeplot":
|
2969
|
+
kws_kde = kwargs.pop("kws_kde", kwargs)
|
2970
|
+
ax = sns.kdeplot(data=data, x=x, hue=hue, ax=ax, **kws_kde)
|
2971
|
+
elif k == "ecdfplot":
|
2972
|
+
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
2973
|
+
ax = sns.ecdfplot(data=data, x=x, hue=hue, ax=ax, **kws_ecdf)
|
2974
|
+
elif k == "rugplot":
|
2975
|
+
kws_rug = kwargs.pop("kws_rug", kwargs)
|
2976
|
+
print(kws_rug)
|
2977
|
+
ax = sns.rugplot(data=data, x=x, hue=hue, ax=ax, **kws_rug)
|
2978
|
+
elif k == "stripplot":
|
2979
|
+
kws_strip = kwargs.pop("kws_strip", kwargs)
|
2980
|
+
ax = sns.stripplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_strip)
|
2981
|
+
elif k == "swarmplot":
|
2982
|
+
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
2983
|
+
ax = sns.swarmplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_swarm)
|
2984
|
+
elif k == "boxplot":
|
2985
|
+
kws_box = kwargs.pop("kws_box", kwargs)
|
2986
|
+
ax = sns.boxplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_box)
|
2987
|
+
elif k == "violinplot":
|
2988
|
+
kws_violin = kwargs.pop("kws_violin", kwargs)
|
2989
|
+
ax = sns.violinplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_violin)
|
2990
|
+
elif k == "boxenplot":
|
2991
|
+
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
2992
|
+
ax = sns.boxenplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_boxen)
|
2993
|
+
elif k == "pointplot":
|
2994
|
+
kws_point = kwargs.pop("kws_point", kwargs)
|
2995
|
+
ax = sns.pointplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_point)
|
2996
|
+
elif k == "barplot":
|
2997
|
+
kws_bar = kwargs.pop("kws_bar", kwargs)
|
2998
|
+
ax = sns.barplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_bar)
|
2999
|
+
elif k == "countplot":
|
3000
|
+
kws_count = kwargs.pop("kws_count", kwargs)
|
3001
|
+
ax = sns.countplot(data=data, x=x, hue=hue, ax=ax, **kws_count)
|
3002
|
+
elif k == "regplot":
|
3003
|
+
kws_reg = kwargs.pop("kws_reg", kwargs)
|
3004
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
|
3005
|
+
elif k == "residplot":
|
3006
|
+
kws_resid = kwargs.pop("kws_resid", kwargs)
|
3007
|
+
ax = sns.residplot(data=data, x=x, y=y, lowess=True, ax=ax, **kws_resid)
|
3008
|
+
elif k == "lineplot":
|
3009
|
+
kws_line = kwargs.pop("kws_line", kwargs)
|
3010
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, hue=hue, **kws_line)
|
3011
|
+
|
3012
|
+
figsets(**kws_figsets)
|
3013
|
+
print(k, " ⤵ ")
|
3014
|
+
print(default_settings[k])
|
3015
|
+
print(
|
3016
|
+
"=>\t",
|
3017
|
+
sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],
|
3018
|
+
)
|
3019
|
+
print()
|
3020
|
+
if "g" in locals():
|
3021
|
+
if ax is not None:
|
3022
|
+
return g, ax
|
3023
|
+
return ax
|
3024
|
+
|
3025
|
+
|
3026
|
+
def volcano(
|
3027
|
+
data,
|
3028
|
+
x,
|
3029
|
+
y,
|
3030
|
+
gene_col=None,
|
3031
|
+
top_genes=5,
|
3032
|
+
thr_x=np.log2(1.5),
|
3033
|
+
thr_y=-np.log10(0.05),
|
3034
|
+
colors=("#e70b0b", "#0d26e3", "#b8bbbe"),
|
3035
|
+
s=20,
|
3036
|
+
fill=True, # plot filled scatter
|
3037
|
+
facecolor="none",
|
3038
|
+
edgecolor="none",
|
3039
|
+
edgelinewidth=0.5,
|
3040
|
+
alpha=0.8,
|
3041
|
+
legend=False,
|
3042
|
+
ax=None,
|
3043
|
+
usage=False,
|
3044
|
+
kws_arrow=None,
|
3045
|
+
kws_text=None,
|
3046
|
+
**kwargs,
|
3047
|
+
):
|
3048
|
+
"""
|
3049
|
+
Generates a customizable scatter plot (e.g., volcano plot).
|
3050
|
+
|
3051
|
+
Parameters:
|
3052
|
+
-----------
|
3053
|
+
data : pd.DataFrame
|
3054
|
+
The DataFrame containing the data to plot.
|
3055
|
+
x : str
|
3056
|
+
Column name for x-axis values (e.g., log2FoldChange).
|
3057
|
+
y : str
|
3058
|
+
Column name for y-axis values (e.g., -log10(FDR)).
|
3059
|
+
gene_col : str, optional
|
3060
|
+
Column name for gene names. If provided, gene names will be displayed. Default is None.
|
3061
|
+
top_genes : int, optional
|
3062
|
+
Number of top genes to label based on y-axis values. Default is 5.
|
3063
|
+
thr_x : float, optional
|
3064
|
+
Threshold for x-axis values. Default is 0.585.
|
3065
|
+
thr_y : float, optional
|
3066
|
+
Threshold for y-axis values (e.g., significance threshold). Default is -np.log10(0.05).
|
3067
|
+
colors : tuple, optional
|
3068
|
+
Colors for points above/below thresholds and neutral points. Default is ("red", "blue", "gray").
|
3069
|
+
figsize : tuple, optional
|
3070
|
+
Figure size. Default is (6, 4).
|
3071
|
+
s : int, optional
|
3072
|
+
Size of points in the plot. Default is 20.
|
3073
|
+
fontsize : int, optional
|
3074
|
+
Font size for gene labels. Default is 10.
|
3075
|
+
alpha : float, optional
|
3076
|
+
Transparency of the points. Default is 0.8.
|
3077
|
+
legend : bool, optional
|
3078
|
+
Whether to show a legend. Default is False.
|
3079
|
+
"""
|
3080
|
+
usage_str = """
|
3081
|
+
_, axs = plt.subplots(1, 1, figsize=(4, 5))
|
3082
|
+
volcano(
|
3083
|
+
ax=axs,
|
3084
|
+
data=ranked_genes,
|
3085
|
+
x="log2(fold_change)",
|
3086
|
+
y="-log10(p-value)",
|
3087
|
+
gene_col="ID_REF",
|
3088
|
+
top_genes=6,
|
3089
|
+
thr_x=np.log2(1.2),
|
3090
|
+
# thr_y=-np.log10(0.05),
|
3091
|
+
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3092
|
+
fill=0,
|
3093
|
+
alpha=1,
|
3094
|
+
facecolor="none",
|
3095
|
+
s=20,
|
3096
|
+
edgelinewidth=0.5,
|
3097
|
+
edgecolor="0.5",
|
3098
|
+
kws_text=dict(fontsize=10, color="k"),
|
3099
|
+
kws_arrow=dict(style="-", color="k", lw=0.5),
|
3100
|
+
# usage=True,
|
3101
|
+
figsets=dict(ylim=[0, 10], title="df"),
|
3102
|
+
)
|
3103
|
+
"""
|
3104
|
+
if usage:
|
3105
|
+
print(usage_str)
|
3106
|
+
return
|
3107
|
+
from adjustText import adjust_text
|
3108
|
+
|
3109
|
+
kws_figsets = {}
|
3110
|
+
for k_arg, v_arg in kwargs.items():
|
3111
|
+
if "figset" in k_arg:
|
3112
|
+
kws_figsets = v_arg
|
3113
|
+
kwargs.pop(k_arg, None)
|
3114
|
+
break
|
3115
|
+
# Color-coding based on thresholds using np.where
|
3116
|
+
data["color"] = np.where(
|
3117
|
+
(data[x] > thr_x) & (data[y] > thr_y),
|
3118
|
+
colors[2],
|
3119
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3120
|
+
)
|
3121
|
+
|
3122
|
+
# Selecting top significant points for labeling
|
3123
|
+
sele_gene = (
|
3124
|
+
data.query("color != @colors[2]") # Exclude gray points
|
3125
|
+
.groupby("color", axis=0)
|
3126
|
+
.apply(lambda x: x.sort_values(y, ascending=False).head(top_genes))
|
3127
|
+
.droplevel(level=0)
|
3128
|
+
)
|
3129
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3130
|
+
# Plot setup
|
3131
|
+
if ax is None:
|
3132
|
+
ax = plt.gca()
|
3133
|
+
|
3134
|
+
# Handle fill parameter
|
3135
|
+
if fill:
|
3136
|
+
facecolors = data["color"] # Fill with colors
|
3137
|
+
edgecolors = edgecolor # Set edgecolor
|
3138
|
+
else:
|
3139
|
+
facecolors = facecolor # No fill, use edge color as the face color
|
3140
|
+
edgecolors = data["color"]
|
3141
|
+
|
3142
|
+
ax = sns.scatterplot(
|
3143
|
+
ax=ax,
|
3144
|
+
data=data,
|
3145
|
+
x=x,
|
3146
|
+
y=y,
|
3147
|
+
# hue="color",
|
3148
|
+
palette=palette,
|
3149
|
+
s=s,
|
3150
|
+
linewidths=edgelinewidth,
|
3151
|
+
color=facecolors,
|
3152
|
+
edgecolor=edgecolors,
|
3153
|
+
alpha=alpha,
|
3154
|
+
legend=legend,
|
3155
|
+
**kwargs,
|
3156
|
+
)
|
3157
|
+
|
3158
|
+
# Add threshold lines for x and y axes
|
3159
|
+
plt.axhline(y=thr_y, color="black", linestyle="--")
|
3160
|
+
plt.axvline(x=-thr_x, color="black", linestyle="--")
|
3161
|
+
plt.axvline(x=thr_x, color="black", linestyle="--")
|
3162
|
+
|
3163
|
+
# Add gene labels for selected significant points
|
3164
|
+
if gene_col:
|
3165
|
+
texts = []
|
3166
|
+
if kws_text:
|
3167
|
+
fontname = kws_text.pop("fontname", "Arial")
|
3168
|
+
textcolor = kws_text.pop("color", "k")
|
3169
|
+
fontsize = kws_text.pop("fontsize", 10)
|
3170
|
+
for i in range(sele_gene.shape[0]):
|
3171
|
+
if isinstance(textcolor, list): # be consistant with dots's color
|
3172
|
+
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
3173
|
+
texts.append(
|
3174
|
+
plt.text(
|
3175
|
+
x=sele_gene[x].iloc[i],
|
3176
|
+
y=sele_gene[y].iloc[i],
|
3177
|
+
s=sele_gene[gene_col].iloc[i],
|
3178
|
+
fontdict={
|
3179
|
+
"fontsize": fontsize,
|
3180
|
+
"color": textcolor,
|
3181
|
+
"fontname": fontname,
|
3182
|
+
},
|
3183
|
+
)
|
3184
|
+
)
|
3185
|
+
|
3186
|
+
arrowstyles = [
|
3187
|
+
"-",
|
3188
|
+
"->",
|
3189
|
+
"-[",
|
3190
|
+
"|->",
|
3191
|
+
"<-",
|
3192
|
+
"<->",
|
3193
|
+
"<|-",
|
3194
|
+
"<|-|>",
|
3195
|
+
"-|>",
|
3196
|
+
"-[ ",
|
3197
|
+
"fancy",
|
3198
|
+
"simple",
|
3199
|
+
"wedge",
|
3200
|
+
]
|
3201
|
+
arrowstyle = kws_arrow.pop("style", "-")
|
3202
|
+
arrowcolor = kws_arrow.pop("color", "0.5")
|
3203
|
+
arrowlinewidth = kws_arrow.pop("lw", 0.5)
|
3204
|
+
shrinkA = kws_arrow.pop("shrinkA", 5)
|
3205
|
+
shrinkB = kws_arrow.pop("shrinkB", 5)
|
3206
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles)[0]
|
3207
|
+
adjust_text(
|
3208
|
+
texts,
|
3209
|
+
expand_text=(1.05, 1.2),
|
3210
|
+
arrowprops=dict(
|
3211
|
+
arrowstyle=arrowstyle,
|
3212
|
+
color=arrowcolor,
|
3213
|
+
lw=arrowlinewidth,
|
3214
|
+
shrinkA=shrinkA,
|
3215
|
+
shrinkB=shrinkB,
|
3216
|
+
**kws_arrow,
|
3217
|
+
),
|
3218
|
+
)
|
3219
|
+
|
3220
|
+
figsets(**kws_figsets)
|
3221
|
+
|
3222
|
+
|
3223
|
+
def sns_func_info(dir_save=None):
|
3224
|
+
sns_info = {
|
3225
|
+
"Functions": [
|
3226
|
+
"relplot",
|
3227
|
+
"scatterplot",
|
3228
|
+
"lineplot",
|
3229
|
+
"lmplot",
|
3230
|
+
"catplot",
|
3231
|
+
"stripplot",
|
3232
|
+
"boxplot",
|
3233
|
+
"violinplot",
|
3234
|
+
"boxenplot",
|
3235
|
+
"pointplot",
|
3236
|
+
"barplot",
|
3237
|
+
"countplot",
|
3238
|
+
"displot",
|
3239
|
+
"histplot",
|
3240
|
+
"kdeplot",
|
3241
|
+
"ecdfplot",
|
3242
|
+
"rugplot",
|
3243
|
+
"regplot",
|
3244
|
+
"residplot",
|
3245
|
+
"pairplot",
|
3246
|
+
"jointplot",
|
3247
|
+
"plotting_context",
|
3248
|
+
],
|
3249
|
+
"Category": [
|
3250
|
+
"relational",
|
3251
|
+
"relational",
|
3252
|
+
"relational",
|
3253
|
+
"relational",
|
3254
|
+
"categorical",
|
3255
|
+
"categorical",
|
3256
|
+
"categorical",
|
3257
|
+
"categorical",
|
3258
|
+
"categorical",
|
3259
|
+
"categorical",
|
3260
|
+
"categorical",
|
3261
|
+
"categorical",
|
3262
|
+
"distribution",
|
3263
|
+
"distribution",
|
3264
|
+
"distribution",
|
3265
|
+
"distribution",
|
3266
|
+
"distribution",
|
3267
|
+
"regression",
|
3268
|
+
"regression",
|
3269
|
+
"grid-based(fig)",
|
3270
|
+
"grid-based(fig)",
|
3271
|
+
"context",
|
3272
|
+
],
|
3273
|
+
"Detail": [
|
3274
|
+
"A figure-level function for creating scatter plots and line plots. It combines the functionality of scatterplot and lineplot.",
|
3275
|
+
"A function for creating scatter plots, useful for visualizing the relationship between two continuous variables.",
|
3276
|
+
"A function for drawing line plots, often used to visualize trends over time or ordered categories.",
|
3277
|
+
"A figure-level function for creating linear model plots, combining regression lines with scatter plots.",
|
3278
|
+
"A figure-level function for creating categorical plots, which can display various types of plots like box plots, violin plots, and bar plots in one function.",
|
3279
|
+
"A function for creating a scatter plot where one of the variables is categorical, helping visualize distribution along a categorical axis.",
|
3280
|
+
"A function for creating box plots, which summarize the distribution of a continuous variable based on a categorical variable.",
|
3281
|
+
"A function for creating violin plots, which combine box plots and KDEs to visualize the distribution of data.",
|
3282
|
+
"A function for creating boxen plots, an enhanced version of box plots that better represent data distributions with more quantiles.",
|
3283
|
+
"A function for creating point plots, which show the mean (or another estimator) of a variable for each level of a categorical variable.",
|
3284
|
+
"A function for creating bar plots, which represent the mean (or other estimators) of a variable with bars, typically used with categorical data.",
|
3285
|
+
"A function for creating count plots, which show the counts of observations in each categorical bin.",
|
3286
|
+
"A figure-level function that creates distribution plots. It can visualize histograms, KDEs, and ECDFs, making it versatile for analyzing the distribution of data.",
|
3287
|
+
"A function for creating histograms, useful for showing the frequency distribution of a continuous variable.",
|
3288
|
+
"A function for creating kernel density estimate (KDE) plots, which visualize the probability density function of a continuous variable.",
|
3289
|
+
"A function for creating empirical cumulative distribution function (ECDF) plots, which show the proportion of observations below a certain value.",
|
3290
|
+
"A function that adds a rug plot to the axes, representing individual data points along an axis.",
|
3291
|
+
"A function for creating regression plots, which fit and visualize a regression model on scatter data.",
|
3292
|
+
"A function for creating residual plots, useful for diagnosing the fit of a regression model.",
|
3293
|
+
"A figure-level function that creates a grid of scatter plots for each pair of variables in a dataset, often used for exploratory data analysis.",
|
3294
|
+
"A figure-level function that combines scatter plots and histograms (or KDEs) to visualize the relationship between two variables and their distributions.",
|
3295
|
+
"Not a plot itself, but a function that allows you to change the context (style and scaling) of your plots to fit different publication requirements or visual preferences.",
|
3296
|
+
],
|
3297
|
+
}
|
3298
|
+
if dir_save is None:
|
3299
|
+
if "mac" in get_os():
|
3300
|
+
dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
|
3301
|
+
else:
|
3302
|
+
dir_save = "Z:\\Jianfeng\\temp\\"
|
3303
|
+
dir_save += "/" if not dir_save.endswith("/") else ""
|
3304
|
+
fsave(
|
3305
|
+
dir_save + "sns_info.json",
|
3306
|
+
sns_info,
|
3307
|
+
)
|