py2ls 0.2.4.10.4__py3-none-any.whl → 0.2.4.10.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py
CHANGED
@@ -1,21 +1,37 @@
|
|
1
1
|
import numpy as np
|
2
|
-
import pandas as pd
|
2
|
+
import pandas as pd
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import matplotlib
|
5
5
|
import seaborn as sns
|
6
6
|
import warnings
|
7
7
|
import logging
|
8
8
|
from typing import Union
|
9
|
-
from .ips import
|
9
|
+
from .ips import (
|
10
|
+
isa,
|
11
|
+
fsave,
|
12
|
+
fload,
|
13
|
+
mkdir,
|
14
|
+
listdir,
|
15
|
+
figsave,
|
16
|
+
strcmp,
|
17
|
+
unique,
|
18
|
+
get_os,
|
19
|
+
ssplit,
|
20
|
+
flatten,
|
21
|
+
plt_font,
|
22
|
+
run_once_within,
|
23
|
+
)
|
10
24
|
from .stats import *
|
11
25
|
import os
|
26
|
+
|
12
27
|
# Suppress INFO messages from fontTools
|
13
28
|
logging.getLogger("fontTools").setLevel(logging.ERROR)
|
14
|
-
logging.getLogger(
|
29
|
+
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
15
30
|
|
16
31
|
warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
|
17
32
|
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
|
18
33
|
|
34
|
+
|
19
35
|
def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
20
36
|
"""Adds text annotations for various types of Seaborn and Matplotlib plots.
|
21
37
|
Args:
|
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
|
|
449
465
|
"""
|
450
466
|
from matplotlib.colors import to_rgba
|
451
467
|
import os
|
452
|
-
|
468
|
+
|
453
469
|
def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
|
454
470
|
if "l" in opt_b["loc"]:
|
455
471
|
xloc_s = xloc - opt_b["x_dist"]
|
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
|
|
755
771
|
)
|
756
772
|
else:
|
757
773
|
from scipy.stats import gaussian_kde
|
774
|
+
|
758
775
|
kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
|
759
776
|
min_val, max_val = ys.min(), ys.max()
|
760
777
|
y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
|
@@ -1812,16 +1829,17 @@ def read_mplstyle(style_file):
|
|
1812
1829
|
return style_dict
|
1813
1830
|
|
1814
1831
|
|
1815
|
-
def figsets(*args, **kwargs):
|
1832
|
+
def figsets(*args, **kwargs):
|
1816
1833
|
import matplotlib
|
1817
1834
|
from cycler import cycler
|
1835
|
+
|
1818
1836
|
matplotlib.rc("text", usetex=False)
|
1819
1837
|
|
1820
1838
|
fig = plt.gcf()
|
1821
|
-
fontsize = kwargs.get("fontsize",11)
|
1822
|
-
plt.rcParams["font.size"]=fontsize
|
1823
|
-
fontname = kwargs.pop("fontname","Arial")
|
1824
|
-
fontname=plt_font(fontname)
|
1839
|
+
fontsize = kwargs.get("fontsize", 11)
|
1840
|
+
plt.rcParams["font.size"] = fontsize
|
1841
|
+
fontname = kwargs.pop("fontname", "Arial")
|
1842
|
+
fontname = plt_font(fontname) # 显示中文
|
1825
1843
|
|
1826
1844
|
sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
|
1827
1845
|
sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
|
@@ -1851,18 +1869,21 @@ def figsets(*args, **kwargs):
|
|
1851
1869
|
nonlocal fontsize, fontname
|
1852
1870
|
if ("fo" in key) and (("size" in key) or ("sz" in key)):
|
1853
1871
|
fontsize = value
|
1854
|
-
plt.rcParams.update(
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1872
|
+
plt.rcParams.update(
|
1873
|
+
{
|
1874
|
+
"font.size": fontsize,
|
1875
|
+
"figure.titlesize": fontsize,
|
1876
|
+
"axes.titlesize": fontsize,
|
1877
|
+
"axes.labelsize": fontsize,
|
1878
|
+
"xtick.labelsize": fontsize,
|
1879
|
+
"ytick.labelsize": fontsize,
|
1880
|
+
"legend.fontsize": fontsize,
|
1881
|
+
"legend.title_fontsize": fontsize,
|
1882
|
+
}
|
1883
|
+
)
|
1863
1884
|
|
1864
1885
|
# Customize tick labels
|
1865
|
-
ax.tick_params(axis=
|
1886
|
+
ax.tick_params(axis="both", which="major", labelsize=fontsize)
|
1866
1887
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
1867
1888
|
label.set_fontname(fontname)
|
1868
1889
|
|
@@ -1910,15 +1931,15 @@ def figsets(*args, **kwargs):
|
|
1910
1931
|
if ("x" in key.lower()) and (
|
1911
1932
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1912
1933
|
):
|
1913
|
-
ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
|
1934
|
+
ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
|
1914
1935
|
if ("y" in key.lower()) and (
|
1915
1936
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1916
1937
|
):
|
1917
|
-
ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
|
1938
|
+
ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
|
1918
1939
|
if ("z" in key.lower()) and (
|
1919
1940
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1920
1941
|
):
|
1921
|
-
ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
|
1942
|
+
ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
|
1922
1943
|
if key == "xlabel" and isinstance(value, dict):
|
1923
1944
|
ax.set_xlabel(**value)
|
1924
1945
|
if key == "ylabel" and isinstance(value, dict):
|
@@ -2110,6 +2131,7 @@ def figsets(*args, **kwargs):
|
|
2110
2131
|
|
2111
2132
|
if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
|
2112
2133
|
import matplotlib.ticker as tck
|
2134
|
+
|
2113
2135
|
if "x" in value.lower() or "x" in key.lower():
|
2114
2136
|
ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
|
2115
2137
|
if "y" in value.lower() or "y" in key.lower():
|
@@ -2162,9 +2184,9 @@ def figsets(*args, **kwargs):
|
|
2162
2184
|
ax.grid(visible=False)
|
2163
2185
|
if "tit" in key.lower():
|
2164
2186
|
if "sup" in key.lower():
|
2165
|
-
plt.suptitle(value,fontname=fontname,fontsize=fontsize)
|
2187
|
+
plt.suptitle(value, fontname=fontname, fontsize=fontsize)
|
2166
2188
|
else:
|
2167
|
-
ax.set_title(value,fontname=fontname,fontsize=fontsize)
|
2189
|
+
ax.set_title(value, fontname=fontname, fontsize=fontsize)
|
2168
2190
|
if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
|
2169
2191
|
if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
|
2170
2192
|
if value:
|
@@ -2185,9 +2207,14 @@ def figsets(*args, **kwargs):
|
|
2185
2207
|
legend = ax.get_legend()
|
2186
2208
|
if legend is not None:
|
2187
2209
|
legend.remove()
|
2188
|
-
if
|
2210
|
+
if (
|
2211
|
+
any(["colorbar" in key.lower(), "cbar" in key.lower()])
|
2212
|
+
and "loc" in key.lower()
|
2213
|
+
):
|
2189
2214
|
cbar = ax.collections[0].colorbar # Access the colorbar from the plot
|
2190
|
-
cbar.ax.set_position(
|
2215
|
+
cbar.ax.set_position(
|
2216
|
+
value
|
2217
|
+
) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
|
2191
2218
|
|
2192
2219
|
for arg in args:
|
2193
2220
|
if isinstance(arg, matplotlib.axes._axes.Axes):
|
@@ -2225,7 +2252,8 @@ def figsets(*args, **kwargs):
|
|
2225
2252
|
if len(fig.get_axes()) > 1:
|
2226
2253
|
plt.tight_layout()
|
2227
2254
|
|
2228
|
-
|
2255
|
+
|
2256
|
+
def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
|
2229
2257
|
"""
|
2230
2258
|
split_legend(
|
2231
2259
|
ax,
|
@@ -2238,15 +2266,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2238
2266
|
# Retrieve all lines and labels from the axis
|
2239
2267
|
handles, labels = ax.get_legend_handles_labels()
|
2240
2268
|
num_labels = len(labels)
|
2241
|
-
|
2269
|
+
|
2242
2270
|
# Calculate the number of labels per legend part
|
2243
|
-
labels_per_part = (num_labels + n - 1) // n # Round up
|
2271
|
+
labels_per_part = (num_labels + n - 1) // n # Round up
|
2244
2272
|
# Create a list to hold each legend object
|
2245
2273
|
legends = []
|
2246
2274
|
|
2247
2275
|
# Default locations and titles if not specified
|
2248
2276
|
if loc is None:
|
2249
|
-
loc = [
|
2277
|
+
loc = ["best"] * n
|
2250
2278
|
if title is None:
|
2251
2279
|
title = [None] * n
|
2252
2280
|
if bbox is None:
|
@@ -2257,7 +2285,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2257
2285
|
# Calculate the range of labels for this part
|
2258
2286
|
start_idx = i * labels_per_part
|
2259
2287
|
end_idx = min(start_idx + labels_per_part, num_labels)
|
2260
|
-
|
2288
|
+
|
2261
2289
|
# Skip if no labels in this range
|
2262
2290
|
if start_idx >= end_idx:
|
2263
2291
|
break
|
@@ -2267,19 +2295,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2267
2295
|
part_labels = labels[start_idx:end_idx]
|
2268
2296
|
|
2269
2297
|
# Create the legend for this part
|
2270
|
-
legend = ax.legend(
|
2271
|
-
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2298
|
+
legend = ax.legend(
|
2299
|
+
handles=part_handles,
|
2300
|
+
labels=part_labels,
|
2301
|
+
loc=loc[i],
|
2302
|
+
title=title[i],
|
2303
|
+
ncol=ncol,
|
2304
|
+
bbox_to_anchor=bbox[i],
|
2305
|
+
**kwargs,
|
2306
|
+
)
|
2307
|
+
|
2278
2308
|
# Add the legend to the axis and save it to the list
|
2279
|
-
|
2309
|
+
(
|
2310
|
+
ax.add_artist(legend) if i != (n - 1) else None
|
2311
|
+
) # the lastone will be added automaticaly
|
2280
2312
|
legends.append(legend)
|
2281
2313
|
return legends
|
2282
2314
|
|
2315
|
+
|
2283
2316
|
def get_colors(
|
2284
2317
|
n: int = 1,
|
2285
2318
|
cmap: str = "auto",
|
@@ -2289,7 +2322,9 @@ def get_colors(
|
|
2289
2322
|
*args,
|
2290
2323
|
**kwargs,
|
2291
2324
|
):
|
2292
|
-
return get_color(n,cmap,alpha,output
|
2325
|
+
return get_color(n, cmap, alpha, output, *args, **kwargs)
|
2326
|
+
|
2327
|
+
|
2293
2328
|
def get_color(
|
2294
2329
|
n: int = 1,
|
2295
2330
|
cmap: str = "auto",
|
@@ -2300,6 +2335,7 @@ def get_color(
|
|
2300
2335
|
**kwargs,
|
2301
2336
|
):
|
2302
2337
|
from cycler import cycler
|
2338
|
+
|
2303
2339
|
def cmap2hex(cmap_name):
|
2304
2340
|
cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
|
2305
2341
|
colors = [cmap_(i) for i in range(cmap_.N)]
|
@@ -2347,49 +2383,137 @@ def get_color(
|
|
2347
2383
|
return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
|
2348
2384
|
else:
|
2349
2385
|
return "#{:02X}{:02X}{:02X}".format(r, g, b)
|
2350
|
-
|
2386
|
+
|
2351
2387
|
# sc.pl.palettes.default_20
|
2352
|
-
cmap_20= [
|
2353
|
-
|
2354
|
-
|
2388
|
+
cmap_20 = [
|
2389
|
+
"#1f77b4",
|
2390
|
+
"#ff7f0e",
|
2391
|
+
"#279e68",
|
2392
|
+
"#d62728",
|
2393
|
+
"#aa40fc",
|
2394
|
+
"#8c564b",
|
2395
|
+
"#e377c2",
|
2396
|
+
"#b5bd61",
|
2397
|
+
"#17becf",
|
2398
|
+
"#aec7e8",
|
2399
|
+
"#ffbb78",
|
2400
|
+
"#98df8a",
|
2401
|
+
"#ff9896",
|
2402
|
+
"#c5b0d5",
|
2403
|
+
"#c49c94",
|
2404
|
+
"#f7b6d2",
|
2405
|
+
"#dbdb8d",
|
2406
|
+
"#9edae5",
|
2407
|
+
"#ad494a",
|
2408
|
+
"#8c6d31",
|
2409
|
+
]
|
2355
2410
|
# sc.pl.palettes.zeileis_28
|
2356
|
-
cmap_28 = [
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2411
|
+
cmap_28 = [
|
2412
|
+
"#023fa5",
|
2413
|
+
"#7d87b9",
|
2414
|
+
"#bec1d4",
|
2415
|
+
"#d6bcc0",
|
2416
|
+
"#bb7784",
|
2417
|
+
"#8e063b",
|
2418
|
+
"#4a6fe3",
|
2419
|
+
"#8595e1",
|
2420
|
+
"#b5bbe3",
|
2421
|
+
"#e6afb9",
|
2422
|
+
"#e07b91",
|
2423
|
+
"#d33f6a",
|
2424
|
+
"#11c638",
|
2425
|
+
"#8dd593",
|
2426
|
+
"#c6dec7",
|
2427
|
+
"#ead3c6",
|
2428
|
+
"#f0b98d",
|
2429
|
+
"#ef9708",
|
2430
|
+
"#0fcfc0",
|
2431
|
+
"#9cded6",
|
2432
|
+
"#d5eae7",
|
2433
|
+
"#f3e1eb",
|
2434
|
+
"#f6c4e1",
|
2435
|
+
"#f79cd4",
|
2436
|
+
"#7f7f7f",
|
2437
|
+
"#c7c7c7",
|
2438
|
+
"#1CE6FF",
|
2439
|
+
"#336600",
|
2440
|
+
]
|
2360
2441
|
if cmap == "gray":
|
2361
2442
|
cmap = "grey"
|
2362
|
-
elif cmap=="20":
|
2363
|
-
cmap=cmap_20
|
2364
|
-
elif cmap=="28":
|
2365
|
-
cmap=cmap_28
|
2443
|
+
elif cmap == "20":
|
2444
|
+
cmap = cmap_20
|
2445
|
+
elif cmap == "28":
|
2446
|
+
cmap = cmap_28
|
2366
2447
|
# Determine color list based on cmap parameter
|
2367
|
-
if isinstance(cmap,str):
|
2448
|
+
if isinstance(cmap, str):
|
2368
2449
|
if "aut" in cmap:
|
2369
2450
|
if n == 1:
|
2370
2451
|
colorlist = ["#3A4453"]
|
2371
2452
|
elif n == 2:
|
2372
2453
|
colorlist = ["#3A4453", "#FF2C00"]
|
2373
2454
|
elif n == 3:
|
2374
|
-
colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
|
2455
|
+
colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
|
2375
2456
|
elif n == 4:
|
2376
|
-
colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
|
2457
|
+
colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
|
2377
2458
|
elif n == 5:
|
2378
|
-
colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
|
2459
|
+
colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
|
2379
2460
|
elif n == 6:
|
2380
|
-
colorlist = [
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2461
|
+
colorlist = [
|
2462
|
+
"#FF2C00",
|
2463
|
+
"#91bfdb",
|
2464
|
+
"#B25E9D",
|
2465
|
+
"#4B8C3B",
|
2466
|
+
"#EF8632",
|
2467
|
+
"#24578E",
|
2468
|
+
]
|
2469
|
+
elif n == 7:
|
2470
|
+
colorlist = [
|
2471
|
+
"#7F7F7F",
|
2472
|
+
"#459AA9",
|
2473
|
+
"#B25E9D",
|
2474
|
+
"#4B8C3B",
|
2475
|
+
"#EF8632",
|
2476
|
+
"#24578E" "#FF2C00",
|
2477
|
+
]
|
2478
|
+
elif n == 8:
|
2384
2479
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
2385
2480
|
# colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
|
2386
|
-
colorlist = [
|
2387
|
-
|
2388
|
-
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2392
|
-
|
2481
|
+
colorlist = [
|
2482
|
+
"#78BFDA",
|
2483
|
+
"#D52E6F",
|
2484
|
+
"#F7D648",
|
2485
|
+
"#A52D28",
|
2486
|
+
"#6B9F41",
|
2487
|
+
"#E18330",
|
2488
|
+
"#E18B9D",
|
2489
|
+
"#3C88CC",
|
2490
|
+
]
|
2491
|
+
elif n == 9:
|
2492
|
+
colorlist = [
|
2493
|
+
"#1f77b4",
|
2494
|
+
"#ff7f0e",
|
2495
|
+
"#367B7F",
|
2496
|
+
"#ff9896",
|
2497
|
+
"#d62728",
|
2498
|
+
"#aa40fc",
|
2499
|
+
"#e377c2",
|
2500
|
+
"#51B34F",
|
2501
|
+
"#17becf",
|
2502
|
+
]
|
2503
|
+
elif n == 10:
|
2504
|
+
colorlist = [
|
2505
|
+
"#1f77b4",
|
2506
|
+
"#ff7f0e",
|
2507
|
+
"#367B7F",
|
2508
|
+
"#ff9896",
|
2509
|
+
"#51B34F",
|
2510
|
+
"#d62728" "#aa40fc",
|
2511
|
+
"#e377c2",
|
2512
|
+
"#375FD2",
|
2513
|
+
"#17becf",
|
2514
|
+
]
|
2515
|
+
elif 10 < n <= 20:
|
2516
|
+
colorlist = cmap_20
|
2393
2517
|
else:
|
2394
2518
|
colorlist = cmap_28
|
2395
2519
|
by = "start"
|
@@ -2426,7 +2550,7 @@ def get_color(
|
|
2426
2550
|
by = "linspace"
|
2427
2551
|
colorlist = cmap2hex(cmap)
|
2428
2552
|
elif isinstance(cmap, list):
|
2429
|
-
colorlist=cmap
|
2553
|
+
colorlist = cmap
|
2430
2554
|
|
2431
2555
|
# Determine method for generating color list
|
2432
2556
|
if "st" in by.lower() or "be" in by.lower():
|
@@ -2446,7 +2570,7 @@ def get_color(
|
|
2446
2570
|
return hue_list
|
2447
2571
|
else:
|
2448
2572
|
raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
|
2449
|
-
|
2573
|
+
|
2450
2574
|
|
2451
2575
|
def stdshade(ax=None, *args, **kwargs):
|
2452
2576
|
"""
|
@@ -2454,7 +2578,7 @@ def stdshade(ax=None, *args, **kwargs):
|
|
2454
2578
|
plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
|
2455
2579
|
"""
|
2456
2580
|
from scipy.signal import savgol_filter
|
2457
|
-
|
2581
|
+
|
2458
2582
|
# Separate kws_line and kws_fill if necessary
|
2459
2583
|
kws_line = kwargs.pop("kws_line", {})
|
2460
2584
|
kws_fill = kwargs.pop("kws_fill", {})
|
@@ -2724,37 +2848,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
|
|
2724
2848
|
# cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
|
2725
2849
|
# return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
|
2726
2850
|
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2851
|
+
|
2852
|
+
def add_colorbar(
|
2853
|
+
im,
|
2854
|
+
cmap="viridis",
|
2855
|
+
vmin=-1,
|
2856
|
+
vmax=1,
|
2857
|
+
orientation="vertical",
|
2858
|
+
width_ratio=0.05,
|
2859
|
+
pad_ratio=0.02,
|
2860
|
+
shrink=1.0,
|
2861
|
+
**kwargs,
|
2862
|
+
):
|
2736
2863
|
import matplotlib as mpl
|
2864
|
+
|
2737
2865
|
if all([cmap, vmin, vmax]):
|
2738
2866
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
2739
2867
|
else:
|
2740
|
-
norm=False
|
2868
|
+
norm = False
|
2741
2869
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
2742
2870
|
sm.set_array([])
|
2743
2871
|
l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
|
2744
|
-
if orientation ==
|
2872
|
+
if orientation == "vertical":
|
2745
2873
|
width = width_ratio * w
|
2746
2874
|
pad = pad_ratio * w
|
2747
|
-
cax = im.figure.add_axes(
|
2875
|
+
cax = im.figure.add_axes(
|
2876
|
+
[l + w + pad, b, width, h * shrink]
|
2877
|
+
) # Right of the image
|
2748
2878
|
else:
|
2749
2879
|
height = width_ratio * h
|
2750
2880
|
pad = pad_ratio * h
|
2751
|
-
cax = im.figure.add_axes(
|
2752
|
-
|
2753
|
-
|
2881
|
+
cax = im.figure.add_axes(
|
2882
|
+
[l, b - height - pad, w * shrink, height]
|
2883
|
+
) # Below the image
|
2884
|
+
cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
|
2885
|
+
return cbar
|
2886
|
+
|
2754
2887
|
|
2755
2888
|
# Usage:
|
2756
2889
|
# add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
|
2757
2890
|
|
2891
|
+
|
2758
2892
|
def generate_xticks_with_gap(x_len, hue_len):
|
2759
2893
|
"""
|
2760
2894
|
Generate a concatenated array based on x_len and hue_len,
|
@@ -2884,6 +3018,7 @@ def style_examples(
|
|
2884
3018
|
f = listdir(
|
2885
3019
|
"/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
|
2886
3020
|
kind=".json",
|
3021
|
+
verbose=False,
|
2887
3022
|
)
|
2888
3023
|
display(f.sample(2))
|
2889
3024
|
# def style_example(dir_save,)
|
@@ -2961,6 +3096,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
|
|
2961
3096
|
|
2962
3097
|
def get_params_from_func_usage(function_signature):
|
2963
3098
|
import re
|
3099
|
+
|
2964
3100
|
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2965
3101
|
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2966
3102
|
# Find all matches
|
@@ -2987,7 +3123,7 @@ def plotxy(
|
|
2987
3123
|
x=None,
|
2988
3124
|
y=None,
|
2989
3125
|
ax=None,
|
2990
|
-
kind: Union[str, list] =
|
3126
|
+
kind: Union[str, list] = "scatter", # Specify the kind of plot
|
2991
3127
|
verbose=False,
|
2992
3128
|
**kwargs,
|
2993
3129
|
):
|
@@ -3011,14 +3147,15 @@ def plotxy(
|
|
3011
3147
|
# Check for valid plot kind
|
3012
3148
|
# Default arguments for various plot types
|
3013
3149
|
from pathlib import Path
|
3150
|
+
|
3014
3151
|
# Get the current script's directory as a Path object
|
3015
|
-
current_directory = Path(__file__).resolve().parent
|
3152
|
+
current_directory = Path(__file__).resolve().parent
|
3153
|
+
|
3154
|
+
if not "default_settings" in locals():
|
3155
|
+
default_settings = fload(current_directory / "data" / "usages_sns.json")
|
3156
|
+
if not "sns_info" in locals():
|
3157
|
+
sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
|
3016
3158
|
|
3017
|
-
if not 'default_settings' in locals():
|
3018
|
-
default_settings = fload(current_directory / 'data' / 'usages_sns.json')
|
3019
|
-
if not 'sns_info' in locals():
|
3020
|
-
sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
|
3021
|
-
|
3022
3159
|
valid_kinds = list(default_settings.keys())
|
3023
3160
|
# print(valid_kinds)
|
3024
3161
|
if kind is not None:
|
@@ -3032,7 +3169,7 @@ def plotxy(
|
|
3032
3169
|
if kind is not None:
|
3033
3170
|
for k in kind:
|
3034
3171
|
if k in valid_kinds:
|
3035
|
-
print(f"{k}:\n\t{default_settings[k]}")
|
3172
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
3036
3173
|
usage_str = """plotxy(data=ranked_genes,
|
3037
3174
|
x="log2(fold_change)",
|
3038
3175
|
y="-log10(p-value)",
|
@@ -3057,15 +3194,25 @@ def plotxy(
|
|
3057
3194
|
kws_add_text = v_arg
|
3058
3195
|
kwargs.pop(k_arg, None)
|
3059
3196
|
break
|
3060
|
-
zorder=0
|
3197
|
+
zorder = 0
|
3061
3198
|
for k in kind:
|
3062
|
-
zorder+=1
|
3199
|
+
zorder += 1
|
3063
3200
|
# indicate 'col' features
|
3064
3201
|
col = kwargs.get("col", None)
|
3065
|
-
sns_with_col = [
|
3202
|
+
sns_with_col = [
|
3203
|
+
"catplot",
|
3204
|
+
"histplot",
|
3205
|
+
"relplot",
|
3206
|
+
"lmplot",
|
3207
|
+
"pairplot",
|
3208
|
+
"displot",
|
3209
|
+
"kdeplot",
|
3210
|
+
]
|
3066
3211
|
if col is not None:
|
3067
3212
|
if not k in sns_with_col:
|
3068
|
-
print(
|
3213
|
+
print(
|
3214
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
3215
|
+
)
|
3069
3216
|
# (1) return FcetGrid
|
3070
3217
|
if k == "jointplot":
|
3071
3218
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
@@ -3092,77 +3239,98 @@ def plotxy(
|
|
3092
3239
|
ax = stdshade(ax=ax, **kwargs)
|
3093
3240
|
elif k == "scatterplot":
|
3094
3241
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3095
|
-
kws_scatter={
|
3242
|
+
kws_scatter = {
|
3243
|
+
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3244
|
+
}
|
3096
3245
|
hue = kwargs.pop("hue", None)
|
3097
3246
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3098
3247
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3099
|
-
palette = kws_scatter.pop(
|
3248
|
+
palette = kws_scatter.pop(
|
3249
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3250
|
+
)
|
3100
3251
|
s = kws_scatter.pop("s", 10)
|
3101
3252
|
alpha = kws_scatter.pop("alpha", 0.7)
|
3102
|
-
ax = sns.scatterplot(
|
3253
|
+
ax = sns.scatterplot(
|
3254
|
+
ax=ax,
|
3255
|
+
data=data,
|
3256
|
+
x=x,
|
3257
|
+
y=y,
|
3258
|
+
hue=hue,
|
3259
|
+
palette=palette,
|
3260
|
+
s=s,
|
3261
|
+
alpha=alpha,
|
3262
|
+
zorder=zorder,
|
3263
|
+
**kws_scatter,
|
3264
|
+
)
|
3103
3265
|
elif k == "histplot":
|
3104
3266
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3105
|
-
kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3106
|
-
ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
|
3267
|
+
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3268
|
+
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3107
3269
|
elif k == "kdeplot":
|
3108
3270
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3109
|
-
kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3110
|
-
ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
|
3271
|
+
kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3272
|
+
ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
|
3111
3273
|
elif k == "ecdfplot":
|
3112
3274
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3113
|
-
kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3114
|
-
ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
|
3275
|
+
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3276
|
+
ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
|
3115
3277
|
elif k == "rugplot":
|
3116
3278
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3117
|
-
kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3118
|
-
ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
|
3279
|
+
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3280
|
+
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3119
3281
|
elif k == "stripplot":
|
3120
3282
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3121
|
-
kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3283
|
+
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3122
3284
|
dodge = kws_strip.pop("dodge", True)
|
3123
|
-
ax = sns.stripplot(
|
3285
|
+
ax = sns.stripplot(
|
3286
|
+
data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
|
3287
|
+
)
|
3124
3288
|
elif k == "swarmplot":
|
3125
3289
|
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
3126
|
-
kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3127
|
-
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
|
3290
|
+
kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3291
|
+
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
|
3128
3292
|
elif k == "boxplot":
|
3129
3293
|
kws_box = kwargs.pop("kws_box", kwargs)
|
3130
|
-
kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3131
|
-
ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
|
3294
|
+
kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3295
|
+
ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
|
3132
3296
|
elif k == "violinplot":
|
3133
3297
|
kws_violin = kwargs.pop("kws_violin", kwargs)
|
3134
|
-
kws_violin={
|
3135
|
-
|
3298
|
+
kws_violin = {
|
3299
|
+
k: v for k, v in kws_violin.items() if not k.startswith("kws_")
|
3300
|
+
}
|
3301
|
+
ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
|
3136
3302
|
elif k == "boxenplot":
|
3137
3303
|
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
3138
|
-
kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3139
|
-
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
|
3304
|
+
kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3305
|
+
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
|
3140
3306
|
elif k == "pointplot":
|
3141
3307
|
kws_point = kwargs.pop("kws_point", kwargs)
|
3142
|
-
kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3143
|
-
ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
|
3308
|
+
kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3309
|
+
ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
|
3144
3310
|
elif k == "barplot":
|
3145
3311
|
kws_bar = kwargs.pop("kws_bar", kwargs)
|
3146
|
-
kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3147
|
-
ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
|
3312
|
+
kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3313
|
+
ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
|
3148
3314
|
elif k == "countplot":
|
3149
3315
|
kws_count = kwargs.pop("kws_count", kwargs)
|
3150
|
-
kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3151
|
-
if not kws_count.get("hue",None):
|
3152
|
-
kws_count.pop("palette",None)
|
3153
|
-
ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
|
3316
|
+
kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3317
|
+
if not kws_count.get("hue", None):
|
3318
|
+
kws_count.pop("palette", None)
|
3319
|
+
ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
|
3154
3320
|
elif k == "regplot":
|
3155
3321
|
kws_reg = kwargs.pop("kws_reg", kwargs)
|
3156
|
-
kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3157
|
-
ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
|
3322
|
+
kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3323
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
|
3158
3324
|
elif k == "residplot":
|
3159
3325
|
kws_resid = kwargs.pop("kws_resid", kwargs)
|
3160
|
-
kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3161
|
-
ax = sns.residplot(
|
3326
|
+
kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3327
|
+
ax = sns.residplot(
|
3328
|
+
data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
|
3329
|
+
)
|
3162
3330
|
elif k == "lineplot":
|
3163
3331
|
kws_line = kwargs.pop("kws_line", kwargs)
|
3164
|
-
kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3165
|
-
ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
|
3332
|
+
kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3333
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
|
3166
3334
|
|
3167
3335
|
figsets(ax=ax, **kws_figsets)
|
3168
3336
|
if kws_add_text:
|
@@ -3176,20 +3344,22 @@ def plotxy(
|
|
3176
3344
|
return g, ax
|
3177
3345
|
return ax
|
3178
3346
|
|
3347
|
+
|
3179
3348
|
def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
|
3180
3349
|
norm_ = plt.Normalize(min_max[0], min_max[1])
|
3181
3350
|
colormap = plt.get_cmap(cmap)
|
3182
3351
|
return colormap(norm_(data))
|
3183
3352
|
|
3353
|
+
|
3184
3354
|
def volcano(
|
3185
|
-
data:pd.DataFrame,
|
3186
|
-
x:str,
|
3187
|
-
y:str,
|
3188
|
-
gene_col:str=None,
|
3189
|
-
top_genes=[5, 5],
|
3190
|
-
thr_x=np.log2(1.5),
|
3355
|
+
data: pd.DataFrame,
|
3356
|
+
x: str,
|
3357
|
+
y: str,
|
3358
|
+
gene_col: str = None,
|
3359
|
+
top_genes=[5, 5], # [down-regulated, up-regulated]
|
3360
|
+
thr_x=np.log2(1.5), # default: 0.585
|
3191
3361
|
thr_y=-np.log10(0.05),
|
3192
|
-
sort_xy="x",
|
3362
|
+
sort_xy="x", #'y', 'xy'
|
3193
3363
|
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3194
3364
|
s=20,
|
3195
3365
|
fill=True, # plot filled scatter
|
@@ -3201,11 +3371,10 @@ def volcano(
|
|
3201
3371
|
ax=None,
|
3202
3372
|
verbose=False,
|
3203
3373
|
kws_text=dict(fontsize=10, color="k"),
|
3204
|
-
kws_bbox=dict(
|
3205
|
-
|
3206
|
-
|
3207
|
-
|
3208
|
-
kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
|
3374
|
+
kws_bbox=dict(
|
3375
|
+
facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
|
3376
|
+
), # '{}' to hide
|
3377
|
+
kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
|
3209
3378
|
**kwargs,
|
3210
3379
|
):
|
3211
3380
|
"""
|
@@ -3275,39 +3444,35 @@ def volcano(
|
|
3275
3444
|
kws_figsets = v_arg
|
3276
3445
|
kwargs.pop(k_arg, None)
|
3277
3446
|
break
|
3278
|
-
|
3279
|
-
data=data.copy()
|
3447
|
+
|
3448
|
+
data = data.copy()
|
3280
3449
|
# filter nan
|
3281
3450
|
data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
|
3282
|
-
data.loc[:,"color"] = np.where(
|
3451
|
+
data.loc[:, "color"] = np.where(
|
3283
3452
|
(data[x] > thr_x) & (data[y] > thr_y),
|
3284
3453
|
colors[2],
|
3285
|
-
np.where((data[x] < -thr_x) & (data[y] > thr_y),
|
3286
|
-
colors[0],
|
3287
|
-
colors[1]),
|
3454
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3288
3455
|
)
|
3289
|
-
top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
|
3290
|
-
|
3456
|
+
top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
|
3457
|
+
|
3291
3458
|
# could custom how to select the top genes, x: x has priority
|
3292
|
-
sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
|
3293
|
-
ascending_up=[True, True] if sort_xy=="x" else [False, True]
|
3294
|
-
ascending_down=[False, True] if sort_xy=="x" else [False, False]
|
3295
|
-
|
3296
|
-
down_reg_genes =
|
3297
|
-
(data["color"] == colors[0]) &
|
3298
|
-
|
3299
|
-
(
|
3300
|
-
|
3301
|
-
up_reg_genes =
|
3302
|
-
(data["color"] == colors[2]) &
|
3303
|
-
|
3304
|
-
(
|
3305
|
-
|
3459
|
+
sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
|
3460
|
+
ascending_up = [True, True] if sort_xy == "x" else [False, True]
|
3461
|
+
ascending_down = [False, True] if sort_xy == "x" else [False, False]
|
3462
|
+
|
3463
|
+
down_reg_genes = (
|
3464
|
+
data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3465
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_up)
|
3466
|
+
.head(top_genes[0])
|
3467
|
+
)
|
3468
|
+
up_reg_genes = (
|
3469
|
+
data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3470
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_down)
|
3471
|
+
.head(top_genes[1])
|
3472
|
+
)
|
3306
3473
|
sele_gene = pd.concat([down_reg_genes, up_reg_genes])
|
3307
|
-
|
3308
|
-
palette = {colors[0]: colors[0],
|
3309
|
-
colors[1]: colors[1],
|
3310
|
-
colors[2]: colors[2]}
|
3474
|
+
|
3475
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3311
3476
|
# Plot setup
|
3312
3477
|
if ax is None:
|
3313
3478
|
ax = plt.gca()
|
@@ -3337,9 +3502,9 @@ def volcano(
|
|
3337
3502
|
)
|
3338
3503
|
|
3339
3504
|
# Add threshold lines for x and y axes
|
3340
|
-
ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
|
3341
|
-
ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
|
3342
|
-
ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
|
3505
|
+
ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
|
3506
|
+
ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
|
3507
|
+
ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
|
3343
3508
|
|
3344
3509
|
# Add gene labels for selected significant points
|
3345
3510
|
if gene_col:
|
@@ -3349,19 +3514,28 @@ def volcano(
|
|
3349
3514
|
textcolor = kws_text.pop("color", "k")
|
3350
3515
|
fontsize = kws_text.pop("fontsize", 10)
|
3351
3516
|
arrowstyles = [
|
3352
|
-
"->",
|
3353
|
-
"
|
3354
|
-
"
|
3517
|
+
"->",
|
3518
|
+
"<-",
|
3519
|
+
"<->",
|
3520
|
+
"<|-",
|
3521
|
+
"-|>",
|
3522
|
+
"<|-|>",
|
3523
|
+
"-",
|
3524
|
+
"-[",
|
3525
|
+
"-[",
|
3526
|
+
"fancy",
|
3527
|
+
"simple",
|
3528
|
+
"wedge",
|
3355
3529
|
]
|
3356
3530
|
arrowstyle = kws_arrow.pop("style", "<|-")
|
3357
|
-
arrowstyle = strcmp(arrowstyle, arrowstyles,scorer=
|
3358
|
-
expand=kws_arrow.pop("expand",(1.05,1.1))
|
3531
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
|
3532
|
+
expand = kws_arrow.pop("expand", (1.05, 1.1))
|
3359
3533
|
arrowcolor = kws_arrow.pop("color", "0.4")
|
3360
3534
|
arrowlinewidth = kws_arrow.pop("lw", 0.75)
|
3361
3535
|
shrinkA = kws_arrow.pop("shrinkA", 0)
|
3362
3536
|
shrinkB = kws_arrow.pop("shrinkB", 0)
|
3363
3537
|
mutation_scale = kws_arrow.pop("head", 10)
|
3364
|
-
arrow_fill=kws_arrow.pop("fill", False)
|
3538
|
+
arrow_fill = kws_arrow.pop("fill", False)
|
3365
3539
|
for i in range(sele_gene.shape[0]):
|
3366
3540
|
if isinstance(textcolor, list): # be consistant with dots's color
|
3367
3541
|
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
@@ -3382,7 +3556,7 @@ def volcano(
|
|
3382
3556
|
adjust_text(
|
3383
3557
|
texts,
|
3384
3558
|
expand=expand,
|
3385
|
-
min_arrow_len=5,
|
3559
|
+
min_arrow_len=5,
|
3386
3560
|
ax=ax,
|
3387
3561
|
arrowprops=dict(
|
3388
3562
|
arrowstyle=arrowstyle,
|
@@ -3393,7 +3567,7 @@ def volcano(
|
|
3393
3567
|
shrinkB=shrinkB,
|
3394
3568
|
mutation_scale=mutation_scale,
|
3395
3569
|
**kws_arrow,
|
3396
|
-
)
|
3570
|
+
),
|
3397
3571
|
)
|
3398
3572
|
|
3399
3573
|
figsets(**kws_figsets)
|
@@ -3499,6 +3673,7 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3499
3673
|
"""Reduce the saturation of a color by a given factor (between 0 and 1)."""
|
3500
3674
|
import matplotlib.colors as mcolors
|
3501
3675
|
import colorsys
|
3676
|
+
|
3502
3677
|
# Convert the color to RGB
|
3503
3678
|
rgb = mcolors.to_rgb(color)
|
3504
3679
|
# Convert RGB to HLS (Hue, Lightness, Saturation)
|
@@ -3508,8 +3683,19 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3508
3683
|
# Convert back to RGB
|
3509
3684
|
return colorsys.hls_to_rgb(h, l, s)
|
3510
3685
|
|
3511
|
-
|
3512
|
-
|
3686
|
+
|
3687
|
+
def textsets(
|
3688
|
+
text,
|
3689
|
+
fontname="Arial",
|
3690
|
+
fontsize=11,
|
3691
|
+
fontweight="normal",
|
3692
|
+
fontstyle="normal",
|
3693
|
+
fontcolor="k",
|
3694
|
+
backgroundcolor=None,
|
3695
|
+
shadow=False,
|
3696
|
+
ha="center",
|
3697
|
+
va="center",
|
3698
|
+
):
|
3513
3699
|
if text: # Ensure text exists
|
3514
3700
|
if fontname:
|
3515
3701
|
text.set_fontname(plt_font(fontname))
|
@@ -3522,26 +3708,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
|
|
3522
3708
|
if fontcolor:
|
3523
3709
|
text.set_color(fontcolor)
|
3524
3710
|
if backgroundcolor:
|
3525
|
-
text.set_backgroundcolor(backgroundcolor)
|
3711
|
+
text.set_backgroundcolor(backgroundcolor)
|
3526
3712
|
text.set_horizontalalignment(ha)
|
3527
|
-
text.set_verticalalignment(va)
|
3713
|
+
text.set_verticalalignment(va)
|
3528
3714
|
if shadow:
|
3529
|
-
text.set_path_effects(
|
3530
|
-
matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
|
3531
|
-
|
3715
|
+
text.set_path_effects(
|
3716
|
+
[matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
|
3717
|
+
)
|
3718
|
+
|
3719
|
+
|
3532
3720
|
def venn(
|
3533
|
-
lists:list,
|
3534
|
-
labels:list=None,
|
3721
|
+
lists: list,
|
3722
|
+
labels: list = None,
|
3535
3723
|
ax=None,
|
3536
3724
|
colors=None,
|
3537
3725
|
edgecolor=None,
|
3538
3726
|
alpha=0.5,
|
3539
|
-
saturation
|
3540
|
-
linewidth=0,
|
3727
|
+
saturation=0.75,
|
3728
|
+
linewidth=0, # default no edge
|
3541
3729
|
linestyle="-",
|
3542
|
-
fontname=
|
3730
|
+
fontname="Arial",
|
3543
3731
|
fontsize=10,
|
3544
|
-
fontcolor=
|
3732
|
+
fontcolor="k",
|
3545
3733
|
fontweight="normal",
|
3546
3734
|
fontstyle="normal",
|
3547
3735
|
ha="center",
|
@@ -3550,18 +3738,18 @@ def venn(
|
|
3550
3738
|
subset_fontsize=10,
|
3551
3739
|
subset_fontweight="normal",
|
3552
3740
|
subset_fontstyle="normal",
|
3553
|
-
subset_fontcolor=
|
3741
|
+
subset_fontcolor="k",
|
3554
3742
|
backgroundcolor=None,
|
3555
3743
|
custom_texts=None,
|
3556
|
-
show_percentages=True,
|
3744
|
+
show_percentages=True, # display percentage
|
3557
3745
|
fmt="{:.1%}",
|
3558
3746
|
ellipse_shape=False, # 椭圆形
|
3559
|
-
ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
|
3560
|
-
**kwargs
|
3747
|
+
ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
|
3748
|
+
**kwargs,
|
3561
3749
|
):
|
3562
3750
|
"""
|
3563
3751
|
Advanced Venn diagram plotting function with extensive customization options.
|
3564
|
-
Usage:
|
3752
|
+
Usage:
|
3565
3753
|
# Define the two sets
|
3566
3754
|
set1 = [1, 2, 3, 4, 5]
|
3567
3755
|
set2 = [4, 5, 6, 7, 8]
|
@@ -3612,56 +3800,70 @@ def venn(
|
|
3612
3800
|
"""
|
3613
3801
|
if ax is None:
|
3614
3802
|
ax = plt.gca()
|
3615
|
-
lists=[set(flatten(i, verbose=False)) for i in lists]
|
3803
|
+
lists = [set(flatten(i, verbose=False)) for i in lists]
|
3616
3804
|
# Function to apply text styles to labels
|
3617
3805
|
if colors is None:
|
3618
|
-
colors=["r","b"] if len(lists)==2 else ["r","g","b"]
|
3806
|
+
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
3619
3807
|
if labels is None:
|
3620
|
-
labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
|
3808
|
+
labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
|
3621
3809
|
if edgecolor is None:
|
3622
|
-
edgecolor=colors
|
3810
|
+
edgecolor = colors
|
3623
3811
|
colors = [desaturate_color(color, saturation) for color in colors]
|
3624
3812
|
# Check colors and auto-calculate overlaps
|
3625
3813
|
if len(lists) == 2:
|
3814
|
+
|
3626
3815
|
def get_count_and_percentage(set_count, subset_count):
|
3627
3816
|
percent = subset_count / set_count if set_count > 0 else 0
|
3628
|
-
return
|
3817
|
+
return (
|
3818
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3819
|
+
if show_percentages
|
3820
|
+
else f"{subset_count}"
|
3821
|
+
)
|
3629
3822
|
|
3630
3823
|
from matplotlib_venn import venn2, venn2_circles
|
3824
|
+
|
3631
3825
|
# Auto-calculate overlap color for 2-set Venn diagram
|
3632
3826
|
overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
|
3633
|
-
|
3827
|
+
|
3634
3828
|
# Draw the venn diagram
|
3635
3829
|
v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3636
3830
|
venn_circles = venn2_circles(subsets=lists, ax=ax)
|
3637
|
-
set1,set2=lists[0],lists[1]
|
3638
|
-
v.get_patch_by_id(
|
3639
|
-
v.get_patch_by_id(
|
3640
|
-
v.get_patch_by_id(
|
3831
|
+
set1, set2 = lists[0], lists[1]
|
3832
|
+
v.get_patch_by_id("10").set_color(colors[0])
|
3833
|
+
v.get_patch_by_id("01").set_color(colors[1])
|
3834
|
+
v.get_patch_by_id("11").set_color(
|
3835
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
3836
|
+
)
|
3641
3837
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
3642
3838
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
3643
3839
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
3644
|
-
v.get_label_by_id(
|
3645
|
-
|
3646
|
-
|
3647
|
-
|
3840
|
+
v.get_label_by_id("10").set_text(
|
3841
|
+
get_count_and_percentage(len(set1), len(set1 - set2))
|
3842
|
+
)
|
3843
|
+
v.get_label_by_id("01").set_text(
|
3844
|
+
get_count_and_percentage(len(set2), len(set2 - set1))
|
3845
|
+
)
|
3846
|
+
v.get_label_by_id("11").set_text(
|
3847
|
+
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
3848
|
+
)
|
3648
3849
|
|
3649
|
-
if not isinstance(linewidth,list):
|
3650
|
-
linewidth=[linewidth]
|
3651
|
-
if isinstance(linestyle,str):
|
3652
|
-
linestyle=[linestyle]
|
3850
|
+
if not isinstance(linewidth, list):
|
3851
|
+
linewidth = [linewidth]
|
3852
|
+
if isinstance(linestyle, str):
|
3853
|
+
linestyle = [linestyle]
|
3653
3854
|
if not isinstance(edgecolor, list):
|
3654
|
-
edgecolor=[edgecolor]
|
3655
|
-
linewidth=linewidth*2 if len(linewidth)==1 else linewidth
|
3656
|
-
linestyle=linestyle*2 if len(linestyle)==1 else linestyle
|
3657
|
-
edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
|
3855
|
+
edgecolor = [edgecolor]
|
3856
|
+
linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
|
3857
|
+
linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
|
3858
|
+
edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
|
3658
3859
|
for i in range(2):
|
3659
3860
|
venn_circles[i].set_lw(linewidth[i])
|
3660
3861
|
venn_circles[i].set_ls(linestyle[i])
|
3661
3862
|
venn_circles[i].set_edgecolor(edgecolor[i])
|
3662
3863
|
# 椭圆
|
3663
3864
|
if ellipse_shape:
|
3664
|
-
import matplotlib.patches as patches
|
3865
|
+
import matplotlib.patches as patches
|
3866
|
+
|
3665
3867
|
for patch in v.patches:
|
3666
3868
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3667
3869
|
center1 = v.get_circle_center(0)
|
@@ -3672,9 +3874,13 @@ def venn(
|
|
3672
3874
|
height=ellipse_scale[1],
|
3673
3875
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3674
3876
|
facecolor=colors[0],
|
3675
|
-
lw=
|
3877
|
+
lw=(
|
3878
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3879
|
+
), # Ensure lw is a number
|
3676
3880
|
ls=linestyle[0],
|
3677
|
-
alpha=
|
3881
|
+
alpha=(
|
3882
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3883
|
+
), # Ensure alpha is a number
|
3678
3884
|
)
|
3679
3885
|
ellipse2 = patches.Ellipse(
|
3680
3886
|
(center2.x, center2.y),
|
@@ -3682,48 +3888,78 @@ def venn(
|
|
3682
3888
|
height=ellipse_scale[1],
|
3683
3889
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3684
3890
|
facecolor=colors[1],
|
3685
|
-
lw=
|
3891
|
+
lw=(
|
3892
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3893
|
+
), # Ensure lw is a number
|
3686
3894
|
ls=linestyle[0],
|
3687
|
-
alpha=
|
3895
|
+
alpha=(
|
3896
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3897
|
+
), # Ensure alpha is a number
|
3688
3898
|
)
|
3689
3899
|
ax.add_patch(ellipse1)
|
3690
3900
|
ax.add_patch(ellipse2)
|
3691
3901
|
# Apply styles to set labels
|
3692
3902
|
for i, text in enumerate(v.set_labels):
|
3693
|
-
textsets(
|
3694
|
-
|
3903
|
+
textsets(
|
3904
|
+
text,
|
3905
|
+
fontname=fontname,
|
3906
|
+
fontsize=fontsize,
|
3907
|
+
fontweight=fontweight,
|
3908
|
+
fontstyle=fontstyle,
|
3909
|
+
fontcolor=fontcolor,
|
3910
|
+
ha=ha,
|
3911
|
+
va=va,
|
3912
|
+
shadow=shadow,
|
3913
|
+
)
|
3695
3914
|
|
3696
3915
|
# Apply styles to subset labels
|
3697
3916
|
for i, text in enumerate(v.subset_labels):
|
3698
3917
|
if text: # Ensure text exists
|
3699
3918
|
if custom_texts: # Custom text handling
|
3700
|
-
text.set_text(custom_texts[i])
|
3701
|
-
textsets(
|
3702
|
-
|
3919
|
+
text.set_text(custom_texts[i])
|
3920
|
+
textsets(
|
3921
|
+
text,
|
3922
|
+
fontname=fontname,
|
3923
|
+
fontsize=subset_fontsize,
|
3924
|
+
fontweight=subset_fontweight,
|
3925
|
+
fontstyle=subset_fontstyle,
|
3926
|
+
fontcolor=subset_fontcolor,
|
3927
|
+
ha=ha,
|
3928
|
+
va=va,
|
3929
|
+
shadow=shadow,
|
3930
|
+
)
|
3703
3931
|
|
3704
3932
|
elif len(lists) == 3:
|
3933
|
+
|
3705
3934
|
def get_label(set_count, subset_count):
|
3706
3935
|
percent = subset_count / set_count if set_count > 0 else 0
|
3707
|
-
return
|
3708
|
-
|
3936
|
+
return (
|
3937
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3938
|
+
if show_percentages
|
3939
|
+
else f"{subset_count}"
|
3940
|
+
)
|
3941
|
+
|
3709
3942
|
from matplotlib_venn import venn3, venn3_circles
|
3943
|
+
|
3710
3944
|
# Auto-calculate overlap colors for 3-set Venn diagram
|
3711
3945
|
colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
|
3712
3946
|
colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
|
3713
3947
|
colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
|
3714
|
-
colorABC =
|
3715
|
-
|
3948
|
+
colorABC = (
|
3949
|
+
get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
|
3950
|
+
)
|
3951
|
+
set1, set2, set3 = lists[0], lists[1], lists[2]
|
3716
3952
|
|
3717
3953
|
# Draw the venn diagram
|
3718
|
-
v = venn3(subsets=lists, set_labels=labels, ax=ax
|
3719
|
-
v.get_patch_by_id(
|
3720
|
-
v.get_patch_by_id(
|
3721
|
-
v.get_patch_by_id(
|
3722
|
-
v.get_patch_by_id(
|
3723
|
-
v.get_patch_by_id(
|
3724
|
-
v.get_patch_by_id(
|
3725
|
-
v.get_patch_by_id(
|
3726
|
-
|
3954
|
+
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3955
|
+
v.get_patch_by_id("100").set_color(colors[0])
|
3956
|
+
v.get_patch_by_id("010").set_color(colors[1])
|
3957
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
3958
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
3959
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
3960
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
3961
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
3962
|
+
|
3727
3963
|
# Correctly labeling subset sizes
|
3728
3964
|
# v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
|
3729
3965
|
# v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
|
@@ -3732,63 +3968,93 @@ def venn(
|
|
3732
3968
|
# v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
|
3733
3969
|
# v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
|
3734
3970
|
# v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
|
3735
|
-
v.get_label_by_id(
|
3736
|
-
v.get_label_by_id(
|
3737
|
-
v.get_label_by_id(
|
3738
|
-
v.get_label_by_id(
|
3739
|
-
|
3740
|
-
|
3741
|
-
v.get_label_by_id(
|
3742
|
-
|
3971
|
+
v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
|
3972
|
+
v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
|
3973
|
+
v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
|
3974
|
+
v.get_label_by_id("110").set_text(
|
3975
|
+
get_label(len(set1 | set2), len(set1 & set2 - set3))
|
3976
|
+
)
|
3977
|
+
v.get_label_by_id("101").set_text(
|
3978
|
+
get_label(len(set1 | set3), len(set1 & set3 - set2))
|
3979
|
+
)
|
3980
|
+
v.get_label_by_id("011").set_text(
|
3981
|
+
get_label(len(set2 | set3), len(set2 & set3 - set1))
|
3982
|
+
)
|
3983
|
+
v.get_label_by_id("111").set_text(
|
3984
|
+
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
3985
|
+
)
|
3743
3986
|
|
3744
3987
|
# Apply styles to set labels
|
3745
3988
|
for i, text in enumerate(v.set_labels):
|
3746
|
-
textsets(
|
3747
|
-
|
3989
|
+
textsets(
|
3990
|
+
text,
|
3991
|
+
fontname=fontname,
|
3992
|
+
fontsize=fontsize,
|
3993
|
+
fontweight=fontweight,
|
3994
|
+
fontstyle=fontstyle,
|
3995
|
+
fontcolor=fontcolor,
|
3996
|
+
ha=ha,
|
3997
|
+
va=va,
|
3998
|
+
shadow=shadow,
|
3999
|
+
)
|
3748
4000
|
|
3749
4001
|
# Apply styles to subset labels
|
3750
4002
|
for i, text in enumerate(v.subset_labels):
|
3751
4003
|
if text: # Ensure text exists
|
3752
4004
|
if custom_texts: # Custom text handling
|
3753
4005
|
text.set_text(custom_texts[i])
|
3754
|
-
textsets(
|
3755
|
-
|
4006
|
+
textsets(
|
4007
|
+
text,
|
4008
|
+
fontname=fontname,
|
4009
|
+
fontsize=subset_fontsize,
|
4010
|
+
fontweight=subset_fontweight,
|
4011
|
+
fontstyle=subset_fontstyle,
|
4012
|
+
fontcolor=subset_fontcolor,
|
4013
|
+
ha=ha,
|
4014
|
+
va=va,
|
4015
|
+
shadow=shadow,
|
4016
|
+
)
|
3756
4017
|
|
3757
4018
|
venn_circles = venn3_circles(subsets=lists, ax=ax)
|
3758
|
-
if not isinstance(linewidth,list):
|
3759
|
-
linewidth=[linewidth]
|
3760
|
-
if isinstance(linestyle,str):
|
3761
|
-
linestyle=[linestyle]
|
4019
|
+
if not isinstance(linewidth, list):
|
4020
|
+
linewidth = [linewidth]
|
4021
|
+
if isinstance(linestyle, str):
|
4022
|
+
linestyle = [linestyle]
|
3762
4023
|
if not isinstance(edgecolor, list):
|
3763
|
-
edgecolor=[edgecolor]
|
3764
|
-
linewidth=linewidth*3 if len(linewidth)==1 else linewidth
|
3765
|
-
linestyle=linestyle*3 if len(linestyle)==1 else linestyle
|
3766
|
-
edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
|
4024
|
+
edgecolor = [edgecolor]
|
4025
|
+
linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
|
4026
|
+
linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
|
4027
|
+
edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
|
3767
4028
|
|
3768
4029
|
# edgecolor=[to_rgba(i) for i in edgecolor]
|
3769
4030
|
|
3770
|
-
for i in range(3):
|
4031
|
+
for i in range(3):
|
3771
4032
|
venn_circles[i].set_lw(linewidth[i])
|
3772
|
-
venn_circles[i].set_ls(linestyle[i])
|
3773
|
-
venn_circles[i].set_edgecolor(edgecolor[i])
|
4033
|
+
venn_circles[i].set_ls(linestyle[i])
|
4034
|
+
venn_circles[i].set_edgecolor(edgecolor[i])
|
3774
4035
|
|
3775
|
-
|
4036
|
+
# 椭圆形
|
3776
4037
|
if ellipse_shape:
|
3777
|
-
import matplotlib.patches as patches
|
4038
|
+
import matplotlib.patches as patches
|
4039
|
+
|
3778
4040
|
for patch in v.patches:
|
3779
4041
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3780
4042
|
center1 = v.get_circle_center(0)
|
3781
4043
|
center2 = v.get_circle_center(1)
|
3782
|
-
center3 = v.get_circle_center(2)
|
4044
|
+
center3 = v.get_circle_center(2)
|
3783
4045
|
ellipse1 = patches.Ellipse(
|
3784
4046
|
(center1.x, center1.y),
|
3785
4047
|
width=ellipse_scale[0],
|
3786
4048
|
height=ellipse_scale[1],
|
3787
4049
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3788
4050
|
facecolor=colors[0],
|
3789
|
-
lw=
|
4051
|
+
lw=(
|
4052
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4053
|
+
), # Ensure lw is a number
|
3790
4054
|
ls=linestyle[0],
|
3791
|
-
alpha=
|
4055
|
+
alpha=(
|
4056
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4057
|
+
), # Ensure alpha is a number
|
3792
4058
|
)
|
3793
4059
|
ellipse2 = patches.Ellipse(
|
3794
4060
|
(center2.x, center2.y),
|
@@ -3796,9 +4062,13 @@ def venn(
|
|
3796
4062
|
height=ellipse_scale[1],
|
3797
4063
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3798
4064
|
facecolor=colors[1],
|
3799
|
-
lw=
|
4065
|
+
lw=(
|
4066
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4067
|
+
), # Ensure lw is a number
|
3800
4068
|
ls=linestyle[0],
|
3801
|
-
alpha=
|
4069
|
+
alpha=(
|
4070
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4071
|
+
), # Ensure alpha is a number
|
3802
4072
|
)
|
3803
4073
|
ellipse3 = patches.Ellipse(
|
3804
4074
|
(center3.x, center3.y),
|
@@ -3806,9 +4076,13 @@ def venn(
|
|
3806
4076
|
height=ellipse_scale[1],
|
3807
4077
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3808
4078
|
facecolor=colors[1],
|
3809
|
-
lw=
|
4079
|
+
lw=(
|
4080
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4081
|
+
), # Ensure lw is a number
|
3810
4082
|
ls=linestyle[0],
|
3811
|
-
alpha=
|
4083
|
+
alpha=(
|
4084
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4085
|
+
), # Ensure alpha is a number
|
3812
4086
|
)
|
3813
4087
|
ax.add_patch(ellipse1)
|
3814
4088
|
ax.add_patch(ellipse2)
|
@@ -3820,17 +4094,20 @@ def venn(
|
|
3820
4094
|
for patch in v.patches:
|
3821
4095
|
if patch:
|
3822
4096
|
patch.set_alpha(alpha)
|
3823
|
-
if
|
4097
|
+
if "none" in edgecolor or 0 in linewidth:
|
3824
4098
|
patch.set_edgecolor("none")
|
3825
4099
|
return ax
|
3826
4100
|
|
4101
|
+
|
3827
4102
|
#! subplots, support automatic extend new axis
|
3828
|
-
def subplot(
|
3829
|
-
|
3830
|
-
|
3831
|
-
|
3832
|
-
|
3833
|
-
|
4103
|
+
def subplot(
|
4104
|
+
rows: int = 2,
|
4105
|
+
cols: int = 2,
|
4106
|
+
figsize: Union[tuple, list] = [8, 8],
|
4107
|
+
sharex=False,
|
4108
|
+
sharey=False,
|
4109
|
+
**kwargs,
|
4110
|
+
):
|
3834
4111
|
"""
|
3835
4112
|
nexttile = subplot(
|
3836
4113
|
8,
|
@@ -3850,11 +4127,14 @@ def subplot(rows:int=2,
|
|
3850
4127
|
ax.set_xlabel(f"Tile {i + 1}")
|
3851
4128
|
"""
|
3852
4129
|
from matplotlib.gridspec import GridSpec
|
4130
|
+
|
3853
4131
|
if run_once_within():
|
3854
|
-
print(
|
4132
|
+
print(
|
4133
|
+
f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
|
4134
|
+
)
|
3855
4135
|
fig = plt.figure(figsize=figsize)
|
3856
4136
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
3857
|
-
occupied = set()
|
4137
|
+
occupied = set()
|
3858
4138
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
3859
4139
|
col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
|
3860
4140
|
|
@@ -3862,8 +4142,9 @@ def subplot(rows:int=2,
|
|
3862
4142
|
nonlocal rows, grid_spec
|
3863
4143
|
rows += 1 # Expands by adding a row
|
3864
4144
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4145
|
+
|
3865
4146
|
def nexttile(rowspan=1, colspan=1, **kwargs):
|
3866
|
-
nonlocal rows, cols, occupied, grid_spec
|
4147
|
+
nonlocal rows, cols, occupied, grid_spec
|
3867
4148
|
for row in range(rows):
|
3868
4149
|
for col in range(cols):
|
3869
4150
|
if all(
|
@@ -3873,29 +4154,29 @@ def subplot(rows:int=2,
|
|
3873
4154
|
):
|
3874
4155
|
break
|
3875
4156
|
else:
|
3876
|
-
continue
|
4157
|
+
continue
|
3877
4158
|
break
|
3878
|
-
else:
|
4159
|
+
else:
|
3879
4160
|
expand_ax()
|
3880
4161
|
return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
|
3881
4162
|
|
3882
|
-
sharex_ax,sharey_ax = None, None
|
4163
|
+
sharex_ax, sharey_ax = None, None
|
3883
4164
|
|
3884
|
-
if sharex:
|
4165
|
+
if sharex:
|
3885
4166
|
sharex_ax = col_first_axes[col]
|
3886
4167
|
|
3887
|
-
if sharey:
|
3888
|
-
sharey_ax = row_first_axes[row]
|
4168
|
+
if sharey:
|
4169
|
+
sharey_ax = row_first_axes[row]
|
3889
4170
|
ax = fig.add_subplot(
|
3890
4171
|
grid_spec[row : row + rowspan, col : col + colspan],
|
3891
4172
|
sharex=sharex_ax,
|
3892
4173
|
sharey=sharey_ax,
|
3893
|
-
**kwargs
|
3894
|
-
)
|
4174
|
+
**kwargs,
|
4175
|
+
)
|
3895
4176
|
if row_first_axes[row] is None:
|
3896
4177
|
row_first_axes[row] = ax
|
3897
4178
|
if col_first_axes[col] is None:
|
3898
|
-
col_first_axes[col] = ax
|
4179
|
+
col_first_axes[col] = ax
|
3899
4180
|
for r in range(row, row + rowspan):
|
3900
4181
|
for c in range(col, col + colspan):
|
3901
4182
|
occupied.add((r, c))
|
@@ -3907,17 +4188,17 @@ def subplot(rows:int=2,
|
|
3907
4188
|
|
3908
4189
|
#! radar chart
|
3909
4190
|
def radar(
|
3910
|
-
data: pd.DataFrame,
|
3911
|
-
ylim=(0,100),
|
4191
|
+
data: pd.DataFrame,
|
4192
|
+
ylim=(0, 100),
|
3912
4193
|
color=get_color(5),
|
3913
4194
|
fontsize=10,
|
3914
|
-
fontcolor=
|
4195
|
+
fontcolor="k",
|
3915
4196
|
size=6,
|
3916
4197
|
linewidth=1,
|
3917
4198
|
linestyle="-",
|
3918
4199
|
alpha=0.5,
|
3919
4200
|
marker="o",
|
3920
|
-
edgecolor=
|
4201
|
+
edgecolor="none",
|
3921
4202
|
edge_linewidth=0,
|
3922
4203
|
bg_color="0.8",
|
3923
4204
|
bg_alpha=None,
|
@@ -3928,18 +4209,19 @@ def radar(
|
|
3928
4209
|
legend_fontsize=10,
|
3929
4210
|
grid_color="gray",
|
3930
4211
|
grid_alpha=0.5,
|
3931
|
-
grid_linestyle="--",
|
4212
|
+
grid_linestyle="--",
|
4213
|
+
grid_linewidth=0.5,
|
3932
4214
|
circular: bool = False,
|
3933
4215
|
tick_fontsize=None,
|
3934
4216
|
tick_fontcolor="0.65",
|
3935
|
-
tick_loc
|
3936
|
-
turning
|
4217
|
+
tick_loc=None, # label position
|
4218
|
+
turning=None,
|
3937
4219
|
ax=None,
|
3938
4220
|
sp=2,
|
3939
|
-
**kwargs
|
4221
|
+
**kwargs,
|
3940
4222
|
):
|
3941
4223
|
"""
|
3942
|
-
Example DATA:
|
4224
|
+
Example DATA:
|
3943
4225
|
df = pd.DataFrame(
|
3944
4226
|
data=[
|
3945
4227
|
[80, 80, 80, 80, 80, 80, 80],
|
@@ -3948,7 +4230,7 @@ def radar(
|
|
3948
4230
|
],
|
3949
4231
|
index=["Hero", "Warrior", "Wizard"],
|
3950
4232
|
columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
|
3951
|
-
|
4233
|
+
|
3952
4234
|
Parameters:
|
3953
4235
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
3954
4236
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4002,8 +4284,12 @@ def radar(
|
|
4002
4284
|
|
4003
4285
|
# bg_color
|
4004
4286
|
if bg_alpha is None:
|
4005
|
-
bg_alpha=alpha
|
4006
|
-
|
4287
|
+
bg_alpha = alpha
|
4288
|
+
(
|
4289
|
+
ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
|
4290
|
+
if circular
|
4291
|
+
else ax.set_facecolor("none")
|
4292
|
+
)
|
4007
4293
|
# Set up the radar chart with straight-line connections
|
4008
4294
|
ax.set_theta_offset(np.pi / 2)
|
4009
4295
|
ax.set_theta_direction(-1)
|
@@ -4012,53 +4298,68 @@ def radar(
|
|
4012
4298
|
ax.set_xticks(angles[:-1])
|
4013
4299
|
ax.set_xticklabels(categories)
|
4014
4300
|
|
4015
|
-
# Set y-axis limits and grid intervals
|
4301
|
+
# Set y-axis limits and grid intervals
|
4016
4302
|
vmin, vmax = ylim
|
4017
|
-
if circular:
|
4018
|
-
|
4019
|
-
ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
|
4020
|
-
ax.grid(
|
4021
|
-
|
4022
|
-
|
4023
|
-
|
4024
|
-
|
4025
|
-
|
4026
|
-
|
4027
|
-
|
4028
|
-
|
4303
|
+
if circular:
|
4304
|
+
# * cicular style
|
4305
|
+
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
4306
|
+
ax.grid(
|
4307
|
+
axis="both",
|
4308
|
+
color=grid_color,
|
4309
|
+
linestyle=grid_linestyle,
|
4310
|
+
alpha=grid_alpha,
|
4311
|
+
linewidth=grid_linewidth,
|
4312
|
+
dash_capstyle="round",
|
4313
|
+
dash_joinstyle="round",
|
4314
|
+
)
|
4315
|
+
ax.spines["polar"].set_color(grid_color)
|
4029
4316
|
ax.spines["polar"].set_linewidth(grid_linewidth)
|
4030
|
-
ax.spines["polar"].set_linestyle(
|
4317
|
+
ax.spines["polar"].set_linestyle("-")
|
4031
4318
|
ax.spines["polar"].set_alpha(grid_alpha)
|
4032
|
-
ax.spines["polar"].set_capstyle(
|
4033
|
-
ax.spines["polar"].set_joinstyle(
|
4319
|
+
ax.spines["polar"].set_capstyle("round")
|
4320
|
+
ax.spines["polar"].set_joinstyle("round")
|
4034
4321
|
|
4035
4322
|
else:
|
4036
|
-
|
4323
|
+
# * spider style: spider-style grid (straight lines, not circles)
|
4037
4324
|
# Create the spider-style grid (straight lines, not circles)
|
4038
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4325
|
+
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4039
4326
|
ax.plot(
|
4040
4327
|
angles + [angles[0]], # Closing the loop
|
4041
|
-
[i * vmax * grid_interval_ratio] * (num_vars+1)
|
4042
|
-
|
4328
|
+
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
4329
|
+
+ [i * vmax * grid_interval_ratio],
|
4330
|
+
color=grid_color,
|
4331
|
+
linestyle=grid_linestyle,
|
4332
|
+
alpha=grid_alpha,
|
4333
|
+
linewidth=grid_linewidth,
|
4043
4334
|
)
|
4044
|
-
# set bg_color
|
4045
|
-
ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
|
4335
|
+
# set bg_color
|
4336
|
+
ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
|
4046
4337
|
ax.yaxis.grid(False)
|
4047
4338
|
# Move radial labels away from plotted line
|
4048
4339
|
if tick_loc is None:
|
4049
|
-
tick_loc
|
4340
|
+
tick_loc = (
|
4341
|
+
np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
|
4342
|
+
)
|
4050
4343
|
|
4051
4344
|
ax.set_rlabel_position(tick_loc)
|
4052
4345
|
ax.set_theta_offset(turning) if turning is not None else None
|
4053
|
-
ax.tick_params(
|
4054
|
-
|
4055
|
-
|
4346
|
+
ax.tick_params(
|
4347
|
+
axis="x", labelsize=fontsize, colors=fontcolor
|
4348
|
+
) # Optional: for angular labels
|
4349
|
+
tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
|
4350
|
+
ax.tick_params(
|
4351
|
+
axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
|
4352
|
+
) # For radial labels
|
4056
4353
|
if not circular:
|
4057
|
-
ax.spines[
|
4058
|
-
ax.tick_params(axis=
|
4059
|
-
ax.tick_params(axis=
|
4354
|
+
ax.spines["polar"].set_visible(False)
|
4355
|
+
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4356
|
+
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4060
4357
|
# colors
|
4061
|
-
colors =
|
4358
|
+
colors = (
|
4359
|
+
get_color(data.shape[0])
|
4360
|
+
if cmap is None
|
4361
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
4362
|
+
)
|
4062
4363
|
# Plot each row with straight lines
|
4063
4364
|
for i, (index, row) in enumerate(data.iterrows()):
|
4064
4365
|
values = row.tolist()
|
@@ -4070,7 +4371,7 @@ def radar(
|
|
4070
4371
|
linewidth=linewidth,
|
4071
4372
|
linestyle=linestyle,
|
4072
4373
|
label=index,
|
4073
|
-
clip_on=False
|
4374
|
+
clip_on=False,
|
4074
4375
|
)
|
4075
4376
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
4076
4377
|
|
@@ -4084,13 +4385,23 @@ def radar(
|
|
4084
4385
|
marker=marker,
|
4085
4386
|
markersize=size,
|
4086
4387
|
markeredgecolor=edgecolor,
|
4087
|
-
markeredgewidth
|
4388
|
+
markeredgewidth=edge_linewidth,
|
4389
|
+
zorder=10,
|
4390
|
+
clip_on=False,
|
4088
4391
|
)
|
4089
4392
|
# ax.tick_params(axis='y', labelleft=False, left=False)
|
4090
|
-
if
|
4393
|
+
if "legend" in kws_figsets:
|
4091
4394
|
figsets(ax=ax, **kws_figsets)
|
4092
4395
|
else:
|
4093
|
-
|
4094
|
-
figsets(
|
4095
|
-
|
4096
|
-
|
4396
|
+
|
4397
|
+
figsets(
|
4398
|
+
ax=ax,
|
4399
|
+
legend=dict(
|
4400
|
+
loc=legend_loc,
|
4401
|
+
fontsize=legend_fontsize,
|
4402
|
+
bbox_to_anchor=[1.1, 1.4],
|
4403
|
+
ncols=2,
|
4404
|
+
),
|
4405
|
+
**kws_figsets,
|
4406
|
+
)
|
4407
|
+
return ax
|