py2ls 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -9,24 +9,32 @@ import matplotlib.ticker as tck
9
9
  from cycler import cycler
10
10
  import logging
11
11
  import os
12
+ import re
12
13
 
13
- from .ips import fsave, fload, mkdir, listdir, figsave
14
+ from .ips import fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit
14
15
  from .stats import *
16
+ from .netfinder import get_soup, fetch
17
+
15
18
 
16
19
  # Suppress INFO messages from fontTools
17
20
  logging.getLogger("fontTools").setLevel(logging.WARNING)
18
21
 
19
22
 
20
- def df_corr(
21
- df,
22
- columns="all",
23
+ def heatmap(
24
+ data,
25
+ ax=None,
26
+ kind="corr", #'corr','direct','pivot'
27
+ columns="all", # pivot, default: coll numeric columns
28
+ index=None, # pivot
29
+ values=None, # pivot
23
30
  tri="u",
24
31
  mask=True,
25
32
  k=1,
26
33
  annot=True,
27
34
  cmap="coolwarm",
28
35
  fmt=".2f",
29
- cluster=False, # New parameter for clustermap option
36
+ cluster=False,
37
+ inplace=False,
30
38
  figsize=(10, 8),
31
39
  row_cluster=True, # Perform clustering on rows
32
40
  col_cluster=True, # Perform clustering on columns
@@ -36,24 +44,115 @@ def df_corr(
36
44
  yticklabels=True, # Show row labels
37
45
  **kwargs,
38
46
  ):
47
+ if ax is None and not cluster:
48
+ ax = plt.gca()
39
49
  # Select numeric columns or specific subset of columns
40
50
  if columns == "all":
41
- df_numeric = df.select_dtypes(include=[float, int])
51
+ df_numeric = data.select_dtypes(include=[float, int])
42
52
  else:
43
- df_numeric = df[columns]
44
-
45
- # Compute the correlation matrix
46
- correlation_matrix = df_numeric.corr()
53
+ df_numeric = data[columns]
54
+
55
+ kinds = ["corr", "direct", "pivot"]
56
+ kind = strcmp(kind, kinds)[0]
57
+ if kind == "corr":
58
+ # Compute the correlation matrix
59
+ data4heatmap = df_numeric.corr()
60
+ # Generate mask for the upper triangle if mask is True
61
+ if mask:
62
+ if "u" in tri.lower(): # upper => np.tril
63
+ mask_array = np.tril(np.ones_like(data4heatmap, dtype=bool), k=k)
64
+ else: # lower => np.triu
65
+ mask_array = np.triu(np.ones_like(data4heatmap, dtype=bool), k=k)
66
+ else:
67
+ mask_array = None
68
+
69
+ # Remove conflicting kwargs
70
+ kwargs.pop("mask", None)
71
+ kwargs.pop("annot", None)
72
+ kwargs.pop("cmap", None)
73
+ kwargs.pop("fmt", None)
74
+
75
+ kwargs.pop("clustermap", None)
76
+ kwargs.pop("row_cluster", None)
77
+ kwargs.pop("col_cluster", None)
78
+ kwargs.pop("dendrogram_ratio", None)
79
+ kwargs.pop("cbar_pos", None)
80
+ kwargs.pop("xticklabels", None)
81
+ kwargs.pop("col_cluster", None)
82
+
83
+ # Plot the heatmap or clustermap
84
+ if cluster:
85
+ # Create a clustermap
86
+ cluster_obj = sns.clustermap(
87
+ data4heatmap,
88
+ # ax=ax,
89
+ mask=mask_array,
90
+ annot=annot,
91
+ cmap=cmap,
92
+ fmt=fmt,
93
+ figsize=figsize, # Figure size, adjusted for professional display
94
+ row_cluster=row_cluster, # Perform clustering on rows
95
+ col_cluster=col_cluster, # Perform clustering on columns
96
+ dendrogram_ratio=dendrogram_ratio, # Adjust size of dendrograms
97
+ cbar_pos=cbar_pos, # Adjust colorbar position
98
+ xticklabels=xticklabels, # Show column labels
99
+ yticklabels=yticklabels, # Show row labels
100
+ **kwargs, # Pass any additional arguments to sns.clustermap
101
+ )
102
+ df_row_cluster = pd.DataFrame()
103
+ df_col_cluster = pd.DataFrame()
104
+ if row_cluster:
105
+ from scipy.cluster.hierarchy import linkage, fcluster
106
+ from scipy.spatial.distance import pdist
107
+
108
+ # Compute pairwise distances
109
+ distances = pdist(data, metric="euclidean")
110
+ # Perform hierarchical clustering
111
+ linkage_matrix = linkage(distances, method="average")
112
+ # Get cluster assignments based on the distance threshold
113
+ row_clusters_value = fcluster(
114
+ linkage_matrix, t=1.5, criterion="distance"
115
+ )
116
+ df_row_cluster["row_cluster"] = row_clusters_value
117
+ if col_cluster:
118
+ col_distances = pdist(
119
+ data4heatmap.T, metric="euclidean"
120
+ ) # Transpose for column clustering
121
+ col_linkage_matrix = linkage(col_distances, method="average")
122
+ col_clusters_value = fcluster(
123
+ col_linkage_matrix, t=1.5, criterion="distance"
124
+ )
125
+ df_col_cluster = pd.DataFrame(
126
+ {"Cluster": col_clusters_value}, index=data4heatmap.columns
127
+ )
47
128
 
48
- # Generate mask for the upper triangle if mask is True
49
- if mask:
50
- if "u" in tri.lower(): # upper => np.tril
51
- mask_array = np.tril(np.ones_like(correlation_matrix, dtype=bool), k=k)
52
- else: # lower => np.triu
53
- mask_array = np.triu(np.ones_like(correlation_matrix, dtype=bool), k=k)
129
+ return (
130
+ cluster_obj.ax_row_dendrogram,
131
+ cluster_obj.ax_col_dendrogram,
132
+ cluster_obj.ax_heatmap,
133
+ df_row_cluster,
134
+ df_col_cluster,
135
+ )
136
+ else:
137
+ # Create a standard heatmap
138
+ ax = sns.heatmap(
139
+ data4heatmap,
140
+ ax=ax,
141
+ mask=mask_array,
142
+ annot=annot,
143
+ cmap=cmap,
144
+ fmt=fmt,
145
+ **kwargs, # Pass any additional arguments to sns.heatmap
146
+ )
147
+ # Return the Axes object for further customization if needed
148
+ return ax
149
+ elif kind == "direct":
150
+ data4heatmap = df_numeric
151
+ elif kind == "pivot":
152
+ print('need 3 param: e.g., index="Task", columns="Model", values="Score"')
153
+ data4heatmap = data.pivot(index=index, columns=columns, values=values)
54
154
  else:
55
- mask_array = None
56
-
155
+ print(f'"{kind}" is not supported')
57
156
  # Remove conflicting kwargs
58
157
  kwargs.pop("mask", None)
59
158
  kwargs.pop("annot", None)
@@ -72,8 +171,9 @@ def df_corr(
72
171
  if cluster:
73
172
  # Create a clustermap
74
173
  cluster_obj = sns.clustermap(
75
- correlation_matrix,
76
- mask=mask_array,
174
+ data4heatmap,
175
+ # ax=ax,
176
+ # mask=mask_array,
77
177
  annot=annot,
78
178
  cmap=cmap,
79
179
  fmt=fmt,
@@ -86,18 +186,43 @@ def df_corr(
86
186
  yticklabels=yticklabels, # Show row labels
87
187
  **kwargs, # Pass any additional arguments to sns.clustermap
88
188
  )
189
+ df_row_cluster = pd.DataFrame()
190
+ df_col_cluster = pd.DataFrame()
191
+ if row_cluster:
192
+ from scipy.cluster.hierarchy import linkage, fcluster
193
+ from scipy.spatial.distance import pdist
194
+
195
+ # Compute pairwise distances
196
+ distances = pdist(data, metric="euclidean")
197
+ # Perform hierarchical clustering
198
+ linkage_matrix = linkage(distances, method="average")
199
+ # Get cluster assignments based on the distance threshold
200
+ row_clusters_value = fcluster(linkage_matrix, t=1.5, criterion="distance")
201
+ df_row_cluster["row_cluster"] = row_clusters_value
202
+ if col_cluster:
203
+ col_distances = pdist(
204
+ data4heatmap.T, metric="euclidean"
205
+ ) # Transpose for column clustering
206
+ col_linkage_matrix = linkage(col_distances, method="average")
207
+ col_clusters_value = fcluster(
208
+ col_linkage_matrix, t=1.5, criterion="distance"
209
+ )
210
+ df_col_cluster = pd.DataFrame(
211
+ {"Cluster": col_clusters_value}, index=data4heatmap.columns
212
+ )
89
213
 
90
214
  return (
91
215
  cluster_obj.ax_row_dendrogram,
92
216
  cluster_obj.ax_col_dendrogram,
93
217
  cluster_obj.ax_heatmap,
218
+ df_row_cluster,
219
+ df_col_cluster,
94
220
  )
95
221
  else:
96
222
  # Create a standard heatmap
97
- plt.figure(figsize=figsize)
98
223
  ax = sns.heatmap(
99
- correlation_matrix,
100
- mask=mask_array,
224
+ data4heatmap,
225
+ ax=ax,
101
226
  annot=annot,
102
227
  cmap=cmap,
103
228
  fmt=fmt,
@@ -107,6 +232,60 @@ def df_corr(
107
232
  return ax
108
233
 
109
234
 
235
+ # !usage: py2ls.plot.heatmap()
236
+ # penguins_clean = penguins.replace([np.inf, -np.inf], np.nan).dropna()
237
+ # from py2ls import plot
238
+
239
+ # _, axs = plt.subplots(2, 2, figsize=(10, 10))
240
+ # # kind='pivot'
241
+ # plot.heatmap(
242
+ # ax=axs[0][0],
243
+ # data=sns.load_dataset("glue"),
244
+ # kind="pi",
245
+ # index="Model",
246
+ # columns="Task",
247
+ # values="Score",
248
+ # fmt=".1f",
249
+ # cbar_kws=dict(shrink=1),
250
+ # annot_kws=dict(size=7),
251
+ # )
252
+ # # kind='direct'
253
+ # plot.heatmap(
254
+ # ax=axs[0][1],
255
+ # data=sns.load_dataset("penguins").iloc[:10, 2:6],
256
+ # kind="direct",
257
+ # tri="lower",
258
+ # fmt=".1f",
259
+ # k=1,
260
+ # cbar_kws=dict(shrink=1),
261
+ # annot_kws=dict(size=7),
262
+ # )
263
+
264
+ # # kind='corr'
265
+ # plot.heatmap(
266
+ # ax=axs[1][0],
267
+ # data=sns.load_dataset("penguins"),
268
+ # kind="corr",
269
+ # fmt=".1f",
270
+ # k=-1,
271
+ # cbar_kws=dict(shrink=1),
272
+ # annot_kws=dict(size=7),
273
+ # )
274
+ # # kind='corr'
275
+ # plot.heatmap(
276
+ # ax=axs[1][1],
277
+ # data=penguins_clean.iloc[:15, :10],
278
+ # kind="direct",
279
+ # tri="lower",
280
+ # fmt=".1f",
281
+ # k=1,
282
+ # annot=False,
283
+ # cluster=True,
284
+ # cbar_kws=dict(shrink=1),
285
+ # annot_kws=dict(size=7),
286
+ # )
287
+
288
+
110
289
  def catplot(data, *args, **kwargs):
111
290
  """
112
291
  catplot(data, opt=None, ax=None)
@@ -1524,6 +1703,10 @@ def figsets(*args, **kwargs):
1524
1703
  alignment='left')
1525
1704
  )
1526
1705
  """
1706
+ import matplotlib
1707
+
1708
+ matplotlib.rc("text", usetex=False)
1709
+
1527
1710
  fig = plt.gcf()
1528
1711
  fontsize = 11
1529
1712
  fontname = "Arial"
@@ -1615,6 +1798,16 @@ def figsets(*args, **kwargs):
1615
1798
  if isinstance(value, list):
1616
1799
  loc = []
1617
1800
  for i in value:
1801
+ ax.tick_params(
1802
+ axis="both",
1803
+ which="both",
1804
+ bottom=False,
1805
+ top=False,
1806
+ left=False,
1807
+ right=False,
1808
+ labelbottom=False,
1809
+ labelleft=False,
1810
+ )
1618
1811
  if ("l" in i.lower()) and ("a" not in i.lower()):
1619
1812
  ax.yaxis.set_ticks_position("left")
1620
1813
  if "r" in i.lower():
@@ -1624,12 +1817,38 @@ def figsets(*args, **kwargs):
1624
1817
  if "b" in i.lower():
1625
1818
  ax.xaxis.set_ticks_position("bottom")
1626
1819
  if i.lower() in ["a", "both", "all", "al", ":"]:
1627
- ax.xaxis.set_ticks_position("both")
1628
- ax.yaxis.set_ticks_position("both")
1820
+ ax.tick_params(
1821
+ axis="both", # Apply to both axes
1822
+ which="both", # Apply to both major and minor ticks
1823
+ bottom=True, # Show ticks at the bottom
1824
+ top=True, # Show ticks at the top
1825
+ left=True, # Show ticks on the left
1826
+ right=True, # Show ticks on the right
1827
+ labelbottom=True, # Show labels at the bottom
1828
+ labelleft=True, # Show labels on the left
1829
+ )
1629
1830
  if i.lower() in ["xnone", "xoff", "none"]:
1630
- ax.xaxis.set_ticks_position("none")
1831
+ ax.tick_params(
1832
+ axis="x",
1833
+ which="both",
1834
+ bottom=False,
1835
+ top=False,
1836
+ left=False,
1837
+ right=False,
1838
+ labelbottom=False,
1839
+ labelleft=False,
1840
+ )
1631
1841
  if i.lower() in ["ynone", "yoff", "none"]:
1632
- ax.yaxis.set_ticks_position("none")
1842
+ ax.tick_params(
1843
+ axis="y",
1844
+ which="both",
1845
+ bottom=False,
1846
+ top=False,
1847
+ left=False,
1848
+ right=False,
1849
+ labelbottom=False,
1850
+ labelleft=False,
1851
+ )
1633
1852
  # ticks / labels
1634
1853
  elif "x" in key.lower():
1635
1854
  if value is None:
@@ -1674,6 +1893,10 @@ def figsets(*args, **kwargs):
1674
1893
 
1675
1894
  if "bo" in key in key: # box setting, and ("p" in key or "l" in key):
1676
1895
  if isinstance(value, (str, list)):
1896
+ # locations = ["left", "right", "top", "bottom"]
1897
+ # for loc, spi in ax.spines.items():
1898
+ # if loc in locations:
1899
+ # spi.set_color("none") # no spine
1677
1900
  locations = []
1678
1901
  for i in value:
1679
1902
  if "l" in i.lower() and not "t" in i.lower():
@@ -1689,12 +1912,12 @@ def figsets(*args, **kwargs):
1689
1912
  locations.append(x)
1690
1913
  for x in ["left", "right", "top", "bottom"]
1691
1914
  ]
1692
- for i in value:
1693
- if i.lower() in "none":
1694
- locations = []
1915
+ if "none" in value:
1916
+ locations = [] # hide all
1695
1917
  # check spines
1696
1918
  for loc, spi in ax.spines.items():
1697
1919
  if loc in locations:
1920
+ # spi.set_color("k")
1698
1921
  spi.set_position(("outward", 0))
1699
1922
  else:
1700
1923
  spi.set_color("none") # no spine
@@ -2527,3 +2750,496 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, usage=F
2527
2750
  plt.tight_layout()
2528
2751
  if show:
2529
2752
  plt.show()
2753
+
2754
+
2755
+ def get_params_from_func_usage(function_signature):
2756
+ # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2757
+ keys_pattern = r"(?<!\*\*)\b(\w+)="
2758
+ # Find all matches
2759
+ matches = re.findall(keys_pattern, function_signature)
2760
+ return matches
2761
+
2762
+
2763
+ def plot_xy(
2764
+ data: pd.DataFrame = None,
2765
+ x=None,
2766
+ y=None,
2767
+ ax=None,
2768
+ kind: str = None, # Specify the kind of plot
2769
+ usage=False,
2770
+ # kws_figsets:dict=None,
2771
+ **kwargs,
2772
+ ):
2773
+ """
2774
+ e.g., plot_xy(data=data_log, x="Component_1", y="Component_2", hue="Cluster",kind='scater)
2775
+ Create a variety of plots based on the kind parameter.
2776
+
2777
+ Parameters:
2778
+ data (pd.DataFrame): DataFrame containing the data.
2779
+ x (str): Column name for the x-axis.
2780
+ y (str): Column name for the y-axis.
2781
+ hue (str): Column name for the hue (color) grouping.
2782
+ ax: Matplotlib axes object for the plot.
2783
+ kind (str): Type of plot ('scatter', 'line', 'displot', 'kdeplot', etc.).
2784
+ usage (bool): If True, print default settings instead of plotting.
2785
+ **kwargs: Additional keyword arguments for the plot functions.
2786
+
2787
+ Returns:
2788
+ ax or FacetGrid: Matplotlib axes object or FacetGrid for displot.
2789
+ """
2790
+ # Check for valid plot kind
2791
+ # Default arguments for various plot types
2792
+ default_settings = fload(
2793
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_sns.json"
2794
+ )
2795
+ sns_info = pd.DataFrame(
2796
+ fload(
2797
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/sns_info.json",
2798
+ )
2799
+ )
2800
+ valid_kinds = list(default_settings.keys())
2801
+ print(valid_kinds)
2802
+ if kind is not None:
2803
+ if isinstance(kind, str):
2804
+ kind = [kind]
2805
+ kind = [strcmp(i, valid_kinds)[0] for i in kind]
2806
+ else:
2807
+ usage = True
2808
+ if usage:
2809
+ if kind is not None:
2810
+ for k in kind:
2811
+ if k in valid_kinds:
2812
+ print(f"{k}:\n\t{default_settings[k]}")
2813
+ print(
2814
+ sns_info[sns_info["Functions"].str.contains(k)]
2815
+ .iloc[:, -1]
2816
+ .tolist()[0]
2817
+ )
2818
+ print()
2819
+ usage_str = """plot_xy(data=ranked_genes,
2820
+ x="log2(fold_change)",
2821
+ y="-log10(p-value)",
2822
+ palette=get_color(3, cmap="coolwarm"),
2823
+ kind=["scatter","rug"],
2824
+ kws_rug=dict(height=0.2),
2825
+ kws_scatter=dict(s=20, color=get_color(3)[2]),
2826
+ usage=0)
2827
+ """
2828
+ print(f"currently support to plot:\n{valid_kinds}\n\nusage:\n{usage_str}")
2829
+ return # Do not plot, just print the usage
2830
+
2831
+ kws_figsets = {}
2832
+ for k_arg, v_arg in kwargs.items():
2833
+ if "figset" in k_arg:
2834
+ kws_figsets = v_arg
2835
+ kwargs.pop(k_arg, None)
2836
+ break
2837
+
2838
+ for k in kind:
2839
+ # indicate 'col' features
2840
+ col = kwargs.get("col", None)
2841
+ sns_with_col = [
2842
+ "catplot",
2843
+ "histplot",
2844
+ "relplot",
2845
+ "lmplot",
2846
+ "pairplot",
2847
+ "displot",
2848
+ "kdeplot",
2849
+ ]
2850
+ if col is not None:
2851
+ if not k in sns_with_col:
2852
+ print(
2853
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
2854
+ )
2855
+ # (1) return FcetGrid
2856
+ if k == "jointplot":
2857
+ kws_joint = kwargs.pop("kws_joint", kwargs)
2858
+ g = sns.jointplot(data=data, x=x, y=y, **kws_joint)
2859
+ elif k == "lmplot":
2860
+ kws_lm = kwargs.pop("kws_lm", kwargs)
2861
+ g = sns.lmplot(data=data, x=x, y=y, **kws_lm)
2862
+ elif k == "catplot_sns":
2863
+ kws_cat = kwargs.pop("kws_cat", kwargs)
2864
+ g = sns.catplot(data=data, x=x, y=y, **kws_cat)
2865
+ elif k == "displot":
2866
+ kws_dis = kwargs.pop("kws_dis", kwargs)
2867
+ # displot creates a new figure and returns a FacetGrid
2868
+ g = sns.displot(data=data, x=x, **kws_dis)
2869
+
2870
+ # (2) return axis
2871
+ if ax is None:
2872
+ ax = plt.gca()
2873
+
2874
+ if k == "catplot":
2875
+ kws_cat = kwargs.pop("kws_cat", kwargs)
2876
+ g = catplot(data=data, x=x, y=y, ax=ax, **kws_cat)
2877
+ elif k == "stdshade":
2878
+ kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
2879
+ ax = stdshade(ax=ax, **kwargs)
2880
+ elif k == "scatterplot":
2881
+ kws_scatter = kwargs.pop("kws_scatter", kwargs)
2882
+ palette = kws_scatter.pop(
2883
+ "palette",
2884
+ (
2885
+ sns.color_palette("tab10", data[hue].nunique())
2886
+ if hue is not None
2887
+ else sns.color_palette("tab10")
2888
+ ),
2889
+ )
2890
+ s = kws_scatter.pop("s", 10)
2891
+ alpha = kws_scatter.pop("alpha", 0.7)
2892
+ ax = sns.scatterplot(
2893
+ ax=ax,
2894
+ data=data,
2895
+ x=x,
2896
+ y=y,
2897
+ hue=hue,
2898
+ palette=palette,
2899
+ s=s,
2900
+ alpha=alpha,
2901
+ **kws_scatter,
2902
+ )
2903
+ elif k == "histplot":
2904
+ kws_hist = kwargs.pop("kws_hist", kwargs)
2905
+ ax = sns.histplot(data=data, x=x, ax=ax, **kws_hist)
2906
+ elif k == "kdeplot":
2907
+ kws_kde = kwargs.pop("kws_kde", kwargs)
2908
+ ax = sns.kdeplot(data=data, x=x, ax=ax, **kws_kde)
2909
+ elif k == "ecdfplot":
2910
+ kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
2911
+ ax = sns.ecdfplot(data=data, x=x, ax=ax, **kws_ecdf)
2912
+ elif k == "rugplot":
2913
+ kws_rug = kwargs.pop("kws_rug", kwargs)
2914
+ print(kws_rug)
2915
+ ax = sns.rugplot(data=data, x=x, ax=ax, **kws_rug)
2916
+ elif k == "stripplot":
2917
+ kws_strip = kwargs.pop("kws_strip", kwargs)
2918
+ ax = sns.stripplot(data=data, x=x, y=y, ax=ax, **kws_strip)
2919
+ elif k == "swarmplot":
2920
+ kws_swarm = kwargs.pop("kws_swarm", kwargs)
2921
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, **kws_swarm)
2922
+ elif k == "boxplot":
2923
+ kws_box = kwargs.pop("kws_box", kwargs)
2924
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax, **kws_box)
2925
+ elif k == "violinplot":
2926
+ kws_violin = kwargs.pop("kws_violin", kwargs)
2927
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax, **kws_violin)
2928
+ elif k == "boxenplot":
2929
+ kws_boxen = kwargs.pop("kws_boxen", kwargs)
2930
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, **kws_boxen)
2931
+ elif k == "pointplot":
2932
+ kws_point = kwargs.pop("kws_point", kwargs)
2933
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax, **kws_point)
2934
+ elif k == "barplot":
2935
+ kws_bar = kwargs.pop("kws_bar", kwargs)
2936
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax, **kws_bar)
2937
+ elif k == "countplot":
2938
+ kws_count = kwargs.pop("kws_count", kwargs)
2939
+ ax = sns.countplot(data=data, x=x, ax=ax, **kws_count)
2940
+ elif k == "regplot":
2941
+ kws_reg = kwargs.pop("kws_reg", kwargs)
2942
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
2943
+ elif k == "residplot":
2944
+ kws_resid = kwargs.pop("kws_resid", kwargs)
2945
+ ax = sns.residplot(data=data, x=x, y=y, lowess=True, ax=ax, **kws_resid)
2946
+ elif k == "lineplot":
2947
+ kws_line = kwargs.pop("kws_line", kwargs)
2948
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, **kws_line)
2949
+
2950
+ figsets(**kws_figsets)
2951
+ print(k, " ⤵ ")
2952
+ print(default_settings[k])
2953
+ print(
2954
+ "=>\t",
2955
+ sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],
2956
+ )
2957
+ print()
2958
+ if "g" in locals():
2959
+ if ax is not None:
2960
+ return g, ax
2961
+ return ax
2962
+
2963
+
2964
+ def volcano(
2965
+ data,
2966
+ x,
2967
+ y,
2968
+ gene_col=None,
2969
+ top_genes=5,
2970
+ thr_x=np.log2(1.5),
2971
+ thr_y=-np.log10(0.05),
2972
+ colors=("#e70b0b", "#0d26e3", "#b8bbbe"),
2973
+ s=20,
2974
+ fill=True, # plot filled scatter
2975
+ facecolor="none",
2976
+ edgecolor="none",
2977
+ edgelinewidth=0.5,
2978
+ alpha=0.8,
2979
+ legend=False,
2980
+ ax=None,
2981
+ usage=False,
2982
+ kws_arrow=None,
2983
+ kws_text=None,
2984
+ **kwargs,
2985
+ ):
2986
+ """
2987
+ Generates a customizable scatter plot (e.g., volcano plot).
2988
+
2989
+ Parameters:
2990
+ -----------
2991
+ data : pd.DataFrame
2992
+ The DataFrame containing the data to plot.
2993
+ x : str
2994
+ Column name for x-axis values (e.g., log2FoldChange).
2995
+ y : str
2996
+ Column name for y-axis values (e.g., -log10(FDR)).
2997
+ gene_col : str, optional
2998
+ Column name for gene names. If provided, gene names will be displayed. Default is None.
2999
+ top_genes : int, optional
3000
+ Number of top genes to label based on y-axis values. Default is 5.
3001
+ thr_x : float, optional
3002
+ Threshold for x-axis values. Default is 0.585.
3003
+ thr_y : float, optional
3004
+ Threshold for y-axis values (e.g., significance threshold). Default is -np.log10(0.05).
3005
+ colors : tuple, optional
3006
+ Colors for points above/below thresholds and neutral points. Default is ("red", "blue", "gray").
3007
+ figsize : tuple, optional
3008
+ Figure size. Default is (6, 4).
3009
+ s : int, optional
3010
+ Size of points in the plot. Default is 20.
3011
+ fontsize : int, optional
3012
+ Font size for gene labels. Default is 10.
3013
+ alpha : float, optional
3014
+ Transparency of the points. Default is 0.8.
3015
+ legend : bool, optional
3016
+ Whether to show a legend. Default is False.
3017
+ """
3018
+ usage_str = """
3019
+ _, axs = plt.subplots(1, 1, figsize=(4, 5))
3020
+ volcano(
3021
+ ax=axs,
3022
+ data=ranked_genes,
3023
+ x="log2(fold_change)",
3024
+ y="-log10(p-value)",
3025
+ gene_col="ID_REF",
3026
+ top_genes=6,
3027
+ thr_x=np.log2(1.2),
3028
+ # thr_y=-np.log10(0.05),
3029
+ colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3030
+ fill=0,
3031
+ alpha=1,
3032
+ facecolor="none",
3033
+ s=20,
3034
+ edgelinewidth=0.5,
3035
+ edgecolor="0.5",
3036
+ kws_text=dict(fontsize=10, color="k"),
3037
+ kws_arrow=dict(style="-", color="k", lw=0.5),
3038
+ # usage=True,
3039
+ figsets=dict(ylim=[0, 10], title="df"),
3040
+ )
3041
+ """
3042
+ if usage:
3043
+ print(usage_str)
3044
+ return
3045
+ from adjustText import adjust_text
3046
+
3047
+ kws_figsets = {}
3048
+ for k_arg, v_arg in kwargs.items():
3049
+ if "figset" in k_arg:
3050
+ kws_figsets = v_arg
3051
+ kwargs.pop(k_arg, None)
3052
+ break
3053
+ # Color-coding based on thresholds using np.where
3054
+ data["color"] = np.where(
3055
+ (data[x] > thr_x) & (data[y] > thr_y),
3056
+ colors[2],
3057
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3058
+ )
3059
+
3060
+ # Selecting top significant points for labeling
3061
+ sele_gene = (
3062
+ data.query("color != @colors[2]") # Exclude gray points
3063
+ .groupby("color", axis=0)
3064
+ .apply(lambda x: x.sort_values(y, ascending=False).head(top_genes))
3065
+ .droplevel(level=0)
3066
+ )
3067
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3068
+ # Plot setup
3069
+ if ax is None:
3070
+ ax = plt.gca()
3071
+
3072
+ # Handle fill parameter
3073
+ if fill:
3074
+ facecolors = data["color"] # Fill with colors
3075
+ edgecolors = edgecolor # Set edgecolor
3076
+ else:
3077
+ facecolors = facecolor # No fill, use edge color as the face color
3078
+ edgecolors = data["color"]
3079
+
3080
+ ax = sns.scatterplot(
3081
+ ax=ax,
3082
+ data=data,
3083
+ x=x,
3084
+ y=y,
3085
+ # hue="color",
3086
+ palette=palette,
3087
+ s=s,
3088
+ linewidths=edgelinewidth,
3089
+ color=facecolors,
3090
+ edgecolor=edgecolors,
3091
+ alpha=alpha,
3092
+ legend=legend,
3093
+ **kwargs,
3094
+ )
3095
+
3096
+ # Add threshold lines for x and y axes
3097
+ plt.axhline(y=thr_y, color="black", linestyle="--")
3098
+ plt.axvline(x=-thr_x, color="black", linestyle="--")
3099
+ plt.axvline(x=thr_x, color="black", linestyle="--")
3100
+
3101
+ # Add gene labels for selected significant points
3102
+ if gene_col:
3103
+ texts = []
3104
+ if kws_text:
3105
+ fontname = kws_text.pop("fontname", "Arial")
3106
+ textcolor = kws_text.pop("color", "k")
3107
+ fontsize = kws_text.pop("fontsize", 10)
3108
+ for i in range(sele_gene.shape[0]):
3109
+ if isinstance(textcolor, list): # be consistant with dots's color
3110
+ textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
3111
+ texts.append(
3112
+ plt.text(
3113
+ x=sele_gene[x].iloc[i],
3114
+ y=sele_gene[y].iloc[i],
3115
+ s=sele_gene[gene_col].iloc[i],
3116
+ fontdict={
3117
+ "fontsize": fontsize,
3118
+ "color": textcolor,
3119
+ "fontname": fontname,
3120
+ },
3121
+ )
3122
+ )
3123
+
3124
+ arrowstyles = [
3125
+ "-",
3126
+ "->",
3127
+ "-[",
3128
+ "|->",
3129
+ "<-",
3130
+ "<->",
3131
+ "<|-",
3132
+ "<|-|>",
3133
+ "-|>",
3134
+ "-[ ",
3135
+ "fancy",
3136
+ "simple",
3137
+ "wedge",
3138
+ ]
3139
+ arrowstyle = kws_arrow.pop("style", "-")
3140
+ arrowcolor = kws_arrow.pop("color", "0.5")
3141
+ arrowlinewidth = kws_arrow.pop("lw", 0.5)
3142
+ shrinkA = kws_arrow.pop("shrinkA", 5)
3143
+ shrinkB = kws_arrow.pop("shrinkB", 5)
3144
+ arrowstyle = strcmp(arrowstyle, arrowstyles)[0]
3145
+ adjust_text(
3146
+ texts,
3147
+ expand_text=(1.05, 1.2),
3148
+ arrowprops=dict(
3149
+ arrowstyle=arrowstyle,
3150
+ color=arrowcolor,
3151
+ lw=arrowlinewidth,
3152
+ shrinkA=shrinkA,
3153
+ shrinkB=shrinkB,
3154
+ **kws_arrow,
3155
+ ),
3156
+ )
3157
+
3158
+ figsets(**kws_figsets)
3159
+
3160
+
3161
+ def sns_func_info(dir_save=None):
3162
+ sns_info = {
3163
+ "Functions": [
3164
+ "relplot",
3165
+ "scatterplot",
3166
+ "lineplot",
3167
+ "lmplot",
3168
+ "catplot",
3169
+ "stripplot",
3170
+ "boxplot",
3171
+ "violinplot",
3172
+ "boxenplot",
3173
+ "pointplot",
3174
+ "barplot",
3175
+ "countplot",
3176
+ "displot",
3177
+ "histplot",
3178
+ "kdeplot",
3179
+ "ecdfplot",
3180
+ "rugplot",
3181
+ "regplot",
3182
+ "residplot",
3183
+ "pairplot",
3184
+ "jointplot",
3185
+ "plotting_context",
3186
+ ],
3187
+ "Category": [
3188
+ "relational",
3189
+ "relational",
3190
+ "relational",
3191
+ "relational",
3192
+ "categorical",
3193
+ "categorical",
3194
+ "categorical",
3195
+ "categorical",
3196
+ "categorical",
3197
+ "categorical",
3198
+ "categorical",
3199
+ "categorical",
3200
+ "distribution",
3201
+ "distribution",
3202
+ "distribution",
3203
+ "distribution",
3204
+ "distribution",
3205
+ "regression",
3206
+ "regression",
3207
+ "grid-based(fig)",
3208
+ "grid-based(fig)",
3209
+ "context",
3210
+ ],
3211
+ "Detail": [
3212
+ "A figure-level function for creating scatter plots and line plots. It combines the functionality of scatterplot and lineplot.",
3213
+ "A function for creating scatter plots, useful for visualizing the relationship between two continuous variables.",
3214
+ "A function for drawing line plots, often used to visualize trends over time or ordered categories.",
3215
+ "A figure-level function for creating linear model plots, combining regression lines with scatter plots.",
3216
+ "A figure-level function for creating categorical plots, which can display various types of plots like box plots, violin plots, and bar plots in one function.",
3217
+ "A function for creating a scatter plot where one of the variables is categorical, helping visualize distribution along a categorical axis.",
3218
+ "A function for creating box plots, which summarize the distribution of a continuous variable based on a categorical variable.",
3219
+ "A function for creating violin plots, which combine box plots and KDEs to visualize the distribution of data.",
3220
+ "A function for creating boxen plots, an enhanced version of box plots that better represent data distributions with more quantiles.",
3221
+ "A function for creating point plots, which show the mean (or another estimator) of a variable for each level of a categorical variable.",
3222
+ "A function for creating bar plots, which represent the mean (or other estimators) of a variable with bars, typically used with categorical data.",
3223
+ "A function for creating count plots, which show the counts of observations in each categorical bin.",
3224
+ "A figure-level function that creates distribution plots. It can visualize histograms, KDEs, and ECDFs, making it versatile for analyzing the distribution of data.",
3225
+ "A function for creating histograms, useful for showing the frequency distribution of a continuous variable.",
3226
+ "A function for creating kernel density estimate (KDE) plots, which visualize the probability density function of a continuous variable.",
3227
+ "A function for creating empirical cumulative distribution function (ECDF) plots, which show the proportion of observations below a certain value.",
3228
+ "A function that adds a rug plot to the axes, representing individual data points along an axis.",
3229
+ "A function for creating regression plots, which fit and visualize a regression model on scatter data.",
3230
+ "A function for creating residual plots, useful for diagnosing the fit of a regression model.",
3231
+ "A figure-level function that creates a grid of scatter plots for each pair of variables in a dataset, often used for exploratory data analysis.",
3232
+ "A figure-level function that combines scatter plots and histograms (or KDEs) to visualize the relationship between two variables and their distributions.",
3233
+ "Not a plot itself, but a function that allows you to change the context (style and scaling) of your plots to fit different publication requirements or visual preferences.",
3234
+ ],
3235
+ }
3236
+ if dir_save is None:
3237
+ if "mac" in get_os():
3238
+ dir_save = "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/"
3239
+ else:
3240
+ dir_save = "Z:\\Jianfeng\\temp\\"
3241
+ dir_save += "/" if not dir_save.endswith("/") else ""
3242
+ fsave(
3243
+ dir_save + "sns_info.json",
3244
+ sns_info,
3245
+ )