py2ls 0.2.4.10.5__py3-none-any.whl → 0.2.4.10.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -1,21 +1,37 @@
1
1
  import numpy as np
2
- import pandas as pd
2
+ import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
4
  import matplotlib
5
5
  import seaborn as sns
6
6
  import warnings
7
7
  import logging
8
8
  from typing import Union
9
- from .ips import isa, fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font,run_once_within
9
+ from .ips import (
10
+ isa,
11
+ fsave,
12
+ fload,
13
+ mkdir,
14
+ listdir,
15
+ figsave,
16
+ strcmp,
17
+ unique,
18
+ get_os,
19
+ ssplit,
20
+ flatten,
21
+ plt_font,
22
+ run_once_within,
23
+ )
10
24
  from .stats import *
11
25
  import os
26
+
12
27
  # Suppress INFO messages from fontTools
13
28
  logging.getLogger("fontTools").setLevel(logging.ERROR)
14
- logging.getLogger('matplotlib').setLevel(logging.ERROR)
29
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
15
30
 
16
31
  warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
17
32
  warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
18
33
 
34
+
19
35
  def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
20
36
  """Adds text annotations for various types of Seaborn and Matplotlib plots.
21
37
  Args:
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
449
465
  """
450
466
  from matplotlib.colors import to_rgba
451
467
  import os
452
-
468
+
453
469
  def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
454
470
  if "l" in opt_b["loc"]:
455
471
  xloc_s = xloc - opt_b["x_dist"]
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
755
771
  )
756
772
  else:
757
773
  from scipy.stats import gaussian_kde
774
+
758
775
  kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
759
776
  min_val, max_val = ys.min(), ys.max()
760
777
  y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
@@ -1416,7 +1433,9 @@ def catplot(data, *args, **kwargs):
1416
1433
  style_load = fload(dir_style + style_use + ".json")
1417
1434
  else:
1418
1435
  style_load = fload(
1419
- listdir(dir_style, "json").path.tolist()[style_use]
1436
+ listdir(dir_style, "json", verbose=False).path.tolist()[
1437
+ style_use
1438
+ ]
1420
1439
  )
1421
1440
  style_load = remove_colors_in_dict(style_load)
1422
1441
  opt.update(style_load)
@@ -1812,16 +1831,17 @@ def read_mplstyle(style_file):
1812
1831
  return style_dict
1813
1832
 
1814
1833
 
1815
- def figsets(*args, **kwargs):
1834
+ def figsets(*args, **kwargs):
1816
1835
  import matplotlib
1817
1836
  from cycler import cycler
1837
+
1818
1838
  matplotlib.rc("text", usetex=False)
1819
1839
 
1820
1840
  fig = plt.gcf()
1821
- fontsize = kwargs.get("fontsize",11)
1822
- plt.rcParams["font.size"]=fontsize
1823
- fontname = kwargs.pop("fontname","Arial")
1824
- fontname=plt_font(fontname) # 显示中文
1841
+ fontsize = kwargs.get("fontsize", 11)
1842
+ plt.rcParams["font.size"] = fontsize
1843
+ fontname = kwargs.pop("fontname", "Arial")
1844
+ fontname = plt_font(fontname) # 显示中文
1825
1845
 
1826
1846
  sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
1827
1847
  sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
@@ -1851,18 +1871,21 @@ def figsets(*args, **kwargs):
1851
1871
  nonlocal fontsize, fontname
1852
1872
  if ("fo" in key) and (("size" in key) or ("sz" in key)):
1853
1873
  fontsize = value
1854
- plt.rcParams.update({"font.size": fontsize,
1855
- "figure.titlesize":fontsize,
1856
- "axes.titlesize":fontsize,
1857
- "axes.labelsize": fontsize,
1858
- "xtick.labelsize": fontsize,
1859
- "ytick.labelsize": fontsize,
1860
- "legend.fontsize": fontsize,
1861
- "legend.title_fontsize":fontsize
1862
- })
1874
+ plt.rcParams.update(
1875
+ {
1876
+ "font.size": fontsize,
1877
+ "figure.titlesize": fontsize,
1878
+ "axes.titlesize": fontsize,
1879
+ "axes.labelsize": fontsize,
1880
+ "xtick.labelsize": fontsize,
1881
+ "ytick.labelsize": fontsize,
1882
+ "legend.fontsize": fontsize,
1883
+ "legend.title_fontsize": fontsize,
1884
+ }
1885
+ )
1863
1886
 
1864
1887
  # Customize tick labels
1865
- ax.tick_params(axis='both', which='major', labelsize=fontsize)
1888
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
1866
1889
  for label in ax.get_xticklabels() + ax.get_yticklabels():
1867
1890
  label.set_fontname(fontname)
1868
1891
 
@@ -1910,15 +1933,15 @@ def figsets(*args, **kwargs):
1910
1933
  if ("x" in key.lower()) and (
1911
1934
  "tic" not in key.lower() and "tk" not in key.lower()
1912
1935
  ):
1913
- ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
1936
+ ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
1914
1937
  if ("y" in key.lower()) and (
1915
1938
  "tic" not in key.lower() and "tk" not in key.lower()
1916
1939
  ):
1917
- ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
1940
+ ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
1918
1941
  if ("z" in key.lower()) and (
1919
1942
  "tic" not in key.lower() and "tk" not in key.lower()
1920
1943
  ):
1921
- ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
1944
+ ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
1922
1945
  if key == "xlabel" and isinstance(value, dict):
1923
1946
  ax.set_xlabel(**value)
1924
1947
  if key == "ylabel" and isinstance(value, dict):
@@ -2110,6 +2133,7 @@ def figsets(*args, **kwargs):
2110
2133
 
2111
2134
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
2112
2135
  import matplotlib.ticker as tck
2136
+
2113
2137
  if "x" in value.lower() or "x" in key.lower():
2114
2138
  ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
2115
2139
  if "y" in value.lower() or "y" in key.lower():
@@ -2162,9 +2186,9 @@ def figsets(*args, **kwargs):
2162
2186
  ax.grid(visible=False)
2163
2187
  if "tit" in key.lower():
2164
2188
  if "sup" in key.lower():
2165
- plt.suptitle(value,fontname=fontname,fontsize=fontsize)
2189
+ plt.suptitle(value, fontname=fontname, fontsize=fontsize)
2166
2190
  else:
2167
- ax.set_title(value,fontname=fontname,fontsize=fontsize)
2191
+ ax.set_title(value, fontname=fontname, fontsize=fontsize)
2168
2192
  if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
2169
2193
  if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
2170
2194
  if value:
@@ -2185,9 +2209,14 @@ def figsets(*args, **kwargs):
2185
2209
  legend = ax.get_legend()
2186
2210
  if legend is not None:
2187
2211
  legend.remove()
2188
- if any(['colorbar' in key.lower(),'cbar' in key.lower()]) and "loc" in key.lower():
2212
+ if (
2213
+ any(["colorbar" in key.lower(), "cbar" in key.lower()])
2214
+ and "loc" in key.lower()
2215
+ ):
2189
2216
  cbar = ax.collections[0].colorbar # Access the colorbar from the plot
2190
- cbar.ax.set_position(value) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2217
+ cbar.ax.set_position(
2218
+ value
2219
+ ) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2191
2220
 
2192
2221
  for arg in args:
2193
2222
  if isinstance(arg, matplotlib.axes._axes.Axes):
@@ -2225,7 +2254,8 @@ def figsets(*args, **kwargs):
2225
2254
  if len(fig.get_axes()) > 1:
2226
2255
  plt.tight_layout()
2227
2256
 
