py2ls 0.2.4.10.4__py3-none-any.whl → 0.2.4.10.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py CHANGED
@@ -1,21 +1,37 @@
1
1
  import numpy as np
2
- import pandas as pd
2
+ import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
4
  import matplotlib
5
5
  import seaborn as sns
6
6
  import warnings
7
7
  import logging
8
8
  from typing import Union
9
- from .ips import isa, fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font,run_once_within
9
+ from .ips import (
10
+ isa,
11
+ fsave,
12
+ fload,
13
+ mkdir,
14
+ listdir,
15
+ figsave,
16
+ strcmp,
17
+ unique,
18
+ get_os,
19
+ ssplit,
20
+ flatten,
21
+ plt_font,
22
+ run_once_within,
23
+ )
10
24
  from .stats import *
11
25
  import os
26
+
12
27
  # Suppress INFO messages from fontTools
13
28
  logging.getLogger("fontTools").setLevel(logging.ERROR)
14
- logging.getLogger('matplotlib').setLevel(logging.ERROR)
29
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
15
30
 
16
31
  warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
17
32
  warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
18
33
 
34
+
19
35
  def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
20
36
  """Adds text annotations for various types of Seaborn and Matplotlib plots.
21
37
  Args:
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
449
465
  """
450
466
  from matplotlib.colors import to_rgba
451
467
  import os
452
-
468
+
453
469
  def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
454
470
  if "l" in opt_b["loc"]:
455
471
  xloc_s = xloc - opt_b["x_dist"]
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
755
771
  )
756
772
  else:
757
773
  from scipy.stats import gaussian_kde
774
+
758
775
  kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
759
776
  min_val, max_val = ys.min(), ys.max()
760
777
  y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
@@ -1812,16 +1829,17 @@ def read_mplstyle(style_file):
1812
1829
  return style_dict
1813
1830
 
1814
1831
 
1815
- def figsets(*args, **kwargs):
1832
+ def figsets(*args, **kwargs):
1816
1833
  import matplotlib
1817
1834
  from cycler import cycler
1835
+
1818
1836
  matplotlib.rc("text", usetex=False)
1819
1837
 
1820
1838
  fig = plt.gcf()
1821
- fontsize = kwargs.get("fontsize",11)
1822
- plt.rcParams["font.size"]=fontsize
1823
- fontname = kwargs.pop("fontname","Arial")
1824
- fontname=plt_font(fontname) # 显示中文
1839
+ fontsize = kwargs.get("fontsize", 11)
1840
+ plt.rcParams["font.size"] = fontsize
1841
+ fontname = kwargs.pop("fontname", "Arial")
1842
+ fontname = plt_font(fontname) # 显示中文
1825
1843
 
1826
1844
  sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
1827
1845
  sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
@@ -1851,18 +1869,21 @@ def figsets(*args, **kwargs):
1851
1869
  nonlocal fontsize, fontname
1852
1870
  if ("fo" in key) and (("size" in key) or ("sz" in key)):
1853
1871
  fontsize = value
1854
- plt.rcParams.update({"font.size": fontsize,
1855
- "figure.titlesize":fontsize,
1856
- "axes.titlesize":fontsize,
1857
- "axes.labelsize": fontsize,
1858
- "xtick.labelsize": fontsize,
1859
- "ytick.labelsize": fontsize,
1860
- "legend.fontsize": fontsize,
1861
- "legend.title_fontsize":fontsize
1862
- })
1872
+ plt.rcParams.update(
1873
+ {
1874
+ "font.size": fontsize,
1875
+ "figure.titlesize": fontsize,
1876
+ "axes.titlesize": fontsize,
1877
+ "axes.labelsize": fontsize,
1878
+ "xtick.labelsize": fontsize,
1879
+ "ytick.labelsize": fontsize,
1880
+ "legend.fontsize": fontsize,
1881
+ "legend.title_fontsize": fontsize,
1882
+ }
1883
+ )
1863
1884
 
1864
1885
  # Customize tick labels
1865
- ax.tick_params(axis='both', which='major', labelsize=fontsize)
1886
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
1866
1887
  for label in ax.get_xticklabels() + ax.get_yticklabels():
1867
1888
  label.set_fontname(fontname)
1868
1889
 
@@ -1910,15 +1931,15 @@ def figsets(*args, **kwargs):
1910
1931
  if ("x" in key.lower()) and (
1911
1932
  "tic" not in key.lower() and "tk" not in key.lower()
1912
1933
  ):
1913
- ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
1934
+ ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
1914
1935
  if ("y" in key.lower()) and (
1915
1936
  "tic" not in key.lower() and "tk" not in key.lower()
1916
1937
  ):
1917
- ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
1938
+ ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
1918
1939
  if ("z" in key.lower()) and (
1919
1940
  "tic" not in key.lower() and "tk" not in key.lower()
1920
1941
  ):
1921
- ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
1942
+ ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
1922
1943
  if key == "xlabel" and isinstance(value, dict):
1923
1944
  ax.set_xlabel(**value)
1924
1945
  if key == "ylabel" and isinstance(value, dict):
@@ -2110,6 +2131,7 @@ def figsets(*args, **kwargs):
2110
2131
 
2111
2132
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
2112
2133
  import matplotlib.ticker as tck
2134
+
2113
2135
  if "x" in value.lower() or "x" in key.lower():
2114
2136
  ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
2115
2137
  if "y" in value.lower() or "y" in key.lower():
@@ -2162,9 +2184,9 @@ def figsets(*args, **kwargs):
2162
2184
  ax.grid(visible=False)
2163
2185
  if "tit" in key.lower():
2164
2186
  if "sup" in key.lower():
2165
- plt.suptitle(value,fontname=fontname,fontsize=fontsize)
2187
+ plt.suptitle(value, fontname=fontname, fontsize=fontsize)
2166
2188
  else:
2167
- ax.set_title(value,fontname=fontname,fontsize=fontsize)
2189
+ ax.set_title(value, fontname=fontname, fontsize=fontsize)
2168
2190
  if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
2169
2191
  if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
2170
2192
  if value:
@@ -2185,9 +2207,14 @@ def figsets(*args, **kwargs):
2185
2207
  legend = ax.get_legend()
2186
2208
  if legend is not None:
2187
2209
  legend.remove()
2188
- if any(['colorbar' in key.lower(),'cbar' in key.lower()]) and "loc" in key.lower():
2210
+ if (
2211
+ any(["colorbar" in key.lower(), "cbar" in key.lower()])
2212
+ and "loc" in key.lower()
2213
+ ):
2189
2214
  cbar = ax.collections[0].colorbar # Access the colorbar from the plot
2190
- cbar.ax.set_position(value) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2215
+ cbar.ax.set_position(
2216
+ value
2217
+ ) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2191
2218
 
2192
2219
  for arg in args:
2193
2220
  if isinstance(arg, matplotlib.axes._axes.Axes):
@@ -2225,7 +2252,8 @@ def figsets(*args, **kwargs):
2225
2252
  if len(fig.get_axes()) > 1:
2226
2253
  plt.tight_layout()
2227
2254
 
