py2ls 0.2.4.10.5__py3-none-any.whl → 0.2.4.10.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py CHANGED
@@ -1,21 +1,37 @@
1
1
  import numpy as np
2
- import pandas as pd
2
+ import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
4
  import matplotlib
5
5
  import seaborn as sns
6
6
  import warnings
7
7
  import logging
8
8
  from typing import Union
9
- from .ips import isa, fsave, fload, mkdir, listdir, figsave, strcmp, unique, get_os, ssplit,flatten,plt_font,run_once_within
9
+ from .ips import (
10
+ isa,
11
+ fsave,
12
+ fload,
13
+ mkdir,
14
+ listdir,
15
+ figsave,
16
+ strcmp,
17
+ unique,
18
+ get_os,
19
+ ssplit,
20
+ flatten,
21
+ plt_font,
22
+ run_once_within,
23
+ )
10
24
  from .stats import *
11
25
  import os
26
+
12
27
  # Suppress INFO messages from fontTools
13
28
  logging.getLogger("fontTools").setLevel(logging.ERROR)
14
- logging.getLogger('matplotlib').setLevel(logging.ERROR)
29
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
15
30
 
16
31
  warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
17
32
  warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
18
33
 
34
+
19
35
  def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
20
36
  """Adds text annotations for various types of Seaborn and Matplotlib plots.
21
37
  Args:
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
449
465
  """
450
466
  from matplotlib.colors import to_rgba
451
467
  import os
452
-
468
+
453
469
  def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
454
470
  if "l" in opt_b["loc"]:
455
471
  xloc_s = xloc - opt_b["x_dist"]
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
755
771
  )
756
772
  else:
757
773
  from scipy.stats import gaussian_kde
774
+
758
775
  kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
759
776
  min_val, max_val = ys.min(), ys.max()
760
777
  y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
@@ -1416,7 +1433,9 @@ def catplot(data, *args, **kwargs):
1416
1433
  style_load = fload(dir_style + style_use + ".json")
1417
1434
  else:
1418
1435
  style_load = fload(
1419
- listdir(dir_style, "json").path.tolist()[style_use]
1436
+ listdir(dir_style, "json", verbose=False).path.tolist()[
1437
+ style_use
1438
+ ]
1420
1439
  )
1421
1440
  style_load = remove_colors_in_dict(style_load)
1422
1441
  opt.update(style_load)
@@ -1812,16 +1831,17 @@ def read_mplstyle(style_file):
1812
1831
  return style_dict
1813
1832
 
1814
1833
 
1815
- def figsets(*args, **kwargs):
1834
+ def figsets(*args, **kwargs):
1816
1835
  import matplotlib
1817
1836
  from cycler import cycler
1837
+
1818
1838
  matplotlib.rc("text", usetex=False)
1819
1839
 
1820
1840
  fig = plt.gcf()
1821
- fontsize = kwargs.get("fontsize",11)
1822
- plt.rcParams["font.size"]=fontsize
1823
- fontname = kwargs.pop("fontname","Arial")
1824
- fontname=plt_font(fontname) # 显示中文
1841
+ fontsize = kwargs.get("fontsize", 11)
1842
+ plt.rcParams["font.size"] = fontsize
1843
+ fontname = kwargs.pop("fontname", "Arial")
1844
+ fontname = plt_font(fontname) # 显示中文
1825
1845
 
1826
1846
  sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
1827
1847
  sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
@@ -1851,18 +1871,21 @@ def figsets(*args, **kwargs):
1851
1871
  nonlocal fontsize, fontname
1852
1872
  if ("fo" in key) and (("size" in key) or ("sz" in key)):
1853
1873
  fontsize = value
1854
- plt.rcParams.update({"font.size": fontsize,
1855
- "figure.titlesize":fontsize,
1856
- "axes.titlesize":fontsize,
1857
- "axes.labelsize": fontsize,
1858
- "xtick.labelsize": fontsize,
1859
- "ytick.labelsize": fontsize,
1860
- "legend.fontsize": fontsize,
1861
- "legend.title_fontsize":fontsize
1862
- })
1874
+ plt.rcParams.update(
1875
+ {
1876
+ "font.size": fontsize,
1877
+ "figure.titlesize": fontsize,
1878
+ "axes.titlesize": fontsize,
1879
+ "axes.labelsize": fontsize,
1880
+ "xtick.labelsize": fontsize,
1881
+ "ytick.labelsize": fontsize,
1882
+ "legend.fontsize": fontsize,
1883
+ "legend.title_fontsize": fontsize,
1884
+ }
1885
+ )
1863
1886
 
1864
1887
  # Customize tick labels
1865
- ax.tick_params(axis='both', which='major', labelsize=fontsize)
1888
+ ax.tick_params(axis="both", which="major", labelsize=fontsize)
1866
1889
  for label in ax.get_xticklabels() + ax.get_yticklabels():
1867
1890
  label.set_fontname(fontname)
1868
1891
 
@@ -1910,15 +1933,15 @@ def figsets(*args, **kwargs):
1910
1933
  if ("x" in key.lower()) and (
1911
1934
  "tic" not in key.lower() and "tk" not in key.lower()
1912
1935
  ):
1913
- ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
1936
+ ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
1914
1937
  if ("y" in key.lower()) and (
1915
1938
  "tic" not in key.lower() and "tk" not in key.lower()
1916
1939
  ):
1917
- ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
1940
+ ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
1918
1941
  if ("z" in key.lower()) and (
1919
1942
  "tic" not in key.lower() and "tk" not in key.lower()
1920
1943
  ):
1921
- ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
1944
+ ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
1922
1945
  if key == "xlabel" and isinstance(value, dict):
1923
1946
  ax.set_xlabel(**value)
1924
1947
  if key == "ylabel" and isinstance(value, dict):
@@ -2110,6 +2133,7 @@ def figsets(*args, **kwargs):
2110
2133
 
2111
2134
  if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
2112
2135
  import matplotlib.ticker as tck
2136
+
2113
2137
  if "x" in value.lower() or "x" in key.lower():
2114
2138
  ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
2115
2139
  if "y" in value.lower() or "y" in key.lower():
@@ -2162,9 +2186,9 @@ def figsets(*args, **kwargs):
2162
2186
  ax.grid(visible=False)
2163
2187
  if "tit" in key.lower():
2164
2188
  if "sup" in key.lower():
2165
- plt.suptitle(value,fontname=fontname,fontsize=fontsize)
2189
+ plt.suptitle(value, fontname=fontname, fontsize=fontsize)
2166
2190
  else:
2167
- ax.set_title(value,fontname=fontname,fontsize=fontsize)
2191
+ ax.set_title(value, fontname=fontname, fontsize=fontsize)
2168
2192
  if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
2169
2193
  if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
2170
2194
  if value:
@@ -2185,9 +2209,14 @@ def figsets(*args, **kwargs):
2185
2209
  legend = ax.get_legend()
2186
2210
  if legend is not None:
2187
2211
  legend.remove()
2188
- if any(['colorbar' in key.lower(),'cbar' in key.lower()]) and "loc" in key.lower():
2212
+ if (
2213
+ any(["colorbar" in key.lower(), "cbar" in key.lower()])
2214
+ and "loc" in key.lower()
2215
+ ):
2189
2216
  cbar = ax.collections[0].colorbar # Access the colorbar from the plot
2190
- cbar.ax.set_position(value) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2217
+ cbar.ax.set_position(
2218
+ value
2219
+ ) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
2191
2220
 
2192
2221
  for arg in args:
2193
2222
  if isinstance(arg, matplotlib.axes._axes.Axes):
@@ -2225,7 +2254,8 @@ def figsets(*args, **kwargs):
2225
2254
  if len(fig.get_axes()) > 1:
2226
2255
  plt.tight_layout()
2227
2256
 
