py2ls 0.2.4.10.4__py3-none-any.whl → 0.2.4.10.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -1,21 +1,37 @@
1
1
  import numpy as np
2
- import pandas as pd
2
+ import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
4
  import matplotlib
5
5
  import seaborn as sns
6
6
  import warnings
7
7
  import logging
8
8
  from typing import Union
9
- from .ips import isa, fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font,run_once_within
9
+ from .ips import (
10
+ isa,
11
+ fsave,
12
+ fload,
13
+ mkdir,
14
+ listdir,
15
+ figsave,
16
+ strcmp,
17
+ unique,
18
+ get_os,
19
+ ssplit,
20
+ flatten,
21
+ plt_font,
22
+ run_once_within,
23
+ )
10
24
  from .stats import *
11
25
  import os
26
+
12
27
  # Suppress INFO messages from fontTools
13
28
  logging.getLogger("fontTools").setLevel(logging.ERROR)
14
- logging.getLogger('matplotlib').setLevel(logging.ERROR)
29
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
15
30
 
16
31
  warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
17
32
  warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
18
33
 
34
+
19
35
  def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
20
36
  """Adds text annotations for various types of Seaborn and Matplotlib plots.
21
37
  Args:
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
449
465
  """
450
466
  from matplotlib.colors import to_rgba
451
467
  import os
452
-
468
+
453
469
  def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
454
470
  if "l" in opt_b["loc"]:
455
471
  xloc_s = xloc - opt_b["x_dist"]
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
755
771
  )
756
772
  else:
757
773
  from scipy.stats import gaussian_kde
774
+
758
775
  kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
759
776
  min_val, max_val = ys.min(), ys.max()
760
777
  y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
@@ -1812,16 +1829,17 @@ def read_mplstyle(style_file):
1812
1829
  return style_dict
1813
1830
 
1814
1831
 
1815
- def figsets(*args, **kwargs):
1832
+ def figsets(*args, **kwargs):
1816
1833
  import matplotlib
1817
1834
  from cycler import cycler
1835
+
1818
1836
  matplotlib.rc("text", usetex=False)
1819
1837
 
1820
1838
  fig = plt.gcf()
1821
- fontsize = kwargs.get("fontsize",11)
1822
- plt.rcParams["font.size"]=fontsize
1823
- fontname = kwargs.pop("fontname","Arial")
1824
- fontname=plt_font(fontname) # 显示中文
1839
+ fontsize = kwargs.get("fontsize", 11)
1840
+ plt.rcParams["font.size"] = fontsize
1841
+ fontname = kwargs.pop("fontname", "Arial")
1842
+ fontname = plt_font(fontname) # 显示中文
1825
1843
 
1826
1844
  sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
1827
1845
  sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
@@ -1851,18 +1869,21 @@ def figsets(*args, **kwargs):
1851
1869
  nonlocal fontsize, fontname
1852
1870
  if ("fo" in key) and (("size" in key) or ("sz" in key)):
1853
1871
  fontsize = value
1854
- plt.rcParams.update({"font.size": fontsize,
1855
- "figure.titlesize":fontsize,
1856
- "axes.titlesize":fontsize,
1857
- "axes.labelsize": fontsize,
1858
- "xtick.labelsize": fontsize,
1859
- "ytick.labelsize": fontsize,
1860
- "legend.fontsize": fontsize,
1861
- "legend.title_fontsize":fontsize
1862
- })
1872
+ plt.rcParams.update(
1873
+ {
1874
+ "font.size": fontsize,
1875
+ "figure.titlesize": fontsize,
1876
+ "axes.titlesize": fontsize,
1877
+ "axes.labelsize": fontsize,
1878
+ "xtick.labelsize": fontsize,
1879
+ "ytick.labelsize": fontsize,
1880
+ "legend.fontsize": fontsize,
1881
+ "legend.title_fontsize": fontsize,
1882
+ }
1883
+ )
1863
1884
 
1864
1885
  # Customize tick labels
1865
- ax.tick_params(axis='both', which='major', labelsize=fontsize)
1886
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
1866
1887
  for label in ax.get_xticklabels() + ax.get_yticklabels():
1867
1888
  label.set_fontname(fontname)
1868
1889
 
@@ -1910,15 +1931,15 @@ def figsets(*args, **kwargs):
1910
1931
  if ("x" in key.lower()) and (
1911
1932
  "tic" not in key.lower() and "tk" not in key.lower()
1912
1933
  ):
1913
- ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
1934
+ ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
1914
1935
  if ("y" in key.lower()) and (
1915
1936
  "tic" not in key.lower() and "tk" not in key.lower()
1916
1937
  ):
1917
- ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
1938
+ ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
1918
1939
  if ("z" in key.lower()) and (
1919
1940
  "tic" not in key.lower() and "tk" not in key.lower()
1920
1941
  ):
1921
- ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
1942
+ ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
1922
1943
  if key == "xlabel" and isinstance(value, dict):
1923
1944
  ax.set_xlabel(**value)
1924
1945
  if key == "ylabel" and isinstance(value, dict):
@@ -2110,6 +2131,7 @@ def figsets(*args, **kwargs):
2110
2131
 
2111
2132
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
2112
2133
  import matplotlib.ticker as tck
2134
+
2113
2135
  if "x" in value.lower() or "x" in key.lower():
2114
2136
  ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
2115
2137
  if "y" in value.lower() or "y" in key.lower():
@@ -2162,9 +2184,9 @@ def figsets(*args, **kwargs):
2162
2184
  ax.grid(visible=False)
2163
2185
  if "tit" in key.lower():
2164
2186
  if "sup" in key.lower():
2165
- plt.suptitle(value,fontname=fontname,fontsize=fontsize)
2187
+ plt.suptitle(value, fontname=fontname, fontsize=fontsize)
2166
2188
  else:
2167
- ax.set_title(value,fontname=fontname,fontsize=fontsize)
2189
+ ax.set_title(value, fontname=fontname, fontsize=fontsize)
2168
2190
  if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
2169
2191
  if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
2170
2192
  if value:
@@ -2185,9 +2207,14 @@ def figsets(*args, **kwargs):
2185
2207
  legend = ax.get_legend()
2186
2208
  if legend is not None:
2187
2209
  legend.remove()
2188
- if any(['colorbar' in key.lower(),'cbar' in key.lower()]) and "loc" in key.lower():
2210
+ if (
2211
+ any(["colorbar" in key.lower(), "cbar" in key.lower()])
2212
+ and "loc" in key.lower()
2213
+ ):
2189
2214
  cbar = ax.collections[0].colorbar # Access the colorbar from the plot
2190
- cbar.ax.set_position(value) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2215
+ cbar.ax.set_position(
2216
+ value
2217
+ ) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2191
2218
 
2192
2219
  for arg in args:
2193
2220
  if isinstance(arg, matplotlib.axes._axes.Axes):
@@ -2225,7 +2252,8 @@ def figsets(*args, **kwargs):
2225
2252
  if len(fig.get_axes()) > 1:
2226
2253
  plt.tight_layout()
2227
2254
 
