py2ls 0.2.4.7__py3-none-any.whl → 0.2.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -1,25 +1,20 @@
1
- import matplotlib.pyplot as plt
2
1
  import numpy as np
3
- import pandas as pd
4
- from matplotlib.colors import to_rgba
5
- from scipy.stats import gaussian_kde
6
- import seaborn as sns
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
7
4
  import matplotlib
8
- import matplotlib.ticker as tck
9
- from cycler import cycler
5
+ import seaborn as sns
6
+ import warnings
10
7
  import logging
11
- import os
12
- import re
13
8
  from typing import Union
14
- from .ips import fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font
9
+ from .ips import isa, fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font,run_once_within
15
10
  from .stats import *
16
- from .netfinder import get_soup, fetch
17
-
11
+ import os
18
12
  # Suppress INFO messages from fontTools
19
13
  logging.getLogger("fontTools").setLevel(logging.ERROR)
20
14
  logging.getLogger('matplotlib').setLevel(logging.ERROR)
21
15
 
22
-
16
+ warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
17
+ warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
23
18
 
24
19
  def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
25
20
  """Adds text annotations for various types of Seaborn and Matplotlib plots.
@@ -452,7 +447,9 @@ def catplot(data, *args, **kwargs):
452
447
  Args:
453
448
  data (array): data matrix
454
449
  """
455
-
450
+ from matplotlib.colors import to_rgba
451
+ import os
452
+
456
453
  def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
457
454
  if "l" in opt_b["loc"]:
458
455
  xloc_s = xloc - opt_b["x_dist"]
@@ -757,6 +754,7 @@ def catplot(data, *args, **kwargs):
757
754
  label=label[i] if label else None,
758
755
  )
759
756
  else:
757
+ from scipy.stats import gaussian_kde
760
758
  kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
761
759
  min_val, max_val = ys.min(), ys.max()
762
760
  y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
@@ -1814,53 +1812,9 @@ def read_mplstyle(style_file):
1814
1812
  return style_dict
1815
1813
 
1816
1814
 
1817
- def figsets(*args, **kwargs):
1818
- """
1819
- usage:
1820
- figsets(ax=axs[1],
1821
- ylim=[0, 10],
1822
- spine=2,
1823
- xticklabel=['wake','sleep'],
1824
- yticksdddd=np.arange(0,316,60),
1825
- labels_loc=['right','top'],
1826
- ticks=dict(
1827
- ax='x',
1828
- which='minor',
1829
- direction='out',
1830
- width=2,
1831
- length=2,
1832
- c_tick='m',
1833
- pad=5,
1834
- label_size=11),
1835
- grid=dict(which='minor',
1836
- ax='x',
1837
- alpha=.4,
1838
- c='b',
1839
- ls='-.',
1840
- lw=0.75,
1841
- ),
1842
- supertitleddddd=f'sleep druations\n(min)',
1843
- c_spine='r',
1844
- minor_ticks='xy',
1845
- style='paper',
1846
- box=['right','bottom'],
1847
- xrot=-45,
1848
- yangle=20,
1849
- font_sz = 12,
1850
- legend=dict(labels=['group_a','group_b'],
1851
- loc='upper left',
1852
- edgecolor='k',
1853
- facecolor='r',
1854
- title='title',
1855
- fancybox=1,
1856
- shadow=1,
1857
- ncols=4,
1858
- bbox_to_anchor=[-0.5,0.7],
1859
- alignment='left')
1860
- )
1861
- """
1815
+ def figsets(*args, **kwargs):
1862
1816
  import matplotlib
1863
-
1817
+ from cycler import cycler
1864
1818
  matplotlib.rc("text", usetex=False)
1865
1819
 
1866
1820
  fig = plt.gcf()
@@ -1897,7 +1851,16 @@ def figsets(*args, **kwargs):
1897
1851
  nonlocal fontsize, fontname
1898
1852
  if ("fo" in key) and (("size" in key) or ("sz" in key)):
1899
1853
  fontsize = value
1900
- plt.rcParams.update({"font.size": fontsize})
1854
+ plt.rcParams.update({"font.size": fontsize,
1855
+ "figure.titlesize":fontsize,
1856
+ "axes.titlesize":fontsize,
1857
+ "axes.labelsize": fontsize,
1858
+ "xtick.labelsize": fontsize,
1859
+ "ytick.labelsize": fontsize,
1860
+ "legend.fontsize": fontsize,
1861
+ "legend.title_fontsize":fontsize
1862
+ })
1863
+
1901
1864
  # Customize tick labels
1902
1865
  ax.tick_params(axis='both', which='major', labelsize=fontsize)
1903
1866
  for label in ax.get_xticklabels() + ax.get_yticklabels():
@@ -2146,6 +2109,7 @@ def figsets(*args, **kwargs):
2146
2109
  # ])
2147
2110
 
2148
2111
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
2112
+ import matplotlib.ticker as tck
2149
2113
  if "x" in value.lower() or "x" in key.lower():