2228
- def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
+
2258
+ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2229
2259
  """
2230
2260
  split_legend(
2231
2261
  ax,
@@ -2238,15 +2268,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2238
2268
  # Retrieve all lines and labels from the axis
2239
2269
  handles, labels = ax.get_legend_handles_labels()
2240
2270
  num_labels = len(labels)
2241
-
2271
+
2242
2272
  # Calculate the number of labels per legend part
2243
- labels_per_part = (num_labels + n - 1) // n # Round up
2273
+ labels_per_part = (num_labels + n - 1) // n # Round up
2244
2274
  # Create a list to hold each legend object
2245
2275
  legends = []
2246
2276
 
2247
2277
  # Default locations and titles if not specified
2248
2278
  if loc is None:
2249
- loc = ['best'] * n
2279
+ loc = ["best"] * n
2250
2280
  if title is None:
2251
2281
  title = [None] * n
2252
2282
  if bbox is None:
@@ -2257,7 +2287,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
2287
  # Calculate the range of labels for this part
2258
2288
  start_idx = i * labels_per_part
2259
2289
  end_idx = min(start_idx + labels_per_part, num_labels)
2260
-
2290
+
2261
2291
  # Skip if no labels in this range
2262
2292
  if start_idx >= end_idx:
2263
2293
  break
@@ -2267,19 +2297,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2267
2297
  part_labels = labels[start_idx:end_idx]
2268
2298
 
2269
2299
  # Create the legend for this part
2270
- legend = ax.legend(handles=part_handles,
2271
- labels=part_labels,
2272
- loc=loc[i],
2273
- title=title[i],
2274
- ncol=ncol,
2275
- bbox_to_anchor=bbox[i],
2276
- **kwargs)
2277
-
2300
+ legend = ax.legend(
2301
+ handles=part_handles,
2302
+ labels=part_labels,
2303
+ loc=loc[i],
2304
+ title=title[i],
2305
+ ncol=ncol,
2306
+ bbox_to_anchor=bbox[i],
2307
+ **kwargs,
2308
+ )
2309
+
2278
2310
  # Add the legend to the axis and save it to the list
2279
- ax.add_artist(legend) if i !=(n-1) else None # the lastone will be added automaticaly
2311
+ (
2312
+ ax.add_artist(legend) if i != (n - 1) else None
2313
+ ) # the lastone will be added automaticaly
2280
2314
  legends.append(legend)
2281
2315
  return legends
2282
2316
 
2317
+
2283
2318
  def get_colors(
2284
2319
  n: int = 1,
2285
2320
  cmap: str = "auto",
@@ -2289,7 +2324,9 @@ def get_colors(
2289
2324
  *args,
2290
2325
  **kwargs,
2291
2326
  ):
2292
- return get_color(n,cmap,alpha,output,*args,**kwargs)
2327
+ return get_color(n, cmap, alpha, output, *args, **kwargs)
2328
+
2329
+
2293
2330
  def get_color(
2294
2331
  n: int = 1,
2295
2332
  cmap: str = "auto",
@@ -2300,6 +2337,7 @@ def get_color(
2300
2337
  **kwargs,
2301
2338
  ):
2302
2339
  from cycler import cycler
2340
+
2303
2341
  def cmap2hex(cmap_name):
2304
2342
  cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
2305
2343
  colors = [cmap_(i) for i in range(cmap_.N)]
@@ -2347,49 +2385,137 @@ def get_color(
2347
2385
  return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
2348
2386
  else:
2349
2387
  return "#{:02X}{:02X}{:02X}".format(r, g, b)
2350
-
2388
+
2351
2389
  # sc.pl.palettes.default_20
2352
- cmap_20= ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61',
2353
- '#17becf','#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2',
2354
- '#dbdb8d','#9edae5','#ad494a','#8c6d31']
2390
+ cmap_20 = [
2391
+ "#1f77b4",
2392
+ "#ff7f0e",
2393
+ "#279e68",
2394
+ "#d62728",
2395
+ "#aa40fc",
2396
+ "#8c564b",
2397
+ "#e377c2",
2398
+ "#b5bd61",
2399
+ "#17becf",
2400
+ "#aec7e8",
2401
+ "#ffbb78",
2402
+ "#98df8a",
2403
+ "#ff9896",
2404
+ "#c5b0d5",
2405
+ "#c49c94",
2406
+ "#f7b6d2",
2407
+ "#dbdb8d",
2408
+ "#9edae5",
2409
+ "#ad494a",
2410
+ "#8c6d31",
2411
+ ]
2355
2412
  # sc.pl.palettes.zeileis_28
2356
- cmap_28 = ["#023fa5","#7d87b9","#bec1d4","#d6bcc0","#bb7784","#8e063b","#4a6fe3","#8595e1",
2357
- "#b5bbe3","#e6afb9","#e07b91","#d33f6a","#11c638","#8dd593","#c6dec7","#ead3c6",
2358
- "#f0b98d","#ef9708","#0fcfc0","#9cded6","#d5eae7","#f3e1eb","#f6c4e1","#f79cd4",
2359
- "#7f7f7f","#c7c7c7","#1CE6FF","#336600"]
2413
+ cmap_28 = [
2414
+ "#023fa5",
2415
+ "#7d87b9",
2416
+ "#bec1d4",
2417
+ "#d6bcc0",
2418
+ "#bb7784",
2419
+ "#8e063b",
2420
+ "#4a6fe3",
2421
+ "#8595e1",
2422
+ "#b5bbe3",
2423
+ "#e6afb9",
2424
+ "#e07b91",
2425
+ "#d33f6a",
2426
+ "#11c638",
2427
+ "#8dd593",
2428
+ "#c6dec7",
2429
+ "#ead3c6",
2430
+ "#f0b98d",
2431
+ "#ef9708",
2432
+ "#0fcfc0",
2433
+ "#9cded6",
2434
+ "#d5eae7",
2435
+ "#f3e1eb",
2436
+ "#f6c4e1",
2437
+ "#f79cd4",
2438
+ "#7f7f7f",
2439
+ "#c7c7c7",
2440
+ "#1CE6FF",
2441
+ "#336600",
2442
+ ]
2360
2443
  if cmap == "gray":
2361
2444
  cmap = "grey"
2362
- elif cmap=="20":
2363
- cmap=cmap_20
2364
- elif cmap=="28":
2365
- cmap=cmap_28
2445
+ elif cmap == "20":
2446
+ cmap = cmap_20
2447
+ elif cmap == "28":
2448
+ cmap = cmap_28
2366
2449
  # Determine color list based on cmap parameter
2367
- if isinstance(cmap,str):
2450
+ if isinstance(cmap, str):
2368
2451
  if "aut" in cmap:
2369
2452
  if n == 1:
2370
2453
  colorlist = ["#3A4453"]
2371
2454
  elif n == 2:
2372
2455
  colorlist = ["#3A4453", "#FF2C00"]
2373
2456
  elif n == 3:
2374
- colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
2457
+ colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
2375
2458
  elif n == 4:
2376
- colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
2459
+ colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
2377
2460
  elif n == 5:
2378
- colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
2461
+ colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
2379
2462
  elif n == 6:
2380
- colorlist = ["#FF2C00","#91bfdb", "#B25E9D", "#4B8C3B","#EF8632", "#24578E"]
2381
- elif n==7:
2382
- colorlist = ["#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2383
- elif n==8:
2463
+ colorlist = [
2464
+ "#FF2C00",
2465
+ "#91bfdb",
2466
+ "#B25E9D",
2467
+ "#4B8C3B",
2468
+ "#EF8632",
2469
+ "#24578E",
2470
+ ]
2471
+ elif n == 7:
2472
+ colorlist = [
2473
+ "#7F7F7F",
2474
+ "#459AA9",
2475
+ "#B25E9D",
2476
+ "#4B8C3B",
2477
+ "#EF8632",
2478
+ "#24578E" "#FF2C00",
2479
+ ]
2480
+ elif n == 8:
2384
2481
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
2385
2482
  # colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
2386
- colorlist = ["#78BFDA","#D52E6F","#F7D648","#A52D28","#6B9F41","#E18330","#E18B9D","#3C88CC"]
2387
- elif n==9:
2388
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#d62728','#aa40fc','#e377c2','#51B34F','#17becf']
2389
- elif n==10:
2390
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#51B34F','#d62728''#aa40fc','#e377c2','#375FD2','#17becf']
2391
- elif 10<n<=20:
2392
- colorlist = cmap_20
2483
+ colorlist = [
2484
+ "#78BFDA",
2485
+ "#D52E6F",
2486
+ "#F7D648",
2487
+ "#A52D28",
2488
+ "#6B9F41",
2489
+ "#E18330",
2490
+ "#E18B9D",
2491
+ "#3C88CC",
2492
+ ]
2493
+ elif n == 9:
2494
+ colorlist = [
2495
+ "#1f77b4",
2496
+ "#ff7f0e",
2497
+ "#367B7F",
2498
+ "#ff9896",
2499
+ "#d62728",
2500
+ "#aa40fc",
2501
+ "#e377c2",
2502
+ "#51B34F",
2503
+ "#17becf",
2504
+ ]
2505
+ elif n == 10:
2506
+ colorlist = [
2507
+ "#1f77b4",
2508
+ "#ff7f0e",
2509
+ "#367B7F",
2510
+ "#ff9896",
2511
+ "#51B34F",
2512
+ "#d62728" "#aa40fc",
2513
+ "#e377c2",
2514
+ "#375FD2",
2515
+ "#17becf",
2516
+ ]
2517
+ elif 10 < n <= 20:
2518
+ colorlist = cmap_20
2393
2519
  else:
2394
2520
  colorlist = cmap_28
2395
2521
  by = "start"
@@ -2426,7 +2552,7 @@ def get_color(
2426
2552
  by = "linspace"
2427
2553
  colorlist = cmap2hex(cmap)
2428
2554
  elif isinstance(cmap, list):
2429
- colorlist=cmap
2555
+ colorlist = cmap
2430
2556
 
2431
2557
  # Determine method for generating color list
2432
2558
  if "st" in by.lower() or "be" in by.lower():
@@ -2446,7 +2572,7 @@ def get_color(
2446
2572
  return hue_list
2447
2573
  else:
2448
2574
  raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
2449
-
2575
+
2450
2576
 
2451
2577
  def stdshade(ax=None, *args, **kwargs):
2452
2578
  """