2228
- def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2255
+
2256
+ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2229
2257
  """
2230
2258
  split_legend(
2231
2259
  ax,
@@ -2238,15 +2266,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2238
2266
  # Retrieve all lines and labels from the axis
2239
2267
  handles, labels = ax.get_legend_handles_labels()
2240
2268
  num_labels = len(labels)
2241
-
2269
+
2242
2270
  # Calculate the number of labels per legend part
2243
- labels_per_part = (num_labels + n - 1) // n # Round up
2271
+ labels_per_part = (num_labels + n - 1) // n # Round up
2244
2272
  # Create a list to hold each legend object
2245
2273
  legends = []
2246
2274
 
2247
2275
  # Default locations and titles if not specified
2248
2276
  if loc is None:
2249
- loc = ['best'] * n
2277
+ loc = ["best"] * n
2250
2278
  if title is None:
2251
2279
  title = [None] * n
2252
2280
  if bbox is None:
@@ -2257,7 +2285,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
2285
  # Calculate the range of labels for this part
2258
2286
  start_idx = i * labels_per_part
2259
2287
  end_idx = min(start_idx + labels_per_part, num_labels)
2260
-
2288
+
2261
2289
  # Skip if no labels in this range
2262
2290
  if start_idx >= end_idx:
2263
2291
  break
@@ -2267,19 +2295,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2267
2295
  part_labels = labels[start_idx:end_idx]
2268
2296
 
2269
2297
  # Create the legend for this part
2270
- legend = ax.legend(handles=part_handles,
2271
- labels=part_labels,
2272
- loc=loc[i],
2273
- title=title[i],
2274
- ncol=ncol,
2275
- bbox_to_anchor=bbox[i],
2276
- **kwargs)
2277
-
2298
+ legend = ax.legend(
2299
+ handles=part_handles,
2300
+ labels=part_labels,
2301
+ loc=loc[i],
2302
+ title=title[i],
2303
+ ncol=ncol,
2304
+ bbox_to_anchor=bbox[i],
2305
+ **kwargs,
2306
+ )
2307
+
2278
2308
  # Add the legend to the axis and save it to the list
2279
- ax.add_artist(legend) if i !=(n-1) else None # the lastone will be added automaticaly
2309
+ (
2310
+ ax.add_artist(legend) if i != (n - 1) else None
2311
+ ) # the lastone will be added automaticaly
2280
2312
  legends.append(legend)
2281
2313
  return legends
2282
2314
 
2315
+
2283
2316
  def get_colors(
2284
2317
  n: int = 1,
2285
2318
  cmap: str = "auto",
@@ -2289,7 +2322,9 @@ def get_colors(
2289
2322
  *args,
2290
2323
  **kwargs,
2291
2324
  ):
2292
- return get_color(n,cmap,alpha,output,*args,**kwargs)
2325
+ return get_color(n, cmap, alpha, output, *args, **kwargs)
2326
+
2327
+
2293
2328
  def get_color(
2294
2329
  n: int = 1,
2295
2330
  cmap: str = "auto",
@@ -2300,6 +2335,7 @@ def get_color(
2300
2335
  **kwargs,
2301
2336
  ):
2302
2337
  from cycler import cycler
2338
+
2303
2339
  def cmap2hex(cmap_name):
2304
2340
  cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
2305
2341
  colors = [cmap_(i) for i in range(cmap_.N)]
@@ -2347,49 +2383,137 @@ def get_color(
2347
2383
  return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
2348
2384
  else:
2349
2385
  return "#{:02X}{:02X}{:02X}".format(r, g, b)
2350
-
2386
+
2351
2387
  # sc.pl.palettes.default_20
2352
- cmap_20= ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61',
2353
- '#17becf','#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2',
2354
- '#dbdb8d','#9edae5','#ad494a','#8c6d31']
2388
+ cmap_20 = [
2389
+ "#1f77b4",
2390
+ "#ff7f0e",
2391
+ "#279e68",
2392
+ "#d62728",
2393
+ "#aa40fc",
2394
+ "#8c564b",
2395
+ "#e377c2",
2396
+ "#b5bd61",
2397
+ "#17becf",
2398
+ "#aec7e8",
2399
+ "#ffbb78",
2400
+ "#98df8a",
2401
+ "#ff9896",
2402
+ "#c5b0d5",
2403
+ "#c49c94",
2404
+ "#f7b6d2",
2405
+ "#dbdb8d",
2406
+ "#9edae5",
2407
+ "#ad494a",
2408
+ "#8c6d31",
2409
+ ]
2355
2410
  # sc.pl.palettes.zeileis_28
2356
- cmap_28 = ["#023fa5","#7d87b9","#bec1d4","#d6bcc0","#bb7784","#8e063b","#4a6fe3","#8595e1",
2357
- "#b5bbe3","#e6afb9","#e07b91","#d33f6a","#11c638","#8dd593","#c6dec7","#ead3c6",
2358
- "#f0b98d","#ef9708","#0fcfc0","#9cded6","#d5eae7","#f3e1eb","#f6c4e1","#f79cd4",
2359
- "#7f7f7f","#c7c7c7","#1CE6FF","#336600"]
2411
+ cmap_28 = [
2412
+ "#023fa5",
2413
+ "#7d87b9",
2414
+ "#bec1d4",
2415
+ "#d6bcc0",
2416
+ "#bb7784",
2417
+ "#8e063b",
2418
+ "#4a6fe3",
2419
+ "#8595e1",
2420
+ "#b5bbe3",
2421
+ "#e6afb9",
2422
+ "#e07b91",
2423
+ "#d33f6a",
2424
+ "#11c638",
2425
+ "#8dd593",
2426
+ "#c6dec7",
2427
+ "#ead3c6",
2428
+ "#f0b98d",
2429
+ "#ef9708",
2430
+ "#0fcfc0",
2431
+ "#9cded6",
2432
+ "#d5eae7",
2433
+ "#f3e1eb",
2434
+ "#f6c4e1",
2435
+ "#f79cd4",
2436
+ "#7f7f7f",
2437
+ "#c7c7c7",
2438
+ "#1CE6FF",
2439
+ "#336600",
2440
+ ]
2360
2441
  if cmap == "gray":
2361
2442
  cmap = "grey"
2362
- elif cmap=="20":
2363
- cmap=cmap_20
2364
- elif cmap=="28":
2365
- cmap=cmap_28
2443
+ elif cmap == "20":
2444
+ cmap = cmap_20
2445
+ elif cmap == "28":
2446
+ cmap = cmap_28
2366
2447
  # Determine color list based on cmap parameter
2367
- if isinstance(cmap,str):
2448
+ if isinstance(cmap, str):
2368
2449
  if "aut" in cmap:
2369
2450
  if n == 1:
2370
2451
  colorlist = ["#3A4453"]
2371
2452
  elif n == 2:
2372
2453
  colorlist = ["#3A4453", "#FF2C00"]
2373
2454
  elif n == 3:
2374
- colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
2455
+ colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
2375
2456
  elif n == 4:
2376
- colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
2457
+ colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
2377
2458
  elif n == 5:
2378
- colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
2459
+ colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
2379
2460
  elif n == 6:
2380
- colorlist = ["#FF2C00","#91bfdb", "#B25E9D", "#4B8C3B","#EF8632", "#24578E"]
2381
- elif n==7:
2382
- colorlist = ["#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2383
- elif n==8:
2461
+ colorlist = [
2462
+ "#FF2C00",
2463
+ "#91bfdb",
2464
+ "#B25E9D",
2465
+ "#4B8C3B",
2466
+ "#EF8632",
2467
+ "#24578E",
2468
+ ]
2469
+ elif n == 7:
2470
+ colorlist = [
2471
+ "#7F7F7F",
2472
+ "#459AA9",
2473
+ "#B25E9D",
2474
+ "#4B8C3B",
2475
+ "#EF8632",
2476
+ "#24578E" "#FF2C00",
2477
+ ]
2478
+ elif n == 8:
2384
2479
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
2385
2480
  # colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
2386
- colorlist = ["#78BFDA","#D52E6F","#F7D648","#A52D28","#6B9F41","#E18330","#E18B9D","#3C88CC"]
2387
- elif n==9:
2388
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#d62728','#aa40fc','#e377c2','#51B34F','#17becf']
2389
- elif n==10:
2390
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#51B34F','#d62728''#aa40fc','#e377c2','#375FD2','#17becf']
2391
- elif 10<n<=20:
2392
- colorlist = cmap_20
2481
+ colorlist = [
2482
+ "#78BFDA",
2483
+ "#D52E6F",
2484
+ "#F7D648",
2485
+ "#A52D28",
2486
+ "#6B9F41",
2487
+ "#E18330",
2488
+ "#E18B9D",
2489
+ "#3C88CC",
2490
+ ]
2491
+ elif n == 9:
2492
+ colorlist = [
2493
+ "#1f77b4",
2494
+ "#ff7f0e",
2495
+ "#367B7F",
2496
+ "#ff9896",
2497
+ "#d62728",
2498
+ "#aa40fc",
2499
+ "#e377c2",
2500
+ "#51B34F",
2501
+ "#17becf",
2502
+ ]
2503
+ elif n == 10:
2504
+ colorlist = [
2505
+ "#1f77b4",
2506
+ "#ff7f0e",
2507
+ "#367B7F",
2508
+ "#ff9896",
2509
+ "#51B34F",
2510
+ "#d62728" "#aa40fc",
2511
+ "#e377c2",
2512
+ "#375FD2",
2513
+ "#17becf",
2514
+ ]
2515
+ elif 10 < n <= 20:
2516
+ colorlist = cmap_20
2393
2517
  else:
2394
2518
  colorlist = cmap_28
2395
2519
  by = "start"
@@ -2426,7 +2550,7 @@ def get_color(
2426
2550
  by = "linspace"
2427
2551
  colorlist = cmap2hex(cmap)
2428
2552
  elif isinstance(cmap, list):
2429
- colorlist=cmap
2553
+ colorlist = cmap
2430
2554
 
2431
2555
  # Determine method for generating color list
2432
2556
  if "st" in by.lower() or "be" in by.lower():
@@ -2446,7 +2570,7 @@ def get_color(
2446
2570
  return hue_list
2447
2571
  else:
2448
2572
  raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
2449
-
2573
+
2450
2574
 
2451
2575
  def stdshade(ax=None, *args, **kwargs):
2452
2576
  """
