py2ls 0.2.4.10.4__py3-none-any.whl → 0.2.4.10.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py
CHANGED
@@ -1,21 +1,37 @@
|
|
1
1
|
import numpy as np
|
2
|
-
import pandas as pd
|
2
|
+
import pandas as pd
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import matplotlib
|
5
5
|
import seaborn as sns
|
6
6
|
import warnings
|
7
7
|
import logging
|
8
8
|
from typing import Union
|
9
|
-
from .ips import
|
9
|
+
from .ips import (
|
10
|
+
isa,
|
11
|
+
fsave,
|
12
|
+
fload,
|
13
|
+
mkdir,
|
14
|
+
listdir,
|
15
|
+
figsave,
|
16
|
+
strcmp,
|
17
|
+
unique,
|
18
|
+
get_os,
|
19
|
+
ssplit,
|
20
|
+
flatten,
|
21
|
+
plt_font,
|
22
|
+
run_once_within,
|
23
|
+
)
|
10
24
|
from .stats import *
|
11
25
|
import os
|
26
|
+
|
12
27
|
# Suppress INFO messages from fontTools
|
13
28
|
logging.getLogger("fontTools").setLevel(logging.ERROR)
|
14
|
-
logging.getLogger(
|
29
|
+
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
15
30
|
|
16
31
|
warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
|
17
32
|
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
|
18
33
|
|
34
|
+
|
19
35
|
def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
20
36
|
"""Adds text annotations for various types of Seaborn and Matplotlib plots.
|
21
37
|
Args:
|
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
|
|
449
465
|
"""
|
450
466
|
from matplotlib.colors import to_rgba
|
451
467
|
import os
|
452
|
-
|
468
|
+
|
453
469
|
def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
|
454
470
|
if "l" in opt_b["loc"]:
|
455
471
|
xloc_s = xloc - opt_b["x_dist"]
|
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
|
|
755
771
|
)
|
756
772
|
else:
|
757
773
|
from scipy.stats import gaussian_kde
|
774
|
+
|
758
775
|
kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
|
759
776
|
min_val, max_val = ys.min(), ys.max()
|
760
777
|
y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
|
@@ -1812,16 +1829,17 @@ def read_mplstyle(style_file):
|
|
1812
1829
|
return style_dict
|
1813
1830
|
|
1814
1831
|
|
1815
|
-
def figsets(*args, **kwargs):
|
1832
|
+
def figsets(*args, **kwargs):
|
1816
1833
|
import matplotlib
|
1817
1834
|
from cycler import cycler
|
1835
|
+
|
1818
1836
|
matplotlib.rc("text", usetex=False)
|
1819
1837
|
|
1820
1838
|
fig = plt.gcf()
|
1821
|
-
fontsize = kwargs.get("fontsize",11)
|
1822
|
-
plt.rcParams["font.size"]=fontsize
|
1823
|
-
fontname = kwargs.pop("fontname","Arial")
|
1824
|
-
fontname=plt_font(fontname)
|
1839
|
+
fontsize = kwargs.get("fontsize", 11)
|
1840
|
+
plt.rcParams["font.size"] = fontsize
|
1841
|
+
fontname = kwargs.pop("fontname", "Arial")
|
1842
|
+
fontname = plt_font(fontname) # 显示中文
|
1825
1843
|
|
1826
1844
|
sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
|
1827
1845
|
sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
|
@@ -1851,18 +1869,21 @@ def figsets(*args, **kwargs):
|
|
1851
1869
|
nonlocal fontsize, fontname
|
1852
1870
|
if ("fo" in key) and (("size" in key) or ("sz" in key)):
|
1853
1871
|
fontsize = value
|
1854
|
-
plt.rcParams.update(
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1872
|
+
plt.rcParams.update(
|
1873
|
+
{
|
1874
|
+
"font.size": fontsize,
|
1875
|
+
"figure.titlesize": fontsize,
|
1876
|
+
"axes.titlesize": fontsize,
|
1877
|
+
"axes.labelsize": fontsize,
|
1878
|
+
"xtick.labelsize": fontsize,
|
1879
|
+
"ytick.labelsize": fontsize,
|
1880
|
+
"legend.fontsize": fontsize,
|
1881
|
+
"legend.title_fontsize": fontsize,
|
1882
|
+
}
|
1883
|
+
)
|
1863
1884
|
|
1864
1885
|
# Customize tick labels
|
1865
|
-
ax.tick_params(axis=
|
1886
|
+
ax.tick_params(axis="both", which="major", labelsize=fontsize)
|
1866
1887
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
1867
1888
|
label.set_fontname(fontname)
|
1868
1889
|
|
@@ -1910,15 +1931,15 @@ def figsets(*args, **kwargs):
|
|
1910
1931
|
if ("x" in key.lower()) and (
|
1911
1932
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1912
1933
|
):
|
1913
|
-
ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
|
1934
|
+
ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
|
1914
1935
|
if ("y" in key.lower()) and (
|
1915
1936
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1916
1937
|
):
|
1917
|
-
ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
|
1938
|
+
ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
|
1918
1939
|
if ("z" in key.lower()) and (
|
1919
1940
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1920
1941
|
):
|
1921
|
-
ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
|
1942
|
+
ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
|
1922
1943
|
if key == "xlabel" and isinstance(value, dict):
|
1923
1944
|
ax.set_xlabel(**value)
|
1924
1945
|
if key == "ylabel" and isinstance(value, dict):
|
@@ -2110,6 +2131,7 @@ def figsets(*args, **kwargs):
|
|
2110
2131
|
|
2111
2132
|
if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
|
2112
2133
|
import matplotlib.ticker as tck
|
2134
|
+
|
2113
2135
|
if "x" in value.lower() or "x" in key.lower():
|
2114
2136
|
ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
|
2115
2137
|
if "y" in value.lower() or "y" in key.lower():
|
@@ -2162,9 +2184,9 @@ def figsets(*args, **kwargs):
|
|
2162
2184
|
ax.grid(visible=False)
|
2163
2185
|
if "tit" in key.lower():
|
2164
2186
|
if "sup" in key.lower():
|
2165
|
-
plt.suptitle(value,fontname=fontname,fontsize=fontsize)
|
2187
|
+
plt.suptitle(value, fontname=fontname, fontsize=fontsize)
|
2166
2188
|
else:
|
2167
|
-
ax.set_title(value,fontname=fontname,fontsize=fontsize)
|
2189
|
+
ax.set_title(value, fontname=fontname, fontsize=fontsize)
|
2168
2190
|
if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
|
2169
2191
|
if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
|
2170
2192
|
if value:
|
@@ -2185,9 +2207,14 @@ def figsets(*args, **kwargs):
|
|
2185
2207
|
legend = ax.get_legend()
|
2186
2208
|
if legend is not None:
|
2187
2209
|
legend.remove()
|
2188
|
-
if
|
2210
|
+
if (
|
2211
|
+
any(["colorbar" in key.lower(), "cbar" in key.lower()])
|
2212
|
+
and "loc" in key.lower()
|
2213
|
+
):
|
2189
2214
|
cbar = ax.collections[0].colorbar # Access the colorbar from the plot
|
2190
|
-
cbar.ax.set_position(
|
2215
|
+
cbar.ax.set_position(
|
2216
|
+
value
|
2217
|
+
) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
|
2191
2218
|
|
2192
2219
|
for arg in args:
|
2193
2220
|
if isinstance(arg, matplotlib.axes._axes.Axes):
|
@@ -2225,7 +2252,8 @@ def figsets(*args, **kwargs):
|
|
2225
2252
|
if len(fig.get_axes()) > 1:
|
2226
2253
|
plt.tight_layout()
|
2227
2254
|
|
2228
|
-
|
2255
|
+
|
2256
|
+
def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
|
2229
2257
|
"""
|
2230
2258
|
split_legend(
|
2231
2259
|
ax,
|
@@ -2238,15 +2266,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2238
2266
|
# Retrieve all lines and labels from the axis
|
2239
2267
|
handles, labels = ax.get_legend_handles_labels()
|
2240
2268
|
num_labels = len(labels)
|
2241
|
-
|
2269
|
+
|
2242
2270
|
# Calculate the number of labels per legend part
|
2243
|
-
labels_per_part = (num_labels + n - 1) // n # Round up
|
2271
|
+
labels_per_part = (num_labels + n - 1) // n # Round up
|
2244
2272
|
# Create a list to hold each legend object
|
2245
2273
|
legends = []
|
2246
2274
|
|
2247
2275
|
# Default locations and titles if not specified
|
2248
2276
|
if loc is None:
|
2249
|
-
loc = [
|
2277
|
+
loc = ["best"] * n
|
2250
2278
|
if title is None:
|
2251
2279
|
title = [None] * n
|
2252
2280
|
if bbox is None:
|
@@ -2257,7 +2285,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2257
2285
|
# Calculate the range of labels for this part
|
2258
2286
|
start_idx = i * labels_per_part
|
2259
2287
|
end_idx = min(start_idx + labels_per_part, num_labels)
|
2260
|
-
|
2288
|
+
|
2261
2289
|
# Skip if no labels in this range
|
2262
2290
|
if start_idx >= end_idx:
|
2263
2291
|
break
|
@@ -2267,19 +2295,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2267
2295
|
part_labels = labels[start_idx:end_idx]
|
2268
2296
|
|
2269
2297
|
# Create the legend for this part
|
2270
|
-
legend = ax.legend(
|
2271
|
-
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2298
|
+
legend = ax.legend(
|
2299
|
+
handles=part_handles,
|
2300
|
+
labels=part_labels,
|
2301
|
+
loc=loc[i],
|
2302
|
+
title=title[i],
|
2303
|
+
ncol=ncol,
|
2304
|
+
bbox_to_anchor=bbox[i],
|
2305
|
+
**kwargs,
|
2306
|
+
)
|
2307
|
+
|
2278
2308
|
# Add the legend to the axis and save it to the list
|
2279
|
-
|
2309
|
+
(
|
2310
|
+
ax.add_artist(legend) if i != (n - 1) else None
|
2311
|
+
) # the lastone will be added automaticaly
|
2280
2312
|
legends.append(legend)
|
2281
2313
|
return legends
|
2282
2314
|
|
2315
|
+
|
2283
2316
|
def get_colors(
|
2284
2317
|
n: int = 1,
|
2285
2318
|
cmap: str = "auto",
|
@@ -2289,7 +2322,9 @@ def get_colors(
|
|
2289
2322
|
*args,
|
2290
2323
|
**kwargs,
|
2291
2324
|
):
|
2292
|
-
return get_color(n,cmap,alpha,output
|
2325
|
+
return get_color(n, cmap, alpha, output, *args, **kwargs)
|
2326
|
+
|
2327
|
+
|
2293
2328
|
def get_color(
|
2294
2329
|
n: int = 1,
|
2295
2330
|
cmap: str = "auto",
|
@@ -2300,6 +2335,7 @@ def get_color(
|
|
2300
2335
|
**kwargs,
|
2301
2336
|
):
|
2302
2337
|
from cycler import cycler
|
2338
|
+
|
2303
2339
|
def cmap2hex(cmap_name):
|
2304
2340
|
cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
|
2305
2341
|
colors = [cmap_(i) for i in range(cmap_.N)]
|
@@ -2347,49 +2383,137 @@ def get_color(
|
|
2347
2383
|
return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
|
2348
2384
|
else:
|
2349
2385
|
return "#{:02X}{:02X}{:02X}".format(r, g, b)
|
2350
|
-
|
2386
|
+
|
2351
2387
|
# sc.pl.palettes.default_20
|
2352
|
-
cmap_20= [
|
2353
|
-
|
2354
|
-
|
2388
|
+
cmap_20 = [
|
2389
|
+
"#1f77b4",
|
2390
|
+
"#ff7f0e",
|
2391
|
+
"#279e68",
|
2392
|
+
"#d62728",
|
2393
|
+
"#aa40fc",
|
2394
|
+
"#8c564b",
|
2395
|
+
"#e377c2",
|
2396
|
+
"#b5bd61",
|
2397
|
+
"#17becf",
|
2398
|
+
"#aec7e8",
|
2399
|
+
"#ffbb78",
|
2400
|
+
"#98df8a",
|
2401
|
+
"#ff9896",
|
2402
|
+
"#c5b0d5",
|
2403
|
+
"#c49c94",
|
2404
|
+
"#f7b6d2",
|
2405
|
+
"#dbdb8d",
|
2406
|
+
"#9edae5",
|
2407
|
+
"#ad494a",
|
2408
|
+
"#8c6d31",
|
2409
|
+
]
|
2355
2410
|
# sc.pl.palettes.zeileis_28
|
2356
|
-
cmap_28 = [
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2411
|
+
cmap_28 = [
|
2412
|
+
"#023fa5",
|
2413
|
+
"#7d87b9",
|
2414
|
+
"#bec1d4",
|
2415
|
+
"#d6bcc0",
|
2416
|
+
"#bb7784",
|
2417
|
+
"#8e063b",
|
2418
|
+
"#4a6fe3",
|
2419
|
+
"#8595e1",
|
2420
|
+
"#b5bbe3",
|
2421
|
+
"#e6afb9",
|
2422
|
+
"#e07b91",
|
2423
|
+
"#d33f6a",
|
2424
|
+
"#11c638",
|
2425
|
+
"#8dd593",
|
2426
|
+
"#c6dec7",
|
2427
|
+
"#ead3c6",
|
2428
|
+
"#f0b98d",
|
2429
|
+
"#ef9708",
|
2430
|
+
"#0fcfc0",
|
2431
|
+
"#9cded6",
|
2432
|
+
"#d5eae7",
|
2433
|
+
"#f3e1eb",
|
2434
|
+
"#f6c4e1",
|
2435
|
+
"#f79cd4",
|
2436
|
+
"#7f7f7f",
|
2437
|
+
"#c7c7c7",
|
2438
|
+
"#1CE6FF",
|
2439
|
+
"#336600",
|
2440
|
+
]
|
2360
2441
|
if cmap == "gray":
|
2361
2442
|
cmap = "grey"
|
2362
|
-
elif cmap=="20":
|
2363
|
-
cmap=cmap_20
|
2364
|
-
elif cmap=="28":
|
2365
|
-
cmap=cmap_28
|
2443
|
+
elif cmap == "20":
|
2444
|
+
cmap = cmap_20
|
2445
|
+
elif cmap == "28":
|
2446
|
+
cmap = cmap_28
|
2366
2447
|
# Determine color list based on cmap parameter
|
2367
|
-
if isinstance(cmap,str):
|
2448
|
+
if isinstance(cmap, str):
|
2368
2449
|
if "aut" in cmap:
|
2369
2450
|
if n == 1:
|
2370
2451
|
colorlist = ["#3A4453"]
|
2371
2452
|
elif n == 2:
|
2372
2453
|
colorlist = ["#3A4453", "#FF2C00"]
|
2373
2454
|
elif n == 3:
|
2374
|
-
colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
|
2455
|
+
colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
|
2375
2456
|
elif n == 4:
|
2376
|
-
colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
|
2457
|
+
colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
|
2377
2458
|
elif n == 5:
|
2378
|
-
colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
|
2459
|
+
colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
|
2379
2460
|
elif n == 6:
|
2380
|
-
colorlist = [
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2461
|
+
colorlist = [
|
2462
|
+
"#FF2C00",
|
2463
|
+
"#91bfdb",
|
2464
|
+
"#B25E9D",
|
2465
|
+
"#4B8C3B",
|
2466
|
+
"#EF8632",
|
2467
|
+
"#24578E",
|
2468
|
+
]
|
2469
|
+
elif n == 7:
|
2470
|
+
colorlist = [
|
2471
|
+
"#7F7F7F",
|
2472
|
+
"#459AA9",
|
2473
|
+
"#B25E9D",
|
2474
|
+
"#4B8C3B",
|
2475
|
+
"#EF8632",
|
2476
|
+
"#24578E" "#FF2C00",
|
2477
|
+
]
|
2478
|
+
elif n == 8:
|
2384
2479
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
2385
2480
|
# colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
|
2386
|
-
colorlist = [
|
2387
|
-
|
2388
|
-
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2392
|
-
|
2481
|
+
colorlist = [
|
2482
|
+
"#78BFDA",
|
2483
|
+
"#D52E6F",
|
2484
|
+
"#F7D648",
|
2485
|
+
"#A52D28",
|
2486
|
+
"#6B9F41",
|
2487
|
+
"#E18330",
|
2488
|
+
"#E18B9D",
|
2489
|
+
"#3C88CC",
|
2490
|
+
]
|
2491
|
+
elif n == 9:
|
2492
|
+
colorlist = [
|
2493
|
+
"#1f77b4",
|
2494
|
+
"#ff7f0e",
|
2495
|
+
"#367B7F",
|
2496
|
+
"#ff9896",
|
2497
|
+
"#d62728",
|
2498
|
+
"#aa40fc",
|
2499
|
+
"#e377c2",
|
2500
|
+
"#51B34F",
|
2501
|
+
"#17becf",
|
2502
|
+
]
|
2503
|
+
elif n == 10:
|
2504
|
+
colorlist = [
|
2505
|
+
"#1f77b4",
|
2506
|
+
"#ff7f0e",
|
2507
|
+
"#367B7F",
|
2508
|
+
"#ff9896",
|
2509
|
+
"#51B34F",
|
2510
|
+
"#d62728" "#aa40fc",
|
2511
|
+
"#e377c2",
|
2512
|
+
"#375FD2",
|
2513
|
+
"#17becf",
|
2514
|
+
]
|
2515
|
+
elif 10 < n <= 20:
|
2516
|
+
colorlist = cmap_20
|
2393
2517
|
else:
|
2394
2518
|
colorlist = cmap_28
|
2395
2519
|
by = "start"
|
@@ -2426,7 +2550,7 @@ def get_color(
|
|
2426
2550
|
by = "linspace"
|
2427
2551
|
colorlist = cmap2hex(cmap)
|
2428
2552
|
elif isinstance(cmap, list):
|
2429
|
-
colorlist=cmap
|
2553
|
+
colorlist = cmap
|
2430
2554
|
|
2431
2555
|
# Determine method for generating color list
|
2432
2556
|
if "st" in by.lower() or "be" in by.lower():
|
@@ -2446,7 +2570,7 @@ def get_color(
|
|
2446
2570
|
return hue_list
|
2447
2571
|
else:
|
2448
2572
|
raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
|
2449
|
-
|
2573
|
+
|
2450
2574
|
|
2451
2575
|
def stdshade(ax=None, *args, **kwargs):
|
2452
2576
|
"""
|
@@ -2454,7 +2578,7 @@ def stdshade(ax=None, *args, **kwargs):
|
|
2454
2578
|
plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
|
2455
2579
|
"""
|
2456
2580
|
from scipy.signal import savgol_filter
|
2457
|
-
|
2581
|
+
|
2458
2582
|
# Separate kws_line and kws_fill if necessary
|
2459
2583
|
kws_line = kwargs.pop("kws_line", {})
|
2460
2584
|
kws_fill = kwargs.pop("kws_fill", {})
|
@@ -2724,37 +2848,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
|
|
2724
2848
|
# cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
|
2725
2849
|
# return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
|
2726
2850
|
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2851
|
+
|
2852
|
+
def add_colorbar(
|
2853
|
+
im,
|
2854
|
+
cmap="viridis",
|
2855
|
+
vmin=-1,
|
2856
|
+
vmax=1,
|
2857
|
+
orientation="vertical",
|
2858
|
+
width_ratio=0.05,
|
2859
|
+
pad_ratio=0.02,
|
2860
|
+
shrink=1.0,
|
2861
|
+
**kwargs,
|
2862
|
+
):
|
2736
2863
|
import matplotlib as mpl
|
2864
|
+
|
2737
2865
|
if all([cmap, vmin, vmax]):
|
2738
2866
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
2739
2867
|
else:
|
2740
|
-
norm=False
|
2868
|
+
norm = False
|
2741
2869
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
2742
2870
|
sm.set_array([])
|
2743
2871
|
l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
|
2744
|
-
if orientation ==
|
2872
|
+
if orientation == "vertical":
|
2745
2873
|
width = width_ratio * w
|
2746
2874
|
pad = pad_ratio * w
|
2747
|
-
cax = im.figure.add_axes(
|
2875
|
+
cax = im.figure.add_axes(
|
2876
|
+
[l + w + pad, b, width, h * shrink]
|
2877
|
+
) # Right of the image
|
2748
2878
|
else:
|
2749
2879
|
height = width_ratio * h
|
2750
2880
|
pad = pad_ratio * h
|
2751
|
-
cax = im.figure.add_axes(
|
2752
|
-
|
2753
|
-
|
2881
|
+
cax = im.figure.add_axes(
|
2882
|
+
[l, b - height - pad, w * shrink, height]
|
2883
|
+
) # Below the image
|
2884
|
+
cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
|
2885
|
+
return cbar
|
2886
|
+
|
2754
2887
|
|
2755
2888
|
# Usage:
|
2756
2889
|
# add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
|
2757
2890
|
|
2891
|
+
|
2758
2892
|
def generate_xticks_with_gap(x_len, hue_len):
|
2759
2893
|
"""
|
2760
2894
|
Generate a concatenated array based on x_len and hue_len,
|
@@ -2884,6 +3018,7 @@ def style_examples(
|
|
2884
3018
|
f = listdir(
|
2885
3019
|
"/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
|
2886
3020
|
kind=".json",
|
3021
|
+
verbose=False,
|
2887
3022
|
)
|
2888
3023
|
display(f.sample(2))
|
2889
3024
|
# def style_example(dir_save,)
|
@@ -2961,6 +3096,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
|
|
2961
3096
|
|
2962
3097
|
def get_params_from_func_usage(function_signature):
|
2963
3098
|
import re
|
3099
|
+
|
2964
3100
|
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2965
3101
|
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2966
3102
|
# Find all matches
|
@@ -2987,7 +3123,7 @@ def plotxy(
|
|
2987
3123
|
x=None,
|
2988
3124
|
y=None,
|
2989
3125
|
ax=None,
|
2990
|
-
kind: Union[str, list] =
|
3126
|
+
kind: Union[str, list] = "scatter", # Specify the kind of plot
|
2991
3127
|
verbose=False,
|
2992
3128
|
**kwargs,
|
2993
3129
|
):
|
@@ -3011,14 +3147,15 @@ def plotxy(
|
|
3011
3147
|
# Check for valid plot kind
|
3012
3148
|
# Default arguments for various plot types
|
3013
3149
|
from pathlib import Path
|
3150
|
+
|
3014
3151
|
# Get the current script's directory as a Path object
|
3015
|
-
current_directory = Path(__file__).resolve().parent
|
3152
|
+
current_directory = Path(__file__).resolve().parent
|
3153
|
+
|
3154
|
+
if not "default_settings" in locals():
|
3155
|
+
default_settings = fload(current_directory / "data" / "usages_sns.json")
|
3156
|
+
if not "sns_info" in locals():
|
3157
|
+
sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
|
3016
3158
|
|
3017
|
-
if not 'default_settings' in locals():
|
3018
|
-
default_settings = fload(current_directory / 'data' / 'usages_sns.json')
|
3019
|
-
if not 'sns_info' in locals():
|
3020
|
-
sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
|
3021
|
-
|
3022
3159
|
valid_kinds = list(default_settings.keys())
|
3023
3160
|
# print(valid_kinds)
|
3024
3161
|
if kind is not None:
|
@@ -3032,7 +3169,7 @@ def plotxy(
|
|
3032
3169
|
if kind is not None:
|
3033
3170
|
for k in kind:
|
3034
3171
|
if k in valid_kinds:
|
3035
|
-
print(f"{k}:\n\t{default_settings[k]}")
|
3172
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
3036
3173
|
usage_str = """plotxy(data=ranked_genes,
|
3037
3174
|
x="log2(fold_change)",
|
3038
3175
|
y="-log10(p-value)",
|
@@ -3057,15 +3194,25 @@ def plotxy(
|
|
3057
3194
|
kws_add_text = v_arg
|
3058
3195
|
kwargs.pop(k_arg, None)
|
3059
3196
|
break
|
3060
|
-
zorder=0
|
3197
|
+
zorder = 0
|
3061
3198
|
for k in kind:
|
3062
|
-
zorder+=1
|
3199
|
+
zorder += 1
|
3063
3200
|
# indicate 'col' features
|
3064
3201
|
col = kwargs.get("col", None)
|
3065
|
-
sns_with_col = [
|
3202
|
+
sns_with_col = [
|
3203
|
+
"catplot",
|
3204
|
+
"histplot",
|
3205
|
+
"relplot",
|
3206
|
+
"lmplot",
|
3207
|
+
"pairplot",
|
3208
|
+
"displot",
|
3209
|
+
"kdeplot",
|
3210
|
+
]
|
3066
3211
|
if col is not None:
|
3067
3212
|
if not k in sns_with_col:
|
3068
|
-
print(
|
3213
|
+
print(
|
3214
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
3215
|
+
)
|
3069
3216
|
# (1) return FcetGrid
|
3070
3217
|
if k == "jointplot":
|
3071
3218
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
@@ -3092,77 +3239,98 @@ def plotxy(
|
|
3092
3239
|
ax = stdshade(ax=ax, **kwargs)
|
3093
3240
|
elif k == "scatterplot":
|
3094
3241
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3095
|
-
kws_scatter={
|
3242
|
+
kws_scatter = {
|
3243
|
+
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3244
|
+
}
|
3096
3245
|
hue = kwargs.pop("hue", None)
|
3097
3246
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3098
3247
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3099
|
-
palette = kws_scatter.pop(
|
3248
|
+
palette = kws_scatter.pop(
|
3249
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3250
|
+
)
|
3100
3251
|
s = kws_scatter.pop("s", 10)
|
3101
3252
|
alpha = kws_scatter.pop("alpha", 0.7)
|
3102
|
-
ax = sns.scatterplot(
|
3253
|
+
ax = sns.scatterplot(
|
3254
|
+
ax=ax,
|
3255
|
+
data=data,
|
3256
|
+
x=x,
|
3257
|
+
y=y,
|
3258
|
+
hue=hue,
|
3259
|
+
palette=palette,
|
3260
|
+
s=s,
|
3261
|
+
alpha=alpha,
|
3262
|
+
zorder=zorder,
|
3263
|
+
**kws_scatter,
|
3264
|
+
)
|
3103
3265
|
elif k == "histplot":
|
3104
3266
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3105
|
-
kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3106
|
-
ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
|
3267
|
+
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3268
|
+
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3107
3269
|
elif k == "kdeplot":
|
3108
3270
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3109
|
-
kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3110
|
-
ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
|
3271
|
+
kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3272
|
+
ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
|
3111
3273
|
elif k == "ecdfplot":
|
3112
3274
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3113
|
-
kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3114
|
-
ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
|
3275
|
+
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3276
|
+
ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
|
3115
3277
|
elif k == "rugplot":
|
3116
3278
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3117
|
-
kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3118
|
-
ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
|
3279
|
+
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3280
|
+
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3119
3281
|
elif k == "stripplot":
|
3120
3282
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3121
|
-
kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3283
|
+
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3122
3284
|
dodge = kws_strip.pop("dodge", True)
|
3123
|
-
ax = sns.stripplot(
|
3285
|
+
ax = sns.stripplot(
|
3286
|
+
data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
|
3287
|
+
)
|
3124
3288
|
elif k == "swarmplot":
|
3125
3289
|
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
3126
|
-
kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3127
|
-
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
|
3290
|
+
kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3291
|
+
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
|
3128
3292
|
elif k == "boxplot":
|
3129
3293
|
kws_box = kwargs.pop("kws_box", kwargs)
|
3130
|
-
kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3131
|
-
ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
|
3294
|
+
kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3295
|
+
ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
|
3132
3296
|
elif k == "violinplot":
|
3133
3297
|
kws_violin = kwargs.pop("kws_violin", kwargs)
|
3134
|
-
kws_violin={
|
3135
|
-
|
3298
|
+
kws_violin = {
|
3299
|
+
k: v for k, v in kws_violin.items() if not k.startswith("kws_")
|
3300
|
+
}
|
3301
|
+
ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
|
3136
3302
|
elif k == "boxenplot":
|
3137
3303
|
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
3138
|
-
kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3139
|
-
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
|
3304
|
+
kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3305
|
+
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
|
3140
3306
|
elif k == "pointplot":
|
3141
3307
|
kws_point = kwargs.pop("kws_point", kwargs)
|
3142
|
-
kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3143
|
-
ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
|
3308
|
+
kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3309
|
+
ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
|
3144
3310
|
elif k == "barplot":
|
3145
3311
|
kws_bar = kwargs.pop("kws_bar", kwargs)
|
3146
|
-
kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3147
|
-
ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
|
3312
|
+
kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3313
|
+
ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
|
3148
3314
|
elif k == "countplot":
|
3149
3315
|
kws_count = kwargs.pop("kws_count", kwargs)
|
3150
|
-
kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3151
|
-
if not kws_count.get("hue",None):
|
3152
|
-
kws_count.pop("palette",None)
|
3153
|
-
ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
|
3316
|
+
kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3317
|
+
if not kws_count.get("hue", None):
|
3318
|
+
kws_count.pop("palette", None)
|
3319
|
+
ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
|
3154
3320
|
elif k == "regplot":
|
3155
3321
|
kws_reg = kwargs.pop("kws_reg", kwargs)
|
3156
|
-
kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3157
|
-
ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
|
3322
|
+
kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3323
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
|
3158
3324
|
elif k == "residplot":
|
3159
3325
|
kws_resid = kwargs.pop("kws_resid", kwargs)
|
3160
|
-
kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3161
|
-
ax = sns.residplot(
|
3326
|
+
kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3327
|
+
ax = sns.residplot(
|
3328
|
+
data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
|
3329
|
+
)
|
3162
3330
|
elif k == "lineplot":
|
3163
3331
|
kws_line = kwargs.pop("kws_line", kwargs)
|
3164
|
-
kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3165
|
-
ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
|
3332
|
+
kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3333
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
|
3166
3334
|
|
3167
3335
|
figsets(ax=ax, **kws_figsets)
|
3168
3336
|
if kws_add_text:
|
@@ -3176,20 +3344,22 @@ def plotxy(
|
|
3176
3344
|
return g, ax
|
3177
3345
|
return ax
|
3178
3346
|
|
3347
|
+
|
3179
3348
|
def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
|
3180
3349
|
norm_ = plt.Normalize(min_max[0], min_max[1])
|
3181
3350
|
colormap = plt.get_cmap(cmap)
|
3182
3351
|
return colormap(norm_(data))
|
3183
3352
|
|
3353
|
+
|
3184
3354
|
def volcano(
|
3185
|
-
data:pd.DataFrame,
|
3186
|
-
x:str,
|
3187
|
-
y:str,
|
3188
|
-
gene_col:str=None,
|
3189
|
-
top_genes=[5, 5],
|
3190
|
-
thr_x=np.log2(1.5),
|
3355
|
+
data: pd.DataFrame,
|
3356
|
+
x: str,
|
3357
|
+
y: str,
|
3358
|
+
gene_col: str = None,
|
3359
|
+
top_genes=[5, 5], # [down-regulated, up-regulated]
|
3360
|
+
thr_x=np.log2(1.5), # default: 0.585
|
3191
3361
|
thr_y=-np.log10(0.05),
|
3192
|
-
sort_xy="x",
|
3362
|
+
sort_xy="x", #'y', 'xy'
|
3193
3363
|
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3194
3364
|
s=20,
|
3195
3365
|
fill=True, # plot filled scatter
|
@@ -3201,11 +3371,10 @@ def volcano(
|
|
3201
3371
|
ax=None,
|
3202
3372
|
verbose=False,
|
3203
3373
|
kws_text=dict(fontsize=10, color="k"),
|
3204
|
-
kws_bbox=dict(
|
3205
|
-
|
3206
|
-
|
3207
|
-
|
3208
|
-
kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
|
3374
|
+
kws_bbox=dict(
|
3375
|
+
facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
|
3376
|
+
), # '{}' to hide
|
3377
|
+
kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
|
3209
3378
|
**kwargs,
|
3210
3379
|
):
|
3211
3380
|
"""
|
@@ -3275,39 +3444,35 @@ def volcano(
|
|
3275
3444
|
kws_figsets = v_arg
|
3276
3445
|
kwargs.pop(k_arg, None)
|
3277
3446
|
break
|
3278
|
-
|
3279
|
-
data=data.copy()
|
3447
|
+
|
3448
|
+
data = data.copy()
|
3280
3449
|
# filter nan
|
3281
3450
|
data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
|
3282
|
-
data.loc[:,"color"] = np.where(
|
3451
|
+
data.loc[:, "color"] = np.where(
|
3283
3452
|
(data[x] > thr_x) & (data[y] > thr_y),
|
3284
3453
|
colors[2],
|
3285
|
-
np.where((data[x] < -thr_x) & (data[y] > thr_y),
|
3286
|
-
colors[0],
|
3287
|
-
colors[1]),
|
3454
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3288
3455
|
)
|
3289
|
-
top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
|
3290
|
-
|
3456
|
+
top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
|
3457
|
+
|
3291
3458
|
# could custom how to select the top genes, x: x has priority
|
3292
|
-
sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
|
3293
|
-
ascending_up=[True, True] if sort_xy=="x" else [False, True]
|
3294
|
-
ascending_down=[False, True] if sort_xy=="x" else [False, False]
|
3295
|
-
|
3296
|
-
down_reg_genes =
|
3297
|
-
(data["color"] == colors[0]) &
|
3298
|
-
|
3299
|
-
(
|
3300
|
-
|
3301
|
-
up_reg_genes =
|
3302
|
-
(data["color"] == colors[2]) &
|
3303
|
-
|
3304
|
-
(
|
3305
|
-
|
3459
|
+
sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
|
3460
|
+
ascending_up = [True, True] if sort_xy == "x" else [False, True]
|
3461
|
+
ascending_down = [False, True] if sort_xy == "x" else [False, False]
|
3462
|
+
|
3463
|
+
down_reg_genes = (
|
3464
|
+
data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3465
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_up)
|
3466
|
+
.head(top_genes[0])
|
3467
|
+
)
|
3468
|
+
up_reg_genes = (
|
3469
|
+
data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3470
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_down)
|
3471
|
+
.head(top_genes[1])
|
3472
|
+
)
|
3306
3473
|
sele_gene = pd.concat([down_reg_genes, up_reg_genes])
|
3307
|
-
|
3308
|
-
palette = {colors[0]: colors[0],
|
3309
|
-
colors[1]: colors[1],
|
3310
|
-
colors[2]: colors[2]}
|
3474
|
+
|
3475
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3311
3476
|
# Plot setup
|
3312
3477
|
if ax is None:
|
3313
3478
|
ax = plt.gca()
|
@@ -3337,9 +3502,9 @@ def volcano(
|
|
3337
3502
|
)
|
3338
3503
|
|
3339
3504
|
# Add threshold lines for x and y axes
|
3340
|
-
ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
|
3341
|
-
ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
|
3342
|
-
ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
|
3505
|
+
ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
|
3506
|
+
ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
|
3507
|
+
ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
|
3343
3508
|
|
3344
3509
|
# Add gene labels for selected significant points
|
3345
3510
|
if gene_col:
|
@@ -3349,19 +3514,28 @@ def volcano(
|
|
3349
3514
|
textcolor = kws_text.pop("color", "k")
|
3350
3515
|
fontsize = kws_text.pop("fontsize", 10)
|
3351
3516
|
arrowstyles = [
|
3352
|
-
"->",
|
3353
|
-
"
|
3354
|
-
"
|
3517
|
+
"->",
|
3518
|
+
"<-",
|
3519
|
+
"<->",
|
3520
|
+
"<|-",
|
3521
|
+
"-|>",
|
3522
|
+
"<|-|>",
|
3523
|
+
"-",
|
3524
|
+
"-[",
|
3525
|
+
"-[",
|
3526
|
+
"fancy",
|
3527
|
+
"simple",
|
3528
|
+
"wedge",
|
3355
3529
|
]
|
3356
3530
|
arrowstyle = kws_arrow.pop("style", "<|-")
|
3357
|
-
arrowstyle = strcmp(arrowstyle, arrowstyles,scorer=
|
3358
|
-
expand=kws_arrow.pop("expand",(1.05,1.1))
|
3531
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
|
3532
|
+
expand = kws_arrow.pop("expand", (1.05, 1.1))
|
3359
3533
|
arrowcolor = kws_arrow.pop("color", "0.4")
|
3360
3534
|
arrowlinewidth = kws_arrow.pop("lw", 0.75)
|
3361
3535
|
shrinkA = kws_arrow.pop("shrinkA", 0)
|
3362
3536
|
shrinkB = kws_arrow.pop("shrinkB", 0)
|
3363
3537
|
mutation_scale = kws_arrow.pop("head", 10)
|
3364
|
-
arrow_fill=kws_arrow.pop("fill", False)
|
3538
|
+
arrow_fill = kws_arrow.pop("fill", False)
|
3365
3539
|
for i in range(sele_gene.shape[0]):
|
3366
3540
|
if isinstance(textcolor, list): # be consistant with dots's color
|
3367
3541
|
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
@@ -3382,7 +3556,7 @@ def volcano(
|
|
3382
3556
|
adjust_text(
|
3383
3557
|
texts,
|
3384
3558
|
expand=expand,
|
3385
|
-
min_arrow_len=5,
|
3559
|
+
min_arrow_len=5,
|
3386
3560
|
ax=ax,
|
3387
3561
|
arrowprops=dict(
|
3388
3562
|
arrowstyle=arrowstyle,
|
@@ -3393,7 +3567,7 @@ def volcano(
|
|
3393
3567
|
shrinkB=shrinkB,
|
3394
3568
|
mutation_scale=mutation_scale,
|
3395
3569
|
**kws_arrow,
|
3396
|
-
)
|
3570
|
+
),
|
3397
3571
|
)
|
3398
3572
|
|
3399
3573
|
figsets(**kws_figsets)
|
@@ -3499,6 +3673,7 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3499
3673
|
"""Reduce the saturation of a color by a given factor (between 0 and 1)."""
|
3500
3674
|
import matplotlib.colors as mcolors
|
3501
3675
|
import colorsys
|
3676
|
+
|
3502
3677
|
# Convert the color to RGB
|
3503
3678
|
rgb = mcolors.to_rgb(color)
|
3504
3679
|
# Convert RGB to HLS (Hue, Lightness, Saturation)
|
@@ -3508,8 +3683,19 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3508
3683
|
# Convert back to RGB
|
3509
3684
|
return colorsys.hls_to_rgb(h, l, s)
|
3510
3685
|
|
3511
|
-
|
3512
|
-
|
3686
|
+
|
3687
|
+
def textsets(
|
3688
|
+
text,
|
3689
|
+
fontname="Arial",
|
3690
|
+
fontsize=11,
|
3691
|
+
fontweight="normal",
|
3692
|
+
fontstyle="normal",
|
3693
|
+
fontcolor="k",
|
3694
|
+
backgroundcolor=None,
|
3695
|
+
shadow=False,
|
3696
|
+
ha="center",
|
3697
|
+
va="center",
|
3698
|
+
):
|
3513
3699
|
if text: # Ensure text exists
|
3514
3700
|
if fontname:
|
3515
3701
|
text.set_fontname(plt_font(fontname))
|
@@ -3522,26 +3708,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
|
|
3522
3708
|
if fontcolor:
|
3523
3709
|
text.set_color(fontcolor)
|
3524
3710
|
if backgroundcolor:
|
3525
|
-
text.set_backgroundcolor(backgroundcolor)
|
3711
|
+
text.set_backgroundcolor(backgroundcolor)
|
3526
3712
|
text.set_horizontalalignment(ha)
|
3527
|
-
text.set_verticalalignment(va)
|
3713
|
+
text.set_verticalalignment(va)
|
3528
3714
|
if shadow:
|
3529
|
-
text.set_path_effects(
|
3530
|
-
matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
|
3531
|
-
|
3715
|
+
text.set_path_effects(
|
3716
|
+
[matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
|
3717
|
+
)
|
3718
|
+
|
3719
|
+
|
3532
3720
|
def venn(
|
3533
|
-
lists:list,
|
3534
|
-
labels:list=None,
|
3721
|
+
lists: list,
|
3722
|
+
labels: list = None,
|
3535
3723
|
ax=None,
|
3536
3724
|
colors=None,
|
3537
3725
|
edgecolor=None,
|
3538
3726
|
alpha=0.5,
|
3539
|
-
saturation
|
3540
|
-
linewidth=0,
|
3727
|
+
saturation=0.75,
|
3728
|
+
linewidth=0, # default no edge
|
3541
3729
|
linestyle="-",
|
3542
|
-
fontname=
|
3730
|
+
fontname="Arial",
|
3543
3731
|
fontsize=10,
|
3544
|
-
fontcolor=
|
3732
|
+
fontcolor="k",
|
3545
3733
|
fontweight="normal",
|
3546
3734
|
fontstyle="normal",
|
3547
3735
|
ha="center",
|
@@ -3550,18 +3738,18 @@ def venn(
|
|
3550
3738
|
subset_fontsize=10,
|
3551
3739
|
subset_fontweight="normal",
|
3552
3740
|
subset_fontstyle="normal",
|
3553
|
-
subset_fontcolor=
|
3741
|
+
subset_fontcolor="k",
|
3554
3742
|
backgroundcolor=None,
|
3555
3743
|
custom_texts=None,
|
3556
|
-
show_percentages=True,
|
3744
|
+
show_percentages=True, # display percentage
|
3557
3745
|
fmt="{:.1%}",
|
3558
3746
|
ellipse_shape=False, # 椭圆形
|
3559
|
-
ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
|
3560
|
-
**kwargs
|
3747
|
+
ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
|
3748
|
+
**kwargs,
|
3561
3749
|
):
|
3562
3750
|
"""
|
3563
3751
|
Advanced Venn diagram plotting function with extensive customization options.
|
3564
|
-
Usage:
|
3752
|
+
Usage:
|
3565
3753
|
# Define the two sets
|
3566
3754
|
set1 = [1, 2, 3, 4, 5]
|
3567
3755
|
set2 = [4, 5, 6, 7, 8]
|
@@ -3612,56 +3800,70 @@ def venn(
|
|
3612
3800
|
"""
|
3613
3801
|
if ax is None:
|
3614
3802
|
ax = plt.gca()
|
3615
|
-
lists=[set(flatten(i, verbose=False)) for i in lists]
|
3803
|
+
lists = [set(flatten(i, verbose=False)) for i in lists]
|
3616
3804
|
# Function to apply text styles to labels
|
3617
3805
|
if colors is None:
|
3618
|
-
colors=["r","b"] if len(lists)==2 else ["r","g","b"]
|
3806
|
+
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
3619
3807
|
if labels is None:
|
3620
|
-
labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
|
3808
|
+
labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
|
3621
3809
|
if edgecolor is None:
|
3622
|
-
edgecolor=colors
|
3810
|
+
edgecolor = colors
|
3623
3811
|
colors = [desaturate_color(color, saturation) for color in colors]
|
3624
3812
|
# Check colors and auto-calculate overlaps
|
3625
3813
|
if len(lists) == 2:
|
3814
|
+
|
3626
3815
|
def get_count_and_percentage(set_count, subset_count):
|
3627
3816
|
percent = subset_count / set_count if set_count > 0 else 0
|
3628
|
-
return
|
3817
|
+
return (
|
3818
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3819
|
+
if show_percentages
|
3820
|
+
else f"{subset_count}"
|
3821
|
+
)
|
3629
3822
|
|
3630
3823
|
from matplotlib_venn import venn2, venn2_circles
|
3824
|
+
|
3631
3825
|
# Auto-calculate overlap color for 2-set Venn diagram
|
3632
3826
|
overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
|
3633
|
-
|
3827
|
+
|
3634
3828
|
# Draw the venn diagram
|
3635
3829
|
v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3636
3830
|
venn_circles = venn2_circles(subsets=lists, ax=ax)
|
3637
|
-
set1,set2=lists[0],lists[1]
|
3638
|
-
v.get_patch_by_id(
|
3639
|
-
v.get_patch_by_id(
|
3640
|
-
v.get_patch_by_id(
|
3831
|
+
set1, set2 = lists[0], lists[1]
|
3832
|
+
v.get_patch_by_id("10").set_color(colors[0])
|
3833
|
+
v.get_patch_by_id("01").set_color(colors[1])
|
3834
|
+
v.get_patch_by_id("11").set_color(
|
3835
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
3836
|
+
)
|
3641
3837
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
3642
3838
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
3643
3839
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
3644
|
-
v.get_label_by_id(
|
3645
|
-
|
3646
|
-
|
3647
|
-
|
3840
|
+
v.get_label_by_id("10").set_text(
|
3841
|
+
get_count_and_percentage(len(set1), len(set1 - set2))
|
3842
|
+
)
|
3843
|
+
v.get_label_by_id("01").set_text(
|
3844
|
+
get_count_and_percentage(len(set2), len(set2 - set1))
|
3845
|
+
)
|
3846
|
+
v.get_label_by_id("11").set_text(
|
3847
|
+
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
3848
|
+
)
|
3648
3849
|
|
3649
|
-
if not isinstance(linewidth,list):
|
3650
|
-
linewidth=[linewidth]
|
3651
|
-
if isinstance(linestyle,str):
|
3652
|
-
linestyle=[linestyle]
|
3850
|
+
if not isinstance(linewidth, list):
|
3851
|
+
linewidth = [linewidth]
|
3852
|
+
if isinstance(linestyle, str):
|
3853
|
+
linestyle = [linestyle]
|
3653
3854
|
if not isinstance(edgecolor, list):
|
3654
|
-
edgecolor=[edgecolor]
|
3655
|
-
linewidth=linewidth*2 if len(linewidth)==1 else linewidth
|
3656
|
-
linestyle=linestyle*2 if len(linestyle)==1 else linestyle
|
3657
|
-
edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
|
3855
|
+
edgecolor = [edgecolor]
|
3856
|
+
linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
|
3857
|
+
linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
|
3858
|
+
edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
|
3658
3859
|
for i in range(2):
|
3659
3860
|
venn_circles[i].set_lw(linewidth[i])
|
3660
3861
|
venn_circles[i].set_ls(linestyle[i])
|
3661
3862
|
venn_circles[i].set_edgecolor(edgecolor[i])
|
3662
3863
|
# 椭圆
|
3663
3864
|
if ellipse_shape:
|
3664
|
-
import matplotlib.patches as patches
|
3865
|
+
import matplotlib.patches as patches
|
3866
|
+
|
3665
3867
|
for patch in v.patches:
|
3666
3868
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3667
3869
|
center1 = v.get_circle_center(0)
|
@@ -3672,9 +3874,13 @@ def venn(
|
|
3672
3874
|
height=ellipse_scale[1],
|
3673
3875
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3674
3876
|
facecolor=colors[0],
|
3675
|
-
lw=
|
3877
|
+
lw=(
|
3878
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3879
|
+
), # Ensure lw is a number
|
3676
3880
|
ls=linestyle[0],
|
3677
|
-
alpha=
|
3881
|
+
alpha=(
|
3882
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3883
|
+
), # Ensure alpha is a number
|
3678
3884
|
)
|
3679
3885
|
ellipse2 = patches.Ellipse(
|
3680
3886
|
(center2.x, center2.y),
|
@@ -3682,48 +3888,78 @@ def venn(
|
|
3682
3888
|
height=ellipse_scale[1],
|
3683
3889
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3684
3890
|
facecolor=colors[1],
|
3685
|
-
lw=
|
3891
|
+
lw=(
|
3892
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3893
|
+
), # Ensure lw is a number
|
3686
3894
|
ls=linestyle[0],
|
3687
|
-
alpha=
|
3895
|
+
alpha=(
|
3896
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3897
|
+
), # Ensure alpha is a number
|
3688
3898
|
)
|
3689
3899
|
ax.add_patch(ellipse1)
|
3690
3900
|
ax.add_patch(ellipse2)
|
3691
3901
|
# Apply styles to set labels
|
3692
3902
|
for i, text in enumerate(v.set_labels):
|
3693
|
-
textsets(
|
3694
|
-
|
3903
|
+
textsets(
|
3904
|
+
text,
|
3905
|
+
fontname=fontname,
|
3906
|
+
fontsize=fontsize,
|
3907
|
+
fontweight=fontweight,
|
3908
|
+
fontstyle=fontstyle,
|
3909
|
+
fontcolor=fontcolor,
|
3910
|
+
ha=ha,
|
3911
|
+
va=va,
|
3912
|
+
shadow=shadow,
|
3913
|
+
)
|
3695
3914
|
|
3696
3915
|
# Apply styles to subset labels
|
3697
3916
|
for i, text in enumerate(v.subset_labels):
|
3698
3917
|
if text: # Ensure text exists
|
3699
3918
|
if custom_texts: # Custom text handling
|
3700
|
-
text.set_text(custom_texts[i])
|
3701
|
-
textsets(
|
3702
|
-
|
3919
|
+
text.set_text(custom_texts[i])
|
3920
|
+
textsets(
|
3921
|
+
text,
|
3922
|
+
fontname=fontname,
|
3923
|
+
fontsize=subset_fontsize,
|
3924
|
+
fontweight=subset_fontweight,
|
3925
|
+
fontstyle=subset_fontstyle,
|
3926
|
+
fontcolor=subset_fontcolor,
|
3927
|
+
ha=ha,
|
3928
|
+
va=va,
|
3929
|
+
shadow=shadow,
|
3930
|
+
)
|
3703
3931
|
|
3704
3932
|
elif len(lists) == 3:
|
3933
|
+
|
3705
3934
|
def get_label(set_count, subset_count):
|
3706
3935
|
percent = subset_count / set_count if set_count > 0 else 0
|
3707
|
-
return
|
3708
|
-
|
3936
|
+
return (
|
3937
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3938
|
+
if show_percentages
|
3939
|
+
else f"{subset_count}"
|
3940
|
+
)
|
3941
|
+
|
3709
3942
|
from matplotlib_venn import venn3, venn3_circles
|
3943
|
+
|
3710
3944
|
# Auto-calculate overlap colors for 3-set Venn diagram
|
3711
3945
|
colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
|
3712
3946
|
colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
|
3713
3947
|
colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
|
3714
|
-
colorABC =
|
3715
|
-
|
3948
|
+
colorABC = (
|
3949
|
+
get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
|
3950
|
+
)
|
3951
|
+
set1, set2, set3 = lists[0], lists[1], lists[2]
|
3716
3952
|
|
3717
3953
|
# Draw the venn diagram
|
3718
|
-
v = venn3(subsets=lists, set_labels=labels, ax=ax
|
3719
|
-
v.get_patch_by_id(
|
3720
|
-
v.get_patch_by_id(
|
3721
|
-
v.get_patch_by_id(
|
3722
|
-
v.get_patch_by_id(
|
3723
|
-
v.get_patch_by_id(
|
3724
|
-
v.get_patch_by_id(
|
3725
|
-
v.get_patch_by_id(
|
3726
|
-
|
3954
|
+
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3955
|
+
v.get_patch_by_id("100").set_color(colors[0])
|
3956
|
+
v.get_patch_by_id("010").set_color(colors[1])
|
3957
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
3958
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
3959
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
3960
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
3961
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
3962
|
+
|
3727
3963
|
# Correctly labeling subset sizes
|
3728
3964
|
# v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
|
3729
3965
|
# v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
|
@@ -3732,63 +3968,93 @@ def venn(
|
|
3732
3968
|
# v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
|
3733
3969
|
# v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
|
3734
3970
|
# v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
|
3735
|
-
v.get_label_by_id(
|
3736
|
-
v.get_label_by_id(
|
3737
|
-
v.get_label_by_id(
|
3738
|
-
v.get_label_by_id(
|
3739
|
-
|
3740
|
-
|
3741
|
-
v.get_label_by_id(
|
3742
|
-
|
3971
|
+
v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
|
3972
|
+
v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
|
3973
|
+
v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
|
3974
|
+
v.get_label_by_id("110").set_text(
|
3975
|
+
get_label(len(set1 | set2), len(set1 & set2 - set3))
|
3976
|
+
)
|
3977
|
+
v.get_label_by_id("101").set_text(
|
3978
|
+
get_label(len(set1 | set3), len(set1 & set3 - set2))
|
3979
|
+
)
|
3980
|
+
v.get_label_by_id("011").set_text(
|
3981
|
+
get_label(len(set2 | set3), len(set2 & set3 - set1))
|
3982
|
+
)
|
3983
|
+
v.get_label_by_id("111").set_text(
|
3984
|
+
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
3985
|
+
)
|
3743
3986
|
|
3744
3987
|
# Apply styles to set labels
|
3745
3988
|
for i, text in enumerate(v.set_labels):
|
3746
|
-
textsets(
|
3747
|
-
|
3989
|
+
textsets(
|
3990
|
+
text,
|
3991
|
+
fontname=fontname,
|
3992
|
+
fontsize=fontsize,
|
3993
|
+
fontweight=fontweight,
|
3994
|
+
fontstyle=fontstyle,
|
3995
|
+
fontcolor=fontcolor,
|
3996
|
+
ha=ha,
|
3997
|
+
va=va,
|
3998
|
+
shadow=shadow,
|
3999
|
+
)
|
3748
4000
|
|
3749
4001
|
# Apply styles to subset labels
|
3750
4002
|
for i, text in enumerate(v.subset_labels):
|
3751
4003
|
if text: # Ensure text exists
|
3752
4004
|
if custom_texts: # Custom text handling
|
3753
4005
|
text.set_text(custom_texts[i])
|
3754
|
-
textsets(
|
3755
|
-
|
4006
|
+
textsets(
|
4007
|
+
text,
|
4008
|
+
fontname=fontname,
|
4009
|
+
fontsize=subset_fontsize,
|
4010
|
+
fontweight=subset_fontweight,
|
4011
|
+
fontstyle=subset_fontstyle,
|
4012
|
+
fontcolor=subset_fontcolor,
|
4013
|
+
ha=ha,
|
4014
|
+
va=va,
|
4015
|
+
shadow=shadow,
|
4016
|
+
)
|
3756
4017
|
|
3757
4018
|
venn_circles = venn3_circles(subsets=lists, ax=ax)
|
3758
|
-
if not isinstance(linewidth,list):
|
3759
|
-
linewidth=[linewidth]
|
3760
|
-
if isinstance(linestyle,str):
|
3761
|
-
linestyle=[linestyle]
|
4019
|
+
if not isinstance(linewidth, list):
|
4020
|
+
linewidth = [linewidth]
|
4021
|
+
if isinstance(linestyle, str):
|
4022
|
+
linestyle = [linestyle]
|
3762
4023
|
if not isinstance(edgecolor, list):
|
3763
|
-
edgecolor=[edgecolor]
|
3764
|
-
linewidth=linewidth*3 if len(linewidth)==1 else linewidth
|
3765
|
-
linestyle=linestyle*3 if len(linestyle)==1 else linestyle
|
3766
|
-
edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
|
4024
|
+
edgecolor = [edgecolor]
|
4025
|
+
linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
|
4026
|
+
linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
|
4027
|
+
edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
|
3767
4028
|
|
3768
4029
|
# edgecolor=[to_rgba(i) for i in edgecolor]
|
3769
4030
|
|
3770
|
-
for i in range(3):
|
4031
|
+
for i in range(3):
|
3771
4032
|
venn_circles[i].set_lw(linewidth[i])
|
3772
|
-
venn_circles[i].set_ls(linestyle[i])
|
3773
|
-
venn_circles[i].set_edgecolor(edgecolor[i])
|
4033
|
+
venn_circles[i].set_ls(linestyle[i])
|
4034
|
+
venn_circles[i].set_edgecolor(edgecolor[i])
|
3774
4035
|
|
3775
|
-
|
4036
|
+
# 椭圆形
|
3776
4037
|
if ellipse_shape:
|
3777
|
-
import matplotlib.patches as patches
|
4038
|
+
import matplotlib.patches as patches
|
4039
|
+
|
3778
4040
|
for patch in v.patches:
|
3779
4041
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3780
4042
|
center1 = v.get_circle_center(0)
|
3781
4043
|
center2 = v.get_circle_center(1)
|
3782
|
-
center3 = v.get_circle_center(2)
|
4044
|
+
center3 = v.get_circle_center(2)
|
3783
4045
|
ellipse1 = patches.Ellipse(
|
3784
4046
|
(center1.x, center1.y),
|
3785
4047
|
width=ellipse_scale[0],
|
3786
4048
|
height=ellipse_scale[1],
|
3787
4049
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3788
4050
|
facecolor=colors[0],
|
3789
|
-
lw=
|
4051
|
+
lw=(
|
4052
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4053
|
+
), # Ensure lw is a number
|
3790
4054
|
ls=linestyle[0],
|
3791
|
-
alpha=
|
4055
|
+
alpha=(
|
4056
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4057
|
+
), # Ensure alpha is a number
|
3792
4058
|
)
|
3793
4059
|
ellipse2 = patches.Ellipse(
|
3794
4060
|
(center2.x, center2.y),
|
@@ -3796,9 +4062,13 @@ def venn(
|
|
3796
4062
|
height=ellipse_scale[1],
|
3797
4063
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3798
4064
|
facecolor=colors[1],
|
3799
|
-
lw=
|
4065
|
+
lw=(
|
4066
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4067
|
+
), # Ensure lw is a number
|
3800
4068
|
ls=linestyle[0],
|
3801
|
-
alpha=
|
4069
|
+
alpha=(
|
4070
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4071
|
+
), # Ensure alpha is a number
|
3802
4072
|
)
|
3803
4073
|
ellipse3 = patches.Ellipse(
|
3804
4074
|
(center3.x, center3.y),
|
@@ -3806,9 +4076,13 @@ def venn(
|
|
3806
4076
|
height=ellipse_scale[1],
|
3807
4077
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3808
4078
|
facecolor=colors[1],
|
3809
|
-
lw=
|
4079
|
+
lw=(
|
4080
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4081
|
+
), # Ensure lw is a number
|
3810
4082
|
ls=linestyle[0],
|
3811
|
-
alpha=
|
4083
|
+
alpha=(
|
4084
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4085
|
+
), # Ensure alpha is a number
|
3812
4086
|
)
|
3813
4087
|
ax.add_patch(ellipse1)
|
3814
4088
|
ax.add_patch(ellipse2)
|
@@ -3820,17 +4094,20 @@ def venn(
|
|
3820
4094
|
for patch in v.patches:
|
3821
4095
|
if patch:
|
3822
4096
|
patch.set_alpha(alpha)
|
3823
|
-
if
|
4097
|
+
if "none" in edgecolor or 0 in linewidth:
|
3824
4098
|
patch.set_edgecolor("none")
|
3825
4099
|
return ax
|
3826
4100
|
|
4101
|
+
|
3827
4102
|
#! subplots, support automatic extend new axis
|
3828
|
-
def subplot(
|
3829
|
-
|
3830
|
-
|
3831
|
-
|
3832
|
-
|
3833
|
-
|
4103
|
+
def subplot(
|
4104
|
+
rows: int = 2,
|
4105
|
+
cols: int = 2,
|
4106
|
+
figsize: Union[tuple, list] = [8, 8],
|
4107
|
+
sharex=False,
|
4108
|
+
sharey=False,
|
4109
|
+
**kwargs,
|
4110
|
+
):
|
3834
4111
|
"""
|
3835
4112
|
nexttile = subplot(
|
3836
4113
|
8,
|
@@ -3850,11 +4127,14 @@ def subplot(rows:int=2,
|
|
3850
4127
|
ax.set_xlabel(f"Tile {i + 1}")
|
3851
4128
|
"""
|
3852
4129
|
from matplotlib.gridspec import GridSpec
|
4130
|
+
|
3853
4131
|
if run_once_within():
|
3854
|
-
print(
|
4132
|
+
print(
|
4133
|
+
f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
|
4134
|
+
)
|
3855
4135
|
fig = plt.figure(figsize=figsize)
|
3856
4136
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
3857
|
-
occupied = set()
|
4137
|
+
occupied = set()
|
3858
4138
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
3859
4139
|
col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
|
3860
4140
|
|
@@ -3862,8 +4142,9 @@ def subplot(rows:int=2,
|
|
3862
4142
|
nonlocal rows, grid_spec
|
3863
4143
|
rows += 1 # Expands by adding a row
|
3864
4144
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4145
|
+
|
3865
4146
|
def nexttile(rowspan=1, colspan=1, **kwargs):
|
3866
|
-
nonlocal rows, cols, occupied, grid_spec
|
4147
|
+
nonlocal rows, cols, occupied, grid_spec
|
3867
4148
|
for row in range(rows):
|
3868
4149
|
for col in range(cols):
|
3869
4150
|
if all(
|
@@ -3873,29 +4154,29 @@ def subplot(rows:int=2,
|
|
3873
4154
|
):
|
3874
4155
|
break
|
3875
4156
|
else:
|
3876
|
-
continue
|
4157
|
+
continue
|
3877
4158
|
break
|
3878
|
-
else:
|
4159
|
+
else:
|
3879
4160
|
expand_ax()
|
3880
4161
|
return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
|
3881
4162
|
|
3882
|
-
sharex_ax,sharey_ax = None, None
|
4163
|
+
sharex_ax, sharey_ax = None, None
|
3883
4164
|
|
3884
|
-
if sharex:
|
4165
|
+
if sharex:
|
3885
4166
|
sharex_ax = col_first_axes[col]
|
3886
4167
|
|
3887
|
-
if sharey:
|
3888
|
-
sharey_ax = row_first_axes[row]
|
4168
|
+
if sharey:
|
4169
|
+
sharey_ax = row_first_axes[row]
|
3889
4170
|
ax = fig.add_subplot(
|
3890
4171
|
grid_spec[row : row + rowspan, col : col + colspan],
|
3891
4172
|
sharex=sharex_ax,
|
3892
4173
|
sharey=sharey_ax,
|
3893
|
-
**kwargs
|
3894
|
-
)
|
4174
|
+
**kwargs,
|
4175
|
+
)
|
3895
4176
|
if row_first_axes[row] is None:
|
3896
4177
|
row_first_axes[row] = ax
|
3897
4178
|
if col_first_axes[col] is None:
|
3898
|
-
col_first_axes[col] = ax
|
4179
|
+
col_first_axes[col] = ax
|
3899
4180
|
for r in range(row, row + rowspan):
|
3900
4181
|
for c in range(col, col + colspan):
|
3901
4182
|
occupied.add((r, c))
|
@@ -3907,17 +4188,17 @@ def subplot(rows:int=2,
|
|
3907
4188
|
|
3908
4189
|
#! radar chart
|
3909
4190
|
def radar(
|
3910
|
-
data: pd.DataFrame,
|
3911
|
-
ylim=(0,100),
|
4191
|
+
data: pd.DataFrame,
|
4192
|
+
ylim=(0, 100),
|
3912
4193
|
color=get_color(5),
|
3913
4194
|
fontsize=10,
|
3914
|
-
fontcolor=
|
4195
|
+
fontcolor="k",
|
3915
4196
|
size=6,
|
3916
4197
|
linewidth=1,
|
3917
4198
|
linestyle="-",
|
3918
4199
|
alpha=0.5,
|
3919
4200
|
marker="o",
|
3920
|
-
edgecolor=
|
4201
|
+
edgecolor="none",
|
3921
4202
|
edge_linewidth=0,
|
3922
4203
|
bg_color="0.8",
|
3923
4204
|
bg_alpha=None,
|
@@ -3928,18 +4209,19 @@ def radar(
|
|
3928
4209
|
legend_fontsize=10,
|
3929
4210
|
grid_color="gray",
|
3930
4211
|
grid_alpha=0.5,
|
3931
|
-
grid_linestyle="--",
|
4212
|
+
grid_linestyle="--",
|
4213
|
+
grid_linewidth=0.5,
|
3932
4214
|
circular: bool = False,
|
3933
4215
|
tick_fontsize=None,
|
3934
4216
|
tick_fontcolor="0.65",
|
3935
|
-
tick_loc
|
3936
|
-
turning
|
4217
|
+
tick_loc=None, # label position
|
4218
|
+
turning=None,
|
3937
4219
|
ax=None,
|
3938
4220
|
sp=2,
|
3939
|
-
**kwargs
|
4221
|
+
**kwargs,
|
3940
4222
|
):
|
3941
4223
|
"""
|
3942
|
-
Example DATA:
|
4224
|
+
Example DATA:
|
3943
4225
|
df = pd.DataFrame(
|
3944
4226
|
data=[
|
3945
4227
|
[80, 80, 80, 80, 80, 80, 80],
|
@@ -3948,7 +4230,7 @@ def radar(
|
|
3948
4230
|
],
|
3949
4231
|
index=["Hero", "Warrior", "Wizard"],
|
3950
4232
|
columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
|
3951
|
-
|
4233
|
+
|
3952
4234
|
Parameters:
|
3953
4235
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
3954
4236
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4002,8 +4284,12 @@ def radar(
|
|
4002
4284
|
|
4003
4285
|
# bg_color
|
4004
4286
|
if bg_alpha is None:
|
4005
|
-
bg_alpha=alpha
|
4006
|
-
|
4287
|
+
bg_alpha = alpha
|
4288
|
+
(
|
4289
|
+
ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
|
4290
|
+
if circular
|
4291
|
+
else ax.set_facecolor("none")
|
4292
|
+
)
|
4007
4293
|
# Set up the radar chart with straight-line connections
|
4008
4294
|
ax.set_theta_offset(np.pi / 2)
|
4009
4295
|
ax.set_theta_direction(-1)
|
@@ -4012,53 +4298,68 @@ def radar(
|
|
4012
4298
|
ax.set_xticks(angles[:-1])
|
4013
4299
|
ax.set_xticklabels(categories)
|
4014
4300
|
|
4015
|
-
# Set y-axis limits and grid intervals
|
4301
|
+
# Set y-axis limits and grid intervals
|
4016
4302
|
vmin, vmax = ylim
|
4017
|
-
if circular:
|
4018
|
-
|
4019
|
-
ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
|
4020
|
-
ax.grid(
|
4021
|
-
|
4022
|
-
|
4023
|
-
|
4024
|
-
|
4025
|
-
|
4026
|
-
|
4027
|
-
|
4028
|
-
|
4303
|
+
if circular:
|
4304
|
+
# * cicular style
|
4305
|
+
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
4306
|
+
ax.grid(
|
4307
|
+
axis="both",
|
4308
|
+
color=grid_color,
|
4309
|
+
linestyle=grid_linestyle,
|
4310
|
+
alpha=grid_alpha,
|
4311
|
+
linewidth=grid_linewidth,
|
4312
|
+
dash_capstyle="round",
|
4313
|
+
dash_joinstyle="round",
|
4314
|
+
)
|
4315
|
+
ax.spines["polar"].set_color(grid_color)
|
4029
4316
|
ax.spines["polar"].set_linewidth(grid_linewidth)
|
4030
|
-
ax.spines["polar"].set_linestyle(
|
4317
|
+
ax.spines["polar"].set_linestyle("-")
|
4031
4318
|
ax.spines["polar"].set_alpha(grid_alpha)
|
4032
|
-
ax.spines["polar"].set_capstyle(
|
4033
|
-
ax.spines["polar"].set_joinstyle(
|
4319
|
+
ax.spines["polar"].set_capstyle("round")
|
4320
|
+
ax.spines["polar"].set_joinstyle("round")
|
4034
4321
|
|
4035
4322
|
else:
|
4036
|
-
|
4323
|
+
# * spider style: spider-style grid (straight lines, not circles)
|
4037
4324
|
# Create the spider-style grid (straight lines, not circles)
|
4038
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4325
|
+
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4039
4326
|
ax.plot(
|
4040
4327
|
angles + [angles[0]], # Closing the loop
|
4041
|
-
[i * vmax * grid_interval_ratio] * (num_vars+1)
|
4042
|
-
|
4328
|
+
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
4329
|
+
+ [i * vmax * grid_interval_ratio],
|
4330
|
+
color=grid_color,
|
4331
|
+
linestyle=grid_linestyle,
|
4332
|
+
alpha=grid_alpha,
|
4333
|
+
linewidth=grid_linewidth,
|
4043
4334
|
)
|
4044
|
-
# set bg_color
|
4045
|
-
ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
|
4335
|
+
# set bg_color
|
4336
|
+
ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
|
4046
4337
|
ax.yaxis.grid(False)
|
4047
4338
|
# Move radial labels away from plotted line
|
4048
4339
|
if tick_loc is None:
|
4049
|
-
tick_loc
|
4340
|
+
tick_loc = (
|
4341
|
+
np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
|
4342
|
+
)
|
4050
4343
|
|
4051
4344
|
ax.set_rlabel_position(tick_loc)
|
4052
4345
|
ax.set_theta_offset(turning) if turning is not None else None
|
4053
|
-
ax.tick_params(
|
4054
|
-
|
4055
|
-
|
4346
|
+
ax.tick_params(
|
4347
|
+
axis="x", labelsize=fontsize, colors=fontcolor
|
4348
|
+
) # Optional: for angular labels
|
4349
|
+
tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
|
4350
|
+
ax.tick_params(
|
4351
|
+
axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
|
4352
|
+
) # For radial labels
|
4056
4353
|
if not circular:
|
4057
|
-
ax.spines[
|
4058
|
-
ax.tick_params(axis=
|
4059
|
-
ax.tick_params(axis=
|
4354
|
+
ax.spines["polar"].set_visible(False)
|
4355
|
+
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4356
|
+
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4060
4357
|
# colors
|
4061
|
-
colors =
|
4358
|
+
colors = (
|
4359
|
+
get_color(data.shape[0])
|
4360
|
+
if cmap is None
|
4361
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
4362
|
+
)
|
4062
4363
|
# Plot each row with straight lines
|
4063
4364
|
for i, (index, row) in enumerate(data.iterrows()):
|
4064
4365
|
values = row.tolist()
|
@@ -4070,7 +4371,7 @@ def radar(
|
|
4070
4371
|
linewidth=linewidth,
|
4071
4372
|
linestyle=linestyle,
|
4072
4373
|
label=index,
|
4073
|
-
clip_on=False
|
4374
|
+
clip_on=False,
|
4074
4375
|
)
|
4075
4376
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
4076
4377
|
|
@@ -4084,13 +4385,23 @@ def radar(
|
|
4084
4385
|
marker=marker,
|
4085
4386
|
markersize=size,
|
4086
4387
|
markeredgecolor=edgecolor,
|
4087
|
-
markeredgewidth
|
4388
|
+
markeredgewidth=edge_linewidth,
|
4389
|
+
zorder=10,
|
4390
|
+
clip_on=False,
|
4088
4391
|
)
|
4089
4392
|
# ax.tick_params(axis='y', labelleft=False, left=False)
|
4090
|
-
if
|
4393
|
+
if "legend" in kws_figsets:
|
4091
4394
|
figsets(ax=ax, **kws_figsets)
|
4092
4395
|
else:
|
4093
|
-
|
4094
|
-
figsets(
|
4095
|
-
|
4096
|
-
|
4396
|
+
|
4397
|
+
figsets(
|
4398
|
+
ax=ax,
|
4399
|
+
legend=dict(
|
4400
|
+
loc=legend_loc,
|
4401
|
+
fontsize=legend_fontsize,
|
4402
|
+
bbox_to_anchor=[1.1, 1.4],
|
4403
|
+
ncols=2,
|
4404
|
+
),
|
4405
|
+
**kws_figsets,
|
4406
|
+
)
|
4407
|
+
return ax
|