@@ -2454,7 +2580,7 @@ def stdshade(ax=None, *args, **kwargs):
2454
2580
  plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
2455
2581
  """
2456
2582
  from scipy.signal import savgol_filter
2457
-
2583
+
2458
2584
  # Separate kws_line and kws_fill if necessary
2459
2585
  kws_line = kwargs.pop("kws_line", {})
2460
2586
  kws_fill = kwargs.pop("kws_fill", {})
@@ -2724,37 +2850,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
2724
2850
  # cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
2725
2851
  # return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
2726
2852
 
2727
- def add_colorbar(im,
2728
- cmap="viridis",
2729
- vmin=-1,
2730
- vmax=1,
2731
- orientation='vertical',
2732
- width_ratio=0.05,
2733
- pad_ratio=0.02,
2734
- shrink=1.0,
2735
- **kwargs):
2853
+
2854
+ def add_colorbar(
2855
+ im,
2856
+ cmap="viridis",
2857
+ vmin=-1,
2858
+ vmax=1,
2859
+ orientation="vertical",
2860
+ width_ratio=0.05,
2861
+ pad_ratio=0.02,
2862
+ shrink=1.0,
2863
+ **kwargs,
2864
+ ):
2736
2865
  import matplotlib as mpl
2866
+
2737
2867
  if all([cmap, vmin, vmax]):
2738
2868
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
2739
2869
  else:
2740
- norm=False
2870
+ norm = False
2741
2871
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
2742
2872
  sm.set_array([])
2743
2873
  l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
2744
- if orientation == 'vertical':
2874
+ if orientation == "vertical":
2745
2875
  width = width_ratio * w
2746
2876
  pad = pad_ratio * w
2747
- cax = im.figure.add_axes([l + w + pad, b, width, h * shrink]) # Right of the image
2877
+ cax = im.figure.add_axes(
2878
+ [l + w + pad, b, width, h * shrink]
2879
+ ) # Right of the image
2748
2880
  else:
2749
2881
  height = width_ratio * h
2750
2882
  pad = pad_ratio * h
2751
- cax = im.figure.add_axes([l, b - height - pad, w * shrink, height]) # Below the image
2752
- cbar=im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2753
- return cbar
2883
+ cax = im.figure.add_axes(
2884
+ [l, b - height - pad, w * shrink, height]
2885
+ ) # Below the image
2886
+ cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2887
+ return cbar
2888
+
2754
2889
 
2755
2890
  # Usage:
2756
2891
  # add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
2757
2892
 
2893
+
2758
2894
  def generate_xticks_with_gap(x_len, hue_len):
2759
2895
  """
2760
2896
  Generate a concatenated array based on x_len and hue_len,