@@ -2454,7 +2578,7 @@ def stdshade(ax=None, *args, **kwargs):
2454
2578
  plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
2455
2579
  """
2456
2580
  from scipy.signal import savgol_filter
2457
-
2581
+
2458
2582
  # Separate kws_line and kws_fill if necessary
2459
2583
  kws_line = kwargs.pop("kws_line", {})
2460
2584
  kws_fill = kwargs.pop("kws_fill", {})
@@ -2724,37 +2848,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
2724
2848
  # cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
2725
2849
  # return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
2726
2850
 
2727
- def add_colorbar(im,
2728
- cmap="viridis",
2729
- vmin=-1,
2730
- vmax=1,
2731
- orientation='vertical',
2732
- width_ratio=0.05,
2733
- pad_ratio=0.02,
2734
- shrink=1.0,
2735
- **kwargs):
2851
+
2852
+ def add_colorbar(
2853
+ im,
2854
+ cmap="viridis",
2855
+ vmin=-1,
2856
+ vmax=1,
2857
+ orientation="vertical",
2858
+ width_ratio=0.05,
2859
+ pad_ratio=0.02,
2860
+ shrink=1.0,
2861
+ **kwargs,
2862
+ ):
2736
2863
  import matplotlib as mpl
2864
+
2737
2865
  if all([cmap, vmin, vmax]):
2738
2866
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
2739
2867
  else:
2740
- norm=False
2868
+ norm = False
2741
2869
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
2742
2870
  sm.set_array([])
2743
2871
  l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
2744
- if orientation == 'vertical':
2872
+ if orientation == "vertical":
2745
2873
  width = width_ratio * w
2746
2874
  pad = pad_ratio * w
2747
- cax = im.figure.add_axes([l + w + pad, b, width, h * shrink]) # Right of the image
2875
+ cax = im.figure.add_axes(
2876
+ [l + w + pad, b, width, h * shrink]
2877
+ ) # Right of the image
2748
2878
  else:
2749
2879
  height = width_ratio * h
2750
2880
  pad = pad_ratio * h
2751
- cax = im.figure.add_axes([l, b - height - pad, w * shrink, height]) # Below the image
2752
- cbar=im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2753
- return cbar
2881
+ cax = im.figure.add_axes(
2882
+ [l, b - height - pad, w * shrink, height]
2883
+ ) # Below the image
2884
+ cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2885
+ return cbar
2886
+
2754
2887
 
2755
2888
  # Usage:
2756
2889
  # add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
2757
2890
 
2891
+
2758
2892
  def generate_xticks_with_gap(x_len, hue_len):
2759
2893
  """
2760
2894
  Generate a concatenated array based on x_len and hue_len,