2228
- def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
+
2258
+ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2229
2259
  """
2230
2260
  split_legend(
2231
2261
  ax,
@@ -2238,15 +2268,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2238
2268
  # Retrieve all lines and labels from the axis
2239
2269
  handles, labels = ax.get_legend_handles_labels()
2240
2270
  num_labels = len(labels)
2241
-
2271
+
2242
2272
  # Calculate the number of labels per legend part
2243
- labels_per_part = (num_labels + n - 1) // n # Round up
2273
+ labels_per_part = (num_labels + n - 1) // n # Round up
2244
2274
  # Create a list to hold each legend object
2245
2275
  legends = []
2246
2276
 
2247
2277
  # Default locations and titles if not specified
2248
2278
  if loc is None:
2249
- loc = ['best'] * n
2279
+ loc = ["best"] * n
2250
2280
  if title is None:
2251
2281
  title = [None] * n
2252
2282
  if bbox is None:
@@ -2257,7 +2287,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2257
2287
  # Calculate the range of labels for this part
2258
2288
  start_idx = i * labels_per_part
2259
2289
  end_idx = min(start_idx + labels_per_part, num_labels)
2260
-
2290
+
2261
2291
  # Skip if no labels in this range
2262
2292
  if start_idx >= end_idx:
2263
2293
  break
@@ -2267,19 +2297,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
2267
2297
  part_labels = labels[start_idx:end_idx]
2268
2298
 
2269
2299
  # Create the legend for this part
2270
- legend = ax.legend(handles=part_handles,
2271
- labels=part_labels,
2272
- loc=loc[i],
2273
- title=title[i],
2274
- ncol=ncol,
2275
- bbox_to_anchor=bbox[i],
2276
- **kwargs)
2277
-
2300
+ legend = ax.legend(
2301
+ handles=part_handles,
2302
+ labels=part_labels,
2303
+ loc=loc[i],
2304
+ title=title[i],
2305
+ ncol=ncol,
2306
+ bbox_to_anchor=bbox[i],
2307
+ **kwargs,
2308
+ )
2309
+
2278
2310
  # Add the legend to the axis and save it to the list
2279
- ax.add_artist(legend) if i !=(n-1) else None # the lastone will be added automaticaly
2311
+ (
2312
+ ax.add_artist(legend) if i != (n - 1) else None
2313
+ ) # the lastone will be added automaticaly
2280
2314
  legends.append(legend)
2281
2315
  return legends
2282
2316
 
2317
+
2283
2318
  def get_colors(
2284
2319
  n: int = 1,
2285
2320
  cmap: str = "auto",
@@ -2289,7 +2324,9 @@ def get_colors(
2289
2324
  *args,
2290
2325
  **kwargs,
2291
2326
  ):
2292
- return get_color(n,cmap,alpha,output,*args,**kwargs)
2327
+ return get_color(n, cmap, alpha, output, *args, **kwargs)
2328
+
2329
+
2293
2330
  def get_color(
2294
2331
  n: int = 1,
2295
2332
  cmap: str = "auto",
@@ -2300,6 +2337,7 @@ def get_color(
2300
2337
  **kwargs,
2301
2338
  ):
2302
2339
  from cycler import cycler
2340
+
2303
2341
  def cmap2hex(cmap_name):
2304
2342
  cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
2305
2343
  colors = [cmap_(i) for i in range(cmap_.N)]
@@ -2347,49 +2385,137 @@ def get_color(
2347
2385
  return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
2348
2386
  else:
2349
2387
  return "#{:02X}{:02X}{:02X}".format(r, g, b)
2350
-
2388
+
2351
2389
  # sc.pl.palettes.default_20
2352
- cmap_20= ['#1f77b4','#ff7f0e','#279e68','#d62728','#aa40fc','#8c564b','#e377c2','#b5bd61',
2353
- '#17becf','#aec7e8','#ffbb78','#98df8a','#ff9896','#c5b0d5','#c49c94','#f7b6d2',
2354
- '#dbdb8d','#9edae5','#ad494a','#8c6d31']
2390
+ cmap_20 = [
2391
+ "#1f77b4",
2392
+ "#ff7f0e",
2393
+ "#279e68",
2394
+ "#d62728",
2395
+ "#aa40fc",
2396
+ "#8c564b",
2397
+ "#e377c2",
2398
+ "#b5bd61",
2399
+ "#17becf",
2400
+ "#aec7e8",
2401
+ "#ffbb78",
2402
+ "#98df8a",
2403
+ "#ff9896",
2404
+ "#c5b0d5",
2405
+ "#c49c94",
2406
+ "#f7b6d2",
2407
+ "#dbdb8d",
2408
+ "#9edae5",
2409
+ "#ad494a",
2410
+ "#8c6d31",
2411
+ ]
2355
2412
  # sc.pl.palettes.zeileis_28
2356
- cmap_28 = ["#023fa5","#7d87b9","#bec1d4","#d6bcc0","#bb7784","#8e063b","#4a6fe3","#8595e1",
2357
- "#b5bbe3","#e6afb9","#e07b91","#d33f6a","#11c638","#8dd593","#c6dec7","#ead3c6",
2358
- "#f0b98d","#ef9708","#0fcfc0","#9cded6","#d5eae7","#f3e1eb","#f6c4e1","#f79cd4",
2359
- "#7f7f7f","#c7c7c7","#1CE6FF","#336600"]
2413
+ cmap_28 = [
2414
+ "#023fa5",
2415
+ "#7d87b9",
2416
+ "#bec1d4",
2417
+ "#d6bcc0",
2418
+ "#bb7784",
2419
+ "#8e063b",
2420
+ "#4a6fe3",
2421
+ "#8595e1",
2422
+ "#b5bbe3",
2423
+ "#e6afb9",
2424
+ "#e07b91",
2425
+ "#d33f6a",
2426
+ "#11c638",
2427
+ "#8dd593",
2428
+ "#c6dec7",
2429
+ "#ead3c6",
2430
+ "#f0b98d",
2431
+ "#ef9708",
2432
+ "#0fcfc0",
2433
+ "#9cded6",
2434
+ "#d5eae7",
2435
+ "#f3e1eb",
2436
+ "#f6c4e1",
2437
+ "#f79cd4",
2438
+ "#7f7f7f",
2439
+ "#c7c7c7",
2440
+ "#1CE6FF",
2441
+ "#336600",
2442
+ ]
2360
2443
  if cmap == "gray":
2361
2444
  cmap = "grey"
2362
- elif cmap=="20":
2363
- cmap=cmap_20
2364
- elif cmap=="28":
2365
- cmap=cmap_28
2445
+ elif cmap == "20":
2446
+ cmap = cmap_20
2447
+ elif cmap == "28":
2448
+ cmap = cmap_28
2366
2449
  # Determine color list based on cmap parameter
2367
- if isinstance(cmap,str):
2450
+ if isinstance(cmap, str):
2368
2451
  if "aut" in cmap:
2369
2452
  if n == 1:
2370
2453
  colorlist = ["#3A4453"]
2371
2454
  elif n == 2:
2372
2455
  colorlist = ["#3A4453", "#FF2C00"]
2373
2456
  elif n == 3:
2374
- colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
2457
+ colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
2375
2458
  elif n == 4:
2376
- colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
2459
+ colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
2377
2460
  elif n == 5:
2378
- colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
2461
+ colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
2379
2462
  elif n == 6:
2380
- colorlist = ["#FF2C00","#91bfdb", "#B25E9D", "#4B8C3B","#EF8632", "#24578E"]
2381
- elif n==7:
2382
- colorlist = ["#7F7F7F", "#459AA9", "#B25E9D", "#4B8C3B","#EF8632", "#24578E" "#FF2C00"]
2383
- elif n==8:
2463
+ colorlist = [
2464
+ "#FF2C00",
2465
+ "#91bfdb",
2466
+ "#B25E9D",
2467
+ "#4B8C3B",
2468
+ "#EF8632",
2469
+ "#24578E",
2470
+ ]
2471
+ elif n == 7:
2472
+ colorlist = [
2473
+ "#7F7F7F",
2474
+ "#459AA9",
2475
+ "#B25E9D",
2476
+ "#4B8C3B",
2477
+ "#EF8632",
2478
+ "#24578E" "#FF2C00",
2479
+ ]
2480
+ elif n == 8:
2384
2481
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
2385
2482
  # colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
2386
- colorlist = ["#78BFDA","#D52E6F","#F7D648","#A52D28","#6B9F41","#E18330","#E18B9D","#3C88CC"]
2387
- elif n==9:
2388
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#d62728','#aa40fc','#e377c2','#51B34F','#17becf']
2389
- elif n==10:
2390
- colorlist = ['#1f77b4','#ff7f0e','#367B7F','#ff9896','#51B34F','#d62728''#aa40fc','#e377c2','#375FD2','#17becf']
2391
- elif 10<n<=20:
2392
- colorlist = cmap_20
2483
+ colorlist = [
2484
+ "#78BFDA",
2485
+ "#D52E6F",
2486
+ "#F7D648",
2487
+ "#A52D28",
2488
+ "#6B9F41",
2489
+ "#E18330",
2490
+ "#E18B9D",
2491
+ "#3C88CC",
2492
+ ]
2493
+ elif n == 9:
2494
+ colorlist = [
2495
+ "#1f77b4",
2496
+ "#ff7f0e",
2497
+ "#367B7F",
2498
+ "#ff9896",
2499
+ "#d62728",
2500
+ "#aa40fc",
2501
+ "#e377c2",
2502
+ "#51B34F",
2503
+ "#17becf",
2504
+ ]
2505
+ elif n == 10:
2506
+ colorlist = [
2507
+ "#1f77b4",
2508
+ "#ff7f0e",
2509
+ "#367B7F",
2510
+ "#ff9896",
2511
+ "#51B34F",
2512
+ "#d62728" "#aa40fc",
2513
+ "#e377c2",
2514
+ "#375FD2",
2515
+ "#17becf",
2516
+ ]
2517
+ elif 10 < n <= 20:
2518
+ colorlist = cmap_20
2393
2519
  else:
2394
2520
  colorlist = cmap_28
2395
2521
  by = "start"
@@ -2426,7 +2552,7 @@ def get_color(
2426
2552
  by = "linspace"
2427
2553
  colorlist = cmap2hex(cmap)
2428
2554
  elif isinstance(cmap, list):
2429
- colorlist=cmap
2555
+ colorlist = cmap
2430
2556
 
2431
2557
  # Determine method for generating color list
2432
2558
  if "st" in by.lower() or "be" in by.lower():
@@ -2446,7 +2572,7 @@ def get_color(
2446
2572
  return hue_list
2447
2573
  else:
2448
2574
  raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
2449
-
2575
+
2450
2576
 
2451
2577
  def stdshade(ax=None, *args, **kwargs):
2452
2578
  """