@@ -2884,6 +3020,7 @@ def style_examples(
2884
3020
  f = listdir(
2885
3021
  "/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
2886
3022
  kind=".json",
3023
+ verbose=False,
2887
3024
  )
2888
3025
  display(f.sample(2))
2889
3026
  # def style_example(dir_save,)
@@ -2961,6 +3098,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
2961
3098
 
2962
3099
  def get_params_from_func_usage(function_signature):
2963
3100
  import re
3101
+
2964
3102
  # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2965
3103
  keys_pattern = r"(?<!\*\*)\b(\w+)="
2966
3104
  # Find all matches
@@ -2987,7 +3125,7 @@ def plotxy(
2987
3125
  x=None,
2988
3126
  y=None,
2989
3127
  ax=None,
2990
- kind: Union[str, list] = 'scatter', # Specify the kind of plot
3128
+ kind: Union[str, list] = "scatter", # Specify the kind of plot
2991
3129
  verbose=False,
2992
3130
  **kwargs,
2993
3131
  ):
@@ -3011,14 +3149,15 @@ def plotxy(
3011
3149
  # Check for valid plot kind
3012
3150
  # Default arguments for various plot types
3013
3151
  from pathlib import Path
3152
+
3014
3153
  # Get the current script's directory as a Path object
3015
- current_directory = Path(__file__).resolve().parent
3154
+ current_directory = Path(__file__).resolve().parent
3155
+
3156
+ if not "default_settings" in locals():
3157
+ default_settings = fload(current_directory / "data" / "usages_sns.json")
3158
+ if not "sns_info" in locals():
3159
+ sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
3016
3160
 
3017
- if not 'default_settings' in locals():
3018
- default_settings = fload(current_directory / 'data' / 'usages_sns.json')
3019
- if not 'sns_info' in locals():
3020
- sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
-
3022
3161
  valid_kinds = list(default_settings.keys())
3023
3162
  # print(valid_kinds)
3024
3163
  if kind is not None:
@@ -3032,7 +3171,7 @@ def plotxy(
3032
3171
  if kind is not None:
3033
3172
  for k in kind:
3034
3173
  if k in valid_kinds:
3035
- print(f"{k}:\n\t{default_settings[k]}")
3174
+ print(f"{k}:\n\t{default_settings[k]}")
3036
3175
  usage_str = """plotxy(data=ranked_genes,
3037
3176
  x="log2(fold_change)",
3038
3177
  y="-log10(p-value)",
@@ -3057,15 +3196,25 @@ def plotxy(
3057
3196
  kws_add_text = v_arg
3058
3197
  kwargs.pop(k_arg, None)
3059
3198
  break
3060
- zorder=0
3199
+ zorder = 0
3061
3200
  for k in kind:
3062
- zorder+=1
3201
+ zorder += 1
3063
3202
  # indicate 'col' features
3064
3203
  col = kwargs.get("col", None)
3065
- sns_with_col = ["catplot","histplot","relplot","lmplot","pairplot","displot","kdeplot"]
3204
+ sns_with_col = [
3205
+ "catplot",
3206
+ "histplot",
3207
+ "relplot",
3208
+ "lmplot",
3209
+ "pairplot",
3210
+ "displot",
3211
+ "kdeplot",
3212
+ ]
3066
3213
  if col is not None:
3067
3214
  if not k in sns_with_col:
3068
- print(f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}")
3215
+ print(
3216
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
3217
+ )
3069
3218
  # (1) return FcetGrid
3070
3219
  if k == "jointplot":
3071
3220
  kws_joint = kwargs.pop("kws_joint", kwargs)
@@ -3092,77 +3241,98 @@ def plotxy(
3092
3241
  ax = stdshade(ax=ax, **kwargs)
3093
3242
  elif k == "scatterplot":
3094
3243
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3095
- kws_scatter={k: v for k, v in kws_scatter.items() if not k.startswith("kws_")}
3244
+ kws_scatter = {
3245
+ k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3246
+ }
3096
3247
  hue = kwargs.pop("hue", None)
3097
3248
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3098
3249
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3099
- palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3250
+ palette = kws_scatter.pop(
3251
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3252
+ )
3100
3253
  s = kws_scatter.pop("s", 10)
3101
3254
  alpha = kws_scatter.pop("alpha", 0.7)
3102
- ax = sns.scatterplot(ax=ax,data=data,x=x,y=y,hue=hue,palette=palette,s=s,alpha=alpha,zorder=zorder,**kws_scatter)
3255
+ ax = sns.scatterplot(
3256
+ ax=ax,
3257
+ data=data,
3258
+ x=x,
3259
+ y=y,
3260
+ hue=hue,
3261
+ palette=palette,
3262
+ s=s,
3263
+ alpha=alpha,
3264
+ zorder=zorder,
3265
+ **kws_scatter,
3266
+ )
3103
3267
  elif k == "histplot":
3104
3268
  kws_hist = kwargs.pop("kws_hist", kwargs)
3105
- kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3106
- ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
3269
+ kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3270
+ ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3107
3271
  elif k == "kdeplot":
3108
3272
  kws_kde = kwargs.pop("kws_kde", kwargs)
3109
- kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3110
- ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
3273
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3274
+ ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3111
3275
  elif k == "ecdfplot":
3112
3276
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3113
- kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3114
- ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
3277
+ kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3278
+ ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3115
3279
  elif k == "rugplot":
3116
3280
  kws_rug = kwargs.pop("kws_rug", kwargs)
3117
- kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3118
- ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
3281
+ kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3282
+ ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3119
3283
  elif k == "stripplot":
3120
3284
  kws_strip = kwargs.pop("kws_strip", kwargs)
3121
- kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3285
+ kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3122
3286
  dodge = kws_strip.pop("dodge", True)
3123
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax,zorder=zorder, dodge=dodge, **kws_strip)
3287
+ ax = sns.stripplot(
3288
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3289
+ )
3124
3290
  elif k == "swarmplot":
3125
3291
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3126
- kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3127
- ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
3292
+ kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3293
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
3128
3294
  elif k == "boxplot":
3129
3295
  kws_box = kwargs.pop("kws_box", kwargs)
3130
- kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3131
- ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
3296
+ kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3297
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
3132
3298
  elif k == "violinplot":
3133
3299
  kws_violin = kwargs.pop("kws_violin", kwargs)
3134
- kws_violin={k: v for k, v in kws_violin.items() if not k.startswith("kws_")}
3135
- ax = sns.violinplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_violin)
3300
+ kws_violin = {
3301
+ k: v for k, v in kws_violin.items() if not k.startswith("kws_")
3302
+ }
3303
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
3136
3304
  elif k == "boxenplot":
3137
3305
  kws_boxen = kwargs.pop("kws_boxen", kwargs)
3138
- kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3139
- ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
3306
+ kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3307
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
3140
3308
  elif k == "pointplot":
3141
3309
  kws_point = kwargs.pop("kws_point", kwargs)
3142
- kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3143
- ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
3310
+ kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3311
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
3144
3312
  elif k == "barplot":
3145
3313
  kws_bar = kwargs.pop("kws_bar", kwargs)
3146
- kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3147
- ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
3314
+ kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3315
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
3148
3316
  elif k == "countplot":
3149
3317
  kws_count = kwargs.pop("kws_count", kwargs)
3150
- kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3151
- if not kws_count.get("hue",None):
3152
- kws_count.pop("palette",None)
3153
- ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
3318
+ kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3319
+ if not kws_count.get("hue", None):
3320
+ kws_count.pop("palette", None)
3321
+ ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
3154
3322
  elif k == "regplot":
3155
3323
  kws_reg = kwargs.pop("kws_reg", kwargs)
3156
- kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3157
- ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
3324
+ kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3325
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
3158
3326
  elif k == "residplot":
3159
3327
  kws_resid = kwargs.pop("kws_resid", kwargs)
3160
- kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3161
- ax = sns.residplot(data=data, x=x, y=y, lowess=True,zorder=zorder, ax=ax, **kws_resid)
3328
+ kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3329
+ ax = sns.residplot(
3330
+ data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
3331
+ )
3162
3332
  elif k == "lineplot":
3163
3333
  kws_line = kwargs.pop("kws_line", kwargs)
3164
- kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3165
- ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
3334
+ kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3335
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
3166
3336
 
3167
3337
  figsets(ax=ax, **kws_figsets)
3168
3338
  if kws_add_text:
@@ -3176,20 +3346,22 @@ def plotxy(
3176
3346
  return g, ax
3177
3347
  return ax
3178
3348
 
3349
+
3179
3350
  def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
3180
3351
  norm_ = plt.Normalize(min_max[0], min_max[1])
3181
3352
  colormap = plt.get_cmap(cmap)
3182
3353
  return colormap(norm_(data))
3183
3354
 
3355
+
3184
3356
  def volcano(
3185
- data:pd.DataFrame,
3186
- x:str,
3187
- y:str,
3188
- gene_col:str=None,
3189
- top_genes=[5, 5], # [down-regulated, up-regulated]
3190
- thr_x=np.log2(1.5), # default: 0.585
3357
+ data: pd.DataFrame,
3358
+ x: str,
3359
+ y: str,
3360
+ gene_col: str = None,
3361
+ top_genes=[5, 5], # [down-regulated, up-regulated]
3362
+ thr_x=np.log2(1.5), # default: 0.585
3191
3363
  thr_y=-np.log10(0.05),
3192
- sort_xy="x", #'y', 'xy'
3364
+ sort_xy="x", #'y', 'xy'
3193
3365
  colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3194
3366
  s=20,
3195
3367
  fill=True, # plot filled scatter
@@ -3201,11 +3373,10 @@ def volcano(
3201
3373
  ax=None,
3202
3374
  verbose=False,
3203
3375
  kws_text=dict(fontsize=10, color="k"),
3204
- kws_bbox=dict(facecolor='none',
3205
- alpha=0.5,
3206
- edgecolor='black',
3207
- boxstyle='round,pad=0.3'),# '{}' to hide
3208
- kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
3376
+ kws_bbox=dict(
3377
+ facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
3378
+ ), # '{}' to hide
3379
+ kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
3209
3380
  **kwargs,
3210
3381
  ):
3211
3382
  """
@@ -3275,39 +3446,35 @@ def volcano(
3275
3446
  kws_figsets = v_arg
3276
3447
  kwargs.pop(k_arg, None)
3277
3448
  break
3278
-
3279
- data=data.copy()
3449
+
3450
+ data = data.copy()
3280
3451
  # filter nan
3281
3452
  data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
3282
- data.loc[:,"color"] = np.where(
3453
+ data.loc[:, "color"] = np.where(
3283
3454
  (data[x] > thr_x) & (data[y] > thr_y),
3284
3455
  colors[2],
3285
- np.where((data[x] < -thr_x) & (data[y] > thr_y),
3286
- colors[0],
3287
- colors[1]),
3456
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3288
3457
  )
3289
- top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
3290
-
3458
+ top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
3459
+
3291
3460
  # could custom how to select the top genes, x: x has priority
3292
- sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
3293
- ascending_up=[True, True] if sort_xy=="x" else [False, True]
3294
- ascending_down=[False, True] if sort_xy=="x" else [False, False]
3295
-
3296
- down_reg_genes = data[
3297
- (data["color"] == colors[0]) &
3298
- (data[x].abs() > thr_x) &
3299
- (data[y] > thr_y)
3300
- ].sort_values(by=sort_by_x_y, ascending=ascending_up).head(top_genes[0])
3301
- up_reg_genes = data[
3302
- (data["color"] == colors[2]) &
3303
- (data[x].abs() > thr_x) &
3304
- (data[y] > thr_y)
3305
- ].sort_values(by=sort_by_x_y, ascending=ascending_down).head(top_genes[1])
3461
+ sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
3462
+ ascending_up = [True, True] if sort_xy == "x" else [False, True]
3463
+ ascending_down = [False, True] if sort_xy == "x" else [False, False]
3464
+
3465
+ down_reg_genes = (
3466
+ data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3467
+ .sort_values(by=sort_by_x_y, ascending=ascending_up)
3468
+ .head(top_genes[0])
3469
+ )
3470
+ up_reg_genes = (
3471
+ data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3472
+ .sort_values(by=sort_by_x_y, ascending=ascending_down)
3473
+ .head(top_genes[1])
3474
+ )
3306
3475
  sele_gene = pd.concat([down_reg_genes, up_reg_genes])
3307
-
3308
- palette = {colors[0]: colors[0],
3309
- colors[1]: colors[1],
3310
- colors[2]: colors[2]}
3476
+
3477
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3311
3478
  # Plot setup
3312
3479
  if ax is None:
3313
3480
  ax = plt.gca()
@@ -3337,9 +3504,9 @@ def volcano(
3337
3504
  )
3338
3505
 
3339
3506
  # Add threshold lines for x and y axes
3340
- ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
3341
- ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
3342
- ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
3507
+ ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
3508
+ ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
3509
+ ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
3343
3510
 
3344
3511
  # Add gene labels for selected significant points
3345
3512
  if gene_col:
@@ -3349,19 +3516,28 @@ def volcano(
3349
3516
  textcolor = kws_text.pop("color", "k")
3350
3517
  fontsize = kws_text.pop("fontsize", 10)
3351
3518
  arrowstyles = [
3352
- "->","<-","<->","<|-","-|>","<|-|>",
3353
- "-","-[","-[",
3354
- "fancy","simple","wedge",
3519
+ "->",
3520
+ "<-",
3521
+ "<->",
3522
+ "<|-",
3523
+ "-|>",
3524
+ "<|-|>",
3525
+ "-",
3526
+ "-[",
3527
+ "-[",
3528
+ "fancy",
3529
+ "simple",
3530
+ "wedge",
3355
3531
  ]
3356
3532
  arrowstyle = kws_arrow.pop("style", "<|-")
3357
- arrowstyle = strcmp(arrowstyle, arrowstyles,scorer='strict')[0]
3358
- expand=kws_arrow.pop("expand",(1.05,1.1))
3533
+ arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
3534
+ expand = kws_arrow.pop("expand", (1.05, 1.1))
3359
3535
  arrowcolor = kws_arrow.pop("color", "0.4")
3360
3536
  arrowlinewidth = kws_arrow.pop("lw", 0.75)
3361
3537
  shrinkA = kws_arrow.pop("shrinkA", 0)
3362
3538
  shrinkB = kws_arrow.pop("shrinkB", 0)
3363
3539
  mutation_scale = kws_arrow.pop("head", 10)
3364
- arrow_fill=kws_arrow.pop("fill", False)
3540
+ arrow_fill = kws_arrow.pop("fill", False)
3365
3541
  for i in range(sele_gene.shape[0]):
3366
3542
  if isinstance(textcolor, list): # be consistant with dots's color
3367
3543
  textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
@@ -3382,7 +3558,7 @@ def volcano(
3382
3558
  adjust_text(
3383
3559
  texts,
3384
3560
  expand=expand,
3385
- min_arrow_len=5,
3561
+ min_arrow_len=5,
3386
3562
  ax=ax,
3387
3563
  arrowprops=dict(
3388
3564
  arrowstyle=arrowstyle,
@@ -3393,7 +3569,7 @@ def volcano(
3393
3569
  shrinkB=shrinkB,
3394
3570
  mutation_scale=mutation_scale,
3395
3571
  **kws_arrow,
3396
- )
3572
+ ),
3397
3573
  )
3398
3574
 
3399
3575
  figsets(**kws_figsets)
@@ -3499,6 +3675,7 @@ def desaturate_color(color, saturation_factor=0.5):
3499
3675
  """Reduce the saturation of a color by a given factor (between 0 and 1)."""
3500
3676
  import matplotlib.colors as mcolors
3501
3677
  import colorsys
3678
+
3502
3679
  # Convert the color to RGB
3503
3680
  rgb = mcolors.to_rgb(color)
3504
3681
  # Convert RGB to HLS (Hue, Lightness, Saturation)
@@ -3508,8 +3685,19 @@ def desaturate_color(color, saturation_factor=0.5):
3508
3685
  # Convert back to RGB
3509
3686
  return colorsys.hls_to_rgb(h, l, s)
3510
3687
 
3511
- def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle="normal",
3512
- fontcolor='k', backgroundcolor=None, shadow=False, ha="center", va="center"):
3688
+
3689
+ def textsets(
3690
+ text,
3691
+ fontname="Arial",
3692
+ fontsize=11,
3693
+ fontweight="normal",
3694
+ fontstyle="normal",
3695
+ fontcolor="k",
3696
+ backgroundcolor=None,
3697
+ shadow=False,
3698
+ ha="center",
3699
+ va="center",
3700
+ ):
3513
3701
  if text: # Ensure text exists
3514
3702
  if fontname:
3515
3703
  text.set_fontname(plt_font(fontname))
@@ -3522,26 +3710,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
3522
3710
  if fontcolor:
3523
3711
  text.set_color(fontcolor)
3524
3712
  if backgroundcolor:
3525
- text.set_backgroundcolor(backgroundcolor)
3713
+ text.set_backgroundcolor(backgroundcolor)
3526
3714
  text.set_horizontalalignment(ha)
3527
- text.set_verticalalignment(va)
3715
+ text.set_verticalalignment(va)
3528
3716
  if shadow:
3529
- text.set_path_effects([
3530
- matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
3531
- ])
3717
+ text.set_path_effects(
3718
+ [matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
3719
+ )
3720
+
3721
+
3532
3722
  def venn(
3533
- lists:list,
3534
- labels:list=None,
3723
+ lists: list,
3724
+ labels: list = None,
3535
3725
  ax=None,
3536
3726
  colors=None,
3537
3727
  edgecolor=None,
3538
3728
  alpha=0.5,
3539
- saturation=.75,
3540
- linewidth=0, # default no edge
3729
+ saturation=0.75,
3730
+ linewidth=0, # default no edge
3541
3731
  linestyle="-",
3542
- fontname='Arial',
3732
+ fontname="Arial",
3543
3733
  fontsize=10,
3544
- fontcolor='k',
3734
+ fontcolor="k",
3545
3735
  fontweight="normal",
3546
3736
  fontstyle="normal",
3547
3737
  ha="center",
@@ -3550,18 +3740,18 @@ def venn(
3550
3740
  subset_fontsize=10,
3551
3741
  subset_fontweight="normal",
3552
3742
  subset_fontstyle="normal",
3553
- subset_fontcolor='k',
3743
+ subset_fontcolor="k",
3554
3744
  backgroundcolor=None,
3555
3745
  custom_texts=None,
3556
- show_percentages=True, # display percentage
3746
+ show_percentages=True, # display percentage
3557
3747
  fmt="{:.1%}",
3558
3748
  ellipse_shape=False, # 椭圆形
3559
- ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
3560
- **kwargs
3749
+ ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
3750
+ **kwargs,
3561
3751
  ):
3562
3752
  """
3563
3753
  Advanced Venn diagram plotting function with extensive customization options.
3564
- Usage:
3754
+ Usage:
3565
3755
  # Define the two sets
3566
3756
  set1 = [1, 2, 3, 4, 5]
3567
3757
  set2 = [4, 5, 6, 7, 8]
@@ -3612,56 +3802,70 @@ def venn(
3612
3802
  """
3613
3803
  if ax is None:
3614
3804
  ax = plt.gca()
3615
- lists=[set(flatten(i, verbose=False)) for i in lists]
3805
+ lists = [set(flatten(i, verbose=False)) for i in lists]
3616
3806
  # Function to apply text styles to labels
3617
3807
  if colors is None:
3618
- colors=["r","b"] if len(lists)==2 else ["r","g","b"]
3808
+ colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
3619
3809
  if labels is None:
3620
- labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
3810
+ labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
3621
3811
  if edgecolor is None:
3622
- edgecolor=colors
3812
+ edgecolor = colors
3623
3813
  colors = [desaturate_color(color, saturation) for color in colors]
3624
3814
  # Check colors and auto-calculate overlaps
3625
3815
  if len(lists) == 2:
3816
+
3626
3817
  def get_count_and_percentage(set_count, subset_count):
3627
3818
  percent = subset_count / set_count if set_count > 0 else 0
3628
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3819
+ return (
3820
+ f"{subset_count}\n({fmt.format(percent)})"
3821
+ if show_percentages
3822
+ else f"{subset_count}"
3823
+ )
3629
3824
 
3630
3825
  from matplotlib_venn import venn2, venn2_circles
3826
+
3631
3827
  # Auto-calculate overlap color for 2-set Venn diagram
3632
3828
  overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
3633
-
3829
+
3634
3830
  # Draw the venn diagram
3635
3831
  v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3636
3832
  venn_circles = venn2_circles(subsets=lists, ax=ax)
3637
- set1,set2=lists[0],lists[1]
3638
- v.get_patch_by_id('10').set_color(colors[0])
3639
- v.get_patch_by_id('01').set_color(colors[1])
3640
- v.get_patch_by_id('11').set_color(get_color_overlap(colors[0], colors[1]) if colors else None)
3833
+ set1, set2 = lists[0], lists[1]
3834
+ v.get_patch_by_id("10").set_color(colors[0])
3835
+ v.get_patch_by_id("01").set_color(colors[1])
3836
+ v.get_patch_by_id("11").set_color(
3837
+ get_color_overlap(colors[0], colors[1]) if colors else None
3838
+ )
3641
3839
  # v.get_label_by_id('10').set_text(len(set1 - set2))
3642
3840
  # v.get_label_by_id('01').set_text(len(set2 - set1))
3643
3841
  # v.get_label_by_id('11').set_text(len(set1 & set2))
3644
- v.get_label_by_id('10').set_text(get_count_and_percentage(len(set1), len(set1 - set2)))
3645
- v.get_label_by_id('01').set_text(get_count_and_percentage(len(set2), len(set2 - set1)))
3646
- v.get_label_by_id('11').set_text(get_count_and_percentage(len(set1 | set2), len(set1 & set2)))
3647
-
3842
+ v.get_label_by_id("10").set_text(
3843
+ get_count_and_percentage(len(set1), len(set1 - set2))
3844
+ )
3845
+ v.get_label_by_id("01").set_text(
3846
+ get_count_and_percentage(len(set2), len(set2 - set1))
3847
+ )
3848
+ v.get_label_by_id("11").set_text(
3849
+ get_count_and_percentage(len(set1 | set2), len(set1 & set2))
3850
+ )
3648
3851
 
3649
- if not isinstance(linewidth,list):
3650
- linewidth=[linewidth]
3651
- if isinstance(linestyle,str):
3652
- linestyle=[linestyle]
3852
+ if not isinstance(linewidth, list):
3853
+ linewidth = [linewidth]
3854
+ if isinstance(linestyle, str):
3855
+ linestyle = [linestyle]
3653
3856
  if not isinstance(edgecolor, list):
3654
- edgecolor=[edgecolor]
3655
- linewidth=linewidth*2 if len(linewidth)==1 else linewidth
3656
- linestyle=linestyle*2 if len(linestyle)==1 else linestyle
3657
- edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
3857
+ edgecolor = [edgecolor]
3858
+ linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
3859
+ linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
3860
+ edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
3658
3861
  for i in range(2):
3659
3862
  venn_circles[i].set_lw(linewidth[i])
3660
3863
  venn_circles[i].set_ls(linestyle[i])
3661
3864
  venn_circles[i].set_edgecolor(edgecolor[i])
3662
3865
  # 椭圆
3663
3866
  if ellipse_shape:
3664
- import matplotlib.patches as patches
3867
+ import matplotlib.patches as patches
3868
+
3665
3869
  for patch in v.patches:
3666
3870
  patch.set_visible(False) # Hide original patches if using ellipses
3667
3871
  center1 = v.get_circle_center(0)
@@ -3672,9 +3876,13 @@ def venn(
3672
3876
  height=ellipse_scale[1],
3673
3877
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3674
3878
  facecolor=colors[0],
3675
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3879
+ lw=(
3880
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3881
+ ), # Ensure lw is a number
3676
3882
  ls=linestyle[0],
3677
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3883
+ alpha=(
3884
+ alpha if isinstance(alpha, (int, float)) else 0.5
3885
+ ), # Ensure alpha is a number
3678
3886
  )
3679
3887
  ellipse2 = patches.Ellipse(
3680
3888
  (center2.x, center2.y),
@@ -3682,48 +3890,78 @@ def venn(
3682
3890
  height=ellipse_scale[1],
3683
3891
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3684
3892
  facecolor=colors[1],
3685
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3893
+ lw=(
3894
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3895
+ ), # Ensure lw is a number
3686
3896
  ls=linestyle[0],
3687
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3897
+ alpha=(
3898
+ alpha if isinstance(alpha, (int, float)) else 0.5
3899
+ ), # Ensure alpha is a number
3688
3900
  )
3689
3901
  ax.add_patch(ellipse1)
3690
3902
  ax.add_patch(ellipse2)
3691
3903
  # Apply styles to set labels
3692
3904
  for i, text in enumerate(v.set_labels):
3693
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3694
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3905
+ textsets(
3906
+ text,
3907
+ fontname=fontname,
3908
+ fontsize=fontsize,
3909
+ fontweight=fontweight,
3910
+ fontstyle=fontstyle,
3911
+ fontcolor=fontcolor,
3912
+ ha=ha,
3913
+ va=va,
3914
+ shadow=shadow,
3915
+ )
3695
3916
 
3696
3917
  # Apply styles to subset labels
3697
3918
  for i, text in enumerate(v.subset_labels):
3698
3919
  if text: # Ensure text exists
3699
3920
  if custom_texts: # Custom text handling
3700
- text.set_text(custom_texts[i])
3701
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3702
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
3921
+ text.set_text(custom_texts[i])
3922
+ textsets(
3923
+ text,
3924
+ fontname=fontname,
3925
+ fontsize=subset_fontsize,
3926
+ fontweight=subset_fontweight,
3927
+ fontstyle=subset_fontstyle,
3928
+ fontcolor=subset_fontcolor,
3929
+ ha=ha,
3930
+ va=va,
3931
+ shadow=shadow,
3932
+ )
3703
3933
 
3704
3934
  elif len(lists) == 3:
3935
+
3705
3936
  def get_label(set_count, subset_count):
3706
3937
  percent = subset_count / set_count if set_count > 0 else 0
3707
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3708
-
3938
+ return (
3939
+ f"{subset_count}\n({fmt.format(percent)})"
3940
+ if show_percentages
3941
+ else f"{subset_count}"
3942
+ )
3943
+
3709
3944
  from matplotlib_venn import venn3, venn3_circles
3945
+
3710
3946
  # Auto-calculate overlap colors for 3-set Venn diagram
3711
3947
  colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
3712
3948
  colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
3713
3949
  colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
3714
- colorABC = get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3715
- set1,set2,set3=lists[0],lists[1],lists[2]
3950
+ colorABC = (
3951
+ get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3952
+ )
3953
+ set1, set2, set3 = lists[0], lists[1], lists[2]
3716
3954
 
3717
3955
  # Draw the venn diagram
3718
- v = venn3(subsets=lists, set_labels=labels, ax=ax,**kwargs)
3719
- v.get_patch_by_id('100').set_color(colors[0])
3720
- v.get_patch_by_id('010').set_color(colors[1])
3721
- v.get_patch_by_id('001').set_color(colors[2])
3722
- v.get_patch_by_id('110').set_color(colorAB)
3723
- v.get_patch_by_id('101').set_color(colorAC)
3724
- v.get_patch_by_id('011').set_color(colorBC)
3725
- v.get_patch_by_id('111').set_color(colorABC)
3726
-
3956
+ v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3957
+ v.get_patch_by_id("100").set_color(colors[0])
3958
+ v.get_patch_by_id("010").set_color(colors[1])
3959
+ v.get_patch_by_id("001").set_color(colors[2])
3960
+ v.get_patch_by_id("110").set_color(colorAB)
3961
+ v.get_patch_by_id("101").set_color(colorAC)
3962
+ v.get_patch_by_id("011").set_color(colorBC)
3963
+ v.get_patch_by_id("111").set_color(colorABC)
3964
+
3727
3965
  # Correctly labeling subset sizes
3728
3966
  # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
3729
3967
  # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
@@ -3732,63 +3970,93 @@ def venn(
3732
3970
  # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
3733
3971
  # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
3734
3972
  # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
3735
- v.get_label_by_id('100').set_text(get_label(len(set1), len(set1 - set2 - set3)))
3736
- v.get_label_by_id('010').set_text(get_label(len(set2), len(set2 - set1 - set3)))
3737
- v.get_label_by_id('001').set_text(get_label(len(set3), len(set3 - set1 - set2)))
3738
- v.get_label_by_id('110').set_text(get_label(len(set1 | set2), len(set1 & set2 - set3)))
3739
- v.get_label_by_id('101').set_text(get_label(len(set1 | set3), len(set1 & set3 - set2)))
3740
- v.get_label_by_id('011').set_text(get_label(len(set2 | set3), len(set2 & set3 - set1)))
3741
- v.get_label_by_id('111').set_text(get_label(len(set1 | set2 | set3), len(set1 & set2 & set3)))
3742
-
3973
+ v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
3974
+ v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
3975
+ v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
3976
+ v.get_label_by_id("110").set_text(
3977
+ get_label(len(set1 | set2), len(set1 & set2 - set3))
3978
+ )
3979
+ v.get_label_by_id("101").set_text(
3980
+ get_label(len(set1 | set3), len(set1 & set3 - set2))
3981
+ )
3982
+ v.get_label_by_id("011").set_text(
3983
+ get_label(len(set2 | set3), len(set2 & set3 - set1))
3984
+ )
3985
+ v.get_label_by_id("111").set_text(
3986
+ get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
3987
+ )
3743
3988
 
3744
3989
  # Apply styles to set labels
3745
3990
  for i, text in enumerate(v.set_labels):
3746
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3747
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3991
+ textsets(
3992
+ text,
3993
+ fontname=fontname,
3994
+ fontsize=fontsize,
3995
+ fontweight=fontweight,
3996
+ fontstyle=fontstyle,
3997
+ fontcolor=fontcolor,
3998
+ ha=ha,
3999
+ va=va,
4000
+ shadow=shadow,
4001
+ )
3748
4002
 
3749
4003
  # Apply styles to subset labels
3750
4004
  for i, text in enumerate(v.subset_labels):
3751
4005
  if text: # Ensure text exists
3752
4006
  if custom_texts: # Custom text handling
3753
4007
  text.set_text(custom_texts[i])
3754
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3755
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
4008
+ textsets(
4009
+ text,
4010
+ fontname=fontname,
4011
+ fontsize=subset_fontsize,
4012
+ fontweight=subset_fontweight,
4013
+ fontstyle=subset_fontstyle,
4014
+ fontcolor=subset_fontcolor,
4015
+ ha=ha,
4016
+ va=va,
4017
+ shadow=shadow,
4018
+ )
3756
4019
 
3757
4020
  venn_circles = venn3_circles(subsets=lists, ax=ax)
3758
- if not isinstance(linewidth,list):
3759
- linewidth=[linewidth]
3760
- if isinstance(linestyle,str):
3761
- linestyle=[linestyle]
4021
+ if not isinstance(linewidth, list):
4022
+ linewidth = [linewidth]
4023
+ if isinstance(linestyle, str):
4024
+ linestyle = [linestyle]
3762
4025
  if not isinstance(edgecolor, list):
3763
- edgecolor=[edgecolor]
3764
- linewidth=linewidth*3 if len(linewidth)==1 else linewidth
3765
- linestyle=linestyle*3 if len(linestyle)==1 else linestyle
3766
- edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
4026
+ edgecolor = [edgecolor]
4027
+ linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
4028
+ linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
4029
+ edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
3767
4030
 
3768
4031
  # edgecolor=[to_rgba(i) for i in edgecolor]
3769
4032
 
3770
- for i in range(3):
4033
+ for i in range(3):
3771
4034
  venn_circles[i].set_lw(linewidth[i])
3772
- venn_circles[i].set_ls(linestyle[i])
3773
- venn_circles[i].set_edgecolor(edgecolor[i])
4035
+ venn_circles[i].set_ls(linestyle[i])
4036
+ venn_circles[i].set_edgecolor(edgecolor[i])
3774
4037
 
3775
- #椭圆形
4038
+ # 椭圆形
3776
4039
  if ellipse_shape:
3777
- import matplotlib.patches as patches
4040
+ import matplotlib.patches as patches
4041
+
3778
4042
  for patch in v.patches:
3779
4043
  patch.set_visible(False) # Hide original patches if using ellipses
3780
4044
  center1 = v.get_circle_center(0)
3781
4045
  center2 = v.get_circle_center(1)
3782
- center3 = v.get_circle_center(2)
4046
+ center3 = v.get_circle_center(2)
3783
4047
  ellipse1 = patches.Ellipse(
3784
4048
  (center1.x, center1.y),
3785
4049
  width=ellipse_scale[0],
3786
4050
  height=ellipse_scale[1],
3787
4051
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3788
4052
  facecolor=colors[0],
3789
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4053
+ lw=(
4054
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4055
+ ), # Ensure lw is a number
3790
4056
  ls=linestyle[0],
3791
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4057
+ alpha=(
4058
+ alpha if isinstance(alpha, (int, float)) else 0.5
4059
+ ), # Ensure alpha is a number
3792
4060
  )
3793
4061
  ellipse2 = patches.Ellipse(
3794
4062
  (center2.x, center2.y),
@@ -3796,9 +4064,13 @@ def venn(
3796
4064
  height=ellipse_scale[1],
3797
4065
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3798
4066
  facecolor=colors[1],
3799
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4067
+ lw=(
4068
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4069
+ ), # Ensure lw is a number
3800
4070
  ls=linestyle[0],
3801
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4071
+ alpha=(
4072
+ alpha if isinstance(alpha, (int, float)) else 0.5
4073
+ ), # Ensure alpha is a number
3802
4074
  )
3803
4075
  ellipse3 = patches.Ellipse(
3804
4076
  (center3.x, center3.y),
@@ -3806,9 +4078,13 @@ def venn(
3806
4078
  height=ellipse_scale[1],
3807
4079
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3808
4080
  facecolor=colors[1],
3809
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4081
+ lw=(
4082
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4083
+ ), # Ensure lw is a number
3810
4084
  ls=linestyle[0],
3811
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4085
+ alpha=(
4086
+ alpha if isinstance(alpha, (int, float)) else 0.5
4087
+ ), # Ensure alpha is a number
3812
4088
  )
3813
4089
  ax.add_patch(ellipse1)
3814
4090
  ax.add_patch(ellipse2)
@@ -3820,17 +4096,20 @@ def venn(
3820
4096
  for patch in v.patches:
3821
4097
  if patch:
3822
4098
  patch.set_alpha(alpha)
3823
- if 'none' in edgecolor or 0 in linewidth:
4099
+ if "none" in edgecolor or 0 in linewidth:
3824
4100
  patch.set_edgecolor("none")
3825
4101
  return ax
3826
4102
 
4103
+
3827
4104
  #! subplots, support automatic extend new axis
3828
- def subplot(rows:int=2,
3829
- cols:int=2,
3830
- figsize:Union[tuple,list]=[8, 8],
3831
- sharex=False,
3832
- sharey=False,
3833
- **kwargs):
4105
+ def subplot(
4106
+ rows: int = 2,
4107
+ cols: int = 2,
4108
+ figsize: Union[tuple, list] = [8, 8],
4109
+ sharex=False,
4110
+ sharey=False,
4111
+ **kwargs,
4112
+ ):
3834
4113
  """
3835
4114
  nexttile = subplot(
3836
4115
  8,
@@ -3850,11 +4129,14 @@ def subplot(rows:int=2,
3850
4129
  ax.set_xlabel(f"Tile {i + 1}")
3851
4130
  """
3852
4131
  from matplotlib.gridspec import GridSpec
4132
+
3853
4133
  if run_once_within():
3854
- print(f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()")
4134
+ print(
4135
+ f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
4136
+ )
3855
4137
  fig = plt.figure(figsize=figsize)
3856
4138
  grid_spec = GridSpec(rows, cols, figure=fig)
3857
- occupied = set()
4139
+ occupied = set()
3858
4140
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
3859
4141
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
3860
4142
 
@@ -3862,8 +4144,9 @@ def subplot(rows:int=2,
3862
4144
  nonlocal rows, grid_spec
3863
4145
  rows += 1 # Expands by adding a row
3864
4146
  grid_spec = GridSpec(rows, cols, figure=fig)
4147
+
3865
4148
  def nexttile(rowspan=1, colspan=1, **kwargs):
3866
- nonlocal rows, cols, occupied, grid_spec
4149
+ nonlocal rows, cols, occupied, grid_spec
3867
4150
  for row in range(rows):
3868
4151
  for col in range(cols):
3869
4152
  if all(
@@ -3873,29 +4156,29 @@ def subplot(rows:int=2,
3873
4156
  ):
3874
4157
  break
3875
4158
  else:
3876
- continue
4159
+ continue
3877
4160
  break
3878
- else:
4161
+ else:
3879
4162
  expand_ax()
3880
4163
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
3881
4164
 
3882
- sharex_ax,sharey_ax = None, None
4165
+ sharex_ax, sharey_ax = None, None
3883
4166
 
3884
- if sharex:
4167
+ if sharex:
3885
4168
  sharex_ax = col_first_axes[col]
3886
4169
 
3887
- if sharey:
3888
- sharey_ax = row_first_axes[row]
4170
+ if sharey:
4171
+ sharey_ax = row_first_axes[row]
3889
4172
  ax = fig.add_subplot(
3890
4173
  grid_spec[row : row + rowspan, col : col + colspan],
3891
4174
  sharex=sharex_ax,
3892
4175
  sharey=sharey_ax,
3893
- **kwargs
3894
- )
4176
+ **kwargs,
4177
+ )
3895
4178
  if row_first_axes[row] is None:
3896
4179
  row_first_axes[row] = ax
3897
4180
  if col_first_axes[col] is None:
3898
- col_first_axes[col] = ax
4181
+ col_first_axes[col] = ax
3899
4182
  for r in range(row, row + rowspan):
3900
4183
  for c in range(col, col + colspan):
3901
4184
  occupied.add((r, c))
@@ -3907,17 +4190,17 @@ def subplot(rows:int=2,
3907
4190
 
3908
4191
  #! radar chart
3909
4192
  def radar(
3910
- data: pd.DataFrame,
3911
- ylim=(0,100),
4193
+ data: pd.DataFrame,
4194
+ ylim=(0, 100),
3912
4195
  color=get_color(5),
3913
4196
  fontsize=10,
3914
- fontcolor='k',
4197
+ fontcolor="k",
3915
4198
  size=6,
3916
4199
  linewidth=1,
3917
4200
  linestyle="-",
3918
4201
  alpha=0.5,
3919
4202
  marker="o",
3920
- edgecolor='none',
4203
+ edgecolor="none",
3921
4204
  edge_linewidth=0,
3922
4205
  bg_color="0.8",
3923
4206
  bg_alpha=None,
@@ -3928,18 +4211,19 @@ def radar(
3928
4211
  legend_fontsize=10,
3929
4212
  grid_color="gray",
3930
4213
  grid_alpha=0.5,
3931
- grid_linestyle="--",grid_linewidth=0.5,
4214
+ grid_linestyle="--",
4215
+ grid_linewidth=0.5,
3932
4216
  circular: bool = False,
3933
4217
  tick_fontsize=None,
3934
4218
  tick_fontcolor="0.65",
3935
- tick_loc = None,# label position
3936
- turning = None,
4219
+ tick_loc=None, # label position
4220
+ turning=None,
3937
4221
  ax=None,
3938
4222
  sp=2,
3939
- **kwargs
4223
+ **kwargs,
3940
4224
  ):
3941
4225
  """
3942
- Example DATA:
4226
+ Example DATA:
3943
4227
  df = pd.DataFrame(
3944
4228
  data=[
3945
4229
  [80, 80, 80, 80, 80, 80, 80],
@@ -3948,7 +4232,7 @@ def radar(
3948
4232
  ],
3949
4233
  index=["Hero", "Warrior", "Wizard"],
3950
4234
  columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
3951
-
4235
+
3952
4236
  Parameters:
3953
4237
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
3954
4238
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4002,8 +4286,12 @@ def radar(
4002
4286
 
4003
4287
  # bg_color
4004
4288
  if bg_alpha is None:
4005
- bg_alpha=alpha
4006
- ax.set_facecolor(to_rgba(bg_color,alpha=bg_alpha)) if circular else ax.set_facecolor('none')
4289
+ bg_alpha = alpha
4290
+ (
4291
+ ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
4292
+ if circular
4293
+ else ax.set_facecolor("none")
4294
+ )
4007
4295
  # Set up the radar chart with straight-line connections
4008
4296
  ax.set_theta_offset(np.pi / 2)
4009
4297
  ax.set_theta_direction(-1)
@@ -4012,53 +4300,68 @@ def radar(
4012
4300
  ax.set_xticks(angles[:-1])
4013
4301
  ax.set_xticklabels(categories)
4014
4302
 
4015
- # Set y-axis limits and grid intervals
4303
+ # Set y-axis limits and grid intervals
4016
4304
  vmin, vmax = ylim
4017
- if circular:
4018
- #* cicular style
4019
- ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
4020
- ax.grid(axis='both',
4021
- color=grid_color,
4022
- linestyle=grid_linestyle,
4023
- alpha=grid_alpha,
4024
- linewidth=grid_linewidth,
4025
- dash_capstyle='round',
4026
- dash_joinstyle='round',
4027
- )
4028
- ax.spines["polar"].set_color(grid_color)
4305
+ if circular:
4306
+ # * cicular style
4307
+ ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
4308
+ ax.grid(
4309
+ axis="both",
4310
+ color=grid_color,
4311
+ linestyle=grid_linestyle,
4312
+ alpha=grid_alpha,
4313
+ linewidth=grid_linewidth,
4314
+ dash_capstyle="round",
4315
+ dash_joinstyle="round",
4316
+ )
4317
+ ax.spines["polar"].set_color(grid_color)
4029
4318
  ax.spines["polar"].set_linewidth(grid_linewidth)
4030
- ax.spines["polar"].set_linestyle('-')
4319
+ ax.spines["polar"].set_linestyle("-")
4031
4320
  ax.spines["polar"].set_alpha(grid_alpha)
4032
- ax.spines["polar"].set_capstyle('round')
4033
- ax.spines["polar"].set_joinstyle('round')
4321
+ ax.spines["polar"].set_capstyle("round")
4322
+ ax.spines["polar"].set_joinstyle("round")
4034
4323
 
4035
4324
  else:
4036
- #* spider style: spider-style grid (straight lines, not circles)
4325
+ # * spider style: spider-style grid (straight lines, not circles)
4037
4326
  # Create the spider-style grid (straight lines, not circles)
4038
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4327
+ for i in range(1, int(vmax * grid_interval_ratio) + 1):
4039
4328
  ax.plot(
4040
4329
  angles + [angles[0]], # Closing the loop
4041
- [i * vmax * grid_interval_ratio] * (num_vars+1) + [i * vmax * grid_interval_ratio],
4042
- color=grid_color, linestyle=grid_linestyle, alpha=grid_alpha,linewidth=grid_linewidth
4330
+ [i * vmax * grid_interval_ratio] * (num_vars + 1)
4331
+ + [i * vmax * grid_interval_ratio],
4332
+ color=grid_color,
4333
+ linestyle=grid_linestyle,
4334
+ alpha=grid_alpha,
4335
+ linewidth=grid_linewidth,
4043
4336
  )
4044
- # set bg_color
4045
- ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
4337
+ # set bg_color
4338
+ ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4046
4339
  ax.yaxis.grid(False)
4047
4340
  # Move radial labels away from plotted line
4048
4341
  if tick_loc is None:
4049
- tick_loc = np.mean([angles[0],angles[1]])/(2*np.pi)*360 if circular else 0
4342
+ tick_loc = (
4343
+ np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
4344
+ )
4050
4345
 
4051
4346
  ax.set_rlabel_position(tick_loc)
4052
4347
  ax.set_theta_offset(turning) if turning is not None else None
4053
- ax.tick_params(axis='x', labelsize=fontsize, colors=fontcolor) # Optional: for angular labels
4054
- tick_fontsize = fontsize-2 if fontsize is None else tick_fontsize
4055
- ax.tick_params(axis='y', labelsize=tick_fontsize, colors=tick_fontcolor) # For radial labels
4348
+ ax.tick_params(
4349
+ axis="x", labelsize=fontsize, colors=fontcolor
4350
+ ) # Optional: for angular labels
4351
+ tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
4352
+ ax.tick_params(
4353
+ axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
4354
+ ) # For radial labels
4056
4355
  if not circular:
4057
- ax.spines['polar'].set_visible(False)
4058
- ax.tick_params(axis='x', pad=sp) # move spines outward
4059
- ax.tick_params(axis='y', pad=sp) # move spines outward
4356
+ ax.spines["polar"].set_visible(False)
4357
+ ax.tick_params(axis="x", pad=sp) # move spines outward
4358
+ ax.tick_params(axis="y", pad=sp) # move spines outward
4060
4359
  # colors
4061
- colors = get_color(data.shape[0]) if cmap is None else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4360
+ colors = (
4361
+ get_color(data.shape[0])
4362
+ if cmap is None
4363
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4364
+ )
4062
4365
  # Plot each row with straight lines
4063
4366
  for i, (index, row) in enumerate(data.iterrows()):
4064
4367
  values = row.tolist()
@@ -4070,7 +4373,7 @@ def radar(
4070
4373
  linewidth=linewidth,
4071
4374
  linestyle=linestyle,
4072
4375
  label=index,
4073
- clip_on=False
4376
+ clip_on=False,
4074
4377
  )
4075
4378
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4076
4379
 
@@ -4084,13 +4387,23 @@ def radar(
4084
4387
  marker=marker,
4085
4388
  markersize=size,
4086
4389
  markeredgecolor=edgecolor,
4087
- markeredgewidth = edge_linewidth, zorder=10,clip_on=False
4390
+ markeredgewidth=edge_linewidth,
4391
+ zorder=10,
4392
+ clip_on=False,
4088
4393
  )
4089
4394
  # ax.tick_params(axis='y', labelleft=False, left=False)
4090
- if 'legend' in kws_figsets:
4395
+ if "legend" in kws_figsets:
4091
4396
  figsets(ax=ax, **kws_figsets)
4092
4397
  else:
4093
-
4094
- figsets(ax=ax,legend=dict(loc=legend_loc,fontsize=legend_fontsize,
4095
- bbox_to_anchor=[1.1,1.4],ncols=2),**kws_figsets)
4096
- return ax
4398
+
4399
+ figsets(
4400
+ ax=ax,
4401
+ legend=dict(
4402
+ loc=legend_loc,
4403
+ fontsize=legend_fontsize,
4404
+ bbox_to_anchor=[1.1, 1.4],
4405
+ ncols=2,
4406
+ ),
4407
+ **kws_figsets,
4408
+ )
4409
+ return ax