2228
- def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2255
+
2256
+ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2229
2257
  """
2230
2258
  split_legend(
2231
2259
  ax,
@@ -2238,15 +2266,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2238
2266
  # Retrieve all lines and labels from the axis
2239
2267
  handles, labels = ax.get_legend_handles_labels()
2240
2268
  num_labels = len(labels)
2241
-
2269
+
2242
2270
  # Calculate the number of labels per legend part
2243
- labels_per_part = (num_labels + n - 1) // n # Round up
2271
+ labels_per_part = (num_labels + n - 1) // n # Round up
2244
2272
  # Create a list to hold each legend object
2245
2273
  legends = []
2246
2274
 
2247
2275
  # Default locations and titles if not specified
2248
2276
  if loc is None:
2249
- loc = ['best'] * n
2277
+ loc = ["best"] * n
2250
2278
  if title is None:
2251
2279
  title = [None] * n
2252
2280
  if bbox is None:
@@ -2257,7 +2285,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
2285
  # Calculate the range of labels for this part
2258
2286
  start_idx = i * labels_per_part
2259
2287
  end_idx = min(start_idx + labels_per_part, num_labels)
2260
-
2288
+
2261
2289
  # Skip if no labels in this range
2262
2290
  if start_idx >= end_idx:
2263
2291
  break
@@ -2267,19 +2295,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2267
2295
  part_labels = labels[start_idx:end_idx]
2268
2296
 
2269
2297
  # Create the legend for this part
2270
- legend = ax.legend(handles=part_handles,
2271
- labels=part_labels,
2272
- loc=loc[i],
2273
- title=title[i],
2274
- ncol=ncol,
2275
- bbox_to_anchor=bbox[i],
2276
- **kwargs)
2277
-
2298
+ legend = ax.legend(
2299
+ handles=part_handles,
2300
+ labels=part_labels,
2301
+ loc=loc[i],
2302
+ title=title[i],
2303
+ ncol=ncol,
2304
+ bbox_to_anchor=bbox[i],
2305
+ **kwargs,
2306
+ )
2307
+
2278
2308
  # Add the legend to the axis and save it to the list
2279
- ax.add_artist(legend) if i !=(n-1) else None # the lastone will be added automaticaly
2309
+ (
2310
+ ax.add_artist(legend) if i != (n - 1) else None
2311
+ ) # the lastone will be added automaticaly
2280
2312
  legends.append(legend)
2281
2313
  return legends
2282
2314
 
2315
+
2283
2316
  def get_colors(
2284
2317
  n: int = 1,
2285
2318
  cmap: str = "auto",
@@ -2289,7 +2322,9 @@ def get_colors(
2289
2322
  *args,
2290
2323
  **kwargs,
2291
2324
  ):
2292
- return get_color(n,cmap,alpha,output,*args,**kwargs)
2325
+ return get_color(n, cmap, alpha, output, *args, **kwargs)
2326
+
2327
+
2293
2328
  def get_color(
2294
2329
  n: int = 1,
2295
2330
  cmap: str = "auto",
@@ -2300,6 +2335,7 @@ def get_color(
2300
2335
  **kwargs,
2301
2336
  ):
2302
2337
  from cycler import cycler
2338
+
2303
2339
  def cmap2hex(cmap_name):
2304
2340
  cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
2305
2341
  colors = [cmap_(i) for i in range(cmap_.N)]
@@ -2347,49 +2383,137 @@ def get_color(
2347
2383
  return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
2348
2384
  else:
2349
2385
  return "#{:02X}{:02X}{:02X}".format(r, g, b)
2350
-
2386
+
2351
2387
  # sc.pl.palettes.default_20
2352
- cmap_20= ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61',
2353
- '#17becf','#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2',
2354
- '#dbdb8d','#9edae5','#ad494a','#8c6d31']
2388
+ cmap_20 = [
2389
+ "#1f77b4",
2390
+ "#ff7f0e",
2391
+ "#279e68",
2392
+ "#d62728",
2393
+ "#aa40fc",
2394
+ "#8c564b",
2395
+ "#e377c2",
2396
+ "#b5bd61",
2397
+ "#17becf",
2398
+ "#aec7e8",
2399
+ "#ffbb78",
2400
+ "#98df8a",
2401
+ "#ff9896",
2402
+ "#c5b0d5",
2403
+ "#c49c94",
2404
+ "#f7b6d2",
2405
+ "#dbdb8d",
2406
+ "#9edae5",
2407
+ "#ad494a",
2408
+ "#8c6d31",
2409
+ ]
2355
2410
  # sc.pl.palettes.zeileis_28
2356
- cmap_28 = ["#023fa5","#7d87b9","#bec1d4","#d6bcc0","#bb7784","#8e063b","#4a6fe3","#8595e1",
2357
- "#b5bbe3","#e6afb9","#e07b91","#d33f6a","#11c638","#8dd593","#c6dec7","#ead3c6",
2358
- "#f0b98d","#ef9708","#0fcfc0","#9cded6","#d5eae7","#f3e1eb","#f6c4e1","#f79cd4",
2359
- "#7f7f7f","#c7c7c7","#1CE6FF","#336600"]
2411
+ cmap_28 = [
2412
+ "#023fa5",
2413
+ "#7d87b9",
2414
+ "#bec1d4",
2415
+ "#d6bcc0",
2416
+ "#bb7784",
2417
+ "#8e063b",
2418
+ "#4a6fe3",
2419
+ "#8595e1",
2420
+ "#b5bbe3",
2421
+ "#e6afb9",
2422
+ "#e07b91",
2423
+ "#d33f6a",
2424
+ "#11c638",
2425
+ "#8dd593",
2426
+ "#c6dec7",
2427
+ "#ead3c6",
2428
+ "#f0b98d",
2429
+ "#ef9708",
2430
+ "#0fcfc0",
2431
+ "#9cded6",
2432
+ "#d5eae7",
2433
+ "#f3e1eb",
2434
+ "#f6c4e1",
2435
+ "#f79cd4",
2436
+ "#7f7f7f",
2437
+ "#c7c7c7",
2438
+ "#1CE6FF",
2439
+ "#336600",
2440
+ ]
2360
2441
  if cmap == "gray":
2361
2442
  cmap = "grey"
2362
- elif cmap=="20":
2363
- cmap=cmap_20
2364
- elif cmap=="28":
2365
- cmap=cmap_28
2443
+ elif cmap == "20":
2444
+ cmap = cmap_20
2445
+ elif cmap == "28":
2446
+ cmap = cmap_28
2366
2447
  # Determine color list based on cmap parameter
2367
- if isinstance(cmap,str):
2448
+ if isinstance(cmap, str):
2368
2449
  if "aut" in cmap:
2369
2450
  if n == 1:
2370
2451
  colorlist = ["#3A4453"]
2371
2452
  elif n == 2:
2372
2453
  colorlist = ["#3A4453", "#FF2C00"]
2373
2454
  elif n == 3:
2374
- colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
2455
+ colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
2375
2456
  elif n == 4:
2376
- colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
2457
+ colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
2377
2458
  elif n == 5:
2378
- colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
2459
+ colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
2379
2460
  elif n == 6:
2380
- colorlist = ["#FF2C00","#91bfdb", "#B25E9D", "#4B8C3B","#EF8632", "#24578E"]
2381
- elif n==7:
2382
- colorlist = ["#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2383
- elif n==8:
2461
+ colorlist = [
2462
+ "#FF2C00",
2463
+ "#91bfdb",
2464
+ "#B25E9D",
2465
+ "#4B8C3B",
2466
+ "#EF8632",
2467
+ "#24578E",
2468
+ ]
2469
+ elif n == 7:
2470
+ colorlist = [
2471
+ "#7F7F7F",
2472
+ "#459AA9",
2473
+ "#B25E9D",
2474
+ "#4B8C3B",
2475
+ "#EF8632",
2476
+ "#24578E" "#FF2C00",
2477
+ ]
2478
+ elif n == 8:
2384
2479
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
2385
2480
  # colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
2386
- colorlist = ["#78BFDA","#D52E6F","#F7D648","#A52D28","#6B9F41","#E18330","#E18B9D","#3C88CC"]
2387
- elif n==9:
2388
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#d62728','#aa40fc','#e377c2','#51B34F','#17becf']
2389
- elif n==10:
2390
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#51B34F','#d62728''#aa40fc','#e377c2','#375FD2','#17becf']
2391
- elif 10<n<=20:
2392
- colorlist = cmap_20
2481
+ colorlist = [
2482
+ "#78BFDA",
2483
+ "#D52E6F",
2484
+ "#F7D648",
2485
+ "#A52D28",
2486
+ "#6B9F41",
2487
+ "#E18330",
2488
+ "#E18B9D",
2489
+ "#3C88CC",
2490
+ ]
2491
+ elif n == 9:
2492
+ colorlist = [
2493
+ "#1f77b4",
2494
+ "#ff7f0e",
2495
+ "#367B7F",
2496
+ "#ff9896",
2497
+ "#d62728",
2498
+ "#aa40fc",
2499
+ "#e377c2",
2500
+ "#51B34F",
2501
+ "#17becf",
2502
+ ]
2503
+ elif n == 10:
2504
+ colorlist = [
2505
+ "#1f77b4",
2506
+ "#ff7f0e",
2507
+ "#367B7F",
2508
+ "#ff9896",
2509
+ "#51B34F",
2510
+ "#d62728" "#aa40fc",
2511
+ "#e377c2",
2512
+ "#375FD2",
2513
+ "#17becf",
2514
+ ]
2515
+ elif 10 < n <= 20:
2516
+ colorlist = cmap_20
2393
2517
  else:
2394
2518
  colorlist = cmap_28
2395
2519
  by = "start"
@@ -2426,7 +2550,7 @@ def get_color(
2426
2550
  by = "linspace"
2427
2551
  colorlist = cmap2hex(cmap)
2428
2552
  elif isinstance(cmap, list):
2429
- colorlist=cmap
2553
+ colorlist = cmap
2430
2554
 
2431
2555
  # Determine method for generating color list
2432
2556
  if "st" in by.lower() or "be" in by.lower():
@@ -2446,7 +2570,7 @@ def get_color(
2446
2570
  return hue_list
2447
2571
  else:
2448
2572
  raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
2449
-
2573
+
2450
2574
 
2451
2575
  def stdshade(ax=None, *args, **kwargs):
2452
2576
  """