2150
2114
  ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
2151
2115
  if "y" in value.lower() or "y" in key.lower():
@@ -2255,24 +2219,77 @@ def figsets(*args, **kwargs):
2255
2219
  key = args[ip * 2].lower()
2256
2220
  value = args[ip * 2 + 1]
2257
2221
  set_step_2(ax, key, value)
2258
- colors = [
2259
- "#474747",
2260
- "#FF2C00",
2261
- "#0C5DA5",
2262
- "#845B97",
2263
- "#58BBCC",
2264
- "#FF9500",
2265
- "#D57DBE",
2266
- ]
2222
+
2223
+ colors = get_color(8)
2267
2224
  matplotlib.rcParams["axes.prop_cycle"] = cycler(color=colors)
2268
2225
  if len(fig.get_axes()) > 1:
2269
2226
  plt.tight_layout()
2270
- plt.gcf().align_labels()
2271
2227
 
2228
+ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2229
+ """
2230
+ split_legend(
2231
+ ax,
2232
+ n=2,
2233
+ loc=["upper left", "lower right"],
2234
+ labelcolor="k",
2235
+ fontsize=6,
2236
+ )
2237
+ """
2238
+ # Retrieve all lines and labels from the axis
2239
+ handles, labels = ax.get_legend_handles_labels()
2240
+ num_labels = len(labels)
2241
+
2242
+ # Calculate the number of labels per legend part
2243
+ labels_per_part = (num_labels + n - 1) // n # Round up
2244
+ # Create a list to hold each legend object
2245
+ legends = []
2246
+
2247
+ # Default locations and titles if not specified
2248
+ if loc is None:
2249
+ loc = ['best'] * n
2250
+ if title is None:
2251
+ title = [None] * n
2252
+ if bbox is None:
2253
+ bbox = [None] * n
2254
+
2255
+ # Loop to create each split legend
2256
+ for i in range(n):
2257
+ # Calculate the range of labels for this part
2258
+ start_idx = i * labels_per_part
2259
+ end_idx = min(start_idx + labels_per_part, num_labels)
2260
+
2261
+ # Skip if no labels in this range
2262
+ if start_idx >= end_idx:
2263
+ break
2272
2264
 
2273
- from cycler import cycler
2274
-
2265
+ # Subset handles and labels
2266
+ part_handles = handles[start_idx:end_idx]
2267
+ part_labels = labels[start_idx:end_idx]
2268
+
2269
+ # Create the legend for this part
2270
+ legend = ax.legend(handles=part_handles,
2271
+ labels=part_labels,
2272
+ loc=loc[i],
2273
+ title=title[i],
2274
+ ncol=ncol,
2275
+ bbox_to_anchor=bbox[i],
2276
+ **kwargs)
2277
+
2278
+ # Add the legend to the axis and save it to the list
2279
+ ax.add_artist(legend) if i !=(n-1) else None # the lastone will be added automaticaly
2280
+ legends.append(legend)
2281
+ return legends
2275
2282
 