@@ -2454,7 +2580,7 @@ def stdshade(ax=None, *args, **kwargs):
2454
2580
  plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
2455
2581
  """
2456
2582
  from scipy.signal import savgol_filter
2457
-
2583
+
2458
2584
  # Separate kws_line and kws_fill if necessary
2459
2585
  kws_line = kwargs.pop("kws_line", {})
2460
2586
  kws_fill = kwargs.pop("kws_fill", {})
@@ -2724,37 +2850,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
2724
2850
  # cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
2725
2851
  # return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
2726
2852
 
2727
- def add_colorbar(im,
2728
- cmap="viridis",
2729
- vmin=-1,
2730
- vmax=1,
2731
- orientation='vertical',
2732
- width_ratio=0.05,
2733
- pad_ratio=0.02,
2734
- shrink=1.0,
2735
- **kwargs):
2853
+
2854
+ def add_colorbar(
2855
+ im,
2856
+ cmap="viridis",
2857
+ vmin=-1,
2858
+ vmax=1,
2859
+ orientation="vertical",
2860
+ width_ratio=0.05,
2861
+ pad_ratio=0.02,
2862
+ shrink=1.0,
2863
+ **kwargs,
2864
+ ):
2736
2865
  import matplotlib as mpl
2866
+
2737
2867
  if all([cmap, vmin, vmax]):
2738
2868
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
2739
2869
  else:
2740
- norm=False
2870
+ norm = False
2741
2871
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
2742
2872
  sm.set_array([])
2743
2873
  l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
2744
- if orientation == 'vertical':
2874
+ if orientation == "vertical":
2745
2875
  width = width_ratio * w
2746
2876
  pad = pad_ratio * w
2747
- cax = im.figure.add_axes([l + w + pad, b, width, h * shrink]) # Right of the image
2877
+ cax = im.figure.add_axes(
2878
+ [l + w + pad, b, width, h * shrink]
2879
+ ) # Right of the image
2748
2880
  else:
2749
2881
  height = width_ratio * h
2750
2882
  pad = pad_ratio * h
2751
- cax = im.figure.add_axes([l, b - height - pad, w * shrink, height]) # Below the image
2752
- cbar=im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2753
- return cbar
2883
+ cax = im.figure.add_axes(
2884
+ [l, b - height - pad, w * shrink, height]
2885
+ ) # Below the image
2886
+ cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
2887
+ return cbar
2888
+
2754
2889
 
2755
2890
  # Usage:
2756
2891
  # add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
2757
2892
 
2893
+
2758
2894
  def generate_xticks_with_gap(x_len, hue_len):
2759
2895
  """
2760
2896
  Generate a concatenated array based on x_len and hue_len,