@@ -2454,7 +2578,7 @@ def stdshade(ax=None, *args, **kwargs):
2454
2578
  plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
2455
2579
  """
2456
2580
  from scipy.signal import savgol_filter
2457
-
2581
+
2458
2582
  # Separate kws_line and kws_fill if necessary
2459
2583
  kws_line = kwargs.pop("kws_line", {})
2460
2584
  kws_fill = kwargs.pop("kws_fill", {})
@@ -2724,37 +2848,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
2724
2848
  # cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
2725
2849
  # return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
2726
2850
 
2727
- def add_colorbar(im,
2728
- cmap="viridis",
2729
- vmin=-1,
2730
- vmax=1,
2731
- orientation='vertical',
2732
- width_ratio=0.05,
2733
- pad_ratio=0.02,
2734
- shrink=1.0,
2735
- **kwargs):
2851
+
2852
+ def add_colorbar(
2853
+ im,
2854
+ cmap="viridis",
2855
+ vmin=-1,
2856
+ vmax=1,
2857
+ orientation="vertical",
2858
+ width_ratio=0.05,
2859
+ pad_ratio=0.02,
2860
+ shrink=1.0,
2861
+ **kwargs,
2862
+ ):
2736
2863
  import matplotlib as mpl
2864
+
2737
2865
  if all([cmap, vmin, vmax]):
2738
2866
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
2739
2867
  else:
2740
- norm=False
2868
+ norm = False
2741
2869
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
2742
2870
  sm.set_array([])
2743
2871
  l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
2744
- if orientation == 'vertical':
2872
+ if orientation == "vertical":
2745
2873
  width = width_ratio * w
2746
2874
  pad = pad_ratio * w
2747
- cax = im.figure.add_axes([l + w + pad, b, width, h * shrink]) # Right of the image
2875
+ cax = im.figure.add_axes(
2876
+ [l + w + pad, b, width, h * shrink]
2877
+ ) # Right of the image
2748
2878
  else:
2749
2879
  height = width_ratio * h
2750
2880
  pad = pad_ratio * h
2751
- cax = im.figure.add_axes([l, b - height - pad, w * shrink, height]) # Below the image
2752
- cbar=im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2753
- return cbar
2881
+ cax = im.figure.add_axes(
2882
+ [l, b - height - pad, w * shrink, height]
2883
+ ) # Below the image
2884
+ cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2885
+ return cbar
2886
+
2754
2887
 
2755
2888
  # Usage:
2756
2889
  # add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
2757
2890
 
2891
+
2758
2892
  def generate_xticks_with_gap(x_len, hue_len):
2759
2893
  """
2760
2894
  Generate a concatenated array based on x_len and hue_len,