@@ -2884,6 +3018,7 @@ def style_examples(
2884
3018
  f = listdir(
2885
3019
  "/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
2886
3020
  kind=".json",
3021
+ verbose=False,
2887
3022
  )
2888
3023
  display(f.sample(2))
2889
3024
  # def style_example(dir_save,)
@@ -2961,6 +3096,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
2961
3096
 
2962
3097
  def get_params_from_func_usage(function_signature):
2963
3098
  import re
3099
+
2964
3100
  # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2965
3101
  keys_pattern = r"(?<!\*\*)\b(\w+)="
2966
3102
  # Find all matches
@@ -2987,7 +3123,7 @@ def plotxy(
2987
3123
  x=None,
2988
3124
  y=None,
2989
3125
  ax=None,
2990
- kind: Union[str, list] = 'scatter', # Specify the kind of plot
3126
+ kind: Union[str, list] = "scatter", # Specify the kind of plot
2991
3127
  verbose=False,
2992
3128
  **kwargs,
2993
3129
  ):
@@ -3011,14 +3147,15 @@ def plotxy(
3011
3147
  # Check for valid plot kind
3012
3148
  # Default arguments for various plot types
3013
3149
  from pathlib import Path
3150
+
3014
3151
  # Get the current script's directory as a Path object
3015
- current_directory = Path(__file__).resolve().parent
3152
+ current_directory = Path(__file__).resolve().parent
3153
+
3154
+ if not "default_settings" in locals():
3155
+ default_settings = fload(current_directory / "data" / "usages_sns.json")
3156
+ if not "sns_info" in locals():
3157
+ sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
3016
3158
 
3017
- if not 'default_settings' in locals():
3018
- default_settings = fload(current_directory / 'data' / 'usages_sns.json')
3019
- if not 'sns_info' in locals():
3020
- sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
-
3022
3159
  valid_kinds = list(default_settings.keys())
3023
3160
  # print(valid_kinds)
3024
3161
  if kind is not None:
@@ -3032,7 +3169,7 @@ def plotxy(
3032
3169
  if kind is not None:
3033
3170
  for k in kind:
3034
3171
  if k in valid_kinds:
3035
- print(f"{k}:\n\t{default_settings[k]}")
3172
+ print(f"{k}:\n\t{default_settings[k]}")
3036
3173
  usage_str = """plotxy(data=ranked_genes,
3037
3174
  x="log2(fold_change)",
3038
3175
  y="-log10(p-value)",
@@ -3057,15 +3194,25 @@ def plotxy(
3057
3194
  kws_add_text = v_arg
3058
3195
  kwargs.pop(k_arg, None)
3059
3196
  break
3060
- zorder=0
3197
+ zorder = 0
3061
3198
  for k in kind:
3062
- zorder+=1
3199
+ zorder += 1
3063
3200
  # indicate 'col' features
3064
3201
  col = kwargs.get("col", None)
3065
- sns_with_col = ["catplot","histplot","relplot","lmplot","pairplot","displot","kdeplot"]
3202
+ sns_with_col = [
3203
+ "catplot",
3204
+ "histplot",
3205
+ "relplot",
3206
+ "lmplot",
3207
+ "pairplot",
3208
+ "displot",
3209
+ "kdeplot",
3210
+ ]
3066
3211
  if col is not None:
3067
3212
  if not k in sns_with_col:
3068
- print(f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}")
3213
+ print(
3214
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
3215
+ )
3069
3216
  # (1) return FcetGrid
3070
3217
  if k == "jointplot":
3071
3218
  kws_joint = kwargs.pop("kws_joint", kwargs)
@@ -3092,77 +3239,98 @@ def plotxy(
3092
3239
  ax = stdshade(ax=ax, **kwargs)
3093
3240
  elif k == "scatterplot":
3094
3241
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3095
- kws_scatter={k: v for k, v in kws_scatter.items() if not k.startswith("kws_")}
3242
+ kws_scatter = {
3243
+ k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3244
+ }
3096
3245
  hue = kwargs.pop("hue", None)
3097
3246
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3098
3247
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3099
- palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3248
+ palette = kws_scatter.pop(
3249
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3250
+ )
3100
3251
  s = kws_scatter.pop("s", 10)
3101
3252
  alpha = kws_scatter.pop("alpha", 0.7)
3102
- ax = sns.scatterplot(ax=ax,data=data,x=x,y=y,hue=hue,palette=palette,s=s,alpha=alpha,zorder=zorder,**kws_scatter)
3253
+ ax = sns.scatterplot(
3254
+ ax=ax,
3255
+ data=data,
3256
+ x=x,
3257
+ y=y,
3258
+ hue=hue,
3259
+ palette=palette,
3260
+ s=s,
3261
+ alpha=alpha,
3262
+ zorder=zorder,
3263
+ **kws_scatter,
3264
+ )
3103
3265
  elif k == "histplot":
3104
3266
  kws_hist = kwargs.pop("kws_hist", kwargs)
3105
- kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3106
- ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
3267
+ kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3268
+ ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3107
3269
  elif k == "kdeplot":
3108
3270
  kws_kde = kwargs.pop("kws_kde", kwargs)
3109
- kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3110
- ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
3271
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3272
+ ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3111
3273
  elif k == "ecdfplot":
3112
3274
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3113
- kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3114
- ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
3275
+ kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3276
+ ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3115
3277
  elif k == "rugplot":
3116
3278
  kws_rug = kwargs.pop("kws_rug", kwargs)
3117
- kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3118
- ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
3279
+ kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3280
+ ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3119
3281
  elif k == "stripplot":
3120
3282
  kws_strip = kwargs.pop("kws_strip", kwargs)
3121
- kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3283
+ kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3122
3284
  dodge = kws_strip.pop("dodge", True)
3123
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax,zorder=zorder, dodge=dodge, **kws_strip)
3285
+ ax = sns.stripplot(
3286
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3287
+ )
3124
3288
  elif k == "swarmplot":
3125
3289
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3126
- kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3127
- ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
3290
+ kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3291
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
3128
3292
  elif k == "boxplot":
3129
3293
  kws_box = kwargs.pop("kws_box", kwargs)
3130
- kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3131
- ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
3294
+ kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3295
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
3132
3296
  elif k == "violinplot":
3133
3297
  kws_violin = kwargs.pop("kws_violin", kwargs)
3134
- kws_violin={k: v for k, v in kws_violin.items() if not k.startswith("kws_")}
3135
- ax = sns.violinplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_violin)
3298
+ kws_violin = {
3299
+ k: v for k, v in kws_violin.items() if not k.startswith("kws_")
3300
+ }
3301
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
3136
3302
  elif k == "boxenplot":
3137
3303
  kws_boxen = kwargs.pop("kws_boxen", kwargs)
3138
- kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3139
- ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
3304
+ kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3305
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
3140
3306
  elif k == "pointplot":
3141
3307
  kws_point = kwargs.pop("kws_point", kwargs)
3142
- kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3143
- ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
3308
+ kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3309
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
3144
3310
  elif k == "barplot":
3145
3311
  kws_bar = kwargs.pop("kws_bar", kwargs)
3146
- kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3147
- ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
3312
+ kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3313
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
3148
3314
  elif k == "countplot":
3149
3315
  kws_count = kwargs.pop("kws_count", kwargs)
3150
- kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3151
- if not kws_count.get("hue",None):
3152
- kws_count.pop("palette",None)
3153
- ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
3316
+ kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3317
+ if not kws_count.get("hue", None):
3318
+ kws_count.pop("palette", None)
3319
+ ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
3154
3320
  elif k == "regplot":
3155
3321
  kws_reg = kwargs.pop("kws_reg", kwargs)
3156
- kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3157
- ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
3322
+ kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3323
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
3158
3324
  elif k == "residplot":
3159
3325
  kws_resid = kwargs.pop("kws_resid", kwargs)
3160
- kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3161
- ax = sns.residplot(data=data, x=x, y=y, lowess=True,zorder=zorder, ax=ax, **kws_resid)
3326
+ kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3327
+ ax = sns.residplot(
3328
+ data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
3329
+ )
3162
3330
  elif k == "lineplot":
3163
3331
  kws_line = kwargs.pop("kws_line", kwargs)
3164
- kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3165
- ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
3332
+ kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3333
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
3166
3334
 
3167
3335
  figsets(ax=ax, **kws_figsets)
3168
3336
  if kws_add_text:
@@ -3176,20 +3344,22 @@ def plotxy(
3176
3344
  return g, ax
3177
3345
  return ax
3178
3346
 
3347
+
3179
3348
  def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
3180
3349
  norm_ = plt.Normalize(min_max[0], min_max[1])
3181
3350
  colormap = plt.get_cmap(cmap)
3182
3351
  return colormap(norm_(data))
3183
3352
 
3353
+
3184
3354
  def volcano(
3185
- data:pd.DataFrame,
3186
- x:str,
3187
- y:str,
3188
- gene_col:str=None,
3189
- top_genes=[5, 5], # [down-regulated, up-regulated]
3190
- thr_x=np.log2(1.5), # default: 0.585
3355
+ data: pd.DataFrame,
3356
+ x: str,
3357
+ y: str,
3358
+ gene_col: str = None,
3359
+ top_genes=[5, 5], # [down-regulated, up-regulated]
3360
+ thr_x=np.log2(1.5), # default: 0.585
3191
3361
  thr_y=-np.log10(0.05),
3192
- sort_xy="x", #'y', 'xy'
3362
+ sort_xy="x", #'y', 'xy'
3193
3363
  colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3194
3364
  s=20,
3195
3365
  fill=True, # plot filled scatter
@@ -3201,11 +3371,10 @@ def volcano(
3201
3371
  ax=None,
3202
3372
  verbose=False,
3203
3373
  kws_text=dict(fontsize=10, color="k"),
3204
- kws_bbox=dict(facecolor='none',
3205
- alpha=0.5,
3206
- edgecolor='black',
3207
- boxstyle='round,pad=0.3'),# '{}' to hide
3208
- kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
3374
+ kws_bbox=dict(
3375
+ facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
3376
+ ), # '{}' to hide
3377
+ kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
3209
3378
  **kwargs,
3210
3379
  ):
3211
3380
  """
@@ -3275,39 +3444,35 @@ def volcano(
3275
3444
  kws_figsets = v_arg
3276
3445
  kwargs.pop(k_arg, None)
3277
3446
  break
3278
-
3279
- data=data.copy()
3447
+
3448
+ data = data.copy()
3280
3449
  # filter nan
3281
3450
  data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
3282
- data.loc[:,"color"] = np.where(
3451
+ data.loc[:, "color"] = np.where(
3283
3452
  (data[x] > thr_x) & (data[y] > thr_y),
3284
3453
  colors[2],
3285
- np.where((data[x] < -thr_x) & (data[y] > thr_y),
3286
- colors[0],
3287
- colors[1]),
3454
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3288
3455
  )
3289
- top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
3290
-
3456
+ top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
3457
+
3291
3458
  # could custom how to select the top genes, x: x has priority
3292
- sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
3293
- ascending_up=[True, True] if sort_xy=="x" else [False, True]
3294
- ascending_down=[False, True] if sort_xy=="x" else [False, False]
3295
-
3296
- down_reg_genes = data[
3297
- (data["color"] == colors[0]) &
3298
- (data[x].abs() > thr_x) &
3299
- (data[y] > thr_y)
3300
- ].sort_values(by=sort_by_x_y, ascending=ascending_up).head(top_genes[0])
3301
- up_reg_genes = data[
3302
- (data["color"] == colors[2]) &
3303
- (data[x].abs() > thr_x) &
3304
- (data[y] > thr_y)
3305
- ].sort_values(by=sort_by_x_y, ascending=ascending_down).head(top_genes[1])
3459
+ sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
3460
+ ascending_up = [True, True] if sort_xy == "x" else [False, True]
3461
+ ascending_down = [False, True] if sort_xy == "x" else [False, False]
3462
+
3463
+ down_reg_genes = (
3464
+ data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3465
+ .sort_values(by=sort_by_x_y, ascending=ascending_up)
3466
+ .head(top_genes[0])
3467
+ )
3468
+ up_reg_genes = (
3469
+ data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3470
+ .sort_values(by=sort_by_x_y, ascending=ascending_down)
3471
+ .head(top_genes[1])
3472
+ )
3306
3473
  sele_gene = pd.concat([down_reg_genes, up_reg_genes])
3307
-
3308
- palette = {colors[0]: colors[0],
3309
- colors[1]: colors[1],
3310
- colors[2]: colors[2]}
3474
+
3475
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3311
3476
  # Plot setup
3312
3477
  if ax is None:
3313
3478
  ax = plt.gca()
@@ -3337,9 +3502,9 @@ def volcano(
3337
3502
  )
3338
3503
 
3339
3504
  # Add threshold lines for x and y axes
3340
- ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
3341
- ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
3342
- ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
3505
+ ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
3506
+ ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
3507
+ ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
3343
3508
 
3344
3509
  # Add gene labels for selected significant points
3345
3510
  if gene_col:
@@ -3349,19 +3514,28 @@ def volcano(
3349
3514
  textcolor = kws_text.pop("color", "k")
3350
3515
  fontsize = kws_text.pop("fontsize", 10)
3351
3516
  arrowstyles = [
3352
- "->","<-","<->","<|-","-|>","<|-|>",
3353
- "-","-[","-[",
3354
- "fancy","simple","wedge",
3517
+ "->",
3518
+ "<-",
3519
+ "<->",
3520
+ "<|-",
3521
+ "-|>",
3522
+ "<|-|>",
3523
+ "-",
3524
+ "-[",
3525
+ "-[",
3526
+ "fancy",
3527
+ "simple",
3528
+ "wedge",
3355
3529
  ]
3356
3530
  arrowstyle = kws_arrow.pop("style", "<|-")
3357
- arrowstyle = strcmp(arrowstyle, arrowstyles,scorer='strict')[0]
3358
- expand=kws_arrow.pop("expand",(1.05,1.1))
3531
+ arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
3532
+ expand = kws_arrow.pop("expand", (1.05, 1.1))
3359
3533
  arrowcolor = kws_arrow.pop("color", "0.4")
3360
3534
  arrowlinewidth = kws_arrow.pop("lw", 0.75)
3361
3535
  shrinkA = kws_arrow.pop("shrinkA", 0)
3362
3536
  shrinkB = kws_arrow.pop("shrinkB", 0)
3363
3537
  mutation_scale = kws_arrow.pop("head", 10)
3364
- arrow_fill=kws_arrow.pop("fill", False)
3538
+ arrow_fill = kws_arrow.pop("fill", False)
3365
3539
  for i in range(sele_gene.shape[0]):
3366
3540
  if isinstance(textcolor, list): # be consistant with dots's color
3367
3541
  textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
@@ -3382,7 +3556,7 @@ def volcano(
3382
3556
  adjust_text(
3383
3557
  texts,
3384
3558
  expand=expand,
3385
- min_arrow_len=5,
3559
+ min_arrow_len=5,
3386
3560
  ax=ax,
3387
3561
  arrowprops=dict(
3388
3562
  arrowstyle=arrowstyle,
@@ -3393,7 +3567,7 @@ def volcano(
3393
3567
  shrinkB=shrinkB,
3394
3568
  mutation_scale=mutation_scale,
3395
3569
  **kws_arrow,
3396
- )
3570
+ ),
3397
3571
  )
3398
3572
 
3399
3573
  figsets(**kws_figsets)
@@ -3499,6 +3673,7 @@ def desaturate_color(color, saturation_factor=0.5):
3499
3673
  """Reduce the saturation of a color by a given factor (between 0 and 1)."""
3500
3674
  import matplotlib.colors as mcolors
3501
3675
  import colorsys
3676
+
3502
3677
  # Convert the color to RGB
3503
3678
  rgb = mcolors.to_rgb(color)
3504
3679
  # Convert RGB to HLS (Hue, Lightness, Saturation)
@@ -3508,8 +3683,19 @@ def desaturate_color(color, saturation_factor=0.5):
3508
3683
  # Convert back to RGB
3509
3684
  return colorsys.hls_to_rgb(h, l, s)
3510
3685
 
3511
- def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle="normal",
3512
- fontcolor='k', backgroundcolor=None, shadow=False, ha="center", va="center"):
3686
+
3687
+ def textsets(
3688
+ text,
3689
+ fontname="Arial",
3690
+ fontsize=11,
3691
+ fontweight="normal",
3692
+ fontstyle="normal",
3693
+ fontcolor="k",
3694
+ backgroundcolor=None,
3695
+ shadow=False,
3696
+ ha="center",
3697
+ va="center",
3698
+ ):
3513
3699
  if text: # Ensure text exists
3514
3700
  if fontname:
3515
3701
  text.set_fontname(plt_font(fontname))
@@ -3522,26 +3708,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
3522
3708
  if fontcolor:
3523
3709
  text.set_color(fontcolor)
3524
3710
  if backgroundcolor:
3525
- text.set_backgroundcolor(backgroundcolor)
3711
+ text.set_backgroundcolor(backgroundcolor)
3526
3712
  text.set_horizontalalignment(ha)
3527
- text.set_verticalalignment(va)
3713
+ text.set_verticalalignment(va)
3528
3714
  if shadow:
3529
- text.set_path_effects([
3530
- matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
3531
- ])
3715
+ text.set_path_effects(
3716
+ [matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
3717
+ )
3718
+
3719
+
3532
3720
  def venn(
3533
- lists:list,
3534
- labels:list=None,
3721
+ lists: list,
3722
+ labels: list = None,
3535
3723
  ax=None,
3536
3724
  colors=None,
3537
3725
  edgecolor=None,
3538
3726
  alpha=0.5,
3539
- saturation=.75,
3540
- linewidth=0, # default no edge
3727
+ saturation=0.75,
3728
+ linewidth=0, # default no edge
3541
3729
  linestyle="-",
3542
- fontname='Arial',
3730
+ fontname="Arial",
3543
3731
  fontsize=10,
3544
- fontcolor='k',
3732
+ fontcolor="k",
3545
3733
  fontweight="normal",
3546
3734
  fontstyle="normal",
3547
3735
  ha="center",
@@ -3550,18 +3738,18 @@ def venn(
3550
3738
  subset_fontsize=10,
3551
3739
  subset_fontweight="normal",
3552
3740
  subset_fontstyle="normal",
3553
- subset_fontcolor='k',
3741
+ subset_fontcolor="k",
3554
3742
  backgroundcolor=None,
3555
3743
  custom_texts=None,
3556
- show_percentages=True, # display percentage
3744
+ show_percentages=True, # display percentage
3557
3745
  fmt="{:.1%}",
3558
3746
  ellipse_shape=False, # 椭圆形
3559
- ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
3560
- **kwargs
3747
+ ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
3748
+ **kwargs,
3561
3749
  ):
3562
3750
  """
3563
3751
  Advanced Venn diagram plotting function with extensive customization options.
3564
- Usage:
3752
+ Usage:
3565
3753
  # Define the two sets
3566
3754
  set1 = [1, 2, 3, 4, 5]
3567
3755
  set2 = [4, 5, 6, 7, 8]
@@ -3612,56 +3800,70 @@ def venn(
3612
3800
  """
3613
3801
  if ax is None:
3614
3802
  ax = plt.gca()
3615
- lists=[set(flatten(i, verbose=False)) for i in lists]
3803
+ lists = [set(flatten(i, verbose=False)) for i in lists]
3616
3804
  # Function to apply text styles to labels
3617
3805
  if colors is None:
3618
- colors=["r","b"] if len(lists)==2 else ["r","g","b"]
3806
+ colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
3619
3807
  if labels is None:
3620
- labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
3808
+ labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
3621
3809
  if edgecolor is None:
3622
- edgecolor=colors
3810
+ edgecolor = colors
3623
3811
  colors = [desaturate_color(color, saturation) for color in colors]
3624
3812
  # Check colors and auto-calculate overlaps
3625
3813
  if len(lists) == 2:
3814
+
3626
3815
  def get_count_and_percentage(set_count, subset_count):
3627
3816
  percent = subset_count / set_count if set_count > 0 else 0
3628
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3817
+ return (
3818
+ f"{subset_count}\n({fmt.format(percent)})"
3819
+ if show_percentages
3820
+ else f"{subset_count}"
3821
+ )
3629
3822
 
3630
3823
  from matplotlib_venn import venn2, venn2_circles
3824
+
3631
3825
  # Auto-calculate overlap color for 2-set Venn diagram
3632
3826
  overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
3633
-
3827
+
3634
3828
  # Draw the venn diagram
3635
3829
  v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3636
3830
  venn_circles = venn2_circles(subsets=lists, ax=ax)
3637
- set1,set2=lists[0],lists[1]
3638
- v.get_patch_by_id('10').set_color(colors[0])
3639
- v.get_patch_by_id('01').set_color(colors[1])
3640
- v.get_patch_by_id('11').set_color(get_color_overlap(colors[0], colors[1]) if colors else None)
3831
+ set1, set2 = lists[0], lists[1]
3832
+ v.get_patch_by_id("10").set_color(colors[0])
3833
+ v.get_patch_by_id("01").set_color(colors[1])
3834
+ v.get_patch_by_id("11").set_color(
3835
+ get_color_overlap(colors[0], colors[1]) if colors else None
3836
+ )
3641
3837
  # v.get_label_by_id('10').set_text(len(set1 - set2))
3642
3838
  # v.get_label_by_id('01').set_text(len(set2 - set1))
3643
3839
  # v.get_label_by_id('11').set_text(len(set1 & set2))
3644
- v.get_label_by_id('10').set_text(get_count_and_percentage(len(set1), len(set1 - set2)))
3645
- v.get_label_by_id('01').set_text(get_count_and_percentage(len(set2), len(set2 - set1)))
3646
- v.get_label_by_id('11').set_text(get_count_and_percentage(len(set1 | set2), len(set1 & set2)))
3647
-
3840
+ v.get_label_by_id("10").set_text(
3841
+ get_count_and_percentage(len(set1), len(set1 - set2))
3842
+ )
3843
+ v.get_label_by_id("01").set_text(
3844
+ get_count_and_percentage(len(set2), len(set2 - set1))
3845
+ )
3846
+ v.get_label_by_id("11").set_text(
3847
+ get_count_and_percentage(len(set1 | set2), len(set1 & set2))
3848
+ )
3648
3849
 
3649
- if not isinstance(linewidth,list):
3650
- linewidth=[linewidth]
3651
- if isinstance(linestyle,str):
3652
- linestyle=[linestyle]
3850
+ if not isinstance(linewidth, list):
3851
+ linewidth = [linewidth]
3852
+ if isinstance(linestyle, str):
3853
+ linestyle = [linestyle]
3653
3854
  if not isinstance(edgecolor, list):
3654
- edgecolor=[edgecolor]
3655
- linewidth=linewidth*2 if len(linewidth)==1 else linewidth
3656
- linestyle=linestyle*2 if len(linestyle)==1 else linestyle
3657
- edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
3855
+ edgecolor = [edgecolor]
3856
+ linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
3857
+ linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
3858
+ edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
3658
3859
  for i in range(2):
3659
3860
  venn_circles[i].set_lw(linewidth[i])
3660
3861
  venn_circles[i].set_ls(linestyle[i])
3661
3862
  venn_circles[i].set_edgecolor(edgecolor[i])
3662
3863
  # 椭圆
3663
3864
  if ellipse_shape:
3664
- import matplotlib.patches as patches
3865
+ import matplotlib.patches as patches
3866
+
3665
3867
  for patch in v.patches:
3666
3868
  patch.set_visible(False) # Hide original patches if using ellipses
3667
3869
  center1 = v.get_circle_center(0)
@@ -3672,9 +3874,13 @@ def venn(
3672
3874
  height=ellipse_scale[1],
3673
3875
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3674
3876
  facecolor=colors[0],
3675
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3877
+ lw=(
3878
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3879
+ ), # Ensure lw is a number
3676
3880
  ls=linestyle[0],
3677
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3881
+ alpha=(
3882
+ alpha if isinstance(alpha, (int, float)) else 0.5
3883
+ ), # Ensure alpha is a number
3678
3884
  )
3679
3885
  ellipse2 = patches.Ellipse(
3680
3886
  (center2.x, center2.y),
@@ -3682,48 +3888,78 @@ def venn(
3682
3888
  height=ellipse_scale[1],
3683
3889
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3684
3890
  facecolor=colors[1],
3685
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3891
+ lw=(
3892
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3893
+ ), # Ensure lw is a number
3686
3894
  ls=linestyle[0],
3687
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3895
+ alpha=(
3896
+ alpha if isinstance(alpha, (int, float)) else 0.5
3897
+ ), # Ensure alpha is a number
3688
3898
  )
3689
3899
  ax.add_patch(ellipse1)
3690
3900
  ax.add_patch(ellipse2)
3691
3901
  # Apply styles to set labels
3692
3902
  for i, text in enumerate(v.set_labels):
3693
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3694
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3903
+ textsets(
3904
+ text,
3905
+ fontname=fontname,
3906
+ fontsize=fontsize,
3907
+ fontweight=fontweight,
3908
+ fontstyle=fontstyle,
3909
+ fontcolor=fontcolor,
3910
+ ha=ha,
3911
+ va=va,
3912
+ shadow=shadow,
3913
+ )
3695
3914
 
3696
3915
  # Apply styles to subset labels
3697
3916
  for i, text in enumerate(v.subset_labels):
3698
3917
  if text: # Ensure text exists
3699
3918
  if custom_texts: # Custom text handling
3700
- text.set_text(custom_texts[i])
3701
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3702
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
3919
+ text.set_text(custom_texts[i])
3920
+ textsets(
3921
+ text,
3922
+ fontname=fontname,
3923
+ fontsize=subset_fontsize,
3924
+ fontweight=subset_fontweight,
3925
+ fontstyle=subset_fontstyle,
3926
+ fontcolor=subset_fontcolor,
3927
+ ha=ha,
3928
+ va=va,
3929
+ shadow=shadow,
3930
+ )
3703
3931
 
3704
3932
  elif len(lists) == 3:
3933
+
3705
3934
  def get_label(set_count, subset_count):
3706
3935
  percent = subset_count / set_count if set_count > 0 else 0
3707
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3708
-
3936
+ return (
3937
+ f"{subset_count}\n({fmt.format(percent)})"
3938
+ if show_percentages
3939
+ else f"{subset_count}"
3940
+ )
3941
+
3709
3942
  from matplotlib_venn import venn3, venn3_circles
3943
+
3710
3944
  # Auto-calculate overlap colors for 3-set Venn diagram
3711
3945
  colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
3712
3946
  colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
3713
3947
  colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
3714
- colorABC = get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3715
- set1,set2,set3=lists[0],lists[1],lists[2]
3948
+ colorABC = (
3949
+ get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3950
+ )
3951
+ set1, set2, set3 = lists[0], lists[1], lists[2]
3716
3952
 
3717
3953
  # Draw the venn diagram
3718
- v = venn3(subsets=lists, set_labels=labels, ax=ax,**kwargs)
3719
- v.get_patch_by_id('100').set_color(colors[0])
3720
- v.get_patch_by_id('010').set_color(colors[1])
3721
- v.get_patch_by_id('001').set_color(colors[2])
3722
- v.get_patch_by_id('110').set_color(colorAB)
3723
- v.get_patch_by_id('101').set_color(colorAC)
3724
- v.get_patch_by_id('011').set_color(colorBC)
3725
- v.get_patch_by_id('111').set_color(colorABC)
3726
-
3954
+ v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3955
+ v.get_patch_by_id("100").set_color(colors[0])
3956
+ v.get_patch_by_id("010").set_color(colors[1])
3957
+ v.get_patch_by_id("001").set_color(colors[2])
3958
+ v.get_patch_by_id("110").set_color(colorAB)
3959
+ v.get_patch_by_id("101").set_color(colorAC)
3960
+ v.get_patch_by_id("011").set_color(colorBC)
3961
+ v.get_patch_by_id("111").set_color(colorABC)
3962
+
3727
3963
  # Correctly labeling subset sizes
3728
3964
  # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
3729
3965
  # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
@@ -3732,63 +3968,93 @@ def venn(
3732
3968
  # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
3733
3969
  # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
3734
3970
  # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
3735
- v.get_label_by_id('100').set_text(get_label(len(set1), len(set1 - set2 - set3)))
3736
- v.get_label_by_id('010').set_text(get_label(len(set2), len(set2 - set1 - set3)))
3737
- v.get_label_by_id('001').set_text(get_label(len(set3), len(set3 - set1 - set2)))
3738
- v.get_label_by_id('110').set_text(get_label(len(set1 | set2), len(set1 & set2 - set3)))
3739
- v.get_label_by_id('101').set_text(get_label(len(set1 | set3), len(set1 & set3 - set2)))
3740
- v.get_label_by_id('011').set_text(get_label(len(set2 | set3), len(set2 & set3 - set1)))
3741
- v.get_label_by_id('111').set_text(get_label(len(set1 | set2 | set3), len(set1 & set2 & set3)))
3742
-
3971
+ v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
3972
+ v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
3973
+ v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
3974
+ v.get_label_by_id("110").set_text(
3975
+ get_label(len(set1 | set2), len(set1 & set2 - set3))
3976
+ )
3977
+ v.get_label_by_id("101").set_text(
3978
+ get_label(len(set1 | set3), len(set1 & set3 - set2))
3979
+ )
3980
+ v.get_label_by_id("011").set_text(
3981
+ get_label(len(set2 | set3), len(set2 & set3 - set1))
3982
+ )
3983
+ v.get_label_by_id("111").set_text(
3984
+ get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
3985
+ )
3743
3986
 
3744
3987
  # Apply styles to set labels
3745
3988
  for i, text in enumerate(v.set_labels):
3746
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3747
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3989
+ textsets(
3990
+ text,
3991
+ fontname=fontname,
3992
+ fontsize=fontsize,
3993
+ fontweight=fontweight,
3994
+ fontstyle=fontstyle,
3995
+ fontcolor=fontcolor,
3996
+ ha=ha,
3997
+ va=va,
3998
+ shadow=shadow,
3999
+ )
3748
4000
 
3749
4001
  # Apply styles to subset labels
3750
4002
  for i, text in enumerate(v.subset_labels):
3751
4003
  if text: # Ensure text exists
3752
4004
  if custom_texts: # Custom text handling
3753
4005
  text.set_text(custom_texts[i])
3754
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3755
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
4006
+ textsets(
4007
+ text,
4008
+ fontname=fontname,
4009
+ fontsize=subset_fontsize,
4010
+ fontweight=subset_fontweight,
4011
+ fontstyle=subset_fontstyle,
4012
+ fontcolor=subset_fontcolor,
4013
+ ha=ha,
4014
+ va=va,
4015
+ shadow=shadow,
4016
+ )
3756
4017
 
3757
4018
  venn_circles = venn3_circles(subsets=lists, ax=ax)
3758
- if not isinstance(linewidth,list):
3759
- linewidth=[linewidth]
3760
- if isinstance(linestyle,str):
3761
- linestyle=[linestyle]
4019
+ if not isinstance(linewidth, list):
4020
+ linewidth = [linewidth]
4021
+ if isinstance(linestyle, str):
4022
+ linestyle = [linestyle]
3762
4023
  if not isinstance(edgecolor, list):
3763
- edgecolor=[edgecolor]
3764
- linewidth=linewidth*3 if len(linewidth)==1 else linewidth
3765
- linestyle=linestyle*3 if len(linestyle)==1 else linestyle
3766
- edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
4024
+ edgecolor = [edgecolor]
4025
+ linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
4026
+ linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
4027
+ edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
3767
4028
 
3768
4029
  # edgecolor=[to_rgba(i) for i in edgecolor]
3769
4030
 
3770
- for i in range(3):
4031
+ for i in range(3):
3771
4032
  venn_circles[i].set_lw(linewidth[i])
3772
- venn_circles[i].set_ls(linestyle[i])
3773
- venn_circles[i].set_edgecolor(edgecolor[i])
4033
+ venn_circles[i].set_ls(linestyle[i])
4034
+ venn_circles[i].set_edgecolor(edgecolor[i])
3774
4035
 
3775
- #椭圆形
4036
+ # 椭圆形
3776
4037
  if ellipse_shape:
3777
- import matplotlib.patches as patches
4038
+ import matplotlib.patches as patches
4039
+
3778
4040
  for patch in v.patches:
3779
4041
  patch.set_visible(False) # Hide original patches if using ellipses
3780
4042
  center1 = v.get_circle_center(0)
3781
4043
  center2 = v.get_circle_center(1)
3782
- center3 = v.get_circle_center(2)
4044
+ center3 = v.get_circle_center(2)
3783
4045
  ellipse1 = patches.Ellipse(
3784
4046
  (center1.x, center1.y),
3785
4047
  width=ellipse_scale[0],
3786
4048
  height=ellipse_scale[1],
3787
4049
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3788
4050
  facecolor=colors[0],
3789
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4051
+ lw=(
4052
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4053
+ ), # Ensure lw is a number
3790
4054
  ls=linestyle[0],
3791
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4055
+ alpha=(
4056
+ alpha if isinstance(alpha, (int, float)) else 0.5
4057
+ ), # Ensure alpha is a number
3792
4058
  )
3793
4059
  ellipse2 = patches.Ellipse(
3794
4060
  (center2.x, center2.y),
@@ -3796,9 +4062,13 @@ def venn(
3796
4062
  height=ellipse_scale[1],
3797
4063
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3798
4064
  facecolor=colors[1],
3799
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4065
+ lw=(
4066
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4067
+ ), # Ensure lw is a number
3800
4068
  ls=linestyle[0],
3801
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4069
+ alpha=(
4070
+ alpha if isinstance(alpha, (int, float)) else 0.5
4071
+ ), # Ensure alpha is a number
3802
4072
  )
3803
4073
  ellipse3 = patches.Ellipse(
3804
4074
  (center3.x, center3.y),
@@ -3806,9 +4076,13 @@ def venn(
3806
4076
  height=ellipse_scale[1],
3807
4077
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3808
4078
  facecolor=colors[1],
3809
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4079
+ lw=(
4080
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4081
+ ), # Ensure lw is a number
3810
4082
  ls=linestyle[0],
3811
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4083
+ alpha=(
4084
+ alpha if isinstance(alpha, (int, float)) else 0.5
4085
+ ), # Ensure alpha is a number
3812
4086
  )
3813
4087
  ax.add_patch(ellipse1)
3814
4088
  ax.add_patch(ellipse2)
@@ -3820,17 +4094,20 @@ def venn(
3820
4094
  for patch in v.patches:
3821
4095
  if patch:
3822
4096
  patch.set_alpha(alpha)
3823
- if 'none' in edgecolor or 0 in linewidth:
4097
+ if "none" in edgecolor or 0 in linewidth:
3824
4098
  patch.set_edgecolor("none")
3825
4099
  return ax
3826
4100
 
4101
+
3827
4102
  #! subplots, support automatic extend new axis
3828
- def subplot(rows:int=2,
3829
- cols:int=2,
3830
- figsize:Union[tuple,list]=[8, 8],
3831
- sharex=False,
3832
- sharey=False,
3833
- **kwargs):
4103
+ def subplot(
4104
+ rows: int = 2,
4105
+ cols: int = 2,
4106
+ figsize: Union[tuple, list] = [8, 8],
4107
+ sharex=False,
4108
+ sharey=False,
4109
+ **kwargs,
4110
+ ):
3834
4111
  """
3835
4112
  nexttile = subplot(
3836
4113
  8,
@@ -3850,11 +4127,14 @@ def subplot(rows:int=2,
3850
4127
  ax.set_xlabel(f"Tile {i + 1}")
3851
4128
  """
3852
4129
  from matplotlib.gridspec import GridSpec
4130
+
3853
4131
  if run_once_within():
3854
- print(f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()")
4132
+ print(
4133
+ f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
4134
+ )
3855
4135
  fig = plt.figure(figsize=figsize)
3856
4136
  grid_spec = GridSpec(rows, cols, figure=fig)
3857
- occupied = set()
4137
+ occupied = set()
3858
4138
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
3859
4139
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
3860
4140
 
@@ -3862,8 +4142,9 @@ def subplot(rows:int=2,
3862
4142
  nonlocal rows, grid_spec
3863
4143
  rows += 1 # Expands by adding a row
3864
4144
  grid_spec = GridSpec(rows, cols, figure=fig)
4145
+
3865
4146
  def nexttile(rowspan=1, colspan=1, **kwargs):
3866
- nonlocal rows, cols, occupied, grid_spec
4147
+ nonlocal rows, cols, occupied, grid_spec
3867
4148
  for row in range(rows):
3868
4149
  for col in range(cols):
3869
4150
  if all(
@@ -3873,29 +4154,29 @@ def subplot(rows:int=2,
3873
4154
  ):
3874
4155
  break
3875
4156
  else:
3876
- continue
4157
+ continue
3877
4158
  break
3878
- else:
4159
+ else:
3879
4160
  expand_ax()
3880
4161
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
3881
4162
 
3882
- sharex_ax,sharey_ax = None, None
4163
+ sharex_ax, sharey_ax = None, None
3883
4164
 
3884
- if sharex:
4165
+ if sharex:
3885
4166
  sharex_ax = col_first_axes[col]
3886
4167
 
3887
- if sharey:
3888
- sharey_ax = row_first_axes[row]
4168
+ if sharey:
4169
+ sharey_ax = row_first_axes[row]
3889
4170
  ax = fig.add_subplot(
3890
4171
  grid_spec[row : row + rowspan, col : col + colspan],
3891
4172
  sharex=sharex_ax,
3892
4173
  sharey=sharey_ax,
3893
- **kwargs
3894
- )
4174
+ **kwargs,
4175
+ )
3895
4176
  if row_first_axes[row] is None:
3896
4177
  row_first_axes[row] = ax
3897
4178
  if col_first_axes[col] is None:
3898
- col_first_axes[col] = ax
4179
+ col_first_axes[col] = ax
3899
4180
  for r in range(row, row + rowspan):
3900
4181
  for c in range(col, col + colspan):
3901
4182
  occupied.add((r, c))
@@ -3907,17 +4188,17 @@ def subplot(rows:int=2,
3907
4188
 
3908
4189
  #! radar chart
3909
4190
  def radar(
3910
- data: pd.DataFrame,
3911
- ylim=(0,100),
4191
+ data: pd.DataFrame,
4192
+ ylim=(0, 100),
3912
4193
  color=get_color(5),
3913
4194
  fontsize=10,
3914
- fontcolor='k',
4195
+ fontcolor="k",
3915
4196
  size=6,
3916
4197
  linewidth=1,
3917
4198
  linestyle="-",
3918
4199
  alpha=0.5,
3919
4200
  marker="o",
3920
- edgecolor='none',
4201
+ edgecolor="none",
3921
4202
  edge_linewidth=0,
3922
4203
  bg_color="0.8",
3923
4204
  bg_alpha=None,
@@ -3928,18 +4209,19 @@ def radar(
3928
4209
  legend_fontsize=10,
3929
4210
  grid_color="gray",
3930
4211
  grid_alpha=0.5,
3931
- grid_linestyle="--",grid_linewidth=0.5,
4212
+ grid_linestyle="--",
4213
+ grid_linewidth=0.5,
3932
4214
  circular: bool = False,
3933
4215
  tick_fontsize=None,
3934
4216
  tick_fontcolor="0.65",
3935
- tick_loc = None,# label position
3936
- turning = None,
4217
+ tick_loc=None, # label position
4218
+ turning=None,
3937
4219
  ax=None,
3938
4220
  sp=2,
3939
- **kwargs
4221
+ **kwargs,
3940
4222
  ):
3941
4223
  """
3942
- Example DATA:
4224
+ Example DATA:
3943
4225
  df = pd.DataFrame(
3944
4226
  data=[
3945
4227
  [80, 80, 80, 80, 80, 80, 80],
@@ -3948,7 +4230,7 @@ def radar(
3948
4230
  ],
3949
4231
  index=["Hero", "Warrior", "Wizard"],
3950
4232
  columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
3951
-
4233
+
3952
4234
  Parameters:
3953
4235
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
3954
4236
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4002,8 +4284,12 @@ def radar(
4002
4284
 
4003
4285
  # bg_color
4004
4286
  if bg_alpha is None:
4005
- bg_alpha=alpha
4006
- ax.set_facecolor(to_rgba(bg_color,alpha=bg_alpha)) if circular else ax.set_facecolor('none')
4287
+ bg_alpha = alpha
4288
+ (
4289
+ ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
4290
+ if circular
4291
+ else ax.set_facecolor("none")
4292
+ )
4007
4293
  # Set up the radar chart with straight-line connections
4008
4294
  ax.set_theta_offset(np.pi / 2)
4009
4295
  ax.set_theta_direction(-1)
@@ -4012,53 +4298,68 @@ def radar(
4012
4298
  ax.set_xticks(angles[:-1])
4013
4299
  ax.set_xticklabels(categories)
4014
4300
 
4015
- # Set y-axis limits and grid intervals
4301
+ # Set y-axis limits and grid intervals
4016
4302
  vmin, vmax = ylim
4017
- if circular:
4018
- #* cicular style
4019
- ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
4020
- ax.grid(axis='both',
4021
- color=grid_color,
4022
- linestyle=grid_linestyle,
4023
- alpha=grid_alpha,
4024
- linewidth=grid_linewidth,
4025
- dash_capstyle='round',
4026
- dash_joinstyle='round',
4027
- )
4028
- ax.spines["polar"].set_color(grid_color)
4303
+ if circular:
4304
+ # * cicular style
4305
+ ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
4306
+ ax.grid(
4307
+ axis="both",
4308
+ color=grid_color,
4309
+ linestyle=grid_linestyle,
4310
+ alpha=grid_alpha,
4311
+ linewidth=grid_linewidth,
4312
+ dash_capstyle="round",
4313
+ dash_joinstyle="round",
4314
+ )
4315
+ ax.spines["polar"].set_color(grid_color)
4029
4316
  ax.spines["polar"].set_linewidth(grid_linewidth)
4030
- ax.spines["polar"].set_linestyle('-')
4317
+ ax.spines["polar"].set_linestyle("-")
4031
4318
  ax.spines["polar"].set_alpha(grid_alpha)
4032
- ax.spines["polar"].set_capstyle('round')
4033
- ax.spines["polar"].set_joinstyle('round')
4319
+ ax.spines["polar"].set_capstyle("round")
4320
+ ax.spines["polar"].set_joinstyle("round")
4034
4321
 
4035
4322
  else:
4036
- #* spider style: spider-style grid (straight lines, not circles)
4323
+ # * spider style: spider-style grid (straight lines, not circles)
4037
4324
  # Create the spider-style grid (straight lines, not circles)
4038
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4325
+ for i in range(1, int(vmax * grid_interval_ratio) + 1):
4039
4326
  ax.plot(
4040
4327
  angles + [angles[0]], # Closing the loop
4041
- [i * vmax * grid_interval_ratio] * (num_vars+1) + [i * vmax * grid_interval_ratio],
4042
- color=grid_color, linestyle=grid_linestyle, alpha=grid_alpha,linewidth=grid_linewidth
4328
+ [i * vmax * grid_interval_ratio] * (num_vars + 1)
4329
+ + [i * vmax * grid_interval_ratio],
4330
+ color=grid_color,
4331
+ linestyle=grid_linestyle,
4332
+ alpha=grid_alpha,
4333
+ linewidth=grid_linewidth,
4043
4334
  )
4044
- # set bg_color
4045
- ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
4335
+ # set bg_color
4336
+ ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4046
4337
  ax.yaxis.grid(False)
4047
4338
  # Move radial labels away from plotted line
4048
4339
  if tick_loc is None:
4049
- tick_loc = np.mean([angles[0],angles[1]])/(2*np.pi)*360 if circular else 0
4340
+ tick_loc = (
4341
+ np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
4342
+ )
4050
4343
 
4051
4344
  ax.set_rlabel_position(tick_loc)
4052
4345
  ax.set_theta_offset(turning) if turning is not None else None
4053
- ax.tick_params(axis='x', labelsize=fontsize, colors=fontcolor) # Optional: for angular labels
4054
- tick_fontsize = fontsize-2 if fontsize is None else tick_fontsize
4055
- ax.tick_params(axis='y', labelsize=tick_fontsize, colors=tick_fontcolor) # For radial labels
4346
+ ax.tick_params(
4347
+ axis="x", labelsize=fontsize, colors=fontcolor
4348
+ ) # Optional: for angular labels
4349
+ tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
4350
+ ax.tick_params(
4351
+ axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
4352
+ ) # For radial labels
4056
4353
  if not circular:
4057
- ax.spines['polar'].set_visible(False)
4058
- ax.tick_params(axis='x', pad=sp) # move spines outward
4059
- ax.tick_params(axis='y', pad=sp) # move spines outward
4354
+ ax.spines["polar"].set_visible(False)
4355
+ ax.tick_params(axis="x", pad=sp) # move spines outward
4356
+ ax.tick_params(axis="y", pad=sp) # move spines outward
4060
4357
  # colors
4061
- colors = get_color(data.shape[0]) if cmap is None else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4358
+ colors = (
4359
+ get_color(data.shape[0])
4360
+ if cmap is None
4361
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4362
+ )
4062
4363
  # Plot each row with straight lines
4063
4364
  for i, (index, row) in enumerate(data.iterrows()):
4064
4365
  values = row.tolist()
@@ -4070,7 +4371,7 @@ def radar(
4070
4371
  linewidth=linewidth,
4071
4372
  linestyle=linestyle,
4072
4373
  label=index,
4073
- clip_on=False
4374
+ clip_on=False,
4074
4375
  )
4075
4376
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4076
4377
 
@@ -4084,13 +4385,23 @@ def radar(
4084
4385
  marker=marker,
4085
4386
  markersize=size,
4086
4387
  markeredgecolor=edgecolor,
4087
- markeredgewidth = edge_linewidth, zorder=10,clip_on=False
4388
+ markeredgewidth=edge_linewidth,
4389
+ zorder=10,
4390
+ clip_on=False,
4088
4391
  )
4089
4392
  # ax.tick_params(axis='y', labelleft=False, left=False)
4090
- if 'legend' in kws_figsets:
4393
+ if "legend" in kws_figsets:
4091
4394
  figsets(ax=ax, **kws_figsets)
4092
4395
  else:
4093
-
4094
- figsets(ax=ax,legend=dict(loc=legend_loc,fontsize=legend_fontsize,
4095
- bbox_to_anchor=[1.1,1.4],ncols=2),**kws_figsets)
4096
- return ax
4396
+
4397
+ figsets(
4398
+ ax=ax,
4399
+ legend=dict(
4400
+ loc=legend_loc,
4401
+ fontsize=legend_fontsize,
4402
+ bbox_to_anchor=[1.1, 1.4],
4403
+ ncols=2,
4404
+ ),
4405
+ **kws_figsets,
4406
+ )
4407
+ return ax