@@ -2884,6 +3020,7 @@ def style_examples(
2884
3020
  f = listdir(
2885
3021
  "/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
2886
3022
  kind=".json",
3023
+ verbose=False,
2887
3024
  )
2888
3025
  display(f.sample(2))
2889
3026
  # def style_example(dir_save,)
@@ -2961,6 +3098,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
2961
3098
 
2962
3099
  def get_params_from_func_usage(function_signature):
2963
3100
  import re
3101
+
2964
3102
  # Regular expression to match parameter names, ignoring '*' and '**kwargs'
2965
3103
  keys_pattern = r"(?<!\*\*)\b(\w+)="
2966
3104
  # Find all matches
@@ -2987,7 +3125,7 @@ def plotxy(
2987
3125
  x=None,
2988
3126
  y=None,
2989
3127
  ax=None,
2990
- kind: Union[str, list] = 'scatter', # Specify the kind of plot
3128
+ kind: Union[str, list] = "scatter", # Specify the kind of plot
2991
3129
  verbose=False,
2992
3130
  **kwargs,
2993
3131
  ):
@@ -3011,14 +3149,15 @@ def plotxy(
3011
3149
  # Check for valid plot kind
3012
3150
  # Default arguments for various plot types
3013
3151
  from pathlib import Path
3152
+
3014
3153
  # Get the current script's directory as a Path object
3015
- current_directory = Path(__file__).resolve().parent
3154
+ current_directory = Path(__file__).resolve().parent
3155
+
3156
+ if not "default_settings" in locals():
3157
+ default_settings = fload(current_directory / "data" / "usages_sns.json")
3158
+ if not "sns_info" in locals():
3159
+ sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
3016
3160
 
3017
- if not 'default_settings' in locals():
3018
- default_settings = fload(current_directory / 'data' / 'usages_sns.json')
3019
- if not 'sns_info' in locals():
3020
- sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
3021
-
3022
3161
  valid_kinds = list(default_settings.keys())
3023
3162
  # print(valid_kinds)
3024
3163
  if kind is not None:
@@ -3032,7 +3171,7 @@ def plotxy(
3032
3171
  if kind is not None:
3033
3172
  for k in kind:
3034
3173
  if k in valid_kinds:
3035
- print(f"{k}:\n\t{default_settings[k]}")
3174
+ print(f"{k}:\n\t{default_settings[k]}")
3036
3175
  usage_str = """plotxy(data=ranked_genes,
3037
3176
  x="log2(fold_change)",
3038
3177
  y="-log10(p-value)",
@@ -3057,15 +3196,25 @@ def plotxy(
3057
3196
  kws_add_text = v_arg
3058
3197
  kwargs.pop(k_arg, None)
3059
3198
  break
3060
- zorder=0
3199
+ zorder = 0
3061
3200
  for k in kind:
3062
- zorder+=1
3201
+ zorder += 1
3063
3202
  # indicate 'col' features
3064
3203
  col = kwargs.get("col", None)
3065
- sns_with_col = ["catplot","histplot","relplot","lmplot","pairplot","displot","kdeplot"]
3204
+ sns_with_col = [
3205
+ "catplot",
3206
+ "histplot",
3207
+ "relplot",
3208
+ "lmplot",
3209
+ "pairplot",
3210
+ "displot",
3211
+ "kdeplot",
3212
+ ]
3066
3213
  if col is not None:
3067
3214
  if not k in sns_with_col:
3068
- print(f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}")
3215
+ print(
3216
+ f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
3217
+ )
3069
3218
  # (1) return FcetGrid
3070
3219
  if k == "jointplot":
3071
3220
  kws_joint = kwargs.pop("kws_joint", kwargs)
@@ -3092,77 +3241,98 @@ def plotxy(
3092
3241
  ax = stdshade(ax=ax, **kwargs)
3093
3242
  elif k == "scatterplot":
3094
3243
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3095
- kws_scatter={k: v for k, v in kws_scatter.items() if not k.startswith("kws_")}
3244
+ kws_scatter = {
3245
+ k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3246
+ }
3096
3247
  hue = kwargs.pop("hue", None)
3097
3248
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3098
3249
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3099
- palette = kws_scatter.pop("palette",get_color(data[hue].nunique()) if hue is not None else None)
3250
+ palette = kws_scatter.pop(
3251
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3252
+ )
3100
3253
  s = kws_scatter.pop("s", 10)
3101
3254
  alpha = kws_scatter.pop("alpha", 0.7)
3102
- ax = sns.scatterplot(ax=ax,data=data,x=x,y=y,hue=hue,palette=palette,s=s,alpha=alpha,zorder=zorder,**kws_scatter)
3255
+ ax = sns.scatterplot(
3256
+ ax=ax,
3257
+ data=data,
3258
+ x=x,
3259
+ y=y,
3260
+ hue=hue,
3261
+ palette=palette,
3262
+ s=s,
3263
+ alpha=alpha,
3264
+ zorder=zorder,
3265
+ **kws_scatter,
3266
+ )
3103
3267
  elif k == "histplot":
3104
3268
  kws_hist = kwargs.pop("kws_hist", kwargs)
3105
- kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3106
- ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
3269
+ kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3270
+ ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3107
3271
  elif k == "kdeplot":
3108
3272
  kws_kde = kwargs.pop("kws_kde", kwargs)
3109
- kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3110
- ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
3273
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3274
+ ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3111
3275
  elif k == "ecdfplot":
3112
3276
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3113
- kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3114
- ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
3277
+ kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3278
+ ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3115
3279
  elif k == "rugplot":
3116
3280
  kws_rug = kwargs.pop("kws_rug", kwargs)
3117
- kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3118
- ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
3281
+ kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3282
+ ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3119
3283
  elif k == "stripplot":
3120
3284
  kws_strip = kwargs.pop("kws_strip", kwargs)
3121
- kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3285
+ kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3122
3286
  dodge = kws_strip.pop("dodge", True)
3123
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax,zorder=zorder, dodge=dodge, **kws_strip)
3287
+ ax = sns.stripplot(
3288
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3289
+ )
3124
3290
  elif k == "swarmplot":
3125
3291
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3126
- kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3127
- ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
3292
+ kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
3293
+ ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
3128
3294
  elif k == "boxplot":
3129
3295
  kws_box = kwargs.pop("kws_box", kwargs)
3130
- kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3131
- ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
3296
+ kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
3297
+ ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
3132
3298
  elif k == "violinplot":
3133
3299
  kws_violin = kwargs.pop("kws_violin", kwargs)
3134
- kws_violin={k: v for k, v in kws_violin.items() if not k.startswith("kws_")}
3135
- ax = sns.violinplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_violin)
3300
+ kws_violin = {
3301
+ k: v for k, v in kws_violin.items() if not k.startswith("kws_")
3302
+ }
3303
+ ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
3136
3304
  elif k == "boxenplot":
3137
3305
  kws_boxen = kwargs.pop("kws_boxen", kwargs)
3138
- kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3139
- ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
3306
+ kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
3307
+ ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
3140
3308
  elif k == "pointplot":
3141
3309
  kws_point = kwargs.pop("kws_point", kwargs)
3142
- kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3143
- ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
3310
+ kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
3311
+ ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
3144
3312
  elif k == "barplot":
3145
3313
  kws_bar = kwargs.pop("kws_bar", kwargs)
3146
- kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3147
- ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
3314
+ kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
3315
+ ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
3148
3316
  elif k == "countplot":
3149
3317
  kws_count = kwargs.pop("kws_count", kwargs)
3150
- kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3151
- if not kws_count.get("hue",None):
3152
- kws_count.pop("palette",None)
3153
- ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
3318
+ kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
3319
+ if not kws_count.get("hue", None):
3320
+ kws_count.pop("palette", None)
3321
+ ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
3154
3322
  elif k == "regplot":
3155
3323
  kws_reg = kwargs.pop("kws_reg", kwargs)
3156
- kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3157
- ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
3324
+ kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3325
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
3158
3326
  elif k == "residplot":
3159
3327
  kws_resid = kwargs.pop("kws_resid", kwargs)
3160
- kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3161
- ax = sns.residplot(data=data, x=x, y=y, lowess=True,zorder=zorder, ax=ax, **kws_resid)
3328
+ kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
3329
+ ax = sns.residplot(
3330
+ data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
3331
+ )
3162
3332
  elif k == "lineplot":
3163
3333
  kws_line = kwargs.pop("kws_line", kwargs)
3164
- kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3165
- ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
3334
+ kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
3335
+ ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
3166
3336
 
3167
3337
  figsets(ax=ax, **kws_figsets)
3168
3338
  if kws_add_text:
@@ -3176,20 +3346,22 @@ def plotxy(
3176
3346
  return g, ax
3177
3347
  return ax
3178
3348
 
3349
+
3179
3350
  def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
3180
3351
  norm_ = plt.Normalize(min_max[0], min_max[1])
3181
3352
  colormap = plt.get_cmap(cmap)
3182
3353
  return colormap(norm_(data))
3183
3354
 
3355
+
3184
3356
  def volcano(
3185
- data:pd.DataFrame,
3186
- x:str,
3187
- y:str,
3188
- gene_col:str=None,
3189
- top_genes=[5, 5], # [down-regulated, up-regulated]
3190
- thr_x=np.log2(1.5), # default: 0.585
3357
+ data: pd.DataFrame,
3358
+ x: str,
3359
+ y: str,
3360
+ gene_col: str = None,
3361
+ top_genes=[5, 5], # [down-regulated, up-regulated]
3362
+ thr_x=np.log2(1.5), # default: 0.585
3191
3363
  thr_y=-np.log10(0.05),
3192
- sort_xy="x", #'y', 'xy'
3364
+ sort_xy="x", #'y', 'xy'
3193
3365
  colors=("#00BFFF", "#9d9a9a", "#FF3030"),
3194
3366
  s=20,
3195
3367
  fill=True, # plot filled scatter
@@ -3201,11 +3373,10 @@ def volcano(
3201
3373
  ax=None,
3202
3374
  verbose=False,
3203
3375
  kws_text=dict(fontsize=10, color="k"),
3204
- kws_bbox=dict(facecolor='none',
3205
- alpha=0.5,
3206
- edgecolor='black',
3207
- boxstyle='round,pad=0.3'),# '{}' to hide
3208
- kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
3376
+ kws_bbox=dict(
3377
+ facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
3378
+ ), # '{}' to hide
3379
+ kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
3209
3380
  **kwargs,
3210
3381
  ):
3211
3382
  """
@@ -3275,39 +3446,35 @@ def volcano(
3275
3446
  kws_figsets = v_arg
3276
3447
  kwargs.pop(k_arg, None)
3277
3448
  break
3278
-
3279
- data=data.copy()
3449
+
3450
+ data = data.copy()
3280
3451
  # filter nan
3281
3452
  data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
3282
- data.loc[:,"color"] = np.where(
3453
+ data.loc[:, "color"] = np.where(
3283
3454
  (data[x] > thr_x) & (data[y] > thr_y),
3284
3455
  colors[2],
3285
- np.where((data[x] < -thr_x) & (data[y] > thr_y),
3286
- colors[0],
3287
- colors[1]),
3456
+ np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
3288
3457
  )
3289
- top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
3290
-
3458
+ top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
3459
+
3291
3460
  # could custom how to select the top genes, x: x has priority
3292
- sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
3293
- ascending_up=[True, True] if sort_xy=="x" else [False, True]
3294
- ascending_down=[False, True] if sort_xy=="x" else [False, False]
3295
-
3296
- down_reg_genes = data[
3297
- (data["color"] == colors[0]) &
3298
- (data[x].abs() > thr_x) &
3299
- (data[y] > thr_y)
3300
- ].sort_values(by=sort_by_x_y, ascending=ascending_up).head(top_genes[0])
3301
- up_reg_genes = data[
3302
- (data["color"] == colors[2]) &
3303
- (data[x].abs() > thr_x) &
3304
- (data[y] > thr_y)
3305
- ].sort_values(by=sort_by_x_y, ascending=ascending_down).head(top_genes[1])
3461
+ sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
3462
+ ascending_up = [True, True] if sort_xy == "x" else [False, True]
3463
+ ascending_down = [False, True] if sort_xy == "x" else [False, False]
3464
+
3465
+ down_reg_genes = (
3466
+ data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3467
+ .sort_values(by=sort_by_x_y, ascending=ascending_up)
3468
+ .head(top_genes[0])
3469
+ )
3470
+ up_reg_genes = (
3471
+ data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
3472
+ .sort_values(by=sort_by_x_y, ascending=ascending_down)
3473
+ .head(top_genes[1])
3474
+ )
3306
3475
  sele_gene = pd.concat([down_reg_genes, up_reg_genes])
3307
-
3308
- palette = {colors[0]: colors[0],
3309
- colors[1]: colors[1],
3310
- colors[2]: colors[2]}
3476
+
3477
+ palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
3311
3478
  # Plot setup
3312
3479
  if ax is None:
3313
3480
  ax = plt.gca()
@@ -3337,9 +3504,9 @@ def volcano(
3337
3504
  )
3338
3505
 
3339
3506
  # Add threshold lines for x and y axes
3340
- ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
3341
- ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
3342
- ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
3507
+ ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
3508
+ ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
3509
+ ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
3343
3510
 
3344
3511
  # Add gene labels for selected significant points
3345
3512
  if gene_col:
@@ -3349,19 +3516,28 @@ def volcano(
3349
3516
  textcolor = kws_text.pop("color", "k")
3350
3517
  fontsize = kws_text.pop("fontsize", 10)
3351
3518
  arrowstyles = [
3352
- "->","<-","<->","<|-","-|>","<|-|>",
3353
- "-","-[","-[",
3354
- "fancy","simple","wedge",
3519
+ "->",
3520
+ "<-",
3521
+ "<->",
3522
+ "<|-",
3523
+ "-|>",
3524
+ "<|-|>",
3525
+ "-",
3526
+ "-[",
3527
+ "-[",
3528
+ "fancy",
3529
+ "simple",
3530
+ "wedge",
3355
3531
  ]
3356
3532
  arrowstyle = kws_arrow.pop("style", "<|-")
3357
- arrowstyle = strcmp(arrowstyle, arrowstyles,scorer='strict')[0]
3358
- expand=kws_arrow.pop("expand",(1.05,1.1))
3533
+ arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
3534
+ expand = kws_arrow.pop("expand", (1.05, 1.1))
3359
3535
  arrowcolor = kws_arrow.pop("color", "0.4")
3360
3536
  arrowlinewidth = kws_arrow.pop("lw", 0.75)
3361
3537
  shrinkA = kws_arrow.pop("shrinkA", 0)
3362
3538
  shrinkB = kws_arrow.pop("shrinkB", 0)
3363
3539
  mutation_scale = kws_arrow.pop("head", 10)
3364
- arrow_fill=kws_arrow.pop("fill", False)
3540
+ arrow_fill = kws_arrow.pop("fill", False)
3365
3541
  for i in range(sele_gene.shape[0]):
3366
3542
  if isinstance(textcolor, list): # be consistant with dots's color
3367
3543
  textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
@@ -3382,7 +3558,7 @@ def volcano(
3382
3558
  adjust_text(
3383
3559
  texts,
3384
3560
  expand=expand,
3385
- min_arrow_len=5,
3561
+ min_arrow_len=5,
3386
3562
  ax=ax,
3387
3563
  arrowprops=dict(
3388
3564
  arrowstyle=arrowstyle,
@@ -3393,7 +3569,7 @@ def volcano(
3393
3569
  shrinkB=shrinkB,
3394
3570
  mutation_scale=mutation_scale,
3395
3571
  **kws_arrow,
3396
- )
3572
+ ),
3397
3573
  )
3398
3574
 
3399
3575
  figsets(**kws_figsets)
@@ -3499,6 +3675,7 @@ def desaturate_color(color, saturation_factor=0.5):
3499
3675
  """Reduce the saturation of a color by a given factor (between 0 and 1)."""
3500
3676
  import matplotlib.colors as mcolors
3501
3677
  import colorsys
3678
+
3502
3679
  # Convert the color to RGB
3503
3680
  rgb = mcolors.to_rgb(color)
3504
3681
  # Convert RGB to HLS (Hue, Lightness, Saturation)
@@ -3508,8 +3685,19 @@ def desaturate_color(color, saturation_factor=0.5):
3508
3685
  # Convert back to RGB
3509
3686
  return colorsys.hls_to_rgb(h, l, s)
3510
3687
 
3511
- def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle="normal",
3512
- fontcolor='k', backgroundcolor=None, shadow=False, ha="center", va="center"):
3688
+
3689
+ def textsets(
3690
+ text,
3691
+ fontname="Arial",
3692
+ fontsize=11,
3693
+ fontweight="normal",
3694
+ fontstyle="normal",
3695
+ fontcolor="k",
3696
+ backgroundcolor=None,
3697
+ shadow=False,
3698
+ ha="center",
3699
+ va="center",
3700
+ ):
3513
3701
  if text: # Ensure text exists
3514
3702
  if fontname:
3515
3703
  text.set_fontname(plt_font(fontname))
@@ -3522,26 +3710,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
3522
3710
  if fontcolor:
3523
3711
  text.set_color(fontcolor)
3524
3712
  if backgroundcolor:
3525
- text.set_backgroundcolor(backgroundcolor)
3713
+ text.set_backgroundcolor(backgroundcolor)
3526
3714
  text.set_horizontalalignment(ha)
3527
- text.set_verticalalignment(va)
3715
+ text.set_verticalalignment(va)
3528
3716
  if shadow:
3529
- text.set_path_effects([
3530
- matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
3531
- ])
3717
+ text.set_path_effects(
3718
+ [matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
3719
+ )
3720
+
3721
+
3532
3722
  def venn(
3533
- lists:list,
3534
- labels:list=None,
3723
+ lists: list,
3724
+ labels: list = None,
3535
3725
  ax=None,
3536
3726
  colors=None,
3537
3727
  edgecolor=None,
3538
3728
  alpha=0.5,
3539
- saturation=.75,
3540
- linewidth=0, # default no edge
3729
+ saturation=0.75,
3730
+ linewidth=0, # default no edge
3541
3731
  linestyle="-",
3542
- fontname='Arial',
3732
+ fontname="Arial",
3543
3733
  fontsize=10,
3544
- fontcolor='k',
3734
+ fontcolor="k",
3545
3735
  fontweight="normal",
3546
3736
  fontstyle="normal",
3547
3737
  ha="center",
@@ -3550,18 +3740,18 @@ def venn(
3550
3740
  subset_fontsize=10,
3551
3741
  subset_fontweight="normal",
3552
3742
  subset_fontstyle="normal",
3553
- subset_fontcolor='k',
3743
+ subset_fontcolor="k",
3554
3744
  backgroundcolor=None,
3555
3745
  custom_texts=None,
3556
- show_percentages=True, # display percentage
3746
+ show_percentages=True, # display percentage
3557
3747
  fmt="{:.1%}",
3558
3748
  ellipse_shape=False, # 椭圆形
3559
- ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
3560
- **kwargs
3749
+ ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
3750
+ **kwargs,
3561
3751
  ):
3562
3752
  """
3563
3753
  Advanced Venn diagram plotting function with extensive customization options.
3564
- Usage:
3754
+ Usage:
3565
3755
  # Define the two sets
3566
3756
  set1 = [1, 2, 3, 4, 5]
3567
3757
  set2 = [4, 5, 6, 7, 8]
@@ -3612,56 +3802,70 @@ def venn(
3612
3802
  """
3613
3803
  if ax is None:
3614
3804
  ax = plt.gca()
3615
- lists=[set(flatten(i, verbose=False)) for i in lists]
3805
+ lists = [set(flatten(i, verbose=False)) for i in lists]
3616
3806
  # Function to apply text styles to labels
3617
3807
  if colors is None:
3618
- colors=["r","b"] if len(lists)==2 else ["r","g","b"]
3808
+ colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
3619
3809
  if labels is None:
3620
- labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
3810
+ labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
3621
3811
  if edgecolor is None:
3622
- edgecolor=colors
3812
+ edgecolor = colors
3623
3813
  colors = [desaturate_color(color, saturation) for color in colors]
3624
3814
  # Check colors and auto-calculate overlaps
3625
3815
  if len(lists) == 2:
3816
+
3626
3817
  def get_count_and_percentage(set_count, subset_count):
3627
3818
  percent = subset_count / set_count if set_count > 0 else 0
3628
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3819
+ return (
3820
+ f"{subset_count}\n({fmt.format(percent)})"
3821
+ if show_percentages
3822
+ else f"{subset_count}"
3823
+ )
3629
3824
 
3630
3825
  from matplotlib_venn import venn2, venn2_circles
3826
+
3631
3827
  # Auto-calculate overlap color for 2-set Venn diagram
3632
3828
  overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
3633
-
3829
+
3634
3830
  # Draw the venn diagram
3635
3831
  v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3636
3832
  venn_circles = venn2_circles(subsets=lists, ax=ax)
3637
- set1,set2=lists[0],lists[1]
3638
- v.get_patch_by_id('10').set_color(colors[0])
3639
- v.get_patch_by_id('01').set_color(colors[1])
3640
- v.get_patch_by_id('11').set_color(get_color_overlap(colors[0], colors[1]) if colors else None)
3833
+ set1, set2 = lists[0], lists[1]
3834
+ v.get_patch_by_id("10").set_color(colors[0])
3835
+ v.get_patch_by_id("01").set_color(colors[1])
3836
+ v.get_patch_by_id("11").set_color(
3837
+ get_color_overlap(colors[0], colors[1]) if colors else None
3838
+ )
3641
3839
  # v.get_label_by_id('10').set_text(len(set1 - set2))
3642
3840
  # v.get_label_by_id('01').set_text(len(set2 - set1))
3643
3841
  # v.get_label_by_id('11').set_text(len(set1 & set2))
3644
- v.get_label_by_id('10').set_text(get_count_and_percentage(len(set1), len(set1 - set2)))
3645
- v.get_label_by_id('01').set_text(get_count_and_percentage(len(set2), len(set2 - set1)))
3646
- v.get_label_by_id('11').set_text(get_count_and_percentage(len(set1 | set2), len(set1 & set2)))
3647
-
3842
+ v.get_label_by_id("10").set_text(
3843
+ get_count_and_percentage(len(set1), len(set1 - set2))
3844
+ )
3845
+ v.get_label_by_id("01").set_text(
3846
+ get_count_and_percentage(len(set2), len(set2 - set1))
3847
+ )
3848
+ v.get_label_by_id("11").set_text(
3849
+ get_count_and_percentage(len(set1 | set2), len(set1 & set2))
3850
+ )
3648
3851
 
3649
- if not isinstance(linewidth,list):
3650
- linewidth=[linewidth]
3651
- if isinstance(linestyle,str):
3652
- linestyle=[linestyle]
3852
+ if not isinstance(linewidth, list):
3853
+ linewidth = [linewidth]
3854
+ if isinstance(linestyle, str):
3855
+ linestyle = [linestyle]
3653
3856
  if not isinstance(edgecolor, list):
3654
- edgecolor=[edgecolor]
3655
- linewidth=linewidth*2 if len(linewidth)==1 else linewidth
3656
- linestyle=linestyle*2 if len(linestyle)==1 else linestyle
3657
- edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
3857
+ edgecolor = [edgecolor]
3858
+ linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
3859
+ linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
3860
+ edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
3658
3861
  for i in range(2):
3659
3862
  venn_circles[i].set_lw(linewidth[i])
3660
3863
  venn_circles[i].set_ls(linestyle[i])
3661
3864
  venn_circles[i].set_edgecolor(edgecolor[i])
3662
3865
  # 椭圆
3663
3866
  if ellipse_shape:
3664
- import matplotlib.patches as patches
3867
+ import matplotlib.patches as patches
3868
+
3665
3869
  for patch in v.patches:
3666
3870
  patch.set_visible(False) # Hide original patches if using ellipses
3667
3871
  center1 = v.get_circle_center(0)
@@ -3672,9 +3876,13 @@ def venn(
3672
3876
  height=ellipse_scale[1],
3673
3877
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3674
3878
  facecolor=colors[0],
3675
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3879
+ lw=(
3880
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3881
+ ), # Ensure lw is a number
3676
3882
  ls=linestyle[0],
3677
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3883
+ alpha=(
3884
+ alpha if isinstance(alpha, (int, float)) else 0.5
3885
+ ), # Ensure alpha is a number
3678
3886
  )
3679
3887
  ellipse2 = patches.Ellipse(
3680
3888
  (center2.x, center2.y),
@@ -3682,48 +3890,78 @@ def venn(
3682
3890
  height=ellipse_scale[1],
3683
3891
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3684
3892
  facecolor=colors[1],
3685
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
3893
+ lw=(
3894
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
3895
+ ), # Ensure lw is a number
3686
3896
  ls=linestyle[0],
3687
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
3897
+ alpha=(
3898
+ alpha if isinstance(alpha, (int, float)) else 0.5
3899
+ ), # Ensure alpha is a number
3688
3900
  )
3689
3901
  ax.add_patch(ellipse1)
3690
3902
  ax.add_patch(ellipse2)
3691
3903
  # Apply styles to set labels
3692
3904
  for i, text in enumerate(v.set_labels):
3693
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3694
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3905
+ textsets(
3906
+ text,
3907
+ fontname=fontname,
3908
+ fontsize=fontsize,
3909
+ fontweight=fontweight,
3910
+ fontstyle=fontstyle,
3911
+ fontcolor=fontcolor,
3912
+ ha=ha,
3913
+ va=va,
3914
+ shadow=shadow,
3915
+ )
3695
3916
 
3696
3917
  # Apply styles to subset labels
3697
3918
  for i, text in enumerate(v.subset_labels):
3698
3919
  if text: # Ensure text exists
3699
3920
  if custom_texts: # Custom text handling
3700
- text.set_text(custom_texts[i])
3701
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3702
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
3921
+ text.set_text(custom_texts[i])
3922
+ textsets(
3923
+ text,
3924
+ fontname=fontname,
3925
+ fontsize=subset_fontsize,
3926
+ fontweight=subset_fontweight,
3927
+ fontstyle=subset_fontstyle,
3928
+ fontcolor=subset_fontcolor,
3929
+ ha=ha,
3930
+ va=va,
3931
+ shadow=shadow,
3932
+ )
3703
3933
 
3704
3934
  elif len(lists) == 3:
3935
+
3705
3936
  def get_label(set_count, subset_count):
3706
3937
  percent = subset_count / set_count if set_count > 0 else 0
3707
- return f"{subset_count}\n({fmt.format(percent)})" if show_percentages else f"{subset_count}"
3708
-
3938
+ return (
3939
+ f"{subset_count}\n({fmt.format(percent)})"
3940
+ if show_percentages
3941
+ else f"{subset_count}"
3942
+ )
3943
+
3709
3944
  from matplotlib_venn import venn3, venn3_circles
3945
+
3710
3946
  # Auto-calculate overlap colors for 3-set Venn diagram
3711
3947
  colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
3712
3948
  colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
3713
3949
  colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
3714
- colorABC = get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3715
- set1,set2,set3=lists[0],lists[1],lists[2]
3950
+ colorABC = (
3951
+ get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
3952
+ )
3953
+ set1, set2, set3 = lists[0], lists[1], lists[2]
3716
3954
 
3717
3955
  # Draw the venn diagram
3718
- v = venn3(subsets=lists, set_labels=labels, ax=ax,**kwargs)
3719
- v.get_patch_by_id('100').set_color(colors[0])
3720
- v.get_patch_by_id('010').set_color(colors[1])
3721
- v.get_patch_by_id('001').set_color(colors[2])
3722
- v.get_patch_by_id('110').set_color(colorAB)
3723
- v.get_patch_by_id('101').set_color(colorAC)
3724
- v.get_patch_by_id('011').set_color(colorBC)
3725
- v.get_patch_by_id('111').set_color(colorABC)
3726
-
3956
+ v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
3957
+ v.get_patch_by_id("100").set_color(colors[0])
3958
+ v.get_patch_by_id("010").set_color(colors[1])
3959
+ v.get_patch_by_id("001").set_color(colors[2])
3960
+ v.get_patch_by_id("110").set_color(colorAB)
3961
+ v.get_patch_by_id("101").set_color(colorAC)
3962
+ v.get_patch_by_id("011").set_color(colorBC)
3963
+ v.get_patch_by_id("111").set_color(colorABC)
3964
+
3727
3965
  # Correctly labeling subset sizes
3728
3966
  # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
3729
3967
  # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
@@ -3732,63 +3970,93 @@ def venn(
3732
3970
  # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
3733
3971
  # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
3734
3972
  # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
3735
- v.get_label_by_id('100').set_text(get_label(len(set1), len(set1 - set2 - set3)))
3736
- v.get_label_by_id('010').set_text(get_label(len(set2), len(set2 - set1 - set3)))
3737
- v.get_label_by_id('001').set_text(get_label(len(set3), len(set3 - set1 - set2)))
3738
- v.get_label_by_id('110').set_text(get_label(len(set1 | set2), len(set1 & set2 - set3)))
3739
- v.get_label_by_id('101').set_text(get_label(len(set1 | set3), len(set1 & set3 - set2)))
3740
- v.get_label_by_id('011').set_text(get_label(len(set2 | set3), len(set2 & set3 - set1)))
3741
- v.get_label_by_id('111').set_text(get_label(len(set1 | set2 | set3), len(set1 & set2 & set3)))
3742
-
3973
+ v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
3974
+ v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
3975
+ v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
3976
+ v.get_label_by_id("110").set_text(
3977
+ get_label(len(set1 | set2), len(set1 & set2 - set3))
3978
+ )
3979
+ v.get_label_by_id("101").set_text(
3980
+ get_label(len(set1 | set3), len(set1 & set3 - set2))
3981
+ )
3982
+ v.get_label_by_id("011").set_text(
3983
+ get_label(len(set2 | set3), len(set2 & set3 - set1))
3984
+ )
3985
+ v.get_label_by_id("111").set_text(
3986
+ get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
3987
+ )
3743
3988
 
3744
3989
  # Apply styles to set labels
3745
3990
  for i, text in enumerate(v.set_labels):
3746
- textsets(text, fontname=fontname, fontsize=fontsize, fontweight=fontweight, fontstyle=fontstyle,
3747
- fontcolor=fontcolor,ha=ha,va=va,shadow=shadow)
3991
+ textsets(
3992
+ text,
3993
+ fontname=fontname,
3994
+ fontsize=fontsize,
3995
+ fontweight=fontweight,
3996
+ fontstyle=fontstyle,
3997
+ fontcolor=fontcolor,
3998
+ ha=ha,
3999
+ va=va,
4000
+ shadow=shadow,
4001
+ )
3748
4002
 
3749
4003
  # Apply styles to subset labels
3750
4004
  for i, text in enumerate(v.subset_labels):
3751
4005
  if text: # Ensure text exists
3752
4006
  if custom_texts: # Custom text handling
3753
4007
  text.set_text(custom_texts[i])
3754
- textsets(text, fontname=fontname, fontsize=subset_fontsize, fontweight=subset_fontweight, fontstyle=subset_fontstyle,
3755
- fontcolor=subset_fontcolor,ha=ha,va=va,shadow=shadow)
4008
+ textsets(
4009
+ text,
4010
+ fontname=fontname,
4011
+ fontsize=subset_fontsize,
4012
+ fontweight=subset_fontweight,
4013
+ fontstyle=subset_fontstyle,
4014
+ fontcolor=subset_fontcolor,
4015
+ ha=ha,
4016
+ va=va,
4017
+ shadow=shadow,
4018
+ )
3756
4019
 
3757
4020
  venn_circles = venn3_circles(subsets=lists, ax=ax)
3758
- if not isinstance(linewidth,list):
3759
- linewidth=[linewidth]
3760
- if isinstance(linestyle,str):
3761
- linestyle=[linestyle]
4021
+ if not isinstance(linewidth, list):
4022
+ linewidth = [linewidth]
4023
+ if isinstance(linestyle, str):
4024
+ linestyle = [linestyle]
3762
4025
  if not isinstance(edgecolor, list):
3763
- edgecolor=[edgecolor]
3764
- linewidth=linewidth*3 if len(linewidth)==1 else linewidth
3765
- linestyle=linestyle*3 if len(linestyle)==1 else linestyle
3766
- edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
4026
+ edgecolor = [edgecolor]
4027
+ linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
4028
+ linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
4029
+ edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
3767
4030
 
3768
4031
  # edgecolor=[to_rgba(i) for i in edgecolor]
3769
4032
 
3770
- for i in range(3):
4033
+ for i in range(3):
3771
4034
  venn_circles[i].set_lw(linewidth[i])
3772
- venn_circles[i].set_ls(linestyle[i])
3773
- venn_circles[i].set_edgecolor(edgecolor[i])
4035
+ venn_circles[i].set_ls(linestyle[i])
4036
+ venn_circles[i].set_edgecolor(edgecolor[i])
3774
4037
 
3775
- #椭圆形
4038
+ # 椭圆形
3776
4039
  if ellipse_shape:
3777
- import matplotlib.patches as patches
4040
+ import matplotlib.patches as patches
4041
+
3778
4042
  for patch in v.patches:
3779
4043
  patch.set_visible(False) # Hide original patches if using ellipses
3780
4044
  center1 = v.get_circle_center(0)
3781
4045
  center2 = v.get_circle_center(1)
3782
- center3 = v.get_circle_center(2)
4046
+ center3 = v.get_circle_center(2)
3783
4047
  ellipse1 = patches.Ellipse(
3784
4048
  (center1.x, center1.y),
3785
4049
  width=ellipse_scale[0],
3786
4050
  height=ellipse_scale[1],
3787
4051
  edgecolor=edgecolor[0] if edgecolor else colors[0],
3788
4052
  facecolor=colors[0],
3789
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4053
+ lw=(
4054
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4055
+ ), # Ensure lw is a number
3790
4056
  ls=linestyle[0],
3791
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4057
+ alpha=(
4058
+ alpha if isinstance(alpha, (int, float)) else 0.5
4059
+ ), # Ensure alpha is a number
3792
4060
  )
3793
4061
  ellipse2 = patches.Ellipse(
3794
4062
  (center2.x, center2.y),
@@ -3796,9 +4064,13 @@ def venn(
3796
4064
  height=ellipse_scale[1],
3797
4065
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3798
4066
  facecolor=colors[1],
3799
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4067
+ lw=(
4068
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4069
+ ), # Ensure lw is a number
3800
4070
  ls=linestyle[0],
3801
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4071
+ alpha=(
4072
+ alpha if isinstance(alpha, (int, float)) else 0.5
4073
+ ), # Ensure alpha is a number
3802
4074
  )
3803
4075
  ellipse3 = patches.Ellipse(
3804
4076
  (center3.x, center3.y),
@@ -3806,9 +4078,13 @@ def venn(
3806
4078
  height=ellipse_scale[1],
3807
4079
  edgecolor=edgecolor[1] if edgecolor else colors[1],
3808
4080
  facecolor=colors[1],
3809
- lw=linewidth if isinstance(linewidth, (int, float)) else 1.0, # Ensure lw is a number
4081
+ lw=(
4082
+ linewidth if isinstance(linewidth, (int, float)) else 1.0
4083
+ ), # Ensure lw is a number
3810
4084
  ls=linestyle[0],
3811
- alpha=alpha if isinstance(alpha, (int, float)) else 0.5 # Ensure alpha is a number
4085
+ alpha=(
4086
+ alpha if isinstance(alpha, (int, float)) else 0.5
4087
+ ), # Ensure alpha is a number
3812
4088
  )
3813
4089
  ax.add_patch(ellipse1)
3814
4090
  ax.add_patch(ellipse2)
@@ -3820,17 +4096,20 @@ def venn(
3820
4096
  for patch in v.patches:
3821
4097
  if patch:
3822
4098
  patch.set_alpha(alpha)
3823
- if 'none' in edgecolor or 0 in linewidth:
4099
+ if "none" in edgecolor or 0 in linewidth:
3824
4100
  patch.set_edgecolor("none")
3825
4101
  return ax
3826
4102
 
4103
+
3827
4104
  #! subplots, support automatic extend new axis
3828
- def subplot(rows:int=2,
3829
- cols:int=2,
3830
- figsize:Union[tuple,list]=[8, 8],
3831
- sharex=False,
3832
- sharey=False,
3833
- **kwargs):
4105
+ def subplot(
4106
+ rows: int = 2,
4107
+ cols: int = 2,
4108
+ figsize: Union[tuple, list] = [8, 8],
4109
+ sharex=False,
4110
+ sharey=False,
4111
+ **kwargs,
4112
+ ):
3834
4113
  """
3835
4114
  nexttile = subplot(
3836
4115
  8,
@@ -3850,11 +4129,14 @@ def subplot(rows:int=2,
3850
4129
  ax.set_xlabel(f"Tile {i + 1}")
3851
4130
  """
3852
4131
  from matplotlib.gridspec import GridSpec
4132
+
3853
4133
  if run_once_within():
3854
- print(f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()")
4134
+ print(
4135
+ f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
4136
+ )
3855
4137
  fig = plt.figure(figsize=figsize)
3856
4138
  grid_spec = GridSpec(rows, cols, figure=fig)
3857
- occupied = set()
4139
+ occupied = set()
3858
4140
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
3859
4141
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
3860
4142
 
@@ -3862,8 +4144,9 @@ def subplot(rows:int=2,
3862
4144
  nonlocal rows, grid_spec
3863
4145
  rows += 1 # Expands by adding a row
3864
4146
  grid_spec = GridSpec(rows, cols, figure=fig)
4147
+
3865
4148
  def nexttile(rowspan=1, colspan=1, **kwargs):
3866
- nonlocal rows, cols, occupied, grid_spec
4149
+ nonlocal rows, cols, occupied, grid_spec
3867
4150
  for row in range(rows):
3868
4151
  for col in range(cols):
3869
4152
  if all(
@@ -3873,29 +4156,29 @@ def subplot(rows:int=2,
3873
4156
  ):
3874
4157
  break
3875
4158
  else:
3876
- continue
4159
+ continue
3877
4160
  break
3878
- else:
4161
+ else:
3879
4162
  expand_ax()
3880
4163
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
3881
4164
 
3882
- sharex_ax,sharey_ax = None, None
4165
+ sharex_ax, sharey_ax = None, None
3883
4166
 
3884
- if sharex:
4167
+ if sharex:
3885
4168
  sharex_ax = col_first_axes[col]
3886
4169
 
3887
- if sharey:
3888
- sharey_ax = row_first_axes[row]
4170
+ if sharey:
4171
+ sharey_ax = row_first_axes[row]
3889
4172
  ax = fig.add_subplot(
3890
4173
  grid_spec[row : row + rowspan, col : col + colspan],
3891
4174
  sharex=sharex_ax,
3892
4175
  sharey=sharey_ax,
3893
- **kwargs
3894
- )
4176
+ **kwargs,
4177
+ )
3895
4178
  if row_first_axes[row] is None:
3896
4179
  row_first_axes[row] = ax
3897
4180
  if col_first_axes[col] is None:
3898
- col_first_axes[col] = ax
4181
+ col_first_axes[col] = ax
3899
4182
  for r in range(row, row + rowspan):
3900
4183
  for c in range(col, col + colspan):
3901
4184
  occupied.add((r, c))
@@ -3907,17 +4190,17 @@ def subplot(rows:int=2,
3907
4190
 
3908
4191
  #! radar chart
3909
4192
  def radar(
3910
- data: pd.DataFrame,
3911
- ylim=(0,100),
4193
+ data: pd.DataFrame,
4194
+ ylim=(0, 100),
3912
4195
  color=get_color(5),
3913
4196
  fontsize=10,
3914
- fontcolor='k',
4197
+ fontcolor="k",
3915
4198
  size=6,
3916
4199
  linewidth=1,
3917
4200
  linestyle="-",
3918
4201
  alpha=0.5,
3919
4202
  marker="o",
3920
- edgecolor='none',
4203
+ edgecolor="none",
3921
4204
  edge_linewidth=0,
3922
4205
  bg_color="0.8",
3923
4206
  bg_alpha=None,
@@ -3928,18 +4211,19 @@ def radar(
3928
4211
  legend_fontsize=10,
3929
4212
  grid_color="gray",
3930
4213
  grid_alpha=0.5,
3931
- grid_linestyle="--",grid_linewidth=0.5,
4214
+ grid_linestyle="--",
4215
+ grid_linewidth=0.5,
3932
4216
  circular: bool = False,
3933
4217
  tick_fontsize=None,
3934
4218
  tick_fontcolor="0.65",
3935
- tick_loc = None,# label position
3936
- turning = None,
4219
+ tick_loc=None, # label position
4220
+ turning=None,
3937
4221
  ax=None,
3938
4222
  sp=2,
3939
- **kwargs
4223
+ **kwargs,
3940
4224
  ):
3941
4225
  """
3942
- Example DATA:
4226
+ Example DATA:
3943
4227
  df = pd.DataFrame(
3944
4228
  data=[
3945
4229
  [80, 80, 80, 80, 80, 80, 80],
@@ -3948,7 +4232,7 @@ def radar(
3948
4232
  ],
3949
4233
  index=["Hero", "Warrior", "Wizard"],
3950
4234
  columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
3951
-
4235
+
3952
4236
  Parameters:
3953
4237
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
3954
4238
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4002,8 +4286,12 @@ def radar(
4002
4286
 
4003
4287
  # bg_color
4004
4288
  if bg_alpha is None:
4005
- bg_alpha=alpha
4006
- ax.set_facecolor(to_rgba(bg_color,alpha=bg_alpha)) if circular else ax.set_facecolor('none')
4289
+ bg_alpha = alpha
4290
+ (
4291
+ ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
4292
+ if circular
4293
+ else ax.set_facecolor("none")
4294
+ )
4007
4295
  # Set up the radar chart with straight-line connections
4008
4296
  ax.set_theta_offset(np.pi / 2)
4009
4297
  ax.set_theta_direction(-1)
@@ -4012,53 +4300,68 @@ def radar(
4012
4300
  ax.set_xticks(angles[:-1])
4013
4301
  ax.set_xticklabels(categories)
4014
4302
 
4015
- # Set y-axis limits and grid intervals
4303
+ # Set y-axis limits and grid intervals
4016
4304
  vmin, vmax = ylim
4017
- if circular:
4018
- #* cicular style
4019
- ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
4020
- ax.grid(axis='both',
4021
- color=grid_color,
4022
- linestyle=grid_linestyle,
4023
- alpha=grid_alpha,
4024
- linewidth=grid_linewidth,
4025
- dash_capstyle='round',
4026
- dash_joinstyle='round',
4027
- )
4028
- ax.spines["polar"].set_color(grid_color)
4305
+ if circular:
4306
+ # * cicular style
4307
+ ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
4308
+ ax.grid(
4309
+ axis="both",
4310
+ color=grid_color,
4311
+ linestyle=grid_linestyle,
4312
+ alpha=grid_alpha,
4313
+ linewidth=grid_linewidth,
4314
+ dash_capstyle="round",
4315
+ dash_joinstyle="round",
4316
+ )
4317
+ ax.spines["polar"].set_color(grid_color)
4029
4318
  ax.spines["polar"].set_linewidth(grid_linewidth)
4030
- ax.spines["polar"].set_linestyle('-')
4319
+ ax.spines["polar"].set_linestyle("-")
4031
4320
  ax.spines["polar"].set_alpha(grid_alpha)
4032
- ax.spines["polar"].set_capstyle('round')
4033
- ax.spines["polar"].set_joinstyle('round')
4321
+ ax.spines["polar"].set_capstyle("round")
4322
+ ax.spines["polar"].set_joinstyle("round")
4034
4323
 
4035
4324
  else:
4036
- #* spider style: spider-style grid (straight lines, not circles)
4325
+ # * spider style: spider-style grid (straight lines, not circles)
4037
4326
  # Create the spider-style grid (straight lines, not circles)
4038
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4327
+ for i in range(1, int(vmax * grid_interval_ratio) + 1):
4039
4328
  ax.plot(
4040
4329
  angles + [angles[0]], # Closing the loop
4041
- [i * vmax * grid_interval_ratio] * (num_vars+1) + [i * vmax * grid_interval_ratio],
4042
- color=grid_color, linestyle=grid_linestyle, alpha=grid_alpha,linewidth=grid_linewidth
4330
+ [i * vmax * grid_interval_ratio] * (num_vars + 1)
4331
+ + [i * vmax * grid_interval_ratio],
4332
+ color=grid_color,
4333
+ linestyle=grid_linestyle,
4334
+ alpha=grid_alpha,
4335
+ linewidth=grid_linewidth,
4043
4336
  )
4044
- # set bg_color
4045
- ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
4337
+ # set bg_color
4338
+ ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4046
4339
  ax.yaxis.grid(False)
4047
4340
  # Move radial labels away from plotted line
4048
4341
  if tick_loc is None:
4049
- tick_loc = np.mean([angles[0],angles[1]])/(2*np.pi)*360 if circular else 0
4342
+ tick_loc = (
4343
+ np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
4344
+ )
4050
4345
 
4051
4346
  ax.set_rlabel_position(tick_loc)
4052
4347
  ax.set_theta_offset(turning) if turning is not None else None
4053
- ax.tick_params(axis='x', labelsize=fontsize, colors=fontcolor) # Optional: for angular labels
4054
- tick_fontsize = fontsize-2 if fontsize is None else tick_fontsize
4055
- ax.tick_params(axis='y', labelsize=tick_fontsize, colors=tick_fontcolor) # For radial labels
4348
+ ax.tick_params(
4349
+ axis="x", labelsize=fontsize, colors=fontcolor
4350
+ ) # Optional: for angular labels
4351
+ tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
4352
+ ax.tick_params(
4353
+ axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
4354
+ ) # For radial labels
4056
4355
  if not circular:
4057
- ax.spines['polar'].set_visible(False)
4058
- ax.tick_params(axis='x', pad=sp) # move spines outward
4059
- ax.tick_params(axis='y', pad=sp) # move spines outward
4356
+ ax.spines["polar"].set_visible(False)
4357
+ ax.tick_params(axis="x", pad=sp) # move spines outward
4358
+ ax.tick_params(axis="y", pad=sp) # move spines outward
4060
4359
  # colors
4061
- colors = get_color(data.shape[0]) if cmap is None else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4360
+ colors = (
4361
+ get_color(data.shape[0])
4362
+ if cmap is None
4363
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4364
+ )
4062
4365
  # Plot each row with straight lines
4063
4366
  for i, (index, row) in enumerate(data.iterrows()):
4064
4367
  values = row.tolist()
@@ -4070,7 +4373,7 @@ def radar(
4070
4373
  linewidth=linewidth,
4071
4374
  linestyle=linestyle,
4072
4375
  label=index,
4073
- clip_on=False
4376
+ clip_on=False,
4074
4377
  )
4075
4378
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4076
4379
 
@@ -4084,13 +4387,23 @@ def radar(
4084
4387
  marker=marker,
4085
4388
  markersize=size,
4086
4389
  markeredgecolor=edgecolor,
4087
- markeredgewidth = edge_linewidth, zorder=10,clip_on=False
4390
+ markeredgewidth=edge_linewidth,
4391
+ zorder=10,
4392
+ clip_on=False,
4088
4393
  )
4089
4394
  # ax.tick_params(axis='y', labelleft=False, left=False)
4090
- if 'legend' in kws_figsets:
4395
+ if "legend" in kws_figsets:
4091
4396
  figsets(ax=ax, **kws_figsets)
4092
4397
  else:
4093
-
4094
- figsets(ax=ax,legend=dict(loc=legend_loc,fontsize=legend_fontsize,
4095
- bbox_to_anchor=[1.1,1.4],ncols=2),**kws_figsets)
4096
- return ax
4398
+
4399
+ figsets(
4400
+ ax=ax,
4401
+ legend=dict(
4402
+ loc=legend_loc,
4403
+ fontsize=legend_fontsize,
4404
+ bbox_to_anchor=[1.1, 1.4],
4405
+ ncols=2,
4406
+ ),
4407
+ **kws_figsets,
4408
+ )
4409
+ return ax