2283
+ def get_colors(
2284
+ n: int = 1,
2285
+ cmap: str = "auto",
2286
+ by: str = "start",
2287
+ alpha: float = 1.0,
2288
+ output: str = "hue",
2289
+ *args,
2290
+ **kwargs,
2291
+ ):
2292
+ return get_color(n,cmap,alpha,output,*args,**kwargs)
2276
2293
  def get_color(
2277
2294
  n: int = 1,
2278
2295
  cmap: str = "auto",
@@ -2282,6 +2299,7 @@ def get_color(
2282
2299
  *args,
2283
2300
  **kwargs,
2284
2301
  ):
2302
+ from cycler import cycler
2285
2303
  def cmap2hex(cmap_name):
2286
2304
  cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
2287
2305
  colors = [cmap_(i) for i in range(cmap_.N)]
@@ -2329,61 +2347,86 @@ def get_color(
2329
2347
  return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
2330
2348
  else:
2331
2349
  return "#{:02X}{:02X}{:02X}".format(r, g, b)
2332
-
2350
+
2351
+ # sc.pl.palettes.default_20
2352
+ cmap_20= ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61',
2353
+ '#17becf','#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2',
2354
+ '#dbdb8d','#9edae5','#ad494a','#8c6d31']
2355
+ # sc.pl.palettes.zeileis_28
2356
+ cmap_28 = ["#023fa5","#7d87b9","#bec1d4","#d6bcc0","#bb7784","#8e063b","#4a6fe3","#8595e1",
2357
+ "#b5bbe3","#e6afb9","#e07b91","#d33f6a","#11c638","#8dd593","#c6dec7","#ead3c6",
2358
+ "#f0b98d","#ef9708","#0fcfc0","#9cded6","#d5eae7","#f3e1eb","#f6c4e1","#f79cd4",
2359
+ "#7f7f7f","#c7c7c7","#1CE6FF","#336600"]
2333
2360
  if cmap == "gray":
2334
2361
  cmap = "grey"
2362
+ elif cmap=="20":
2363
+ cmap=cmap_20
2364
+ elif cmap=="28":
2365
+ cmap=cmap_28
2335
2366
  # Determine color list based on cmap parameter
2336
- if "aut" in cmap:
2337
- if n == 1:
2338
- colorlist = ["#3A4453"]
2339
- elif n == 2:
2340
- colorlist = ["#3A4453", "#FBAF63"]
2341
- elif n == 3:
2342
- colorlist = ["#3A4453", "#FBAF63", "#299D8F"]
2343
- elif n == 4:
2344
- colorlist = ["#087cf7", "#FBAF63", "#3C898A","#FF2C00"]
2345
- elif n == 5:
2346
- colorlist = ["#459AA9", "#B25E9D", "#4B8C3B","#EF8632","#FF2C00"]
2347
- elif n == 6:
2348
- colorlist = ["#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E","#FF2C00"]
2349
- elif n==7:
2350
- colorlist = [ "#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2351
- else:
2352
- colorlist = ['#FF7F0E','#2CA02C','#D62728','#9467BD','#E377C2','#7F7F7F','#7BB8CC','#06daf2']
2353
-
2354
- by = "start"
2355
- elif any(["cub" in cmap.lower(), "sns" in cmap.lower()]):
2356
- if kwargs:
2357
- colorlist = sns.cubehelix_palette(n, **kwargs)
2358
- else:
2359
- colorlist = sns.cubehelix_palette(
2360
- n, start=0.5, rot=-0.75, light=0.85, dark=0.15, as_cmap=False
2367
+ if isinstance(cmap,str):
2368
+ if "aut" in cmap:
2369
+ if n == 1:
2370
+ colorlist = ["#3A4453"]
2371
+ elif n == 2:
2372
+ colorlist = ["#3A4453", "#FF2C00"]
2373
+ elif n == 3:
2374
+ colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
2375
+ elif n == 4:
2376
+ colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
2377
+ elif n == 5:
2378
+ colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
2379
+ elif n == 6:
2380
+ colorlist = ["#FF2C00","#91bfdb", "#B25E9D", "#4B8C3B","#EF8632", "#24578E"]
2381
+ elif n==7:
2382
+ colorlist = ["#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2383
+ elif n==8:
2384
+ # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
2385
+ # colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
2386
+ colorlist = ["#78BFDA","#D52E6F","#F7D648","#A52D28","#6B9F41","#E18330","#E18B9D","#3C88CC"]
2387
+ elif n==9:
2388
+ colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#d62728','#aa40fc','#e377c2','#51B34F','#17becf']
2389
+ elif n==10:
2390
+ colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#51B34F','#d62728''#aa40fc','#e377c2','#375FD2','#17becf']
2391
+ elif 10<n<=20:
2392
+ colorlist = cmap_20
2393
+ else:
2394
+ colorlist = cmap_28
2395
+ by = "start"
2396
+ elif any(["cub" in cmap.lower(), "sns" in cmap.lower()]):
2397
+ if kwargs:
2398
+ colorlist = sns.cubehelix_palette(n, **kwargs)
2399
+ else:
2400
+ colorlist = sns.cubehelix_palette(
2401
+ n, start=0.5, rot=-0.75, light=0.85, dark=0.15, as_cmap=False
2402
+ )
2403
+ colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2404
+ elif any(["hls" in cmap.lower(), "hsl" in cmap.lower()]):
2405
+ if kwargs:
2406
+ colorlist = sns.hls_palette(n, **kwargs)
2407
+ else:
2408
+ colorlist = sns.hls_palette(n)
2409
+ colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2410
+ elif any(["col" in cmap.lower(), "pal" in cmap.lower()]):
2411
+ palette, desat, as_cmap = None, None, False
2412
+ if kwargs:
2413
+ for k, v in kwargs.items():
2414
+ if "p" in k:
2415
+ palette = v
2416
+ elif "d" in k:
2417
+ desat = v
2418
+ elif "a" in k:
2419
+ as_cmap = v
2420
+ colorlist = sns.color_palette(
2421
+ palette=palette, n_colors=n, desat=desat, as_cmap=as_cmap
2361
2422
  )
2362
- colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2363
- elif any(["hls" in cmap.lower(), "hsl" in cmap.lower()]):
2364
- if kwargs:
2365
- colorlist = sns.hls_palette(n, **kwargs)
2423
+ colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2366
2424
  else:
2367
- colorlist = sns.hls_palette(n)
2368
- colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2369
- elif any(["col" in cmap.lower(), "pal" in cmap.lower()]):
2370
- palette, desat, as_cmap = None, None, False
2371
- if kwargs:
2372
- for k, v in kwargs.items():
2373
- if "p" in k:
2374
- palette = v
2375
- elif "d" in k:
2376
- desat = v
2377
- elif "a" in k:
2378
- as_cmap = v
2379
- colorlist = sns.color_palette(
2380
- palette=palette, n_colors=n, desat=desat, as_cmap=as_cmap
2381
- )
2382
- colorlist = [matplotlib.colors.rgb2hex(color) for color in colorlist]
2383
- else:
2384
- if by == "start":
2385
- by = "linspace"
2386
- colorlist = cmap2hex(cmap)
2425
+ if by == "start":
2426
+ by = "linspace"
2427
+ colorlist = cmap2hex(cmap)
2428
+ elif isinstance(cmap, list):
2429
+ colorlist=cmap
2387
2430
 
2388
2431
  # Determine method for generating color list
2389
2432
  if "st" in by.lower() or "be" in by.lower():
@@ -2403,35 +2446,15 @@ def get_color(
2403
2446
  return hue_list
2404
2447
  else:
2405
2448
  raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
2406
-
2407
- # # Example usage
2408
- # colors = get_color(n=5, cmap="viridis", by="linear", alpha=0.5,output='rgb')
2409
- # print(colors)
2410
-
2411
-
2412
- """
2413
- # n = 7
2414
- # clist = get_color(n, cmap="auto", by="linspace") # get_color(100)
2415
- # plt.figure(figsize=[8, 5], dpi=100)
2416
- # x = np.linspace(0, 2 * np.pi, 50) * 100
2417
- # y = np.sin(x)
2418
- # for i in range(1, n + 1):
2419
- # plt.plot(x, y + i, c=clist[i - 1], lw=5, label=str(i))
2420
- # plt.legend()
2421
- # plt.ylim(-2, 20)
2422
- # figsets(plt.gca(), {"style": "whitegrid"}) """
2423
-
2424
-
2425
- from scipy.signal import savgol_filter
2426
- import numpy as np
2427
- import matplotlib.pyplot as plt
2428
-
2449
+
2429
2450
 
2430
2451
  def stdshade(ax=None, *args, **kwargs):
2431
2452
  """
2432
2453
  usage:
2433
2454
  plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
2434
2455
  """
2456
+ from scipy.signal import savgol_filter
2457
+
2435
2458
  # Separate kws_line and kws_fill if necessary
2436
2459
  kws_line = kwargs.pop("kws_line", {})
2437
2460
  kws_fill = kwargs.pop("kws_fill", {})
@@ -2937,6 +2960,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
2937
2960
 
2938
2961
 
2939
2962
  def get_params_from_func_usage(function_signature):
2963
+ import re
2940
2964
  # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2941
2965
  keys_pattern = r"(?<!\*\*)\b(\w+)="
2942
2966
  # Find all matches
@@ -2963,7 +2987,7 @@ def plotxy(
2963
2987
  x=None,
2964
2988
  y=None,
2965
2989
  ax=None,
2966
- kind: Union[str, list] = None, # Specify the kind of plot
2990
+ kind: Union[str, list] = 'scatter', # Specify the kind of plot
2967
2991
  verbose=False,
2968
2992
  **kwargs,
2969
2993
  ):
@@ -2986,34 +3010,30 @@ def plotxy(
2986
3010
  """
2987
3011
  # Check for valid plot kind
2988
3012
  # Default arguments for various plot types
2989
- default_settings = fload(
2990
- "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/usages_sns.json"
2991
- )
2992
- sns_info = pd.DataFrame(
2993
- fload(
2994
- "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/sns_info.json",
2995
- )
2996
- )
3013
+ from pathlib import Path
3014
+ # Get the current script's directory as a Path object
3015
+ current_directory = Path(__file__).resolve().parent
3016
+
3017
+ if not 'default_settings' in locals():
3018
+ default_settings = fload(current_directory / 'data' / 'usages_sns.json')
3019
+ if not 'sns_info' in locals():
3020
+ sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
+
2997
3022
  valid_kinds = list(default_settings.keys())
2998
- print(valid_kinds)
3023
+ # print(valid_kinds)
2999
3024
  if kind is not None:
3000
3025
  if isinstance(kind, str):
3001
3026
  kind = [kind]
3002
3027
  kind = [strcmp(i, valid_kinds)[0] for i in kind]
3003
3028
  else:
3004
3029
  verbose = True
3030
+
3005
3031
  if verbose:
3006
3032
  if kind is not None:
3007
3033
  for k in kind:
3008
3034
  if k in valid_kinds:
3009
- print(f"{k}:\n\t{default_settings[k]}")
3010
- print(
3011
- sns_info[sns_info["Functions"].str.contains(k)]
3012
- .iloc[:, -1]
3013
- .tolist()[0]
3014
- )
3015
- print()
3016
- usage_str = """plot_xy(data=ranked_genes,
3035
+ print(f"{k}:\n\t{default_settings[k]}")
3036
+ usage_str = """plotxy(data=ranked_genes,
3017
3037
  x="log2(fold_change)",
3018
3038
  y="-log10(p-value)",
3019
3039
  palette=get_color(3, cmap="coolwarm"),
@@ -3037,24 +3057,15 @@ def plotxy(
3037
3057
  kws_add_text = v_arg
3038
3058
  kwargs.pop(k_arg, None)
3039
3059
  break
3040
-
3060
+ zorder=0
3041
3061
  for k in kind:
3062
+ zorder+=1
3042
3063
  # indicate 'col' features
3043
3064
  col = kwargs.get("col", None)
3044
- sns_with_col = [
3045
- "catplot",
3046
- "histplot",
3047
- "relplot",
3048
- "lmplot",
3049
- "pairplot",
3050
- "displot",
3051
- "kdeplot",
3052
- ]
3065
+ sns_with_col = ["catplot","histplot","relplot","lmplot","pairplot","displot","kdeplot"]
3053
3066
  if col is not None:
3054
3067
  if not k in sns_with_col:
3055
- print(
3056
- f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
3057
- )
3068
+ print(f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}")
3058
3069
  # (1) return FcetGrid
3059
3070
  if k == "jointplot":
3060
3071
  kws_joint = kwargs.pop("kws_joint", kwargs)
@@ -3073,7 +3084,6 @@ def plotxy(
3073
3084
  # (2) return axis
3074
3085
  if ax is None:
3075
3086
  ax = plt.gca()
3076
-
3077
3087
  if k == "catplot":
3078
3088
  kws_cat = kwargs.pop("kws_cat", kwargs)
3079
3089
  g = catplot(data=data, x=x, y=y, ax=ax, **kws_cat)
@@ -3082,88 +3092,85 @@ def plotxy(
3082
3092
  ax = stdshade(ax=ax, **kwargs)
3083
3093
  elif k == "scatterplot":
3084
3094
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3095
+ kws_scatter={k: v for k, v in kws_scatter.items() if not k.startswith("kws_")}
3085
3096
  hue = kwargs.pop("hue", None)
3086
- palette = kws_scatter.pop(
3087
- "palette",
3088
- (
3089
- sns.color_palette("tab20", data[hue].nunique())
3090
- if hue is not None
3091
- else sns.color_palette("tab20")
3092
- ),
3093
- )
3097
+ if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3098
+ kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3099
+ palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3094
3100
  s = kws_scatter.pop("s", 10)
3095
3101
  alpha = kws_scatter.pop("alpha", 0.7)
3096
- ax = sns.scatterplot(
3097
- ax=ax,
3098
- data=data,
3099
- x=x,
3100
- y=y,
3101
- hue=hue,
3102
- palette=palette,
3103
- s=s,
3104
- alpha=alpha,
3105
- **kws_scatter,
3106
- )
3102
+ ax = sns.scatterplot(ax=ax,data=data,x=x,y=y,hue=hue,palette=palette,s=s,alpha=alpha,zorder=zorder,**kws_scatter)
3107
3103
  elif k == "histplot":
3108
3104
  kws_hist = kwargs.pop("kws_hist", kwargs)
3109
- ax = sns.histplot(data=data, x=x, ax=ax, **kws_hist)
3105
+ kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3106
+ ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
3110
3107
  elif k == "kdeplot":
3111
3108
  kws_kde = kwargs.pop("kws_kde", kwargs)
3112
- ax = sns.kdeplot(data=data, x=x, ax=ax, **kws_kde)
3109
+ kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3110
+ ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
3113
3111
  elif k == "ecdfplot":
3114
3112
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3115
- ax = sns.ecdfplot(data=data, x=x, ax=ax, **kws_ecdf)
3113
+ kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3114
+ ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
3116
3115
  elif k == "rugplot":
3117
3116
  kws_rug = kwargs.pop("kws_rug", kwargs)
3118
- print(kws_rug)
3119
- ax = sns.rugplot(data=data, x=x, ax=ax, **kws_rug)
3117
+ kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3118
+ ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
3120
3119
  elif k == "stripplot":
3121
3120
  kws_strip = kwargs.pop("kws_strip", kwargs)
3121
+ kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3122
3122
  dodge = kws_strip.pop("dodge", True)
3123
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax, dodge=dodge, **kws_strip)
3123
+ ax = sns.stripplot(data=data, x=x, y=y, ax=ax,zorder=zorder, dodge=dodge, **kws_strip)
3124
3124
  elif k == "swarmplot":
3125
3125
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3126
- ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, **kws_swarm)
3126
+ kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3127
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
3127
3128
  elif k == "boxplot":
3128
3129
  kws_box = kwargs.pop("kws_box", kwargs)
3129
- ax = sns.boxplot(data=data, x=x, y=y, ax=ax, **kws_box)
3130
+ kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3131
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
3130
3132
  elif k == "violinplot":
3131
3133
  kws_violin = kwargs.pop("kws_violin", kwargs)
3132
- ax = sns.violinplot(data=data, x=x, y=y, ax=ax, **kws_violin)
3134
+ kws_violin={k: v for k, v in kws_violin.items() if not k.startswith("kws_")}
3135
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_violin)
3133
3136
  elif k == "boxenplot":
3134
3137
  kws_boxen = kwargs.pop("kws_boxen", kwargs)
3135
- ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, **kws_boxen)
3138
+ kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3139
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
3136
3140
  elif k == "pointplot":
3137
3141
  kws_point = kwargs.pop("kws_point", kwargs)
3138
- ax = sns.pointplot(data=data, x=x, y=y, ax=ax, **kws_point)
3142
+ kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3143
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
3139
3144
  elif k == "barplot":
3140
3145
  kws_bar = kwargs.pop("kws_bar", kwargs)
3141
- ax = sns.barplot(data=data, x=x, y=y, ax=ax, **kws_bar)
3146
+ kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3147
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
3142
3148
  elif k == "countplot":
3143
3149
  kws_count = kwargs.pop("kws_count", kwargs)
3150
+ kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3144
3151
  if not kws_count.get("hue",None):
3145
3152
  kws_count.pop("palette",None)
3146
- ax = sns.countplot(data=data, x=x,y=y, ax=ax, **kws_count)
3153
+ ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
3147
3154
  elif k == "regplot":
3148
3155
  kws_reg = kwargs.pop("kws_reg", kwargs)
3149
- ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3156
+ kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3157
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
3150
3158
  elif k == "residplot":
3151
3159
  kws_resid = kwargs.pop("kws_resid", kwargs)
3152
- ax = sns.residplot(data=data, x=x, y=y, lowess=True, ax=ax, **kws_resid)
3160
+ kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3161
+ ax = sns.residplot(data=data, x=x, y=y, lowess=True,zorder=zorder, ax=ax, **kws_resid)
3153
3162
  elif k == "lineplot":
3154
3163
  kws_line = kwargs.pop("kws_line", kwargs)
3155
- ax = sns.lineplot(ax=ax, data=data, x=x, y=y, **kws_line)
3164
+ kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3165
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
3156
3166
 
3157
3167
  figsets(ax=ax, **kws_figsets)
3158
- print(kws_add_text)
3159
- add_text(ax=ax, **kws_add_text) if kws_add_text else None
3160
- print(k, " ⤵ ")
3161
- print(default_settings[k])
3162
- print(
3163
- "=>\t",
3164
- sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],
3165
- )
3166
- print()
3168
+ if kws_add_text:
3169
+ add_text(ax=ax, **kws_add_text) if kws_add_text else None
3170
+ if run_once_within(60):
3171
+ print(f"\n{k}⤵ ")
3172
+ print(default_settings[k])
3173
+ # print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
3167
3174
  if "g" in locals():
3168
3175
  if ax is not None:
3169
3176
  return g, ax
@@ -3815,4 +3822,275 @@ def venn(
3815
3822
  patch.set_alpha(alpha)
3816
3823
  if 'none' in edgecolor or 0 in linewidth:
3817
3824
  patch.set_edgecolor("none")
3825
+ return ax
3826
+
3827
+ #! subplots, support automatic extend new axis
3828
+ def subplot(rows:int=2,
3829
+ cols:int=2,
3830
+ figsize:Union[tuple,list]=[8, 8],
3831
+ sharex=False,
3832
+ sharey=False,
3833
+ **kwargs):
3834
+ """
3835
+ nexttile = subplot(
3836
+ 8,
3837
+ 2,
3838
+ figsize=(8, 9),
3839
+ sharey=True,
3840
+ sharex=True,
3841
+ )
3842
+
3843
+ for i in range(8):
3844
+ ax = nexttile()
3845
+ x = np.linspace(0, 10, 100) + i
3846
+ ax.plot(x, np.sin(x + i) + i, label=f"Plot {i + 1}")
3847
+ ax.legend()
3848
+ ax.set_title(f"Tile {i + 1}")
3849
+ ax.set_ylabel(f"Tile {i + 1}")
3850
+ ax.set_xlabel(f"Tile {i + 1}")
3851
+ """
3852
+ from matplotlib.gridspec import GridSpec
3853
+ if run_once_within():
3854
+ print(f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()")
3855
+ fig = plt.figure(figsize=figsize)
3856
+ grid_spec = GridSpec(rows, cols, figure=fig)
3857
+ occupied = set()
3858
+ row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
3859
+ col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
3860
+
3861
+ def expand_ax():
3862
+ nonlocal rows, grid_spec
3863
+ rows += 1 # Expands by adding a row
3864
+ grid_spec = GridSpec(rows, cols, figure=fig)
3865
+ def nexttile(rowspan=1, colspan=1, **kwargs):
3866
+ nonlocal rows, cols, occupied, grid_spec
3867
+ for row in range(rows):
3868
+ for col in range(cols):
3869
+ if all(
3870
+ (row + r, col + c) not in occupied
3871
+ for r in range(rowspan)
3872
+ for c in range(colspan)
3873
+ ):
3874
+ break
3875
+ else:
3876
+ continue
3877
+ break
3878
+ else:
3879
+ expand_ax()
3880
+ return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
3881
+
3882
+ sharex_ax,sharey_ax = None, None
3883
+
3884
+ if sharex:
3885
+ sharex_ax = col_first_axes[col]
3886
+
3887
+ if sharey:
3888
+ sharey_ax = row_first_axes[row]
3889
+ ax = fig.add_subplot(
3890
+ grid_spec[row : row + rowspan, col : col + colspan],
3891
+ sharex=sharex_ax,
3892
+ sharey=sharey_ax,
3893
+ **kwargs
3894
+ )
3895
+ if row_first_axes[row] is None:
3896
+ row_first_axes[row] = ax
3897
+ if col_first_axes[col] is None:
3898
+ col_first_axes[col] = ax
3899
+ for r in range(row, row + rowspan):
3900
+ for c in range(col, col + colspan):
3901
+ occupied.add((r, c))
3902
+
3903
+ return ax
3904
+
3905
+ return nexttile
3906
+
3907
+
3908
+ #! radar chart
3909
+ def radar(
3910
+ data: pd.DataFrame,
3911
+ ylim=(0,100),
3912
+ color=get_color(5),
3913
+ fontsize=10,
3914
+ fontcolor='k',
3915
+ size=6,
3916
+ linewidth=1,
3917
+ linestyle="-",
3918
+ alpha=0.5,
3919
+ marker="o",
3920
+ edgecolor='none',
3921
+ edge_linewidth=0,
3922
+ bg_color="0.8",
3923
+ bg_alpha=None,
3924
+ grid_interval_ratio=0.2,
3925
+ title="Radar Chart",
3926
+ cmap=None,
3927
+ legend_loc="upper right",
3928
+ legend_fontsize=10,
3929
+ grid_color="gray",
3930
+ grid_alpha=0.5,
3931
+ grid_linestyle="--",grid_linewidth=0.5,
3932
+ circular: bool = False,
3933
+ tick_fontsize=None,
3934
+ tick_fontcolor="0.65",
3935
+ tick_loc = None,# label position
3936
+ turning = None,
3937
+ ax=None,
3938
+ sp=2,
3939
+ **kwargs
3940
+ ):
3941
+ """
3942
+ Example DATA:
3943
+ df = pd.DataFrame(
3944
+ data=[
3945
+ [80, 80, 80, 80, 80, 80, 80],
3946
+ [90, 20, 95, 95, 30, 30, 80],
3947
+ [60, 90, 20, 20, 100, 90, 50],
3948
+ ],
3949
+ index=["Hero", "Warrior", "Wizard"],
3950
+ columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
3951
+
3952
+ Parameters:
3953
+ - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
3954
+ - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
3955
+ - color: The color(s) for the plot. Can be a single color or a list of colors.
3956
+ - fontsize (int): Font size for the angular labels (x-axis).
3957
+ - fontcolor (str): Color for the angular labels.
3958
+ - size (int): The size of the markers for each data point.
3959
+ - linewidth (int): Line width for the plot lines.
3960
+ - linestyle (str): Line style for the plot lines.
3961
+ - alpha (float): The transparency level for the filled area.
3962
+ - marker (str): The marker style for the data points.
3963
+ - edgecolor (str): The color for the marker edges.
3964
+ - edge_linewidth (int): Line width for the marker edges.
3965
+ - bg_color (str): Background color for the radar chart.
3966
+ - grid_interval_ratio (float): Determines the intervals for the grid lines as a fraction of the y-limit.
3967
+ - title (str): The title of the radar chart.
3968
+ - cmap (str): The colormap to use if `color` is a list.
3969
+ - legend_loc (str): The location of the legend.
3970
+ - legend_fontsize (int): Font size for the legend.
3971
+ - grid_color (str): Color for the grid lines.
3972
+ - grid_alpha (float): Transparency of the grid lines.
3973
+ - grid_linestyle (str): Style of the grid lines.
3974
+ - grid_linewidth (int): Line width of the grid lines.
3975
+ - circular (bool): If True, use circular grid lines. If False, use spider-style grid lines (straight lines).
3976
+ - tick_fontsize (int): Font size for the radial (y-axis) labels.
3977
+ - tick_fontcolor (str): Font color for the radial (y-axis) labels.
3978
+ - tick_loc (float or None): The location of the radial tick labels (between 0 and 1). If None, it is automatically calculated.
3979
+ - turning (float or None): Rotation of the radar chart. If None, it is not applied.
3980
+ - ax (matplotlib.axes.Axes or None): The axis on which to plot the radar chart. If None, a new axis will be created.
3981
+ - sp (int): Padding for the ticks from the plot area.
3982
+ - **kwargs: Additional arguments for customization.
3983
+ """
3984
+ if circular:
3985
+ from matplotlib.colors import to_rgba
3986
+ kws_figsets = {}
3987
+ for k_arg, v_arg in kwargs.items():
3988
+ if "figset" in k_arg:
3989
+ kws_figsets = v_arg
3990
+ kwargs.pop(k_arg, None)
3991
+ break
3992
+ categories = list(data.columns)
3993
+ num_vars = len(categories)
3994
+
3995
+ # Set up angle for each category on radar chart
3996
+ angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
3997
+ angles += angles[:1] # Complete the loop to ensure straight-line connections
3998
+
3999
+ # If no axis is provided, create a new one
4000
+ if ax is None:
4001
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
4002
+
4003
+ # bg_color
4004
+ if bg_alpha is None:
4005
+ bg_alpha=alpha
4006
+ ax.set_facecolor(to_rgba(bg_color,alpha=bg_alpha)) if circular else ax.set_facecolor('none')
4007
+ # Set up the radar chart with straight-line connections
4008
+ ax.set_theta_offset(np.pi / 2)
4009
+ ax.set_theta_direction(-1)
4010
+
4011
+ # Draw one axis per variable and add labels
4012
+ ax.set_xticks(angles[:-1])
4013
+ ax.set_xticklabels(categories)
4014
+
4015
+ # Set y-axis limits and grid intervals
4016
+ vmin, vmax = ylim
4017
+ if circular:
4018
+ #* cicular style
4019
+ ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
4020
+ ax.grid(axis='both',
4021
+ color=grid_color,
4022
+ linestyle=grid_linestyle,
4023
+ alpha=grid_alpha,
4024
+ linewidth=grid_linewidth,
4025
+ dash_capstyle='round',
4026
+ dash_joinstyle='round',
4027
+ )
4028
+ ax.spines["polar"].set_color(grid_color)
4029
+ ax.spines["polar"].set_linewidth(grid_linewidth)
4030
+ ax.spines["polar"].set_linestyle('-')
4031
+ ax.spines["polar"].set_alpha(grid_alpha)
4032
+ ax.spines["polar"].set_capstyle('round')
4033
+ ax.spines["polar"].set_joinstyle('round')
4034
+
4035
+ else:
4036
+ #* spider style: spider-style grid (straight lines, not circles)
4037
+ # Create the spider-style grid (straight lines, not circles)
4038
+ for i in range(1, int(vmax * grid_interval_ratio) + 1):
4039
+ ax.plot(
4040
+ angles + [angles[0]], # Closing the loop
4041
+ [i * vmax * grid_interval_ratio] * (num_vars+1) + [i * vmax * grid_interval_ratio],
4042
+ color=grid_color, linestyle=grid_linestyle, alpha=grid_alpha,linewidth=grid_linewidth
4043
+ )
4044
+ # set bg_color
4045
+ ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
4046
+ ax.yaxis.grid(False)
4047
+ # Move radial labels away from plotted line
4048
+ if tick_loc is None:
4049
+ tick_loc = np.mean([angles[0],angles[1]])/(2*np.pi)*360 if circular else 0
4050
+
4051
+ ax.set_rlabel_position(tick_loc)
4052
+ ax.set_theta_offset(turning) if turning is not None else None
4053
+ ax.tick_params(axis='x', labelsize=fontsize, colors=fontcolor) # Optional: for angular labels
4054
+ tick_fontsize = fontsize-2 if fontsize is None else tick_fontsize
4055
+ ax.tick_params(axis='y', labelsize=tick_fontsize, colors=tick_fontcolor) # For radial labels
4056
+ if not circular:
4057
+ ax.spines['polar'].set_visible(False)
4058
+ ax.tick_params(axis='x', pad=sp) # move spines outward
4059
+ ax.tick_params(axis='y', pad=sp) # move spines outward
4060
+ # colors
4061
+ colors = get_color(data.shape[0]) if cmap is None else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4062
+ # Plot each row with straight lines
4063
+ for i, (index, row) in enumerate(data.iterrows()):
4064
+ values = row.tolist()
4065
+ values += values[:1] # Close the loop
4066
+ ax.plot(
4067
+ angles,
4068
+ values,
4069
+ color=colors[i],
4070
+ linewidth=linewidth,
4071
+ linestyle=linestyle,
4072
+ label=index,
4073
+ clip_on=False
4074
+ )
4075
+ ax.fill(angles, values, color=colors[i], alpha=alpha)
4076
+
4077
+ ax.set_ylim(ylim)
4078
+ # Add markers for each data point
4079
+ for i, row in enumerate(data.values):
4080
+ ax.plot(
4081
+ angles,
4082
+ list(row) + [row[0]], # Close the loop for markers
4083
+ color=colors[i],
4084
+ marker=marker,
4085
+ markersize=size,
4086
+ markeredgecolor=edgecolor,
4087
+ markeredgewidth = edge_linewidth, zorder=10,clip_on=False
4088
+ )
4089
+ # ax.tick_params(axis='y', labelleft=False, left=False)
4090
+ if 'legend' in kws_figsets:
4091
+ figsets(ax=ax, **kws_figsets)
4092
+ else:
4093
+
4094
+ figsets(ax=ax,legend=dict(loc=legend_loc,fontsize=legend_fontsize,
4095
+ bbox_to_anchor=[1.1,1.4],ncols=2),**kws_figsets)
3818
4096
  return ax