py2ls 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -9,24 +9,97 @@ import matplotlib.ticker as tck
9
9
  from cycler import cycler
10
10
  import logging
11
11
  import os
12
+ import re
12
13
 
13
- from .ips import fsave, fload, mkdir, listdir, figsave
14
+ from .ips import fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit
14
15
  from .stats import *
16
+ from .netfinder import get_soup, fetch
17
+
15
18
 
16
19
  # Suppress INFO messages from fontTools
17
20
  logging.getLogger("fontTools").setLevel(logging.WARNING)
18
21
 
19
22
 
20
- def df_corr(
21
- df,
22
- columns="all",
23
+ def update_sns_usages(
24
+ url="https://seaborn.pydata.org/generated/seaborn.swarmplot.html",
25
+ dir_save=None,
26
+ ):
27
+ """
28
+ Fetches usage examples of various Seaborn plotting functions from the Seaborn documentation website.
29
+ It filters the relevant plot-related links, extracts usage examples, and saves them in a JSON file.
30
+
31
+ Parameters:
32
+ - url (str): URL of the Seaborn page to start extracting plot usages (default is swarmplot page).
33
+ - dir_save (str): Directory where the JSON file containing usages will be saved (default is a local path).
34
+
35
+ Saves:
36
+ - A JSON file named 'usages_sns.json' containing plotting function names and their usage descriptions.
37
+
38
+ Returns:
39
+ - None
40
+ """
41
+
42
+ # extract each usage from its url
43
+ def get_usage(url):
44
+ sp = get_soup(url, driver="se")
45
+ # preview(sp)
46
+ return fetch(sp, where="dt")[0]
47
+
48
+ if dir_save is None:
49
+ if "mac" in get_os():
50
+ dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
51
+ else:
52
+ dir_save = "Z:\\Jianfeng\\temp\\"
53
+ sp = get_soup(url, driver="se")
54
+ links_all = fetch(sp, where="a", get="href", class_="reference internal")
55
+ filtered_links = unique(
56
+ [
57
+ i
58
+ for i in links_all
59
+ if not any(
60
+ [
61
+ i.startswith(cond)
62
+ for cond in [
63
+ "seaborn.JointGrid",
64
+ "seaborn.PairGrid",
65
+ "seaborn.objects",
66
+ ]
67
+ ]
68
+ + ["plot" not in i]
69
+ )
70
+ ]
71
+ )
72
+ links = ["https://seaborn.pydata.org/generated/" + i for i in filtered_links]
73
+
74
+ usages = [get_usage(i) for i in links]
75
+ dict_usage = {}
76
+ for usage in usages:
77
+ dict_usage.update(
78
+ {ssplit(usage, by="(")[0].replace("seaborn.", ""): usage[:-1]}
79
+ )
80
+ # save to local
81
+ dir_save += "/" if not dir_save.endswith("/") else ""
82
+ fsave(
83
+ dir_save + "usages_sns.json",
84
+ dict_usage,
85
+ )
86
+
87
+
88
+ def heatmap(
89
+ data,
90
+ ax=None,
91
+ kind="corr", #'corr','direct','pivot'
92
+ columns="all", # pivot, default: coll numeric columns
93
+ index=None, # pivot
94
+ values=None, # pivot
23
95
  tri="u",
24
96
  mask=True,
25
97
  k=1,
26
98
  annot=True,
27
99
  cmap="coolwarm",
28
100
  fmt=".2f",
29
- cluster=False, # New parameter for clustermap option
101
+ cluster=False,
102
+ inplace=False,
30
103
  figsize=(10, 8),
31
104
  row_cluster=True, # Perform clustering on rows
32
105
  col_cluster=True, # Perform clustering on columns
@@ -36,24 +109,115 @@ def df_corr(
36
109
  yticklabels=True, # Show row labels
37
110
  **kwargs,
38
111
  ):
112
+ if ax is None and not cluster:
113
+ ax = plt.gca()
39
114
  # Select numeric columns or specific subset of columns
40
115
  if columns == "all":
41
- df_numeric = df.select_dtypes(include=[float, int])
116
+ df_numeric = data.select_dtypes(include=[float, int])
42
117
  else:
43
- df_numeric = df[columns]
44
-
45
- # Compute the correlation matrix
46
- correlation_matrix = df_numeric.corr()
118
+ df_numeric = data[columns]
119
+
120
+ kinds = ["corr", "direct", "pivot"]
121
+ kind = strcmp(kind, kinds)[0]
122
+ if kind == "corr":
123
+ # Compute the correlation matrix
124
+ data4heatmap = df_numeric.corr()
125
+ # Generate mask for the upper triangle if mask is True
126
+ if mask:
127
+ if "u" in tri.lower(): # upper => np.tril
128
+ mask_array = np.tril(np.ones_like(data4heatmap, dtype=bool), k=k)
129
+ else: # lower => np.triu
130
+ mask_array = np.triu(np.ones_like(data4heatmap, dtype=bool), k=k)
131
+ else:
132
+ mask_array = None
133
+
134
+ # Remove conflicting kwargs
135
+ kwargs.pop("mask", None)
136
+ kwargs.pop("annot", None)
137
+ kwargs.pop("cmap", None)
138
+ kwargs.pop("fmt", None)
139
+
140
+ kwargs.pop("clustermap", None)
141
+ kwargs.pop("row_cluster", None)
142
+ kwargs.pop("col_cluster", None)
143
+ kwargs.pop("dendrogram_ratio", None)
144
+ kwargs.pop("cbar_pos", None)
145
+ kwargs.pop("xticklabels", None)
146
+ kwargs.pop("col_cluster", None)
147
+
148
+ # Plot the heatmap or clustermap
149
+ if cluster:
150
+ # Create a clustermap
151
+ cluster_obj = sns.clustermap(
152
+ data4heatmap,
153
+ # ax=ax,
154
+ mask=mask_array,
155
+ annot=annot,
156
+ cmap=cmap,
157
+ fmt=fmt,
158
+ figsize=figsize, # Figure size, adjusted for professional display
159
+ row_cluster=row_cluster, # Perform clustering on rows
160
+ col_cluster=col_cluster, # Perform clustering on columns
161
+ dendrogram_ratio=dendrogram_ratio, # Adjust size of dendrograms
162
+ cbar_pos=cbar_pos, # Adjust colorbar position
163
+ xticklabels=xticklabels, # Show column labels
164
+ yticklabels=yticklabels, # Show row labels
165
+ **kwargs, # Pass any additional arguments to sns.clustermap
166
+ )
167
+ df_row_cluster = pd.DataFrame()
168
+ df_col_cluster = pd.DataFrame()
169
+ if row_cluster:
170
+ from scipy.cluster.hierarchy import linkage, fcluster
171
+ from scipy.spatial.distance import pdist
172
+
173
+ # Compute pairwise distances
174
+ distances = pdist(data, metric="euclidean")
175
+ # Perform hierarchical clustering
176
+ linkage_matrix = linkage(distances, method="average")
177
+ # Get cluster assignments based on the distance threshold
178
+ row_clusters_value = fcluster(
179
+ linkage_matrix, t=1.5, criterion="distance"
180
+ )
181
+ df_row_cluster["row_cluster"] = row_clusters_value
182
+ if col_cluster:
183
+ col_distances = pdist(
184
+ data4heatmap.T, metric="euclidean"
185
+ ) # Transpose for column clustering
186
+ col_linkage_matrix = linkage(col_distances, method="average")
187
+ col_clusters_value = fcluster(
188
+ col_linkage_matrix, t=1.5, criterion="distance"
189
+ )
190
+ df_col_cluster = pd.DataFrame(
191
+ {"Cluster": col_clusters_value}, index=data4heatmap.columns
192
+ )
47
193
 
48
- # Generate mask for the upper triangle if mask is True
49
- if mask:
50
- if "u" in tri.lower(): # upper => np.tril
51
- mask_array = np.tril(np.ones_like(correlation_matrix, dtype=bool), k=k)
52
- else: # lower => np.triu
53
- mask_array = np.triu(np.ones_like(correlation_matrix, dtype=bool), k=k)
194
+ return (
195
+ cluster_obj.ax_row_dendrogram,
196
+ cluster_obj.ax_col_dendrogram,
197
+ cluster_obj.ax_heatmap,
198
+ df_row_cluster,
199
+ df_col_cluster,
200
+ )
201
+ else:
202
+ # Create a standard heatmap
203
+ ax = sns.heatmap(
204
+ data4heatmap,
205
+ ax=ax,
206
+ mask=mask_array,
207
+ annot=annot,
208
+ cmap=cmap,
209
+ fmt=fmt,
210
+ **kwargs, # Pass any additional arguments to sns.heatmap
211
+ )
212
+ # Return the Axes object for further customization if needed
213
+ return ax
214
+ elif kind == "direct":
215
+ data4heatmap = df_numeric
216
+ elif kind == "pivot":
217
+ print('need 3 param: e.g., index="Task", columns="Model", values="Score"')
218
+ data4heatmap = data.pivot(index=index, columns=columns, values=values)
54
219
  else:
55
- mask_array = None
56
-
220
+ print(f'"{kind}" is not supported')
57
221
  # Remove conflicting kwargs
58
222
  kwargs.pop("mask", None)
59
223
  kwargs.pop("annot", None)
@@ -72,8 +236,9 @@ def df_corr(
72
236
  if cluster:
73
237
  # Create a clustermap
74
238
  cluster_obj = sns.clustermap(
75
- correlation_matrix,
76
- mask=mask_array,
239
+ data4heatmap,
240
+ # ax=ax,
241
+ # mask=mask_array,
77
242
  annot=annot,
78
243
  cmap=cmap,
79
244
  fmt=fmt,
@@ -86,18 +251,43 @@ def df_corr(
86
251
  yticklabels=yticklabels, # Show row labels
87
252
  **kwargs, # Pass any additional arguments to sns.clustermap
88
253
  )
254
+ df_row_cluster = pd.DataFrame()
255
+ df_col_cluster = pd.DataFrame()
256
+ if row_cluster:
257
+ from scipy.cluster.hierarchy import linkage, fcluster
258
+ from scipy.spatial.distance import pdist
259
+
260
+ # Compute pairwise distances
261
+ distances = pdist(data, metric="euclidean")
262
+ # Perform hierarchical clustering
263
+ linkage_matrix = linkage(distances, method="average")
264
+ # Get cluster assignments based on the distance threshold
265
+ row_clusters_value = fcluster(linkage_matrix, t=1.5, criterion="distance")
266
+ df_row_cluster["row_cluster"] = row_clusters_value
267
+ if col_cluster:
268
+ col_distances = pdist(
269
+ data4heatmap.T, metric="euclidean"
270
+ ) # Transpose for column clustering
271
+ col_linkage_matrix = linkage(col_distances, method="average")
272
+ col_clusters_value = fcluster(
273
+ col_linkage_matrix, t=1.5, criterion="distance"
274
+ )
275
+ df_col_cluster = pd.DataFrame(
276
+ {"Cluster": col_clusters_value}, index=data4heatmap.columns
277
+ )
89
278
 
90
279
  return (
91
280
  cluster_obj.ax_row_dendrogram,
92
281
  cluster_obj.ax_col_dendrogram,
93
282
  cluster_obj.ax_heatmap,
283
+ df_row_cluster,
284
+ df_col_cluster,
94
285
  )
95
286
  else:
96
287
  # Create a standard heatmap
97
- plt.figure(figsize=figsize)
98
288
  ax = sns.heatmap(
99
- correlation_matrix,
100
- mask=mask_array,
289
+ data4heatmap,
290
+ ax=ax,
101
291
  annot=annot,
102
292
  cmap=cmap,
103
293
  fmt=fmt,
@@ -107,6 +297,60 @@ def df_corr(
107
297
  return ax
108
298
 
109
299
 
300
+ # !usage: py2ls.plot.heatmap()
301
+ # penguins_clean = penguins.replace([np.inf, -np.inf], np.nan).dropna()
302
+ # from py2ls import plot
303
+
304
+ # _, axs = plt.subplots(2, 2, figsize=(10, 10))
305
+ # # kind='pivot'
306
+ # plot.heatmap(
307
+ # ax=axs[0][0],
308
+ # data=sns.load_dataset("glue"),
309
+ # kind="pi",
310
+ # index="Model",
311
+ # columns="Task",
312
+ # values="Score",
313
+ # fmt=".1f",
314
+ # cbar_kws=dict(shrink=1),
315
+ # annot_kws=dict(size=7),
316
+ # )
317
+ # # kind='direct'
318
+ # plot.heatmap(
319
+ # ax=axs[0][1],
320
+ # data=sns.load_dataset("penguins").iloc[:10, 2:6],
321
+ # kind="direct",
322
+ # tri="lower",
323
+ # fmt=".1f",
324
+ # k=1,
325
+ # cbar_kws=dict(shrink=1),
326
+ # annot_kws=dict(size=7),
327
+ # )
328
+
329
+ # # kind='corr'
330
+ # plot.heatmap(
331
+ # ax=axs[1][0],
332
+ # data=sns.load_dataset("penguins"),
333
+ # kind="corr",
334
+ # fmt=".1f",
335
+ # k=-1,
336
+ # cbar_kws=dict(shrink=1),
337
+ # annot_kws=dict(size=7),
338
+ # )
339
+ # # kind='corr'
340
+ # plot.heatmap(
341
+ # ax=axs[1][1],
342
+ # data=penguins_clean.iloc[:15, :10],
343
+ # kind="direct",
344
+ # tri="lower",
345
+ # fmt=".1f",
346
+ # k=1,
347
+ # annot=False,
348
+ # cluster=True,
349
+ # cbar_kws=dict(shrink=1),
350
+ # annot_kws=dict(size=7),
351
+ # )
352
+
353
+
110
354
  def catplot(data, *args, **kwargs):
111
355
  """
112
356
  catplot(data, opt=None, ax=None)
@@ -1524,6 +1768,10 @@ def figsets(*args, **kwargs):
1524
1768
  alignment='left')
1525
1769
  )
1526
1770
  """
1771
+ import matplotlib
1772
+
1773
+ matplotlib.rc("text", usetex=False)
1774
+
1527
1775
  fig = plt.gcf()
1528
1776
  fontsize = 11
1529
1777
  fontname = "Arial"
@@ -1615,6 +1863,16 @@ def figsets(*args, **kwargs):
1615
1863
  if isinstance(value, list):
1616
1864
  loc = []
1617
1865
  for i in value:
1866
+ ax.tick_params(
1867
+ axis="both",
1868
+ which="both",
1869
+ bottom=False,
1870
+ top=False,
1871
+ left=False,
1872
+ right=False,
1873
+ labelbottom=False,
1874
+ labelleft=False,
1875
+ )
1618
1876
  if ("l" in i.lower()) and ("a" not in i.lower()):
1619
1877
  ax.yaxis.set_ticks_position("left")
1620
1878
  if "r" in i.lower():
@@ -1624,12 +1882,38 @@ def figsets(*args, **kwargs):
1624
1882
  if "b" in i.lower():
1625
1883
  ax.xaxis.set_ticks_position("bottom")
1626
1884
  if i.lower() in ["a", "both", "all", "al", ":"]:
1627
- ax.xaxis.set_ticks_position("both")
1628
- ax.yaxis.set_ticks_position("both")
1885
+ ax.tick_params(
1886
+ axis="both", # Apply to both axes
1887
+ which="both", # Apply to both major and minor ticks
1888
+ bottom=True, # Show ticks at the bottom
1889
+ top=True, # Show ticks at the top
1890
+ left=True, # Show ticks on the left
1891
+ right=True, # Show ticks on the right
1892
+ labelbottom=True, # Show labels at the bottom
1893
+ labelleft=True, # Show labels on the left
1894
+ )
1629
1895
  if i.lower() in ["xnone", "xoff", "none"]:
1630
- ax.xaxis.set_ticks_position("none")
1896
+ ax.tick_params(
1897
+ axis="x",
1898
+ which="both",
1899
+ bottom=False,
1900
+ top=False,
1901
+ left=False,
1902
+ right=False,
1903
+ labelbottom=False,
1904
+ labelleft=False,
1905
+ )
1631
1906
  if i.lower() in ["ynone", "yoff", "none"]:
1632
- ax.yaxis.set_ticks_position("none")
1907
+ ax.tick_params(
1908
+ axis="y",
1909
+ which="both",
1910
+ bottom=False,
1911
+ top=False,
1912
+ left=False,
1913
+ right=False,
1914
+ labelbottom=False,
1915
+ labelleft=False,
1916
+ )
1633
1917
  # ticks / labels
1634
1918
  elif "x" in key.lower():
1635
1919
  if value is None:
@@ -1674,6 +1958,10 @@ def figsets(*args, **kwargs):
1674
1958
 
1675
1959
  if "bo" in key in key: # box setting, and ("p" in key or "l" in key):
1676
1960
  if isinstance(value, (str, list)):
1961
+ # locations = ["left", "right", "top", "bottom"]
1962
+ # for loc, spi in ax.spines.items():
1963
+ # if loc in locations:
1964
+ # spi.set_color("none") # no spine
1677
1965
  locations = []
1678
1966
  for i in value:
1679
1967
  if "l" in i.lower() and not "t" in i.lower():
@@ -1689,12 +1977,12 @@ def figsets(*args, **kwargs):
1689
1977
  locations.append(x)
1690
1978
  for x in ["left", "right", "top", "bottom"]
1691
1979
  ]
1692
- for i in value:
1693
- if i.lower() in "none":
1694
- locations = []
1980
+ if "none" in value:
1981
+ locations = [] # hide all
1695
1982
  # check spines
1696
1983
  for loc, spi in ax.spines.items():
1697
1984
  if loc in locations:
1985
+ # spi.set_color("k")
1698
1986
  spi.set_position(("outward", 0))
1699
1987
  else:
1700
1988
  spi.set_color("none") # no spine
@@ -2527,3 +2815,493 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, usage=F
2527
2815
  plt.tight_layout()
2528
2816
  if show:
2529
2817
  plt.show()
2818
+
2819
+
2820
+ def get_params_from_func_usage(function_signature):
2821
+ # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2822
+ keys_pattern = r"(?<!\*\*)\b(\w+)="
2823
+ # Find all matches
2824
+ matches = re.findall(keys_pattern, function_signature)
2825
+ return matches
2826
+
2827
+
2828
+ def plot_xy(
2829
+ data: pd.DataFrame = None,
2830
+ x=None,
2831
+ y=None,
2832
+ ax=None,
2833
+ kind: str = None, # Specify the kind of plot
2834
+ usage=False,
2835
+ # kws_figsets:dict=None,
2836
+ **kwargs,
2837
+ ):
2838
+ """
2839
+ e.g., plot_xy(data=data_log, x="Component_1", y="Component_2", hue="Cluster",kind='scater)
2840
+ Create a variety of plots based on the kind parameter.
2841
+
2842
+ Parameters:
2843
+ data (pd.DataFrame): DataFrame containing the data.
2844
+ x (str): Column name for the x-axis.
2845
+ y (str): Column name for the y-axis.
2846
+ hue (str): Column name for the hue (color) grouping.
2847
+ ax: Matplotlib axes object for the plot.
2848
+ kind (str): Type of plot ('scatter', 'line', 'displot', 'kdeplot', etc.).
2849
+ usage (bool): If True, print default settings instead of plotting.
2850
+ **kwargs: Additional keyword arguments for the plot functions.
2851
+
2852
+ Returns:
2853
+ ax or FacetGrid: Matplotlib axes object or FacetGrid for displot.
2854
+ """
2855
+ # Check for valid plot kind
2856
+ # Default arguments for various plot types
2857
+ default_settings = fload(
2858
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_sns.json"
2859
+ )
2860
+ sns_info = pd.DataFrame(
2861
+ fload(
2862
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/sns_info.json",
2863
+ )
2864
+ )
2865
+ valid_kinds = list(default_settings.keys())
2866
+ print(valid_kinds)
2867
+ if kind is not None:
2868
+ if isinstance(kind, str):
2869
+ kind = [kind]
2870
+ kind = [strcmp(i, valid_kinds)[0] for i in kind]
2871
+ else:
2872
+ usage = True
2873
+ if usage:
2874
+ if kind is not None:
2875
+ for k in kind:
2876
+ if k in valid_kinds:
2877
+ print(f"{k}:\n\t{default_settings[k]}")
2878
+ print(
2879
+ sns_info[sns_info["Functions"].str.contains(k)]
2880
+ .iloc[:, -1]
2881
+ .tolist()[0]
2882
+ )
2883
+ print()
2884
+ usage_str = """plot_xy(data=ranked_genes,
2885
+ x="log2(fold_change)",
2886
+ y="-log10(p-value)",
2887
+ palette=get_color(3, cmap="coolwarm"),
2888
+ kind=["scatter","rug"],
2889
+ kws_rug=dict(height=0.2),
2890
+ kws_scatter=dict(s=20, color=get_color(3)[2]),
2891
+ usage=0)
2892
+ """
2893
+ print(f"currently support to plot:\n{valid_kinds}\n\nusage:\n{usage_str}")
2894
+ return # Do not plot, just print the usage
2895
+
2896
+ kws_figsets = {}
2897
+ for k_arg, v_arg in kwargs.items():
2898
+ if "figset" in k_arg:
2899
+ kws_figsets = v_arg
2900
+ kwargs.pop(k_arg, None)
2901
+ break
2902
+
2903
+ for k in kind:
2904
+ # indicate 'col' features
2905
+ col = kwargs.get("col", None)
2906
+ sns_with_col = [
2907
+ "catplot",
2908
+ "histplot",
2909
+ "relplot",
2910
+ "lmplot",
2911
+ "pairplot",
2912
+ "displot",
2913
+ "kdeplot",
2914
+ ]
2915
+ if col is not None:
2916
+ if not k in sns_with_col:
2917
+ print(
2918
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
2919
+ )
2920
+ # (1) return FcetGrid
2921
+ if k == "jointplot":
2922
+ kws_joint = kwargs.pop("kws_joint", kwargs)
2923
+ g = sns.jointplot(data=data, x=x, y=y, hue=hue, **kws_joint)
2924
+ elif k == "lmplot":
2925
+ kws_lm = kwargs.pop("kws_lm", kwargs)
2926
+ g = sns.lmplot(data=data, x=x, y=y, hue=hue, **kws_lm)
2927
+ elif k == "catplot_sns":
2928
+ kws_cat = kwargs.pop("kws_cat", kwargs)
2929
+ g = sns.catplot(data=data, x=x, y=y, hue=hue, **kws_cat)
2930
+ elif k == "displot":
2931
+ kws_dis = kwargs.pop("kws_dis", kwargs)
2932
+ # displot creates a new figure and returns a FacetGrid
2933
+ g = sns.displot(data=data, x=x, hue=hue, **kws_dis)
2934
+
2935
+ # (2) return axis
2936
+ if ax is None:
2937
+ ax = plt.gca()
2938
+
2939
+ if k == "catplot":
2940
+ kws_cat = kwargs.pop("kws_cat", kwargs)
2941
+ g = catplot(data=data, x=x, y=y, ax=ax, **kws_cat)
2942
+ elif k == "scatterplot":
2943
+ kws_scatter = kwargs.pop("kws_scatter", kwargs)
2944
+ palette = kws_scatter.pop(
2945
+ "palette",
2946
+ (
2947
+ sns.color_palette("tab10", data[hue].nunique())
2948
+ if hue is not None
2949
+ else sns.color_palette("tab10")
2950
+ ),
2951
+ )
2952
+ s = kws_scatter.pop("s", 10)
2953
+ alpha = kws_scatter.pop("alpha", 0.7)
2954
+ ax = sns.scatterplot(
2955
+ ax=ax,
2956
+ data=data,
2957
+ x=x,
2958
+ y=y,
2959
+ hue=hue,
2960
+ palette=palette,
2961
+ s=s,
2962
+ alpha=alpha,
2963
+ **kws_scatter,
2964
+ )
2965
+ elif k == "histplot":
2966
+ kws_hist = kwargs.pop("kws_hist", kwargs)
2967
+ ax = sns.histplot(data=data, x=x, hue=hue, ax=ax, **kws_hist)
2968
+ elif k == "kdeplot":
2969
+ kws_kde = kwargs.pop("kws_kde", kwargs)
2970
+ ax = sns.kdeplot(data=data, x=x, hue=hue, ax=ax, **kws_kde)
2971
+ elif k == "ecdfplot":
2972
+ kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
2973
+ ax = sns.ecdfplot(data=data, x=x, hue=hue, ax=ax, **kws_ecdf)
2974
+ elif k == "rugplot":
2975
+ kws_rug = kwargs.pop("kws_rug", kwargs)
2976
+ print(kws_rug)
2977
+ ax = sns.rugplot(data=data, x=x, hue=hue, ax=ax, **kws_rug)
2978
+ elif k == "stripplot":
2979
+ kws_strip = kwargs.pop("kws_strip", kwargs)
2980
+ ax = sns.stripplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_strip)
2981
+ elif k == "swarmplot":
2982
+ kws_swarm = kwargs.pop("kws_swarm", kwargs)
2983
+ ax = sns.swarmplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_swarm)
2984
+ elif k == "boxplot":
2985
+ kws_box = kwargs.pop("kws_box", kwargs)
2986
+ ax = sns.boxplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_box)
2987
+ elif k == "violinplot":
2988
+ kws_violin = kwargs.pop("kws_violin", kwargs)
2989
+ ax = sns.violinplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_violin)
2990
+ elif k == "boxenplot":
2991
+ kws_boxen = kwargs.pop("kws_boxen", kwargs)
2992
+ ax = sns.boxenplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_boxen)
2993
+ elif k == "pointplot":
2994
+ kws_point = kwargs.pop("kws_point", kwargs)
2995
+ ax = sns.pointplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_point)
2996
+ elif k == "barplot":
2997
+ kws_bar = kwargs.pop("kws_bar", kwargs)
2998
+ ax = sns.barplot(data=data, x=x, y=y, hue=hue, ax=ax, **kws_bar)
2999
+ elif k == "countplot":
3000
+ kws_count = kwargs.pop("kws_count", kwargs)
3001
+ ax = sns.countplot(data=data, x=x, hue=hue, ax=ax, **kws_count)
3002
+ elif k == "regplot":
3003
+ kws_reg = kwargs.pop("kws_reg", kwargs)
3004
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3005
+ elif k == "residplot":
3006
+ kws_resid = kwargs.pop("kws_resid", kwargs)
3007
+ ax = sns.residplot(data=data, x=x, y=y, lowess=True, ax=ax, **kws_resid)
3008
+ elif k == "lineplot":
3009
+ kws_line = kwargs.pop("kws_line", kwargs)
3010
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, hue=hue, **kws_line)
3011
+
3012
+ figsets(**kws_figsets)
3013
+ print(k, " ⤵ ")
3014
+ print(default_settings[k])
3015
+ print(
3016
+ "=>\t",
3017
+ sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],
3018
+ )
3019
+ print()
3020
+ if "g" in locals():
3021
+ if ax is not None:
3022
+ return g, ax
3023
+ return ax
3024
+
3025
+
3026
+ def volcano(
3027
+ data,
3028
+ x,
3029
+ y,
3030
+ gene_col=None,
3031
+ top_genes=5,
3032
+ thr_x=np.log2(1.5),
3033
+ thr_y=-np.log10(0.05),
3034
+ colors=("#e70b0b", "#0d26e3", "#b8bbbe"),
3035
+ s=20,
3036
+ fill=True, # plot filled scatter
3037
+ facecolor="none",
3038
+ edgecolor="none",
3039
+ edgelinewidth=0.5,
3040
+ alpha=0.8,
3041
+ legend=False,
3042
+ ax=None,
3043
+ usage=False,
3044
+ kws_arrow=None,
3045
+ kws_text=None,
3046
+ **kwargs,
3047
+ ):
3048
+ """
3049
+ Generates a customizable scatter plot (e.g., volcano plot).
3050
+
3051
+ Parameters:
3052
+ -----------
3053
+ data : pd.DataFrame
3054
+ The DataFrame containing the data to plot.
3055
+ x : str
3056
+ Column name for x-axis values (e.g., log2FoldChange).
3057
+ y : str
3058
+ Column name for y-axis values (e.g., -log10(FDR)).
3059
+ gene_col : str, optional
3060
+ Column name for gene names. If provided, gene names will be displayed. Default is None.
3061
+ top_genes : int, optional
3062
+ Number of top genes to label based on y-axis values. Default is 5.
3063
+ thr_x : float, optional
3064
+ Threshold for x-axis values. Default is 0.585.
3065
+ thr_y : float, optional
3066
+ Threshold for y-axis values (e.g., significance threshold). Default is -np.log10(0.05).
3067
+ colors : tuple, optional
3068
+ Colors for points above/below thresholds and neutral points. Default is ("red", "blue", "gray").
3069
+ figsize : tuple, optional
3070
+ Figure size. Default is (6, 4).
3071
+ s : int, optional
3072
+ Size of points in the plot. Default is 20.
3073
+ fontsize : int, optional
3074
+ Font size for gene labels. Default is 10.
3075
+ alpha : float, optional
3076
+ Transparency of the points. Default is 0.8.
3077
+ legend : bool, optional
3078
+ Whether to show a legend. Default is False.
3079
+ """
3080
+ usage_str = """
3081
+ _, axs = plt.subplots(1, 1, figsize=(4, 5))
3082
+ volcano(
3083
+ ax=axs,
3084
+ data=ranked_genes,
3085
+ x="log2(fold_change)",
3086
+ y="-log10(p-value)",
3087
+ gene_col="ID_REF",
3088
+ top_genes=6,
3089
+ thr_x=np.log2(1.2),
3090
+ # thr_y=-np.log10(0.05),
3091
+ colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3092
+ fill=0,
3093
+ alpha=1,
3094
+ facecolor="none",
3095
+ s=20,
3096
+ edgelinewidth=0.5,
3097
+ edgecolor="0.5",
3098
+ kws_text=dict(fontsize=10, color="k"),
3099
+ kws_arrow=dict(style="-", color="k", lw=0.5),
3100
+ # usage=True,
3101
+ figsets=dict(ylim=[0, 10], title="df"),
3102
+ )
3103
+ """
3104
+ if usage:
3105
+ print(usage_str)
3106
+ return
3107
+ from adjustText import adjust_text
3108
+
3109
+ kws_figsets = {}
3110
+ for k_arg, v_arg in kwargs.items():
3111
+ if "figset" in k_arg:
3112
+ kws_figsets = v_arg
3113
+ kwargs.pop(k_arg, None)
3114
+ break
3115
+ # Color-coding based on thresholds using np.where
3116
+ data["color"] = np.where(
3117
+ (data[x] > thr_x) & (data[y] > thr_y),
3118
+ colors[2],
3119
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3120
+ )
3121
+
3122
+ # Selecting top significant points for labeling
3123
+ sele_gene = (
3124
+ data.query("color != @colors[2]") # Exclude gray points
3125
+ .groupby("color", axis=0)
3126
+ .apply(lambda x: x.sort_values(y, ascending=False).head(top_genes))
3127
+ .droplevel(level=0)
3128
+ )
3129
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3130
+ # Plot setup
3131
+ if ax is None:
3132
+ ax = plt.gca()
3133
+
3134
+ # Handle fill parameter
3135
+ if fill:
3136
+ facecolors = data["color"] # Fill with colors
3137
+ edgecolors = edgecolor # Set edgecolor
3138
+ else:
3139
+ facecolors = facecolor # No fill, use edge color as the face color
3140
+ edgecolors = data["color"]
3141
+
3142
+ ax = sns.scatterplot(
3143
+ ax=ax,
3144
+ data=data,
3145
+ x=x,
3146
+ y=y,
3147
+ # hue="color",
3148
+ palette=palette,
3149
+ s=s,
3150
+ linewidths=edgelinewidth,
3151
+ color=facecolors,
3152
+ edgecolor=edgecolors,
3153
+ alpha=alpha,
3154
+ legend=legend,
3155
+ **kwargs,
3156
+ )
3157
+
3158
+ # Add threshold lines for x and y axes
3159
+ plt.axhline(y=thr_y, color="black", linestyle="--")
3160
+ plt.axvline(x=-thr_x, color="black", linestyle="--")
3161
+ plt.axvline(x=thr_x, color="black", linestyle="--")
3162
+
3163
+ # Add gene labels for selected significant points
3164
+ if gene_col:
3165
+ texts = []
3166
+ if kws_text:
3167
+ fontname = kws_text.pop("fontname", "Arial")
3168
+ textcolor = kws_text.pop("color", "k")
3169
+ fontsize = kws_text.pop("fontsize", 10)
3170
+ for i in range(sele_gene.shape[0]):
3171
+ if isinstance(textcolor, list): # be consistant with dots's color
3172
+ textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
3173
+ texts.append(
3174
+ plt.text(
3175
+ x=sele_gene[x].iloc[i],
3176
+ y=sele_gene[y].iloc[i],
3177
+ s=sele_gene[gene_col].iloc[i],
3178
+ fontdict={
3179
+ "fontsize": fontsize,
3180
+ "color": textcolor,
3181
+ "fontname": fontname,
3182
+ },
3183
+ )
3184
+ )
3185
+
3186
+ arrowstyles = [
3187
+ "-",
3188
+ "->",
3189
+ "-[",
3190
+ "|->",
3191
+ "<-",
3192
+ "<->",
3193
+ "<|-",
3194
+ "<|-|>",
3195
+ "-|>",
3196
+ "-[ ",
3197
+ "fancy",
3198
+ "simple",
3199
+ "wedge",
3200
+ ]
3201
+ arrowstyle = kws_arrow.pop("style", "-")
3202
+ arrowcolor = kws_arrow.pop("color", "0.5")
3203
+ arrowlinewidth = kws_arrow.pop("lw", 0.5)
3204
+ shrinkA = kws_arrow.pop("shrinkA", 5)
3205
+ shrinkB = kws_arrow.pop("shrinkB", 5)
3206
+ arrowstyle = strcmp(arrowstyle, arrowstyles)[0]
3207
+ adjust_text(
3208
+ texts,
3209
+ expand_text=(1.05, 1.2),
3210
+ arrowprops=dict(
3211
+ arrowstyle=arrowstyle,
3212
+ color=arrowcolor,
3213
+ lw=arrowlinewidth,
3214
+ shrinkA=shrinkA,
3215
+ shrinkB=shrinkB,
3216
+ **kws_arrow,
3217
+ ),
3218
+ )
3219
+
3220
+ figsets(**kws_figsets)
3221
+
3222
+
3223
+ def sns_func_info(dir_save=None):
3224
+ sns_info = {
3225
+ "Functions": [
3226
+ "relplot",
3227
+ "scatterplot",
3228
+ "lineplot",
3229
+ "lmplot",
3230
+ "catplot",
3231
+ "stripplot",
3232
+ "boxplot",
3233
+ "violinplot",
3234
+ "boxenplot",
3235
+ "pointplot",
3236
+ "barplot",
3237
+ "countplot",
3238
+ "displot",
3239
+ "histplot",
3240
+ "kdeplot",
3241
+ "ecdfplot",
3242
+ "rugplot",
3243
+ "regplot",
3244
+ "residplot",
3245
+ "pairplot",
3246
+ "jointplot",
3247
+ "plotting_context",
3248
+ ],
3249
+ "Category": [
3250
+ "relational",
3251
+ "relational",
3252
+ "relational",
3253
+ "relational",
3254
+ "categorical",
3255
+ "categorical",
3256
+ "categorical",
3257
+ "categorical",
3258
+ "categorical",
3259
+ "categorical",
3260
+ "categorical",
3261
+ "categorical",
3262
+ "distribution",
3263
+ "distribution",
3264
+ "distribution",
3265
+ "distribution",
3266
+ "distribution",
3267
+ "regression",
3268
+ "regression",
3269
+ "grid-based(fig)",
3270
+ "grid-based(fig)",
3271
+ "context",
3272
+ ],
3273
+ "Detail": [
3274
+ "A figure-level function for creating scatter plots and line plots. It combines the functionality of scatterplot and lineplot.",
3275
+ "A function for creating scatter plots, useful for visualizing the relationship between two continuous variables.",
3276
+ "A function for drawing line plots, often used to visualize trends over time or ordered categories.",
3277
+ "A figure-level function for creating linear model plots, combining regression lines with scatter plots.",
3278
+ "A figure-level function for creating categorical plots, which can display various types of plots like box plots, violin plots, and bar plots in one function.",
3279
+ "A function for creating a scatter plot where one of the variables is categorical, helping visualize distribution along a categorical axis.",
3280
+ "A function for creating box plots, which summarize the distribution of a continuous variable based on a categorical variable.",
3281
+ "A function for creating violin plots, which combine box plots and KDEs to visualize the distribution of data.",
3282
+ "A function for creating boxen plots, an enhanced version of box plots that better represent data distributions with more quantiles.",
3283
+ "A function for creating point plots, which show the mean (or another estimator) of a variable for each level of a categorical variable.",
3284
+ "A function for creating bar plots, which represent the mean (or other estimators) of a variable with bars, typically used with categorical data.",
3285
+ "A function for creating count plots, which show the counts of observations in each categorical bin.",
3286
+ "A figure-level function that creates distribution plots. It can visualize histograms, KDEs, and ECDFs, making it versatile for analyzing the distribution of data.",
3287
+ "A function for creating histograms, useful for showing the frequency distribution of a continuous variable.",
3288
+ "A function for creating kernel density estimate (KDE) plots, which visualize the probability density function of a continuous variable.",
3289
+ "A function for creating empirical cumulative distribution function (ECDF) plots, which show the proportion of observations below a certain value.",
3290
+ "A function that adds a rug plot to the axes, representing individual data points along an axis.",
3291
+ "A function for creating regression plots, which fit and visualize a regression model on scatter data.",
3292
+ "A function for creating residual plots, useful for diagnosing the fit of a regression model.",
3293
+ "A figure-level function that creates a grid of scatter plots for each pair of variables in a dataset, often used for exploratory data analysis.",
3294
+ "A figure-level function that combines scatter plots and histograms (or KDEs) to visualize the relationship between two variables and their distributions.",
3295
+ "Not a plot itself, but a function that allows you to change the context (style and scaling) of your plots to fit different publication requirements or visual preferences.",
3296
+ ],
3297
+ }
3298
+ if dir_save is None:
3299
+ if "mac" in get_os():
3300
+ dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
3301
+ else:
3302
+ dir_save = "Z:\\Jianfeng\\temp\\"
3303
+ dir_save += "/" if not dir_save.endswith("/") else ""
3304
+ fsave(
3305
+ dir_save + "sns_info.json",
3306
+ sns_info,
3307
+ )