@@ -2884,6 +3018,7 @@ def style_examples(
2884
3018
  f = listdir(
2885
3019
  "/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
2886
3020
  kind=".json",
3021
+ verbose=False,
2887
3022
  )
2888
3023
  display(f.sample(2))
2889
3024
  # def style_example(dir_save,)
@@ -2961,6 +3096,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
2961
3096
 
2962
3097
  def get_params_from_func_usage(function_signature):
2963
3098
  import re
3099
+
2964
3100
  # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2965
3101
  keys_pattern = r"(?<!\*\*)\b(\w+)="
2966
3102
  # Find all matches
@@ -2987,7 +3123,7 @@ def plotxy(
2987
3123
  x=None,
2988
3124
  y=None,
2989
3125
  ax=None,
2990
- kind: Union[str, list] = 'scatter', # Specify the kind of plot
3126
+ kind: Union[str, list] = "scatter", # Specify the kind of plot
2991
3127
  verbose=False,
2992
3128
  **kwargs,
2993
3129
  ):
@@ -3011,14 +3147,15 @@ def plotxy(
3011
3147
  # Check for valid plot kind
3012
3148
  # Default arguments for various plot types
3013
3149
  from pathlib import Path
3150
+
3014
3151
  # Get the current script's directory as a Path object
3015
- current_directory = Path(__file__).resolve().parent
3152
+ current_directory = Path(__file__).resolve().parent
3153
+
3154
+ if not "default_settings" in locals():
3155
+ default_settings = fload(current_directory / "data" / "usages_sns.json")
3156
+ if not "sns_info" in locals():
3157
+ sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
3016
3158
 
3017
- if not 'default_settings' in locals():
3018
- default_settings = fload(current_directory / 'data' / 'usages_sns.json')
3019
- if not 'sns_info' in locals():
3020
- sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
-
3022
3159
  valid_kinds = list(default_settings.keys())
3023
3160
  # print(valid_kinds)
3024
3161
  if kind is not None:
@@ -3032,7 +3169,7 @@ def plotxy(
3032
3169
  if kind is not None:
3033
3170
  for k in kind:
3034
3171
  if k in valid_kinds:
3035
- print(f"{k}:\n\t{default_settings[k]}")
3172
+ print(f"{k}:\n\t{default_settings[k]}")
3036
3173
  usage_str = """plotxy(data=ranked_genes,
3037
3174
  x="log2(fold_change)",
3038
3175
  y="-log10(p-value)",
@@ -3057,15 +3194,25 @@ def plotxy(
3057
3194
  kws_add_text = v_arg
3058
3195
  kwargs.pop(k_arg, None)
3059
3196
  break
3060
- zorder=0
3197
+ zorder = 0
3061
3198
  for k in kind:
3062
- zorder+=1
3199
+ zorder += 1
3063
3200
  # indicate 'col' features
3064
3201
  col = kwargs.get("col", None)
3065
- sns_with_col = ["catplot","histplot","relplot","lmplot","pairplot","displot","kdeplot"]
3202
+ sns_with_col = [
3203
+ "catplot",
3204
+ "histplot",
3205
+ "relplot",
3206
+ "lmplot",
3207
+ "pairplot",
3208
+ "displot",
3209
+ "kdeplot",
3210
+ ]
3066
3211
  if col is not None:
3067
3212
  if not k in sns_with_col:
3068
- print(f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}")
3213
+ print(
3214
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
3215
+ )
3069
3216
  # (1) return FcetGrid
3070
3217
  if k == "jointplot":
3071
3218
  kws_joint = kwargs.pop("kws_joint", kwargs)
@@ -3092,77 +3239,98 @@ def plotxy(
3092
3239
  ax = stdshade(ax=ax, **kwargs)
3093
3240
  elif k == "scatterplot":
3094
3241
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3095
- kws_scatter={k: v for k, v in kws_scatter.items() if not k.startswith("kws_")}
3242
+ kws_scatter = {
3243
+ k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3244
+ }
3096
3245
  hue = kwargs.pop("hue", None)
3097
3246
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3098
3247
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3099
- palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3248
+ palette = kws_scatter.pop(
3249
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3250
+ )
3100
3251
  s = kws_scatter.pop("s", 10)
3101
3252
  alpha = kws_scatter.pop("alpha", 0.7)
3102
- ax = sns.scatterplot(ax=ax,data=data,x=x,y=y,hue=hue,palette=palette,s=s,alpha=alpha,zorder=zorder,**kws_scatter)
3253
+ ax = sns.scatterplot(
3254
+ ax=ax,
3255
+ data=data,
3256
+ x=x,
3257
+ y=y,
3258
+ hue=hue,
3259
+ palette=palette,
3260
+ s=s,
3261
+ alpha=alpha,
3262
+ zorder=zorder,
3263
+ **kws_scatter,
3264
+ )
3103
3265
  elif k == "histplot":
3104
3266
  kws_hist = kwargs.pop("kws_hist", kwargs)
3105
- kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3106
- ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
3267
+ kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3268
+ ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3107
3269
  elif k == "kdeplot":
3108
3270
  kws_kde = kwargs.pop("kws_kde", kwargs)
3109
- kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3110
- ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
3271
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3272
+ ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3111
3273
  elif k == "ecdfplot":
3112
3274
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3113
- kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3114
- ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
3275
+ kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3276
+ ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3115
3277
  elif k == "rugplot":
3116
3278
  kws_rug = kwargs.pop("kws_rug", kwargs)
3117
- kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3118
- ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
3279
+ kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3280
+ ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3119
3281
  elif k == "stripplot":
3120
3282
  kws_strip = kwargs.pop("kws_strip", kwargs)
3121
- kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3283
+ kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3122
3284
  dodge = kws_strip.pop("dodge", True)
3123
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax,zorder=zorder, dodge=dodge, **kws_strip)
3285
+ ax = sns.stripplot(
3286
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3287
+ )
3124
3288
  elif k == "swarmplot":
3125
3289
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3126
- kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3127
- ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
3290
+ kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3291
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
3128
3292
  elif k == "boxplot":
3129
3293
  kws_box = kwargs.pop("kws_box", kwargs)
3130
- kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3131
- ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
3294
+ kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3295
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
3132
3296
  elif k == "violinplot":
3133
3297
  kws_violin = kwargs.pop("kws_violin", kwargs)
3134
- kws_violin={k: v for k, v in kws_violin.items() if not k.startswith("kws_")}
3135
- ax = sns.violinplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_violin)
3298
+ kws_violin = {
3299
+ k: v for k, v in kws_violin.items() if not k.startswith("kws_")
3300
+ }
3301
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
3136
3302
  elif k == "boxenplot":
3137
3303
  kws_boxen = kwargs.pop("kws_boxen", kwargs)
3138
- kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3139
- ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
3304
+ kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3305
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
3140
3306
  elif k == "pointplot":
3141
3307
  kws_point = kwargs.pop("kws_point", kwargs)
3142
- kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3143
- ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
3308
+ kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3309
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
3144
3310
  elif k == "barplot":
3145
3311
  kws_bar = kwargs.pop("kws_bar", kwargs)
3146
- kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3147
- ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
3312
+ kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3313
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
3148
3314
  elif k == "countplot":
3149
3315
  kws_count = kwargs.pop("kws_count", kwargs)
3150
- kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3151
- if not kws_count.get("hue",None):
3152
- kws_count.pop("palette",None)
3153
- ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
3316
+ kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3317
+ if not kws_count.get("hue", None):
3318
+ kws_count.pop("palette", None)
3319
+ ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
3154
3320
  elif k == "regplot":
3155
3321
  kws_reg = kwargs.pop("kws_reg", kwargs)
3156
- kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3157
- ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
3322
+ kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3323
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
3158
3324
  elif k == "residplot":
3159
3325
  kws_resid = kwargs.pop("kws_resid", kwargs)
3160
- kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3161
- ax = sns.residplot(data=data, x=x, y=y, lowess=True,zorder=zorder, ax=ax, **kws_resid)
3326
+ kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3327
+ ax = sns.residplot(
3328
+ data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
3329
+ )
3162
3330
  elif k == "lineplot":
3163
3331
  kws_line = kwargs.pop("kws_line", kwargs)
3164
- kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3165
- ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
3332
+ kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3333
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
3166
3334
 
3167
3335
  figsets(ax=ax, **kws_figsets)
3168
3336
  if kws_add_text:
@@ -3176,20 +3344,22 @@ def plotxy(
3176
3344
  return g, ax
3177
3345
  return ax
3178
3346
 
3347
+
3179
3348
  def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
3180
3349
  norm_ = plt.Normalize(min_max[0], min_max[1])
3181
3350
  colormap = plt.get_cmap(cmap)
3182
3351
  return colormap(norm_(data))
3183
3352
 
3353
+
3184
3354
  def volcano(
3185
- data:pd.DataFrame,
3186
- x:str,
3187
- y:str,
3188
- gene_col:str=None,
3189
- top_genes=[5, 5], # [down-regulated, up-regulated]
3190
- thr_x=np.log2(1.5), # default: 0.585
3355
+ data: pd.DataFrame,
3356
+ x: str,
3357
+ y: str,
3358
+ gene_col: str = None,
3359
+ top_genes=[5, 5], # [down-regulated, up-regulated]
3360
+ thr_x=np.log2(1.5), # default: 0.585
3191
3361
  thr_y=-np.log10(0.05),
3192
- sort_xy="x", #'y', 'xy'
3362
+ sort_xy="x", #'y', 'xy'
3193
3363
  colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3194
3364
  s=20,
3195
3365
  fill=True, # plot filled scatter
@@ -3201,11 +3371,10 @@ def volcano(
3201
3371
  ax=None,
3202
3372
  verbose=False,
3203
3373
  kws_text=dict(fontsize=10, color="k"),
3204
- kws_bbox=dict(facecolor='none',
3205
- alpha=0.5,
3206
- edgecolor='black',
3207
- boxstyle='round,pad=0.3'),# '{}' to hide
3208
- kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
3374
+ kws_bbox=dict(
3375
+ facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
3376
+ ), # '{}' to hide
3377
+ kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
3209
3378
  **kwargs,
3210
3379
  ):
3211
3380
  """
@@ -3275,39 +3444,35 @@ def volcano(
3275
3444
  kws_figsets = v_arg
3276
3445
  kwargs.pop(k_arg, None)
3277
3446
  break
3278
-
3279
- data=data.copy()
3447
+
3448
+ data = data.copy()
3280
3449
  # filter nan
3281
3450
  data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
3282
- data.loc[:,"color"] = np.where(
3451
+ data.loc[:, "color"] = np.where(
3283
3452
  (data[x] > thr_x) & (data[y] > thr_y),
3284
3453
  colors[2],
3285
- np.where((data[x] < -thr_x) & (data[y] > thr_y),
3286
- colors[0],
3287
- colors[1]),
3454
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3288
3455
  )
3289
- top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
3290
-
3456
+ top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
3457
+
3291
3458
  # could custom how to select the top genes, x: x has priority
3292
- sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
3293
- ascending_up=[True, True] if sort_xy=="x" else [False, True]
3294
- ascending_down=[False, True] if sort_xy=="x" else [False, False]
3295
-
3296
- down_reg_genes = data[
3297
- (data["color"] == colors[0]) &
3298
- (data[x].abs() > thr_x) &
3299
- (data[y] > thr_y)
3300
- ].sort_values(by=sort_by_x_y, ascending=ascending_up).head(top_genes[0])
3301
- up_reg_genes = data[
3302
- (data["color"] == colors[2]) &
3303
- (data[x].abs() > thr_x) &
3304
- (data[y] > thr_y)
3305
- ].sort_values(by=sort_by_x_y, ascending=ascending_down).head(top_genes[1])
3459
+ sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
3460
+ ascending_up = [True, True] if sort_xy == "x" else [False, True]
3461
+ ascending_down = [False, True] if sort_xy == "x" else [False, False]
3462
+
3463
+ down_reg_genes = (
3464
+ data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3465
+ .sort_values(by=sort_by_x_y, ascending=ascending_up)
3466
+ .head(top_genes[0])
3467
+ )
3468
+ up_reg_genes = (
3469
+ data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3470
+ .sort_values(by=sort_by_x_y, ascending=ascending_down)
3471
+ .head(top_genes[1])
3472
+ )
3306
3473
  sele_gene = pd.concat([down_reg_genes, up_reg_genes])
3307
-
3308
- palette = {colors[0]: colors[0],
3309
- colors[1]: colors[1],
3310
- colors[2]: colors[2]}
3474
+
3475
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3311
3476
  # Plot setup
3312
3477
  if ax is None:
3313
3478
  ax = plt.gca()
@@ -3337,9 +3502,9 @@ def volcano(
3337
3502
  )
3338
3503
 
3339
3504
  # Add threshold lines for x and y axes
3340
- ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
3341
- ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
3342
- ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
3505
+ ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
3506
+ ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
3507
+ ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
3343
3508
 
3344
3509
  # Add gene labels for selected significant points
3345
3510
  if gene_col:
@@ -3349,19 +3514,28 @@ def volcano(
3349
3514
  textcolor = kws_text.pop("color", "k")
3350
3515
  fontsize = kws_text.pop("fontsize", 10)
3351
3516
  arrowstyles = [
3352
- "->","<-","<->","<|-","-|>","<|-|>",
3353
- "-","-[","-[",
3354
- "fancy","simple","wedge",
3517
+ "->",
3518
+ "<-",
3519
+ "<->",
3520
+ "<|-",
3521
+ "-|>",
3522
+ "<|-|>",
3523
+ "-",
3524
+ "-[",
3525
+ "-[",
3526
+ "fancy",
3527
+ "simple",
3528
+ "wedge",
3355
3529
  ]
3356
3530
  arrowstyle = kws_arrow.pop("style", "<|-")
3357
- arrowstyle = strcmp(arrowstyle, arrowstyles,scorer='strict')[0]
3358
- expand=kws_arrow.pop("expand",(1.05,1.1))
3531
+ arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
3532
+ expand = kws_arrow.pop("expand", (1.05, 1.1))
3359
3533
  arrowcolor = kws_arrow.pop("color", "0.4")
3360
3534
  arrowlinewidth = kws_arrow.pop("lw", 0.75)
3361
3535
  shrinkA = kws_arrow.pop("shrinkA", 0)
3362
3536
  shrinkB = kws_arrow.pop("shrinkB", 0)
3363
3537
  mutation_scale = kws_arrow.pop("head", 10)
3364
- arrow_fill=kws_arrow.pop("fill", False)
3538
+ arrow_fill = kws_arrow.pop("fill", False)
3365
3539
  for i in range(sele_gene.shape[0]):
3366
3540
  if isinstance(textcolor, list): # be consistant with dots's color
3367
3541
  textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
@@ -3382,7 +3556,7 @@ def volcano(
3382
3556
  adjust_text(
3383
3557
  texts,
3384
3558
  expand=expand,
3385
- min_arrow_len=5,
3559
+ min_arrow_len=5,
3386
3560
  ax=ax,
3387
3561
  arrowprops=dict(
3388
3562
  arrowstyle=arrowstyle,
@@ -3393,7 +3567,7 @@ def volcano(
3393
3567
  shrinkB=shrinkB,
3394
3568
  mutation_scale=mutation_scale,
3395
3569
  **kws_arrow,
3396
- )
3570
+ ),
3397
3571
  )
3398
3572
 
3399
3573
  figsets(**kws_figsets)
@@ -3499,6 +3673,7 @@ def desaturate_color(color, saturation_factor=0.5):
3499
3673
  """Reduce the saturation of a color by a given factor (between 0 and 1)."""
3500
3674
  import matplotlib.colors as mcolors
3501
3675
  import colorsys
3676
+
3502
3677
  # Convert the color to RGB
3503
3678
  rgb = mcolors.to_rgb(color)
3504
3679
  # Convert RGB to HLS (Hue, Lightness, Saturation)
@@ -3508,8 +3683,19 @@ def desaturate_color(color, saturation_factor=0.5):
3508
3683
  # Convert back to RGB
3509
3684
  return colorsys.hls_to_rgb(h, l, s)
3510
3685
 
3511
- def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle="normal",
3512
- fontcolor='k', backgroundcolor=None, shadow=False, ha="center", va="center"):
3686
+
3687
+ def textsets(
3688
+ text,
3689
+ fontname="Arial",
3690
+ fontsize=11,
3691
+ fontweight="normal",
3692
+ fontstyle="normal",
3693
+ fontcolor="k",
3694
+ backgroundcolor=None,
3695
+ shadow=False,
3696
+ ha="center",
3697
+ va="center",
3698
+ ):
3513
3699
  if text: # Ensure text exists
3514
3700
  if fontname:
3515
3701
  text.set_fontname(plt_font(fontname))
@@ -3522,26 +3708,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
3522
3708
  if fontcolor:
3523
3709
  text.set_color(fontcolor)
3524
3710
  if backgroundcolor:
3525
- text.set_backgroundcolor(backgroundcolor)
3711
+ text.set_backgroundcolor(backgroundcolor)
3526
3712
  text.set_horizontalalignment(ha)
3527
- text.set_verticalalignment(va)
3713
+ text.set_verticalalignment(va)
3528
3714
  if shadow:
3529
- text.set_path_effects([
3530
- matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
3531
- ])
3715
+ text.set_path_effects(
3716
+ [matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
3717
+ )
3718
+
3719
+
3532
3720
  def venn(
3533
- lists:list,
3534
- labels:list=None,
3721
+ lists: list,
3722
+ labels: list = None,
3535
3723
  ax=None,
3536
3724
  colors=None,
3537
3725
  edgecolor=None,
3538
3726
  alpha=0.5,
3539
- saturation=.75,
3540
- linewidth=0, # default no edge
3727
+ saturation=0.75,
3728
+ linewidth=0, # default no edge
3541
3729
  linestyle="-",
3542
- fontname='Arial',
3730
+ fontname="Arial",
3543
3731
  fontsize=10,
3544
- fontcolor='k',
3732
+ fontcolor="k",
3545
3733
  fontweight="normal",
3546
3734
  fontstyle="normal",
3547
3735
  ha="center",
@@ -3550,18 +3738,18 @@ def venn(
3550
3738
  subset_fontsize=10,
3551
3739
  subset_fontweight="normal",
3552
3740
  subset_fontstyle="normal",
3553
- subset_fontcolor='k',
3741
+ subset_fontcolor="k",
3554
3742
  backgroundcolor=None,
3555
3743
  custom_texts=None,
3556
- show_percentages=True, # display percentage
3744
+ show_percentages=True, # display percentage
3557
3745
  fmt="{:.1%}",
3558
3746
  ellipse_shape=False, # 椭圆形
3559
- ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
3560
- **kwargs
3747
+ ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
3748
+ **kwargs,
3561
3749
  ):
3562
3750
  """
3563
3751
  Advanced Venn diagram plotting function with extensive customization options.
3564
- Usage:
3752
+ Usage:
3565
3753
  # Define the two sets
3566
3754
  set1 = [1, 2, 3, 4, 5]
3567
3755
  set2 = [4, 5, 6, 7, 8]
@@ -3612,56 +3800,70 @@ def venn(
3612
3800
  """
3613
3801
  if ax is None:
3614
3802
  ax = plt.gca()
3615
- lists=[set(flatten(i, verbose=False)) for i in lists]
3803
+ lists = [set(flatten(i, verbose=False)) for i in lists]
3616
3804
  # Function to apply text styles to labels
3617
3805
  if colors is None:
3618
- colors=["r","b"] if len(lists)==2 else ["r","g","b"]
3806
+ colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
3619
3807
  if labels is None:
3620
- labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
3808
+ labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
3621
3809
  if edgecolor is None:
3622
- edgecolor=colors
3810
+ edgecolor = colors
3623
3811
  colors = [desaturate_color(color, saturation) for color in colors]
3624
3812
  # Check colors and auto-calculate overlaps
3625
3813
  if len(lists) == 2:
3814
+
3626
3815
  def get_count_and_percentage(set_count, subset_count):
3627
3816
  percent = subset_count / set_count if set_count > 0 else 0
3628
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3817
+ return (
3818
+ f"{subset_count}\n({fmt.format(percent)})"
3819
+ if show_percentages
3820
+ else f"{subset_count}"
3821
+ )
3629
3822
 
3630
3823
  from matplotlib_venn import venn2, venn2_circles
3824
+
3631
3825
  # Auto-calculate overlap color for 2-set Venn diagram
3632
3826
  overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
3633
-
3827
+
3634
3828
  # Draw the venn diagram
3635
3829
  v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3636
3830
  venn_circles = venn2_circles(subsets=lists, ax=ax)
3637
- set1,set2=lists[0],lists[1]
3638
- v.get_patch_by_id('10').set_color(colors[0])
3639
- v.get_patch_by_id('01').set_color(colors[1])
3640
- v.get_patch_by_id('11').set_color(get_color_overlap(colors[0], colors[1]) if colors else None)
3831
+ set1, set2 = lists[0], lists[1]
3832
+ v.get_patch_by_id("10").set_color(colors[0])
3833
+ v.get_patch_by_id("01").set_color(colors[1])
3834
+ v.get_patch_by_id("11").set_color(
3835
+ get_color_overlap(colors[0], colors[1]) if colors else None
3836
+ )
3641
3837
  # v.get_label_by_id('10').set_text(len(set1 - set2))
3642
3838
  # v.get_label_by_id('01').set_text(len(set2 - set1))
3643
3839
  # v.get_label_by_id('11').set_text(len(set1 & set2))
3644
- v.get_label_by_id('10').set_text(get_count_and_percentage(len(set1), len(set1 - set2)))
3645
- v.get_label_by_id('01').set_text(get_count_and_percentage(len(set2), len(set2 - set1)))
3646
- v.get_label_by_id('11').set_text(get_count_and_percentage(len(set1 | set2), len(set1 & set2)))
3647
-
3840
+ v.get_label_by_id("10").set_text(
3841
+ get_count_and_percentage(len(set1), len(set1 - set2))
3842
+ )
3843
+ v.get_label_by_id("01").set_text(
3844
+ get_count_and_percentage(len(set2), len(set2 - set1))
3845
+ )
3846
+ v.get_label_by_id("11").set_text(
3847
+ get_count_and_percentage(len(set1 | set2), len(set1 & set2))
3848
+ )
3648
3849
 
3649
- if not isinstance(linewidth,list):
3650
- linewidth=[linewidth]
3651
- if isinstance(linestyle,str):
3652
- linestyle=[linestyle]
3850
+ if not isinstance(linewidth, list):
3851
+ linewidth = [linewidth]
3852
+ if isinstance(linestyle, str):
3853
+ linestyle = [linestyle]
3653
3854
  if not isinstance(edgecolor, list):
3654
- edgecolor=[edgecolor]
3655
- linewidth=linewidth*2 if len(linewidth)==1 else linewidth
3656
- linestyle=linestyle*2 if len(linestyle)==1 else linestyle
3657
- edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
3855
+ edgecolor = [edgecolor]
3856
+ linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
3857
+ linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
3858
+ edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
3658
3859
  for i in range(2):
3659
3860
  venn_circles[i].set_lw(linewidth[i])
3660
3861
  venn_circles[i].set_ls(linestyle[i])
3661
3862
  venn_circles[i].set_edgecolor(edgecolor[i])
3662
3863
  # 椭圆
3663
3864
  if ellipse_shape:
3664
- import matplotlib.patches as patches
3865
+ import matplotlib.patches as patches
3866
+
3665
3867
  for patch in v.patches:
3666
3868
  patch.set_visible(False) # Hide original patches if using ellipses
3667
3869
  center1 = v.get_circle_center(0)
@@ -3672,9 +3874,13 @@ def venn(
3672
3874
  height=ellipse_scale[1],
3673
3875
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3674
3876
  facecolor=colors[0],
3675
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3877
+ lw=(
3878
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3879
+ ), # Ensure lw is a number
3676
3880
  ls=linestyle[0],
3677
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3881
+ alpha=(
3882
+ alpha if isinstance(alpha, (int, float)) else 0.5
3883
+ ), # Ensure alpha is a number
3678
3884
  )
3679
3885
  ellipse2 = patches.Ellipse(
3680
3886
  (center2.x, center2.y),
@@ -3682,48 +3888,78 @@ def venn(
3682
3888
  height=ellipse_scale[1],
3683
3889
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3684
3890
  facecolor=colors[1],
3685
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3891
+ lw=(
3892
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3893
+ ), # Ensure lw is a number
3686
3894
  ls=linestyle[0],
3687
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3895
+ alpha=(
3896
+ alpha if isinstance(alpha, (int, float)) else 0.5
3897
+ ), # Ensure alpha is a number
3688
3898
  )
3689
3899
  ax.add_patch(ellipse1)
3690
3900
  ax.add_patch(ellipse2)
3691
3901
  # Apply styles to set labels
3692
3902
  for i, text in enumerate(v.set_labels):
3693
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3694
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3903
+ textsets(
3904
+ text,
3905
+ fontname=fontname,
3906
+ fontsize=fontsize,
3907
+ fontweight=fontweight,
3908
+ fontstyle=fontstyle,
3909
+ fontcolor=fontcolor,
3910
+ ha=ha,
3911
+ va=va,
3912
+ shadow=shadow,
3913
+ )
3695
3914
 
3696
3915
  # Apply styles to subset labels
3697
3916
  for i, text in enumerate(v.subset_labels):
3698
3917
  if text: # Ensure text exists
3699
3918
  if custom_texts: # Custom text handling
3700
- text.set_text(custom_texts[i])
3701
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3702
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
3919
+ text.set_text(custom_texts[i])
3920
+ textsets(
3921
+ text,
3922
+ fontname=fontname,
3923
+ fontsize=subset_fontsize,
3924
+ fontweight=subset_fontweight,
3925
+ fontstyle=subset_fontstyle,
3926
+ fontcolor=subset_fontcolor,
3927
+ ha=ha,
3928
+ va=va,
3929
+ shadow=shadow,
3930
+ )
3703
3931
 
3704
3932
  elif len(lists) == 3:
3933
+
3705
3934
  def get_label(set_count, subset_count):
3706
3935
  percent = subset_count / set_count if set_count > 0 else 0
3707
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3708
-
3936
+ return (
3937
+ f"{subset_count}\n({fmt.format(percent)})"
3938
+ if show_percentages
3939
+ else f"{subset_count}"
3940
+ )
3941
+
3709
3942
  from matplotlib_venn import venn3, venn3_circles
3943
+
3710
3944
  # Auto-calculate overlap colors for 3-set Venn diagram
3711
3945
  colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
3712
3946
  colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
3713
3947
  colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
3714
- colorABC = get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3715
- set1,set2,set3=lists[0],lists[1],lists[2]
3948
+ colorABC = (
3949
+ get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3950
+ )
3951
+ set1, set2, set3 = lists[0], lists[1], lists[2]
3716
3952
 
3717
3953
  # Draw the venn diagram
3718
- v = venn3(subsets=lists, set_labels=labels, ax=ax,**kwargs)
3719
- v.get_patch_by_id('100').set_color(colors[0])
3720
- v.get_patch_by_id('010').set_color(colors[1])
3721
- v.get_patch_by_id('001').set_color(colors[2])
3722
- v.get_patch_by_id('110').set_color(colorAB)
3723
- v.get_patch_by_id('101').set_color(colorAC)
3724
- v.get_patch_by_id('011').set_color(colorBC)
3725
- v.get_patch_by_id('111').set_color(colorABC)
3726
-
3954
+ v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3955
+ v.get_patch_by_id("100").set_color(colors[0])
3956
+ v.get_patch_by_id("010").set_color(colors[1])
3957
+ v.get_patch_by_id("001").set_color(colors[2])
3958
+ v.get_patch_by_id("110").set_color(colorAB)
3959
+ v.get_patch_by_id("101").set_color(colorAC)
3960
+ v.get_patch_by_id("011").set_color(colorBC)
3961
+ v.get_patch_by_id("111").set_color(colorABC)
3962
+
3727
3963
  # Correctly labeling subset sizes
3728
3964
  # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
3729
3965
  # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
@@ -3732,63 +3968,93 @@ def venn(
3732
3968
  # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
3733
3969
  # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
3734
3970
  # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
3735
- v.get_label_by_id('100').set_text(get_label(len(set1), len(set1 - set2 - set3)))
3736
- v.get_label_by_id('010').set_text(get_label(len(set2), len(set2 - set1 - set3)))
3737
- v.get_label_by_id('001').set_text(get_label(len(set3), len(set3 - set1 - set2)))
3738
- v.get_label_by_id('110').set_text(get_label(len(set1 | set2), len(set1 & set2 - set3)))
3739
- v.get_label_by_id('101').set_text(get_label(len(set1 | set3), len(set1 & set3 - set2)))
3740
- v.get_label_by_id('011').set_text(get_label(len(set2 | set3), len(set2 & set3 - set1)))
3741
- v.get_label_by_id('111').set_text(get_label(len(set1 | set2 | set3), len(set1 & set2 & set3)))
3742
-
3971
+ v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
3972
+ v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
3973
+ v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
3974
+ v.get_label_by_id("110").set_text(
3975
+ get_label(len(set1 | set2), len(set1 & set2 - set3))
3976
+ )
3977
+ v.get_label_by_id("101").set_text(
3978
+ get_label(len(set1 | set3), len(set1 & set3 - set2))
3979
+ )
3980
+ v.get_label_by_id("011").set_text(
3981
+ get_label(len(set2 | set3), len(set2 & set3 - set1))
3982
+ )
3983
+ v.get_label_by_id("111").set_text(
3984
+ get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
3985
+ )
3743
3986
 
3744
3987
  # Apply styles to set labels
3745
3988
  for i, text in enumerate(v.set_labels):
3746
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3747
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3989
+ textsets(
3990
+ text,
3991
+ fontname=fontname,
3992
+ fontsize=fontsize,
3993
+ fontweight=fontweight,
3994
+ fontstyle=fontstyle,
3995
+ fontcolor=fontcolor,
3996
+ ha=ha,
3997
+ va=va,
3998
+ shadow=shadow,
3999
+ )
3748
4000
 
3749
4001
  # Apply styles to subset labels
3750
4002
  for i, text in enumerate(v.subset_labels):
3751
4003
  if text: # Ensure text exists
3752
4004
  if custom_texts: # Custom text handling
3753
4005
  text.set_text(custom_texts[i])
3754
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3755
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
4006
+ textsets(
4007
+ text,
4008
+ fontname=fontname,
4009
+ fontsize=subset_fontsize,
4010
+ fontweight=subset_fontweight,
4011
+ fontstyle=subset_fontstyle,
4012
+ fontcolor=subset_fontcolor,
4013
+ ha=ha,
4014
+ va=va,
4015
+ shadow=shadow,
4016
+ )
3756
4017
 
3757
4018
  venn_circles = venn3_circles(subsets=lists, ax=ax)
3758
- if not isinstance(linewidth,list):
3759
- linewidth=[linewidth]
3760
- if isinstance(linestyle,str):
3761
- linestyle=[linestyle]
4019
+ if not isinstance(linewidth, list):
4020
+ linewidth = [linewidth]
4021
+ if isinstance(linestyle, str):
4022
+ linestyle = [linestyle]
3762
4023
  if not isinstance(edgecolor, list):
3763
- edgecolor=[edgecolor]
3764
- linewidth=linewidth*3 if len(linewidth)==1 else linewidth
3765
- linestyle=linestyle*3 if len(linestyle)==1 else linestyle
3766
- edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
4024
+ edgecolor = [edgecolor]
4025
+ linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
4026
+ linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
4027
+ edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
3767
4028
 
3768
4029
  # edgecolor=[to_rgba(i) for i in edgecolor]
3769
4030
 
3770
- for i in range(3):
4031
+ for i in range(3):
3771
4032
  venn_circles[i].set_lw(linewidth[i])
3772
- venn_circles[i].set_ls(linestyle[i])
3773
- venn_circles[i].set_edgecolor(edgecolor[i])
4033
+ venn_circles[i].set_ls(linestyle[i])
4034
+ venn_circles[i].set_edgecolor(edgecolor[i])
3774
4035
 
3775
- #椭圆形
4036
+ # 椭圆形
3776
4037
  if ellipse_shape:
3777
- import matplotlib.patches as patches
4038
+ import matplotlib.patches as patches
4039
+
3778
4040
  for patch in v.patches:
3779
4041
  patch.set_visible(False) # Hide original patches if using ellipses
3780
4042
  center1 = v.get_circle_center(0)
3781
4043
  center2 = v.get_circle_center(1)
3782
- center3 = v.get_circle_center(2)
4044
+ center3 = v.get_circle_center(2)
3783
4045
  ellipse1 = patches.Ellipse(
3784
4046
  (center1.x, center1.y),
3785
4047
  width=ellipse_scale[0],
3786
4048
  height=ellipse_scale[1],
3787
4049
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3788
4050
  facecolor=colors[0],
3789
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4051
+ lw=(
4052
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4053
+ ), # Ensure lw is a number
3790
4054
  ls=linestyle[0],
3791
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4055
+ alpha=(
4056
+ alpha if isinstance(alpha, (int, float)) else 0.5
4057
+ ), # Ensure alpha is a number
3792
4058
  )
3793
4059
  ellipse2 = patches.Ellipse(
3794
4060
  (center2.x, center2.y),
@@ -3796,9 +4062,13 @@ def venn(
3796
4062
  height=ellipse_scale[1],
3797
4063
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3798
4064
  facecolor=colors[1],
3799
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4065
+ lw=(
4066
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4067
+ ), # Ensure lw is a number
3800
4068
  ls=linestyle[0],
3801
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4069
+ alpha=(
4070
+ alpha if isinstance(alpha, (int, float)) else 0.5
4071
+ ), # Ensure alpha is a number
3802
4072
  )
3803
4073
  ellipse3 = patches.Ellipse(
3804
4074
  (center3.x, center3.y),
@@ -3806,9 +4076,13 @@ def venn(
3806
4076
  height=ellipse_scale[1],
3807
4077
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3808
4078
  facecolor=colors[1],
3809
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4079
+ lw=(
4080
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4081
+ ), # Ensure lw is a number
3810
4082
  ls=linestyle[0],
3811
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4083
+ alpha=(
4084
+ alpha if isinstance(alpha, (int, float)) else 0.5
4085
+ ), # Ensure alpha is a number
3812
4086
  )
3813
4087
  ax.add_patch(ellipse1)
3814
4088
  ax.add_patch(ellipse2)
@@ -3820,17 +4094,20 @@ def venn(
3820
4094
  for patch in v.patches:
3821
4095
  if patch:
3822
4096
  patch.set_alpha(alpha)
3823
- if 'none' in edgecolor or 0 in linewidth:
4097
+ if "none" in edgecolor or 0 in linewidth:
3824
4098
  patch.set_edgecolor("none")
3825
4099
  return ax
3826
4100
 
4101
+
3827
4102
  #! subplots, support automatic extend new axis
3828
- def subplot(rows:int=2,
3829
- cols:int=2,
3830
- figsize:Union[tuple,list]=[8, 8],
3831
- sharex=False,
3832
- sharey=False,
3833
- **kwargs):
4103
+ def subplot(
4104
+ rows: int = 2,
4105
+ cols: int = 2,
4106
+ figsize: Union[tuple, list] = [8, 8],
4107
+ sharex=False,
4108
+ sharey=False,
4109
+ **kwargs,
4110
+ ):
3834
4111
  """
3835
4112
  nexttile = subplot(
3836
4113
  8,
@@ -3850,11 +4127,14 @@ def subplot(rows:int=2,
3850
4127
  ax.set_xlabel(f"Tile {i + 1}")
3851
4128
  """
3852
4129
  from matplotlib.gridspec import GridSpec
4130
+
3853
4131
  if run_once_within():
3854
- print(f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()")
4132
+ print(
4133
+ f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
4134
+ )
3855
4135
  fig = plt.figure(figsize=figsize)
3856
4136
  grid_spec = GridSpec(rows, cols, figure=fig)
3857
- occupied = set()
4137
+ occupied = set()
3858
4138
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
3859
4139
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
3860
4140
 
@@ -3862,8 +4142,9 @@ def subplot(rows:int=2,
3862
4142
  nonlocal rows, grid_spec
3863
4143
  rows += 1 # Expands by adding a row
3864
4144
  grid_spec = GridSpec(rows, cols, figure=fig)
4145
+
3865
4146
  def nexttile(rowspan=1, colspan=1, **kwargs):
3866
- nonlocal rows, cols, occupied, grid_spec
4147
+ nonlocal rows, cols, occupied, grid_spec
3867
4148
  for row in range(rows):
3868
4149
  for col in range(cols):
3869
4150
  if all(
@@ -3873,29 +4154,29 @@ def subplot(rows:int=2,
3873
4154
  ):
3874
4155
  break
3875
4156
  else:
3876
- continue
4157
+ continue
3877
4158
  break
3878
- else:
4159
+ else:
3879
4160
  expand_ax()
3880
4161
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
3881
4162
 
3882
- sharex_ax,sharey_ax = None, None
4163
+ sharex_ax, sharey_ax = None, None
3883
4164
 
3884
- if sharex:
4165
+ if sharex:
3885
4166
  sharex_ax = col_first_axes[col]
3886
4167
 
3887
- if sharey:
3888
- sharey_ax = row_first_axes[row]
4168
+ if sharey:
4169
+ sharey_ax = row_first_axes[row]
3889
4170
  ax = fig.add_subplot(
3890
4171
  grid_spec[row : row + rowspan, col : col + colspan],
3891
4172
  sharex=sharex_ax,
3892
4173
  sharey=sharey_ax,
3893
- **kwargs
3894
- )
4174
+ **kwargs,
4175
+ )
3895
4176
  if row_first_axes[row] is None:
3896
4177
  row_first_axes[row] = ax
3897
4178
  if col_first_axes[col] is None:
3898
- col_first_axes[col] = ax
4179
+ col_first_axes[col] = ax
3899
4180
  for r in range(row, row + rowspan):
3900
4181
  for c in range(col, col + colspan):
3901
4182
  occupied.add((r, c))
@@ -3907,17 +4188,17 @@ def subplot(rows:int=2,
3907
4188
 
3908
4189
  #! radar chart
3909
4190
  def radar(
3910
- data: pd.DataFrame,
3911
- ylim=(0,100),
4191
+ data: pd.DataFrame,
4192
+ ylim=(0, 100),
3912
4193
  color=get_color(5),
3913
4194
  fontsize=10,
3914
- fontcolor='k',
4195
+ fontcolor="k",
3915
4196
  size=6,
3916
4197
  linewidth=1,
3917
4198
  linestyle="-",
3918
4199
  alpha=0.5,
3919
4200
  marker="o",
3920
- edgecolor='none',
4201
+ edgecolor="none",
3921
4202
  edge_linewidth=0,
3922
4203
  bg_color="0.8",
3923
4204
  bg_alpha=None,
@@ -3928,18 +4209,19 @@ def radar(
3928
4209
  legend_fontsize=10,
3929
4210
  grid_color="gray",
3930
4211
  grid_alpha=0.5,
3931
- grid_linestyle="--",grid_linewidth=0.5,
4212
+ grid_linestyle="--",
4213
+ grid_linewidth=0.5,
3932
4214
  circular: bool = False,
3933
4215
  tick_fontsize=None,
3934
4216
  tick_fontcolor="0.65",
3935
- tick_loc = None,# label position
3936
- turning = None,
4217
+ tick_loc=None, # label position
4218
+ turning=None,
3937
4219
  ax=None,
3938
4220
  sp=2,
3939
- **kwargs
4221
+ **kwargs,
3940
4222
  ):
3941
4223
  """
3942
- Example DATA:
4224
+ Example DATA:
3943
4225
  df = pd.DataFrame(
3944
4226
  data=[
3945
4227
  [80, 80, 80, 80, 80, 80, 80],
@@ -3948,7 +4230,7 @@ def radar(
3948
4230
  ],
3949
4231
  index=["Hero", "Warrior", "Wizard"],
3950
4232
  columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
3951
-
4233
+
3952
4234
  Parameters:
3953
4235
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
3954
4236
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4002,8 +4284,12 @@ def radar(
4002
4284
 
4003
4285
  # bg_color
4004
4286
  if bg_alpha is None:
4005
- bg_alpha=alpha
4006
- ax.set_facecolor(to_rgba(bg_color,alpha=bg_alpha)) if circular else ax.set_facecolor('none')
4287
+ bg_alpha = alpha
4288
+ (
4289
+ ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
4290
+ if circular
4291
+ else ax.set_facecolor("none")
4292
+ )
4007
4293
  # Set up the radar chart with straight-line connections
4008
4294
  ax.set_theta_offset(np.pi / 2)
4009
4295
  ax.set_theta_direction(-1)
@@ -4012,53 +4298,68 @@ def radar(
4012
4298
  ax.set_xticks(angles[:-1])
4013
4299
  ax.set_xticklabels(categories)
4014
4300
 
4015
- # Set y-axis limits and grid intervals
4301
+ # Set y-axis limits and grid intervals
4016
4302
  vmin, vmax = ylim
4017
- if circular:
4018
- #* cicular style
4019
- ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
4020
- ax.grid(axis='both',
4021
- color=grid_color,
4022
- linestyle=grid_linestyle,
4023
- alpha=grid_alpha,
4024
- linewidth=grid_linewidth,
4025
- dash_capstyle='round',
4026
- dash_joinstyle='round',
4027
- )
4028
- ax.spines["polar"].set_color(grid_color)
4303
+ if circular:
4304
+ # * cicular style
4305
+ ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
4306
+ ax.grid(
4307
+ axis="both",
4308
+ color=grid_color,
4309
+ linestyle=grid_linestyle,
4310
+ alpha=grid_alpha,
4311
+ linewidth=grid_linewidth,
4312
+ dash_capstyle="round",
4313
+ dash_joinstyle="round",
4314
+ )
4315
+ ax.spines["polar"].set_color(grid_color)
4029
4316
  ax.spines["polar"].set_linewidth(grid_linewidth)
4030
- ax.spines["polar"].set_linestyle('-')
4317
+ ax.spines["polar"].set_linestyle("-")
4031
4318
  ax.spines["polar"].set_alpha(grid_alpha)
4032
- ax.spines["polar"].set_capstyle('round')
4033
- ax.spines["polar"].set_joinstyle('round')
4319
+ ax.spines["polar"].set_capstyle("round")
4320
+ ax.spines["polar"].set_joinstyle("round")
4034
4321
 
4035
4322
  else:
4036
- #* spider style: spider-style grid (straight lines, not circles)
4323
+ # * spider style: spider-style grid (straight lines, not circles)
4037
4324
  # Create the spider-style grid (straight lines, not circles)
4038
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4325
+ for i in range(1, int(vmax * grid_interval_ratio) + 1):
4039
4326
  ax.plot(
4040
4327
  angles + [angles[0]], # Closing the loop
4041
- [i * vmax * grid_interval_ratio] * (num_vars+1) + [i * vmax * grid_interval_ratio],
4042
- color=grid_color, linestyle=grid_linestyle, alpha=grid_alpha,linewidth=grid_linewidth
4328
+ [i * vmax * grid_interval_ratio] * (num_vars + 1)
4329
+ + [i * vmax * grid_interval_ratio],
4330
+ color=grid_color,
4331
+ linestyle=grid_linestyle,
4332
+ alpha=grid_alpha,
4333
+ linewidth=grid_linewidth,
4043
4334
  )
4044
- # set bg_color
4045
- ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
4335
+ # set bg_color
4336
+ ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4046
4337
  ax.yaxis.grid(False)
4047
4338
  # Move radial labels away from plotted line
4048
4339
  if tick_loc is None:
4049
- tick_loc = np.mean([angles[0],angles[1]])/(2*np.pi)*360 if circular else 0
4340
+ tick_loc = (
4341
+ np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
4342
+ )
4050
4343
 
4051
4344
  ax.set_rlabel_position(tick_loc)
4052
4345
  ax.set_theta_offset(turning) if turning is not None else None
4053
- ax.tick_params(axis='x', labelsize=fontsize, colors=fontcolor) # Optional: for angular labels
4054
- tick_fontsize = fontsize-2 if fontsize is None else tick_fontsize
4055
- ax.tick_params(axis='y', labelsize=tick_fontsize, colors=tick_fontcolor) # For radial labels
4346
+ ax.tick_params(
4347
+ axis="x", labelsize=fontsize, colors=fontcolor
4348
+ ) # Optional: for angular labels
4349
+ tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
4350
+ ax.tick_params(
4351
+ axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
4352
+ ) # For radial labels
4056
4353
  if not circular:
4057
- ax.spines['polar'].set_visible(False)
4058
- ax.tick_params(axis='x', pad=sp) # move spines outward
4059
- ax.tick_params(axis='y', pad=sp) # move spines outward
4354
+ ax.spines["polar"].set_visible(False)
4355
+ ax.tick_params(axis="x", pad=sp) # move spines outward
4356
+ ax.tick_params(axis="y", pad=sp) # move spines outward
4060
4357
  # colors
4061
- colors = get_color(data.shape[0]) if cmap is None else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4358
+ colors = (
4359
+ get_color(data.shape[0])
4360
+ if cmap is None
4361
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4362
+ )
4062
4363
  # Plot each row with straight lines
4063
4364
  for i, (index, row) in enumerate(data.iterrows()):
4064
4365
  values = row.tolist()
@@ -4070,7 +4371,7 @@ def radar(
4070
4371
  linewidth=linewidth,
4071
4372
  linestyle=linestyle,
4072
4373
  label=index,
4073
- clip_on=False
4374
+ clip_on=False,
4074
4375
  )
4075
4376
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4076
4377
 
@@ -4084,13 +4385,23 @@ def radar(
4084
4385
  marker=marker,
4085
4386
  markersize=size,
4086
4387
  markeredgecolor=edgecolor,
4087
- markeredgewidth = edge_linewidth, zorder=10,clip_on=False
4388
+ markeredgewidth=edge_linewidth,
4389
+ zorder=10,
4390
+ clip_on=False,
4088
4391
  )
4089
4392
  # ax.tick_params(axis='y', labelleft=False, left=False)
4090
- if 'legend' in kws_figsets:
4393
+ if "legend" in kws_figsets:
4091
4394
  figsets(ax=ax, **kws_figsets)
4092
4395
  else:
4093
-
4094
- figsets(ax=ax,legend=dict(loc=legend_loc,fontsize=legend_fontsize,
4095
- bbox_to_anchor=[1.1,1.4],ncols=2),**kws_figsets)
4096
- return ax
4396
+
4397
+ figsets(
4398
+ ax=ax,
4399
+ legend=dict(
4400
+ loc=legend_loc,
4401
+ fontsize=legend_fontsize,
4402
+ bbox_to_anchor=[1.1, 1.4],
4403
+ ncols=2,
4404
+ ),
4405
+ **kws_figsets,
4406
+ )
4407
+ return ax