py2ls 0.2.4.10.5__py3-none-any.whl → 0.2.4.10.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py
CHANGED
@@ -1,21 +1,37 @@
|
|
1
1
|
import numpy as np
|
2
|
-
import pandas as pd
|
2
|
+
import pandas as pd
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import matplotlib
|
5
5
|
import seaborn as sns
|
6
6
|
import warnings
|
7
7
|
import logging
|
8
8
|
from typing import Union
|
9
|
-
from .ips import
|
9
|
+
from .ips import (
|
10
|
+
isa,
|
11
|
+
fsave,
|
12
|
+
fload,
|
13
|
+
mkdir,
|
14
|
+
listdir,
|
15
|
+
figsave,
|
16
|
+
strcmp,
|
17
|
+
unique,
|
18
|
+
get_os,
|
19
|
+
ssplit,
|
20
|
+
flatten,
|
21
|
+
plt_font,
|
22
|
+
run_once_within,
|
23
|
+
)
|
10
24
|
from .stats import *
|
11
25
|
import os
|
26
|
+
|
12
27
|
# Suppress INFO messages from fontTools
|
13
28
|
logging.getLogger("fontTools").setLevel(logging.ERROR)
|
14
|
-
logging.getLogger(
|
29
|
+
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
15
30
|
|
16
31
|
warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
|
17
32
|
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
|
18
33
|
|
34
|
+
|
19
35
|
def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
20
36
|
"""Adds text annotations for various types of Seaborn and Matplotlib plots.
|
21
37
|
Args:
|
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
|
|
449
465
|
"""
|
450
466
|
from matplotlib.colors import to_rgba
|
451
467
|
import os
|
452
|
-
|
468
|
+
|
453
469
|
def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
|
454
470
|
if "l" in opt_b["loc"]:
|
455
471
|
xloc_s = xloc - opt_b["x_dist"]
|
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
|
|
755
771
|
)
|
756
772
|
else:
|
757
773
|
from scipy.stats import gaussian_kde
|
774
|
+
|
758
775
|
kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
|
759
776
|
min_val, max_val = ys.min(), ys.max()
|
760
777
|
y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
|
@@ -1416,7 +1433,9 @@ def catplot(data, *args, **kwargs):
|
|
1416
1433
|
style_load = fload(dir_style + style_use + ".json")
|
1417
1434
|
else:
|
1418
1435
|
style_load = fload(
|
1419
|
-
listdir(dir_style, "json").path.tolist()[
|
1436
|
+
listdir(dir_style, "json", verbose=False).path.tolist()[
|
1437
|
+
style_use
|
1438
|
+
]
|
1420
1439
|
)
|
1421
1440
|
style_load = remove_colors_in_dict(style_load)
|
1422
1441
|
opt.update(style_load)
|
@@ -1812,16 +1831,17 @@ def read_mplstyle(style_file):
|
|
1812
1831
|
return style_dict
|
1813
1832
|
|
1814
1833
|
|
1815
|
-
def figsets(*args, **kwargs):
|
1834
|
+
def figsets(*args, **kwargs):
|
1816
1835
|
import matplotlib
|
1817
1836
|
from cycler import cycler
|
1837
|
+
|
1818
1838
|
matplotlib.rc("text", usetex=False)
|
1819
1839
|
|
1820
1840
|
fig = plt.gcf()
|
1821
|
-
fontsize = kwargs.get("fontsize",11)
|
1822
|
-
plt.rcParams["font.size"]=fontsize
|
1823
|
-
fontname = kwargs.pop("fontname","Arial")
|
1824
|
-
fontname=plt_font(fontname)
|
1841
|
+
fontsize = kwargs.get("fontsize", 11)
|
1842
|
+
plt.rcParams["font.size"] = fontsize
|
1843
|
+
fontname = kwargs.pop("fontname", "Arial")
|
1844
|
+
fontname = plt_font(fontname) # 显示中文
|
1825
1845
|
|
1826
1846
|
sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
|
1827
1847
|
sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
|
@@ -1851,18 +1871,21 @@ def figsets(*args, **kwargs):
|
|
1851
1871
|
nonlocal fontsize, fontname
|
1852
1872
|
if ("fo" in key) and (("size" in key) or ("sz" in key)):
|
1853
1873
|
fontsize = value
|
1854
|
-
plt.rcParams.update(
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1874
|
+
plt.rcParams.update(
|
1875
|
+
{
|
1876
|
+
"font.size": fontsize,
|
1877
|
+
"figure.titlesize": fontsize,
|
1878
|
+
"axes.titlesize": fontsize,
|
1879
|
+
"axes.labelsize": fontsize,
|
1880
|
+
"xtick.labelsize": fontsize,
|
1881
|
+
"ytick.labelsize": fontsize,
|
1882
|
+
"legend.fontsize": fontsize,
|
1883
|
+
"legend.title_fontsize": fontsize,
|
1884
|
+
}
|
1885
|
+
)
|
1863
1886
|
|
1864
1887
|
# Customize tick labels
|
1865
|
-
ax.tick_params(axis=
|
1888
|
+
ax.tick_params(axis="both", which="major", labelsize=fontsize)
|
1866
1889
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
1867
1890
|
label.set_fontname(fontname)
|
1868
1891
|
|
@@ -1910,15 +1933,15 @@ def figsets(*args, **kwargs):
|
|
1910
1933
|
if ("x" in key.lower()) and (
|
1911
1934
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1912
1935
|
):
|
1913
|
-
ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
|
1936
|
+
ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
|
1914
1937
|
if ("y" in key.lower()) and (
|
1915
1938
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1916
1939
|
):
|
1917
|
-
ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
|
1940
|
+
ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
|
1918
1941
|
if ("z" in key.lower()) and (
|
1919
1942
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1920
1943
|
):
|
1921
|
-
ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
|
1944
|
+
ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
|
1922
1945
|
if key == "xlabel" and isinstance(value, dict):
|
1923
1946
|
ax.set_xlabel(**value)
|
1924
1947
|
if key == "ylabel" and isinstance(value, dict):
|
@@ -2110,6 +2133,7 @@ def figsets(*args, **kwargs):
|
|
2110
2133
|
|
2111
2134
|
if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
|
2112
2135
|
import matplotlib.ticker as tck
|
2136
|
+
|
2113
2137
|
if "x" in value.lower() or "x" in key.lower():
|
2114
2138
|
ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
|
2115
2139
|
if "y" in value.lower() or "y" in key.lower():
|
@@ -2162,9 +2186,9 @@ def figsets(*args, **kwargs):
|
|
2162
2186
|
ax.grid(visible=False)
|
2163
2187
|
if "tit" in key.lower():
|
2164
2188
|
if "sup" in key.lower():
|
2165
|
-
plt.suptitle(value,fontname=fontname,fontsize=fontsize)
|
2189
|
+
plt.suptitle(value, fontname=fontname, fontsize=fontsize)
|
2166
2190
|
else:
|
2167
|
-
ax.set_title(value,fontname=fontname,fontsize=fontsize)
|
2191
|
+
ax.set_title(value, fontname=fontname, fontsize=fontsize)
|
2168
2192
|
if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
|
2169
2193
|
if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
|
2170
2194
|
if value:
|
@@ -2185,9 +2209,14 @@ def figsets(*args, **kwargs):
|
|
2185
2209
|
legend = ax.get_legend()
|
2186
2210
|
if legend is not None:
|
2187
2211
|
legend.remove()
|
2188
|
-
if
|
2212
|
+
if (
|
2213
|
+
any(["colorbar" in key.lower(), "cbar" in key.lower()])
|
2214
|
+
and "loc" in key.lower()
|
2215
|
+
):
|
2189
2216
|
cbar = ax.collections[0].colorbar # Access the colorbar from the plot
|
2190
|
-
cbar.ax.set_position(
|
2217
|
+
cbar.ax.set_position(
|
2218
|
+
value
|
2219
|
+
) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
|
2191
2220
|
|
2192
2221
|
for arg in args:
|
2193
2222
|
if isinstance(arg, matplotlib.axes._axes.Axes):
|
@@ -2225,7 +2254,8 @@ def figsets(*args, **kwargs):
|
|
2225
2254
|
if len(fig.get_axes()) > 1:
|
2226
2255
|
plt.tight_layout()
|
2227
2256
|
|
2228
|
-
|
2257
|
+
|
2258
|
+
def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
|
2229
2259
|
"""
|
2230
2260
|
split_legend(
|
2231
2261
|
ax,
|
@@ -2238,15 +2268,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2238
2268
|
# Retrieve all lines and labels from the axis
|
2239
2269
|
handles, labels = ax.get_legend_handles_labels()
|
2240
2270
|
num_labels = len(labels)
|
2241
|
-
|
2271
|
+
|
2242
2272
|
# Calculate the number of labels per legend part
|
2243
|
-
labels_per_part = (num_labels + n - 1) // n # Round up
|
2273
|
+
labels_per_part = (num_labels + n - 1) // n # Round up
|
2244
2274
|
# Create a list to hold each legend object
|
2245
2275
|
legends = []
|
2246
2276
|
|
2247
2277
|
# Default locations and titles if not specified
|
2248
2278
|
if loc is None:
|
2249
|
-
loc = [
|
2279
|
+
loc = ["best"] * n
|
2250
2280
|
if title is None:
|
2251
2281
|
title = [None] * n
|
2252
2282
|
if bbox is None:
|
@@ -2257,7 +2287,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2257
2287
|
# Calculate the range of labels for this part
|
2258
2288
|
start_idx = i * labels_per_part
|
2259
2289
|
end_idx = min(start_idx + labels_per_part, num_labels)
|
2260
|
-
|
2290
|
+
|
2261
2291
|
# Skip if no labels in this range
|
2262
2292
|
if start_idx >= end_idx:
|
2263
2293
|
break
|
@@ -2267,19 +2297,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2267
2297
|
part_labels = labels[start_idx:end_idx]
|
2268
2298
|
|
2269
2299
|
# Create the legend for this part
|
2270
|
-
legend = ax.legend(
|
2271
|
-
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2300
|
+
legend = ax.legend(
|
2301
|
+
handles=part_handles,
|
2302
|
+
labels=part_labels,
|
2303
|
+
loc=loc[i],
|
2304
|
+
title=title[i],
|
2305
|
+
ncol=ncol,
|
2306
|
+
bbox_to_anchor=bbox[i],
|
2307
|
+
**kwargs,
|
2308
|
+
)
|
2309
|
+
|
2278
2310
|
# Add the legend to the axis and save it to the list
|
2279
|
-
|
2311
|
+
(
|
2312
|
+
ax.add_artist(legend) if i != (n - 1) else None
|
2313
|
+
) # the lastone will be added automaticaly
|
2280
2314
|
legends.append(legend)
|
2281
2315
|
return legends
|
2282
2316
|
|
2317
|
+
|
2283
2318
|
def get_colors(
|
2284
2319
|
n: int = 1,
|
2285
2320
|
cmap: str = "auto",
|
@@ -2289,7 +2324,9 @@ def get_colors(
|
|
2289
2324
|
*args,
|
2290
2325
|
**kwargs,
|
2291
2326
|
):
|
2292
|
-
return get_color(n,cmap,alpha,output
|
2327
|
+
return get_color(n, cmap, alpha, output, *args, **kwargs)
|
2328
|
+
|
2329
|
+
|
2293
2330
|
def get_color(
|
2294
2331
|
n: int = 1,
|
2295
2332
|
cmap: str = "auto",
|
@@ -2300,6 +2337,7 @@ def get_color(
|
|
2300
2337
|
**kwargs,
|
2301
2338
|
):
|
2302
2339
|
from cycler import cycler
|
2340
|
+
|
2303
2341
|
def cmap2hex(cmap_name):
|
2304
2342
|
cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
|
2305
2343
|
colors = [cmap_(i) for i in range(cmap_.N)]
|
@@ -2347,49 +2385,137 @@ def get_color(
|
|
2347
2385
|
return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
|
2348
2386
|
else:
|
2349
2387
|
return "#{:02X}{:02X}{:02X}".format(r, g, b)
|
2350
|
-
|
2388
|
+
|
2351
2389
|
# sc.pl.palettes.default_20
|
2352
|
-
cmap_20= [
|
2353
|
-
|
2354
|
-
|
2390
|
+
cmap_20 = [
|
2391
|
+
"#1f77b4",
|
2392
|
+
"#ff7f0e",
|
2393
|
+
"#279e68",
|
2394
|
+
"#d62728",
|
2395
|
+
"#aa40fc",
|
2396
|
+
"#8c564b",
|
2397
|
+
"#e377c2",
|
2398
|
+
"#b5bd61",
|
2399
|
+
"#17becf",
|
2400
|
+
"#aec7e8",
|
2401
|
+
"#ffbb78",
|
2402
|
+
"#98df8a",
|
2403
|
+
"#ff9896",
|
2404
|
+
"#c5b0d5",
|
2405
|
+
"#c49c94",
|
2406
|
+
"#f7b6d2",
|
2407
|
+
"#dbdb8d",
|
2408
|
+
"#9edae5",
|
2409
|
+
"#ad494a",
|
2410
|
+
"#8c6d31",
|
2411
|
+
]
|
2355
2412
|
# sc.pl.palettes.zeileis_28
|
2356
|
-
cmap_28 = [
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2413
|
+
cmap_28 = [
|
2414
|
+
"#023fa5",
|
2415
|
+
"#7d87b9",
|
2416
|
+
"#bec1d4",
|
2417
|
+
"#d6bcc0",
|
2418
|
+
"#bb7784",
|
2419
|
+
"#8e063b",
|
2420
|
+
"#4a6fe3",
|
2421
|
+
"#8595e1",
|
2422
|
+
"#b5bbe3",
|
2423
|
+
"#e6afb9",
|
2424
|
+
"#e07b91",
|
2425
|
+
"#d33f6a",
|
2426
|
+
"#11c638",
|
2427
|
+
"#8dd593",
|
2428
|
+
"#c6dec7",
|
2429
|
+
"#ead3c6",
|
2430
|
+
"#f0b98d",
|
2431
|
+
"#ef9708",
|
2432
|
+
"#0fcfc0",
|
2433
|
+
"#9cded6",
|
2434
|
+
"#d5eae7",
|
2435
|
+
"#f3e1eb",
|
2436
|
+
"#f6c4e1",
|
2437
|
+
"#f79cd4",
|
2438
|
+
"#7f7f7f",
|
2439
|
+
"#c7c7c7",
|
2440
|
+
"#1CE6FF",
|
2441
|
+
"#336600",
|
2442
|
+
]
|
2360
2443
|
if cmap == "gray":
|
2361
2444
|
cmap = "grey"
|
2362
|
-
elif cmap=="20":
|
2363
|
-
cmap=cmap_20
|
2364
|
-
elif cmap=="28":
|
2365
|
-
cmap=cmap_28
|
2445
|
+
elif cmap == "20":
|
2446
|
+
cmap = cmap_20
|
2447
|
+
elif cmap == "28":
|
2448
|
+
cmap = cmap_28
|
2366
2449
|
# Determine color list based on cmap parameter
|
2367
|
-
if isinstance(cmap,str):
|
2450
|
+
if isinstance(cmap, str):
|
2368
2451
|
if "aut" in cmap:
|
2369
2452
|
if n == 1:
|
2370
2453
|
colorlist = ["#3A4453"]
|
2371
2454
|
elif n == 2:
|
2372
2455
|
colorlist = ["#3A4453", "#FF2C00"]
|
2373
2456
|
elif n == 3:
|
2374
|
-
colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
|
2457
|
+
colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
|
2375
2458
|
elif n == 4:
|
2376
|
-
colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
|
2459
|
+
colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
|
2377
2460
|
elif n == 5:
|
2378
|
-
colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
|
2461
|
+
colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
|
2379
2462
|
elif n == 6:
|
2380
|
-
colorlist = [
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2463
|
+
colorlist = [
|
2464
|
+
"#FF2C00",
|
2465
|
+
"#91bfdb",
|
2466
|
+
"#B25E9D",
|
2467
|
+
"#4B8C3B",
|
2468
|
+
"#EF8632",
|
2469
|
+
"#24578E",
|
2470
|
+
]
|
2471
|
+
elif n == 7:
|
2472
|
+
colorlist = [
|
2473
|
+
"#7F7F7F",
|
2474
|
+
"#459AA9",
|
2475
|
+
"#B25E9D",
|
2476
|
+
"#4B8C3B",
|
2477
|
+
"#EF8632",
|
2478
|
+
"#24578E" "#FF2C00",
|
2479
|
+
]
|
2480
|
+
elif n == 8:
|
2384
2481
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
2385
2482
|
# colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
|
2386
|
-
colorlist = [
|
2387
|
-
|
2388
|
-
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2392
|
-
|
2483
|
+
colorlist = [
|
2484
|
+
"#78BFDA",
|
2485
|
+
"#D52E6F",
|
2486
|
+
"#F7D648",
|
2487
|
+
"#A52D28",
|
2488
|
+
"#6B9F41",
|
2489
|
+
"#E18330",
|
2490
|
+
"#E18B9D",
|
2491
|
+
"#3C88CC",
|
2492
|
+
]
|
2493
|
+
elif n == 9:
|
2494
|
+
colorlist = [
|
2495
|
+
"#1f77b4",
|
2496
|
+
"#ff7f0e",
|
2497
|
+
"#367B7F",
|
2498
|
+
"#ff9896",
|
2499
|
+
"#d62728",
|
2500
|
+
"#aa40fc",
|
2501
|
+
"#e377c2",
|
2502
|
+
"#51B34F",
|
2503
|
+
"#17becf",
|
2504
|
+
]
|
2505
|
+
elif n == 10:
|
2506
|
+
colorlist = [
|
2507
|
+
"#1f77b4",
|
2508
|
+
"#ff7f0e",
|
2509
|
+
"#367B7F",
|
2510
|
+
"#ff9896",
|
2511
|
+
"#51B34F",
|
2512
|
+
"#d62728" "#aa40fc",
|
2513
|
+
"#e377c2",
|
2514
|
+
"#375FD2",
|
2515
|
+
"#17becf",
|
2516
|
+
]
|
2517
|
+
elif 10 < n <= 20:
|
2518
|
+
colorlist = cmap_20
|
2393
2519
|
else:
|
2394
2520
|
colorlist = cmap_28
|
2395
2521
|
by = "start"
|
@@ -2426,7 +2552,7 @@ def get_color(
|
|
2426
2552
|
by = "linspace"
|
2427
2553
|
colorlist = cmap2hex(cmap)
|
2428
2554
|
elif isinstance(cmap, list):
|
2429
|
-
colorlist=cmap
|
2555
|
+
colorlist = cmap
|
2430
2556
|
|
2431
2557
|
# Determine method for generating color list
|
2432
2558
|
if "st" in by.lower() or "be" in by.lower():
|
@@ -2446,7 +2572,7 @@ def get_color(
|
|
2446
2572
|
return hue_list
|
2447
2573
|
else:
|
2448
2574
|
raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
|
2449
|
-
|
2575
|
+
|
2450
2576
|
|
2451
2577
|
def stdshade(ax=None, *args, **kwargs):
|
2452
2578
|
"""
|
@@ -2454,7 +2580,7 @@ def stdshade(ax=None, *args, **kwargs):
|
|
2454
2580
|
plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
|
2455
2581
|
"""
|
2456
2582
|
from scipy.signal import savgol_filter
|
2457
|
-
|
2583
|
+
|
2458
2584
|
# Separate kws_line and kws_fill if necessary
|
2459
2585
|
kws_line = kwargs.pop("kws_line", {})
|
2460
2586
|
kws_fill = kwargs.pop("kws_fill", {})
|
@@ -2724,37 +2850,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
|
|
2724
2850
|
# cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
|
2725
2851
|
# return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
|
2726
2852
|
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2853
|
+
|
2854
|
+
def add_colorbar(
|
2855
|
+
im,
|
2856
|
+
cmap="viridis",
|
2857
|
+
vmin=-1,
|
2858
|
+
vmax=1,
|
2859
|
+
orientation="vertical",
|
2860
|
+
width_ratio=0.05,
|
2861
|
+
pad_ratio=0.02,
|
2862
|
+
shrink=1.0,
|
2863
|
+
**kwargs,
|
2864
|
+
):
|
2736
2865
|
import matplotlib as mpl
|
2866
|
+
|
2737
2867
|
if all([cmap, vmin, vmax]):
|
2738
2868
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
2739
2869
|
else:
|
2740
|
-
norm=False
|
2870
|
+
norm = False
|
2741
2871
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
2742
2872
|
sm.set_array([])
|
2743
2873
|
l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
|
2744
|
-
if orientation ==
|
2874
|
+
if orientation == "vertical":
|
2745
2875
|
width = width_ratio * w
|
2746
2876
|
pad = pad_ratio * w
|
2747
|
-
cax = im.figure.add_axes(
|
2877
|
+
cax = im.figure.add_axes(
|
2878
|
+
[l + w + pad, b, width, h * shrink]
|
2879
|
+
) # Right of the image
|
2748
2880
|
else:
|
2749
2881
|
height = width_ratio * h
|
2750
2882
|
pad = pad_ratio * h
|
2751
|
-
cax = im.figure.add_axes(
|
2752
|
-
|
2753
|
-
|
2883
|
+
cax = im.figure.add_axes(
|
2884
|
+
[l, b - height - pad, w * shrink, height]
|
2885
|
+
) # Below the image
|
2886
|
+
cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
|
2887
|
+
return cbar
|
2888
|
+
|
2754
2889
|
|
2755
2890
|
# Usage:
|
2756
2891
|
# add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
|
2757
2892
|
|
2893
|
+
|
2758
2894
|
def generate_xticks_with_gap(x_len, hue_len):
|
2759
2895
|
"""
|
2760
2896
|
Generate a concatenated array based on x_len and hue_len,
|
@@ -2884,6 +3020,7 @@ def style_examples(
|
|
2884
3020
|
f = listdir(
|
2885
3021
|
"/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
|
2886
3022
|
kind=".json",
|
3023
|
+
verbose=False,
|
2887
3024
|
)
|
2888
3025
|
display(f.sample(2))
|
2889
3026
|
# def style_example(dir_save,)
|
@@ -2961,6 +3098,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
|
|
2961
3098
|
|
2962
3099
|
def get_params_from_func_usage(function_signature):
|
2963
3100
|
import re
|
3101
|
+
|
2964
3102
|
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2965
3103
|
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2966
3104
|
# Find all matches
|
@@ -2987,7 +3125,7 @@ def plotxy(
|
|
2987
3125
|
x=None,
|
2988
3126
|
y=None,
|
2989
3127
|
ax=None,
|
2990
|
-
kind: Union[str, list] =
|
3128
|
+
kind: Union[str, list] = "scatter", # Specify the kind of plot
|
2991
3129
|
verbose=False,
|
2992
3130
|
**kwargs,
|
2993
3131
|
):
|
@@ -3011,14 +3149,15 @@ def plotxy(
|
|
3011
3149
|
# Check for valid plot kind
|
3012
3150
|
# Default arguments for various plot types
|
3013
3151
|
from pathlib import Path
|
3152
|
+
|
3014
3153
|
# Get the current script's directory as a Path object
|
3015
|
-
current_directory = Path(__file__).resolve().parent
|
3154
|
+
current_directory = Path(__file__).resolve().parent
|
3155
|
+
|
3156
|
+
if not "default_settings" in locals():
|
3157
|
+
default_settings = fload(current_directory / "data" / "usages_sns.json")
|
3158
|
+
if not "sns_info" in locals():
|
3159
|
+
sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
|
3016
3160
|
|
3017
|
-
if not 'default_settings' in locals():
|
3018
|
-
default_settings = fload(current_directory / 'data' / 'usages_sns.json')
|
3019
|
-
if not 'sns_info' in locals():
|
3020
|
-
sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
|
3021
|
-
|
3022
3161
|
valid_kinds = list(default_settings.keys())
|
3023
3162
|
# print(valid_kinds)
|
3024
3163
|
if kind is not None:
|
@@ -3032,7 +3171,7 @@ def plotxy(
|
|
3032
3171
|
if kind is not None:
|
3033
3172
|
for k in kind:
|
3034
3173
|
if k in valid_kinds:
|
3035
|
-
print(f"{k}:\n\t{default_settings[k]}")
|
3174
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
3036
3175
|
usage_str = """plotxy(data=ranked_genes,
|
3037
3176
|
x="log2(fold_change)",
|
3038
3177
|
y="-log10(p-value)",
|
@@ -3057,15 +3196,25 @@ def plotxy(
|
|
3057
3196
|
kws_add_text = v_arg
|
3058
3197
|
kwargs.pop(k_arg, None)
|
3059
3198
|
break
|
3060
|
-
zorder=0
|
3199
|
+
zorder = 0
|
3061
3200
|
for k in kind:
|
3062
|
-
zorder+=1
|
3201
|
+
zorder += 1
|
3063
3202
|
# indicate 'col' features
|
3064
3203
|
col = kwargs.get("col", None)
|
3065
|
-
sns_with_col = [
|
3204
|
+
sns_with_col = [
|
3205
|
+
"catplot",
|
3206
|
+
"histplot",
|
3207
|
+
"relplot",
|
3208
|
+
"lmplot",
|
3209
|
+
"pairplot",
|
3210
|
+
"displot",
|
3211
|
+
"kdeplot",
|
3212
|
+
]
|
3066
3213
|
if col is not None:
|
3067
3214
|
if not k in sns_with_col:
|
3068
|
-
print(
|
3215
|
+
print(
|
3216
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
3217
|
+
)
|
3069
3218
|
# (1) return FcetGrid
|
3070
3219
|
if k == "jointplot":
|
3071
3220
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
@@ -3092,77 +3241,98 @@ def plotxy(
|
|
3092
3241
|
ax = stdshade(ax=ax, **kwargs)
|
3093
3242
|
elif k == "scatterplot":
|
3094
3243
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3095
|
-
kws_scatter={
|
3244
|
+
kws_scatter = {
|
3245
|
+
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3246
|
+
}
|
3096
3247
|
hue = kwargs.pop("hue", None)
|
3097
3248
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3098
3249
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3099
|
-
palette = kws_scatter.pop(
|
3250
|
+
palette = kws_scatter.pop(
|
3251
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3252
|
+
)
|
3100
3253
|
s = kws_scatter.pop("s", 10)
|
3101
3254
|
alpha = kws_scatter.pop("alpha", 0.7)
|
3102
|
-
ax = sns.scatterplot(
|
3255
|
+
ax = sns.scatterplot(
|
3256
|
+
ax=ax,
|
3257
|
+
data=data,
|
3258
|
+
x=x,
|
3259
|
+
y=y,
|
3260
|
+
hue=hue,
|
3261
|
+
palette=palette,
|
3262
|
+
s=s,
|
3263
|
+
alpha=alpha,
|
3264
|
+
zorder=zorder,
|
3265
|
+
**kws_scatter,
|
3266
|
+
)
|
3103
3267
|
elif k == "histplot":
|
3104
3268
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3105
|
-
kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3106
|
-
ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
|
3269
|
+
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3270
|
+
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3107
3271
|
elif k == "kdeplot":
|
3108
3272
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3109
|
-
kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3110
|
-
ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
|
3273
|
+
kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3274
|
+
ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
|
3111
3275
|
elif k == "ecdfplot":
|
3112
3276
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3113
|
-
kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3114
|
-
ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
|
3277
|
+
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3278
|
+
ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
|
3115
3279
|
elif k == "rugplot":
|
3116
3280
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3117
|
-
kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3118
|
-
ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
|
3281
|
+
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3282
|
+
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3119
3283
|
elif k == "stripplot":
|
3120
3284
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3121
|
-
kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3285
|
+
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3122
3286
|
dodge = kws_strip.pop("dodge", True)
|
3123
|
-
ax = sns.stripplot(
|
3287
|
+
ax = sns.stripplot(
|
3288
|
+
data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
|
3289
|
+
)
|
3124
3290
|
elif k == "swarmplot":
|
3125
3291
|
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
3126
|
-
kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3127
|
-
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
|
3292
|
+
kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3293
|
+
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
|
3128
3294
|
elif k == "boxplot":
|
3129
3295
|
kws_box = kwargs.pop("kws_box", kwargs)
|
3130
|
-
kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3131
|
-
ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
|
3296
|
+
kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3297
|
+
ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
|
3132
3298
|
elif k == "violinplot":
|
3133
3299
|
kws_violin = kwargs.pop("kws_violin", kwargs)
|
3134
|
-
kws_violin={
|
3135
|
-
|
3300
|
+
kws_violin = {
|
3301
|
+
k: v for k, v in kws_violin.items() if not k.startswith("kws_")
|
3302
|
+
}
|
3303
|
+
ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
|
3136
3304
|
elif k == "boxenplot":
|
3137
3305
|
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
3138
|
-
kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3139
|
-
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
|
3306
|
+
kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3307
|
+
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
|
3140
3308
|
elif k == "pointplot":
|
3141
3309
|
kws_point = kwargs.pop("kws_point", kwargs)
|
3142
|
-
kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3143
|
-
ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
|
3310
|
+
kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3311
|
+
ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
|
3144
3312
|
elif k == "barplot":
|
3145
3313
|
kws_bar = kwargs.pop("kws_bar", kwargs)
|
3146
|
-
kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3147
|
-
ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
|
3314
|
+
kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3315
|
+
ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
|
3148
3316
|
elif k == "countplot":
|
3149
3317
|
kws_count = kwargs.pop("kws_count", kwargs)
|
3150
|
-
kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3151
|
-
if not kws_count.get("hue",None):
|
3152
|
-
kws_count.pop("palette",None)
|
3153
|
-
ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
|
3318
|
+
kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3319
|
+
if not kws_count.get("hue", None):
|
3320
|
+
kws_count.pop("palette", None)
|
3321
|
+
ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
|
3154
3322
|
elif k == "regplot":
|
3155
3323
|
kws_reg = kwargs.pop("kws_reg", kwargs)
|
3156
|
-
kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3157
|
-
ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
|
3324
|
+
kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3325
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
|
3158
3326
|
elif k == "residplot":
|
3159
3327
|
kws_resid = kwargs.pop("kws_resid", kwargs)
|
3160
|
-
kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3161
|
-
ax = sns.residplot(
|
3328
|
+
kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3329
|
+
ax = sns.residplot(
|
3330
|
+
data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
|
3331
|
+
)
|
3162
3332
|
elif k == "lineplot":
|
3163
3333
|
kws_line = kwargs.pop("kws_line", kwargs)
|
3164
|
-
kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3165
|
-
ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
|
3334
|
+
kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3335
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
|
3166
3336
|
|
3167
3337
|
figsets(ax=ax, **kws_figsets)
|
3168
3338
|
if kws_add_text:
|
@@ -3176,20 +3346,22 @@ def plotxy(
|
|
3176
3346
|
return g, ax
|
3177
3347
|
return ax
|
3178
3348
|
|
3349
|
+
|
3179
3350
|
def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
|
3180
3351
|
norm_ = plt.Normalize(min_max[0], min_max[1])
|
3181
3352
|
colormap = plt.get_cmap(cmap)
|
3182
3353
|
return colormap(norm_(data))
|
3183
3354
|
|
3355
|
+
|
3184
3356
|
def volcano(
|
3185
|
-
data:pd.DataFrame,
|
3186
|
-
x:str,
|
3187
|
-
y:str,
|
3188
|
-
gene_col:str=None,
|
3189
|
-
top_genes=[5, 5],
|
3190
|
-
thr_x=np.log2(1.5),
|
3357
|
+
data: pd.DataFrame,
|
3358
|
+
x: str,
|
3359
|
+
y: str,
|
3360
|
+
gene_col: str = None,
|
3361
|
+
top_genes=[5, 5], # [down-regulated, up-regulated]
|
3362
|
+
thr_x=np.log2(1.5), # default: 0.585
|
3191
3363
|
thr_y=-np.log10(0.05),
|
3192
|
-
sort_xy="x",
|
3364
|
+
sort_xy="x", #'y', 'xy'
|
3193
3365
|
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3194
3366
|
s=20,
|
3195
3367
|
fill=True, # plot filled scatter
|
@@ -3201,11 +3373,10 @@ def volcano(
|
|
3201
3373
|
ax=None,
|
3202
3374
|
verbose=False,
|
3203
3375
|
kws_text=dict(fontsize=10, color="k"),
|
3204
|
-
kws_bbox=dict(
|
3205
|
-
|
3206
|
-
|
3207
|
-
|
3208
|
-
kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
|
3376
|
+
kws_bbox=dict(
|
3377
|
+
facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
|
3378
|
+
), # '{}' to hide
|
3379
|
+
kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
|
3209
3380
|
**kwargs,
|
3210
3381
|
):
|
3211
3382
|
"""
|
@@ -3275,39 +3446,35 @@ def volcano(
|
|
3275
3446
|
kws_figsets = v_arg
|
3276
3447
|
kwargs.pop(k_arg, None)
|
3277
3448
|
break
|
3278
|
-
|
3279
|
-
data=data.copy()
|
3449
|
+
|
3450
|
+
data = data.copy()
|
3280
3451
|
# filter nan
|
3281
3452
|
data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
|
3282
|
-
data.loc[:,"color"] = np.where(
|
3453
|
+
data.loc[:, "color"] = np.where(
|
3283
3454
|
(data[x] > thr_x) & (data[y] > thr_y),
|
3284
3455
|
colors[2],
|
3285
|
-
np.where((data[x] < -thr_x) & (data[y] > thr_y),
|
3286
|
-
colors[0],
|
3287
|
-
colors[1]),
|
3456
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3288
3457
|
)
|
3289
|
-
top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
|
3290
|
-
|
3458
|
+
top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
|
3459
|
+
|
3291
3460
|
# could custom how to select the top genes, x: x has priority
|
3292
|
-
sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
|
3293
|
-
ascending_up=[True, True] if sort_xy=="x" else [False, True]
|
3294
|
-
ascending_down=[False, True] if sort_xy=="x" else [False, False]
|
3295
|
-
|
3296
|
-
down_reg_genes =
|
3297
|
-
(data["color"] == colors[0]) &
|
3298
|
-
|
3299
|
-
(
|
3300
|
-
|
3301
|
-
up_reg_genes =
|
3302
|
-
(data["color"] == colors[2]) &
|
3303
|
-
|
3304
|
-
(
|
3305
|
-
|
3461
|
+
sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
|
3462
|
+
ascending_up = [True, True] if sort_xy == "x" else [False, True]
|
3463
|
+
ascending_down = [False, True] if sort_xy == "x" else [False, False]
|
3464
|
+
|
3465
|
+
down_reg_genes = (
|
3466
|
+
data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3467
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_up)
|
3468
|
+
.head(top_genes[0])
|
3469
|
+
)
|
3470
|
+
up_reg_genes = (
|
3471
|
+
data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3472
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_down)
|
3473
|
+
.head(top_genes[1])
|
3474
|
+
)
|
3306
3475
|
sele_gene = pd.concat([down_reg_genes, up_reg_genes])
|
3307
|
-
|
3308
|
-
palette = {colors[0]: colors[0],
|
3309
|
-
colors[1]: colors[1],
|
3310
|
-
colors[2]: colors[2]}
|
3476
|
+
|
3477
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3311
3478
|
# Plot setup
|
3312
3479
|
if ax is None:
|
3313
3480
|
ax = plt.gca()
|
@@ -3337,9 +3504,9 @@ def volcano(
|
|
3337
3504
|
)
|
3338
3505
|
|
3339
3506
|
# Add threshold lines for x and y axes
|
3340
|
-
ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
|
3341
|
-
ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
|
3342
|
-
ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
|
3507
|
+
ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
|
3508
|
+
ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
|
3509
|
+
ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
|
3343
3510
|
|
3344
3511
|
# Add gene labels for selected significant points
|
3345
3512
|
if gene_col:
|
@@ -3349,19 +3516,28 @@ def volcano(
|
|
3349
3516
|
textcolor = kws_text.pop("color", "k")
|
3350
3517
|
fontsize = kws_text.pop("fontsize", 10)
|
3351
3518
|
arrowstyles = [
|
3352
|
-
"->",
|
3353
|
-
"
|
3354
|
-
"
|
3519
|
+
"->",
|
3520
|
+
"<-",
|
3521
|
+
"<->",
|
3522
|
+
"<|-",
|
3523
|
+
"-|>",
|
3524
|
+
"<|-|>",
|
3525
|
+
"-",
|
3526
|
+
"-[",
|
3527
|
+
"-[",
|
3528
|
+
"fancy",
|
3529
|
+
"simple",
|
3530
|
+
"wedge",
|
3355
3531
|
]
|
3356
3532
|
arrowstyle = kws_arrow.pop("style", "<|-")
|
3357
|
-
arrowstyle = strcmp(arrowstyle, arrowstyles,scorer=
|
3358
|
-
expand=kws_arrow.pop("expand",(1.05,1.1))
|
3533
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
|
3534
|
+
expand = kws_arrow.pop("expand", (1.05, 1.1))
|
3359
3535
|
arrowcolor = kws_arrow.pop("color", "0.4")
|
3360
3536
|
arrowlinewidth = kws_arrow.pop("lw", 0.75)
|
3361
3537
|
shrinkA = kws_arrow.pop("shrinkA", 0)
|
3362
3538
|
shrinkB = kws_arrow.pop("shrinkB", 0)
|
3363
3539
|
mutation_scale = kws_arrow.pop("head", 10)
|
3364
|
-
arrow_fill=kws_arrow.pop("fill", False)
|
3540
|
+
arrow_fill = kws_arrow.pop("fill", False)
|
3365
3541
|
for i in range(sele_gene.shape[0]):
|
3366
3542
|
if isinstance(textcolor, list): # be consistant with dots's color
|
3367
3543
|
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
@@ -3382,7 +3558,7 @@ def volcano(
|
|
3382
3558
|
adjust_text(
|
3383
3559
|
texts,
|
3384
3560
|
expand=expand,
|
3385
|
-
min_arrow_len=5,
|
3561
|
+
min_arrow_len=5,
|
3386
3562
|
ax=ax,
|
3387
3563
|
arrowprops=dict(
|
3388
3564
|
arrowstyle=arrowstyle,
|
@@ -3393,7 +3569,7 @@ def volcano(
|
|
3393
3569
|
shrinkB=shrinkB,
|
3394
3570
|
mutation_scale=mutation_scale,
|
3395
3571
|
**kws_arrow,
|
3396
|
-
)
|
3572
|
+
),
|
3397
3573
|
)
|
3398
3574
|
|
3399
3575
|
figsets(**kws_figsets)
|
@@ -3499,6 +3675,7 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3499
3675
|
"""Reduce the saturation of a color by a given factor (between 0 and 1)."""
|
3500
3676
|
import matplotlib.colors as mcolors
|
3501
3677
|
import colorsys
|
3678
|
+
|
3502
3679
|
# Convert the color to RGB
|
3503
3680
|
rgb = mcolors.to_rgb(color)
|
3504
3681
|
# Convert RGB to HLS (Hue, Lightness, Saturation)
|
@@ -3508,8 +3685,19 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3508
3685
|
# Convert back to RGB
|
3509
3686
|
return colorsys.hls_to_rgb(h, l, s)
|
3510
3687
|
|
3511
|
-
|
3512
|
-
|
3688
|
+
|
3689
|
+
def textsets(
|
3690
|
+
text,
|
3691
|
+
fontname="Arial",
|
3692
|
+
fontsize=11,
|
3693
|
+
fontweight="normal",
|
3694
|
+
fontstyle="normal",
|
3695
|
+
fontcolor="k",
|
3696
|
+
backgroundcolor=None,
|
3697
|
+
shadow=False,
|
3698
|
+
ha="center",
|
3699
|
+
va="center",
|
3700
|
+
):
|
3513
3701
|
if text: # Ensure text exists
|
3514
3702
|
if fontname:
|
3515
3703
|
text.set_fontname(plt_font(fontname))
|
@@ -3522,26 +3710,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
|
|
3522
3710
|
if fontcolor:
|
3523
3711
|
text.set_color(fontcolor)
|
3524
3712
|
if backgroundcolor:
|
3525
|
-
text.set_backgroundcolor(backgroundcolor)
|
3713
|
+
text.set_backgroundcolor(backgroundcolor)
|
3526
3714
|
text.set_horizontalalignment(ha)
|
3527
|
-
text.set_verticalalignment(va)
|
3715
|
+
text.set_verticalalignment(va)
|
3528
3716
|
if shadow:
|
3529
|
-
text.set_path_effects(
|
3530
|
-
matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
|
3531
|
-
|
3717
|
+
text.set_path_effects(
|
3718
|
+
[matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
|
3719
|
+
)
|
3720
|
+
|
3721
|
+
|
3532
3722
|
def venn(
|
3533
|
-
lists:list,
|
3534
|
-
labels:list=None,
|
3723
|
+
lists: list,
|
3724
|
+
labels: list = None,
|
3535
3725
|
ax=None,
|
3536
3726
|
colors=None,
|
3537
3727
|
edgecolor=None,
|
3538
3728
|
alpha=0.5,
|
3539
|
-
saturation
|
3540
|
-
linewidth=0,
|
3729
|
+
saturation=0.75,
|
3730
|
+
linewidth=0, # default no edge
|
3541
3731
|
linestyle="-",
|
3542
|
-
fontname=
|
3732
|
+
fontname="Arial",
|
3543
3733
|
fontsize=10,
|
3544
|
-
fontcolor=
|
3734
|
+
fontcolor="k",
|
3545
3735
|
fontweight="normal",
|
3546
3736
|
fontstyle="normal",
|
3547
3737
|
ha="center",
|
@@ -3550,18 +3740,18 @@ def venn(
|
|
3550
3740
|
subset_fontsize=10,
|
3551
3741
|
subset_fontweight="normal",
|
3552
3742
|
subset_fontstyle="normal",
|
3553
|
-
subset_fontcolor=
|
3743
|
+
subset_fontcolor="k",
|
3554
3744
|
backgroundcolor=None,
|
3555
3745
|
custom_texts=None,
|
3556
|
-
show_percentages=True,
|
3746
|
+
show_percentages=True, # display percentage
|
3557
3747
|
fmt="{:.1%}",
|
3558
3748
|
ellipse_shape=False, # 椭圆形
|
3559
|
-
ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
|
3560
|
-
**kwargs
|
3749
|
+
ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
|
3750
|
+
**kwargs,
|
3561
3751
|
):
|
3562
3752
|
"""
|
3563
3753
|
Advanced Venn diagram plotting function with extensive customization options.
|
3564
|
-
Usage:
|
3754
|
+
Usage:
|
3565
3755
|
# Define the two sets
|
3566
3756
|
set1 = [1, 2, 3, 4, 5]
|
3567
3757
|
set2 = [4, 5, 6, 7, 8]
|
@@ -3612,56 +3802,70 @@ def venn(
|
|
3612
3802
|
"""
|
3613
3803
|
if ax is None:
|
3614
3804
|
ax = plt.gca()
|
3615
|
-
lists=[set(flatten(i, verbose=False)) for i in lists]
|
3805
|
+
lists = [set(flatten(i, verbose=False)) for i in lists]
|
3616
3806
|
# Function to apply text styles to labels
|
3617
3807
|
if colors is None:
|
3618
|
-
colors=["r","b"] if len(lists)==2 else ["r","g","b"]
|
3808
|
+
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
3619
3809
|
if labels is None:
|
3620
|
-
labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
|
3810
|
+
labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
|
3621
3811
|
if edgecolor is None:
|
3622
|
-
edgecolor=colors
|
3812
|
+
edgecolor = colors
|
3623
3813
|
colors = [desaturate_color(color, saturation) for color in colors]
|
3624
3814
|
# Check colors and auto-calculate overlaps
|
3625
3815
|
if len(lists) == 2:
|
3816
|
+
|
3626
3817
|
def get_count_and_percentage(set_count, subset_count):
|
3627
3818
|
percent = subset_count / set_count if set_count > 0 else 0
|
3628
|
-
return
|
3819
|
+
return (
|
3820
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3821
|
+
if show_percentages
|
3822
|
+
else f"{subset_count}"
|
3823
|
+
)
|
3629
3824
|
|
3630
3825
|
from matplotlib_venn import venn2, venn2_circles
|
3826
|
+
|
3631
3827
|
# Auto-calculate overlap color for 2-set Venn diagram
|
3632
3828
|
overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
|
3633
|
-
|
3829
|
+
|
3634
3830
|
# Draw the venn diagram
|
3635
3831
|
v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3636
3832
|
venn_circles = venn2_circles(subsets=lists, ax=ax)
|
3637
|
-
set1,set2=lists[0],lists[1]
|
3638
|
-
v.get_patch_by_id(
|
3639
|
-
v.get_patch_by_id(
|
3640
|
-
v.get_patch_by_id(
|
3833
|
+
set1, set2 = lists[0], lists[1]
|
3834
|
+
v.get_patch_by_id("10").set_color(colors[0])
|
3835
|
+
v.get_patch_by_id("01").set_color(colors[1])
|
3836
|
+
v.get_patch_by_id("11").set_color(
|
3837
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
3838
|
+
)
|
3641
3839
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
3642
3840
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
3643
3841
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
3644
|
-
v.get_label_by_id(
|
3645
|
-
|
3646
|
-
|
3647
|
-
|
3842
|
+
v.get_label_by_id("10").set_text(
|
3843
|
+
get_count_and_percentage(len(set1), len(set1 - set2))
|
3844
|
+
)
|
3845
|
+
v.get_label_by_id("01").set_text(
|
3846
|
+
get_count_and_percentage(len(set2), len(set2 - set1))
|
3847
|
+
)
|
3848
|
+
v.get_label_by_id("11").set_text(
|
3849
|
+
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
3850
|
+
)
|
3648
3851
|
|
3649
|
-
if not isinstance(linewidth,list):
|
3650
|
-
linewidth=[linewidth]
|
3651
|
-
if isinstance(linestyle,str):
|
3652
|
-
linestyle=[linestyle]
|
3852
|
+
if not isinstance(linewidth, list):
|
3853
|
+
linewidth = [linewidth]
|
3854
|
+
if isinstance(linestyle, str):
|
3855
|
+
linestyle = [linestyle]
|
3653
3856
|
if not isinstance(edgecolor, list):
|
3654
|
-
edgecolor=[edgecolor]
|
3655
|
-
linewidth=linewidth*2 if len(linewidth)==1 else linewidth
|
3656
|
-
linestyle=linestyle*2 if len(linestyle)==1 else linestyle
|
3657
|
-
edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
|
3857
|
+
edgecolor = [edgecolor]
|
3858
|
+
linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
|
3859
|
+
linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
|
3860
|
+
edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
|
3658
3861
|
for i in range(2):
|
3659
3862
|
venn_circles[i].set_lw(linewidth[i])
|
3660
3863
|
venn_circles[i].set_ls(linestyle[i])
|
3661
3864
|
venn_circles[i].set_edgecolor(edgecolor[i])
|
3662
3865
|
# 椭圆
|
3663
3866
|
if ellipse_shape:
|
3664
|
-
import matplotlib.patches as patches
|
3867
|
+
import matplotlib.patches as patches
|
3868
|
+
|
3665
3869
|
for patch in v.patches:
|
3666
3870
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3667
3871
|
center1 = v.get_circle_center(0)
|
@@ -3672,9 +3876,13 @@ def venn(
|
|
3672
3876
|
height=ellipse_scale[1],
|
3673
3877
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3674
3878
|
facecolor=colors[0],
|
3675
|
-
lw=
|
3879
|
+
lw=(
|
3880
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3881
|
+
), # Ensure lw is a number
|
3676
3882
|
ls=linestyle[0],
|
3677
|
-
alpha=
|
3883
|
+
alpha=(
|
3884
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3885
|
+
), # Ensure alpha is a number
|
3678
3886
|
)
|
3679
3887
|
ellipse2 = patches.Ellipse(
|
3680
3888
|
(center2.x, center2.y),
|
@@ -3682,48 +3890,78 @@ def venn(
|
|
3682
3890
|
height=ellipse_scale[1],
|
3683
3891
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3684
3892
|
facecolor=colors[1],
|
3685
|
-
lw=
|
3893
|
+
lw=(
|
3894
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3895
|
+
), # Ensure lw is a number
|
3686
3896
|
ls=linestyle[0],
|
3687
|
-
alpha=
|
3897
|
+
alpha=(
|
3898
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3899
|
+
), # Ensure alpha is a number
|
3688
3900
|
)
|
3689
3901
|
ax.add_patch(ellipse1)
|
3690
3902
|
ax.add_patch(ellipse2)
|
3691
3903
|
# Apply styles to set labels
|
3692
3904
|
for i, text in enumerate(v.set_labels):
|
3693
|
-
textsets(
|
3694
|
-
|
3905
|
+
textsets(
|
3906
|
+
text,
|
3907
|
+
fontname=fontname,
|
3908
|
+
fontsize=fontsize,
|
3909
|
+
fontweight=fontweight,
|
3910
|
+
fontstyle=fontstyle,
|
3911
|
+
fontcolor=fontcolor,
|
3912
|
+
ha=ha,
|
3913
|
+
va=va,
|
3914
|
+
shadow=shadow,
|
3915
|
+
)
|
3695
3916
|
|
3696
3917
|
# Apply styles to subset labels
|
3697
3918
|
for i, text in enumerate(v.subset_labels):
|
3698
3919
|
if text: # Ensure text exists
|
3699
3920
|
if custom_texts: # Custom text handling
|
3700
|
-
text.set_text(custom_texts[i])
|
3701
|
-
textsets(
|
3702
|
-
|
3921
|
+
text.set_text(custom_texts[i])
|
3922
|
+
textsets(
|
3923
|
+
text,
|
3924
|
+
fontname=fontname,
|
3925
|
+
fontsize=subset_fontsize,
|
3926
|
+
fontweight=subset_fontweight,
|
3927
|
+
fontstyle=subset_fontstyle,
|
3928
|
+
fontcolor=subset_fontcolor,
|
3929
|
+
ha=ha,
|
3930
|
+
va=va,
|
3931
|
+
shadow=shadow,
|
3932
|
+
)
|
3703
3933
|
|
3704
3934
|
elif len(lists) == 3:
|
3935
|
+
|
3705
3936
|
def get_label(set_count, subset_count):
|
3706
3937
|
percent = subset_count / set_count if set_count > 0 else 0
|
3707
|
-
return
|
3708
|
-
|
3938
|
+
return (
|
3939
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3940
|
+
if show_percentages
|
3941
|
+
else f"{subset_count}"
|
3942
|
+
)
|
3943
|
+
|
3709
3944
|
from matplotlib_venn import venn3, venn3_circles
|
3945
|
+
|
3710
3946
|
# Auto-calculate overlap colors for 3-set Venn diagram
|
3711
3947
|
colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
|
3712
3948
|
colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
|
3713
3949
|
colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
|
3714
|
-
colorABC =
|
3715
|
-
|
3950
|
+
colorABC = (
|
3951
|
+
get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
|
3952
|
+
)
|
3953
|
+
set1, set2, set3 = lists[0], lists[1], lists[2]
|
3716
3954
|
|
3717
3955
|
# Draw the venn diagram
|
3718
|
-
v = venn3(subsets=lists, set_labels=labels, ax=ax
|
3719
|
-
v.get_patch_by_id(
|
3720
|
-
v.get_patch_by_id(
|
3721
|
-
v.get_patch_by_id(
|
3722
|
-
v.get_patch_by_id(
|
3723
|
-
v.get_patch_by_id(
|
3724
|
-
v.get_patch_by_id(
|
3725
|
-
v.get_patch_by_id(
|
3726
|
-
|
3956
|
+
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3957
|
+
v.get_patch_by_id("100").set_color(colors[0])
|
3958
|
+
v.get_patch_by_id("010").set_color(colors[1])
|
3959
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
3960
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
3961
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
3962
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
3963
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
3964
|
+
|
3727
3965
|
# Correctly labeling subset sizes
|
3728
3966
|
# v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
|
3729
3967
|
# v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
|
@@ -3732,63 +3970,93 @@ def venn(
|
|
3732
3970
|
# v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
|
3733
3971
|
# v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
|
3734
3972
|
# v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
|
3735
|
-
v.get_label_by_id(
|
3736
|
-
v.get_label_by_id(
|
3737
|
-
v.get_label_by_id(
|
3738
|
-
v.get_label_by_id(
|
3739
|
-
|
3740
|
-
|
3741
|
-
v.get_label_by_id(
|
3742
|
-
|
3973
|
+
v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
|
3974
|
+
v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
|
3975
|
+
v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
|
3976
|
+
v.get_label_by_id("110").set_text(
|
3977
|
+
get_label(len(set1 | set2), len(set1 & set2 - set3))
|
3978
|
+
)
|
3979
|
+
v.get_label_by_id("101").set_text(
|
3980
|
+
get_label(len(set1 | set3), len(set1 & set3 - set2))
|
3981
|
+
)
|
3982
|
+
v.get_label_by_id("011").set_text(
|
3983
|
+
get_label(len(set2 | set3), len(set2 & set3 - set1))
|
3984
|
+
)
|
3985
|
+
v.get_label_by_id("111").set_text(
|
3986
|
+
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
3987
|
+
)
|
3743
3988
|
|
3744
3989
|
# Apply styles to set labels
|
3745
3990
|
for i, text in enumerate(v.set_labels):
|
3746
|
-
textsets(
|
3747
|
-
|
3991
|
+
textsets(
|
3992
|
+
text,
|
3993
|
+
fontname=fontname,
|
3994
|
+
fontsize=fontsize,
|
3995
|
+
fontweight=fontweight,
|
3996
|
+
fontstyle=fontstyle,
|
3997
|
+
fontcolor=fontcolor,
|
3998
|
+
ha=ha,
|
3999
|
+
va=va,
|
4000
|
+
shadow=shadow,
|
4001
|
+
)
|
3748
4002
|
|
3749
4003
|
# Apply styles to subset labels
|
3750
4004
|
for i, text in enumerate(v.subset_labels):
|
3751
4005
|
if text: # Ensure text exists
|
3752
4006
|
if custom_texts: # Custom text handling
|
3753
4007
|
text.set_text(custom_texts[i])
|
3754
|
-
textsets(
|
3755
|
-
|
4008
|
+
textsets(
|
4009
|
+
text,
|
4010
|
+
fontname=fontname,
|
4011
|
+
fontsize=subset_fontsize,
|
4012
|
+
fontweight=subset_fontweight,
|
4013
|
+
fontstyle=subset_fontstyle,
|
4014
|
+
fontcolor=subset_fontcolor,
|
4015
|
+
ha=ha,
|
4016
|
+
va=va,
|
4017
|
+
shadow=shadow,
|
4018
|
+
)
|
3756
4019
|
|
3757
4020
|
venn_circles = venn3_circles(subsets=lists, ax=ax)
|
3758
|
-
if not isinstance(linewidth,list):
|
3759
|
-
linewidth=[linewidth]
|
3760
|
-
if isinstance(linestyle,str):
|
3761
|
-
linestyle=[linestyle]
|
4021
|
+
if not isinstance(linewidth, list):
|
4022
|
+
linewidth = [linewidth]
|
4023
|
+
if isinstance(linestyle, str):
|
4024
|
+
linestyle = [linestyle]
|
3762
4025
|
if not isinstance(edgecolor, list):
|
3763
|
-
edgecolor=[edgecolor]
|
3764
|
-
linewidth=linewidth*3 if len(linewidth)==1 else linewidth
|
3765
|
-
linestyle=linestyle*3 if len(linestyle)==1 else linestyle
|
3766
|
-
edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
|
4026
|
+
edgecolor = [edgecolor]
|
4027
|
+
linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
|
4028
|
+
linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
|
4029
|
+
edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
|
3767
4030
|
|
3768
4031
|
# edgecolor=[to_rgba(i) for i in edgecolor]
|
3769
4032
|
|
3770
|
-
for i in range(3):
|
4033
|
+
for i in range(3):
|
3771
4034
|
venn_circles[i].set_lw(linewidth[i])
|
3772
|
-
venn_circles[i].set_ls(linestyle[i])
|
3773
|
-
venn_circles[i].set_edgecolor(edgecolor[i])
|
4035
|
+
venn_circles[i].set_ls(linestyle[i])
|
4036
|
+
venn_circles[i].set_edgecolor(edgecolor[i])
|
3774
4037
|
|
3775
|
-
|
4038
|
+
# 椭圆形
|
3776
4039
|
if ellipse_shape:
|
3777
|
-
import matplotlib.patches as patches
|
4040
|
+
import matplotlib.patches as patches
|
4041
|
+
|
3778
4042
|
for patch in v.patches:
|
3779
4043
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3780
4044
|
center1 = v.get_circle_center(0)
|
3781
4045
|
center2 = v.get_circle_center(1)
|
3782
|
-
center3 = v.get_circle_center(2)
|
4046
|
+
center3 = v.get_circle_center(2)
|
3783
4047
|
ellipse1 = patches.Ellipse(
|
3784
4048
|
(center1.x, center1.y),
|
3785
4049
|
width=ellipse_scale[0],
|
3786
4050
|
height=ellipse_scale[1],
|
3787
4051
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3788
4052
|
facecolor=colors[0],
|
3789
|
-
lw=
|
4053
|
+
lw=(
|
4054
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4055
|
+
), # Ensure lw is a number
|
3790
4056
|
ls=linestyle[0],
|
3791
|
-
alpha=
|
4057
|
+
alpha=(
|
4058
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4059
|
+
), # Ensure alpha is a number
|
3792
4060
|
)
|
3793
4061
|
ellipse2 = patches.Ellipse(
|
3794
4062
|
(center2.x, center2.y),
|
@@ -3796,9 +4064,13 @@ def venn(
|
|
3796
4064
|
height=ellipse_scale[1],
|
3797
4065
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3798
4066
|
facecolor=colors[1],
|
3799
|
-
lw=
|
4067
|
+
lw=(
|
4068
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4069
|
+
), # Ensure lw is a number
|
3800
4070
|
ls=linestyle[0],
|
3801
|
-
alpha=
|
4071
|
+
alpha=(
|
4072
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4073
|
+
), # Ensure alpha is a number
|
3802
4074
|
)
|
3803
4075
|
ellipse3 = patches.Ellipse(
|
3804
4076
|
(center3.x, center3.y),
|
@@ -3806,9 +4078,13 @@ def venn(
|
|
3806
4078
|
height=ellipse_scale[1],
|
3807
4079
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3808
4080
|
facecolor=colors[1],
|
3809
|
-
lw=
|
4081
|
+
lw=(
|
4082
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4083
|
+
), # Ensure lw is a number
|
3810
4084
|
ls=linestyle[0],
|
3811
|
-
alpha=
|
4085
|
+
alpha=(
|
4086
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4087
|
+
), # Ensure alpha is a number
|
3812
4088
|
)
|
3813
4089
|
ax.add_patch(ellipse1)
|
3814
4090
|
ax.add_patch(ellipse2)
|
@@ -3820,17 +4096,20 @@ def venn(
|
|
3820
4096
|
for patch in v.patches:
|
3821
4097
|
if patch:
|
3822
4098
|
patch.set_alpha(alpha)
|
3823
|
-
if
|
4099
|
+
if "none" in edgecolor or 0 in linewidth:
|
3824
4100
|
patch.set_edgecolor("none")
|
3825
4101
|
return ax
|
3826
4102
|
|
4103
|
+
|
3827
4104
|
#! subplots, support automatic extend new axis
|
3828
|
-
def subplot(
|
3829
|
-
|
3830
|
-
|
3831
|
-
|
3832
|
-
|
3833
|
-
|
4105
|
+
def subplot(
|
4106
|
+
rows: int = 2,
|
4107
|
+
cols: int = 2,
|
4108
|
+
figsize: Union[tuple, list] = [8, 8],
|
4109
|
+
sharex=False,
|
4110
|
+
sharey=False,
|
4111
|
+
**kwargs,
|
4112
|
+
):
|
3834
4113
|
"""
|
3835
4114
|
nexttile = subplot(
|
3836
4115
|
8,
|
@@ -3850,11 +4129,14 @@ def subplot(rows:int=2,
|
|
3850
4129
|
ax.set_xlabel(f"Tile {i + 1}")
|
3851
4130
|
"""
|
3852
4131
|
from matplotlib.gridspec import GridSpec
|
4132
|
+
|
3853
4133
|
if run_once_within():
|
3854
|
-
print(
|
4134
|
+
print(
|
4135
|
+
f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
|
4136
|
+
)
|
3855
4137
|
fig = plt.figure(figsize=figsize)
|
3856
4138
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
3857
|
-
occupied = set()
|
4139
|
+
occupied = set()
|
3858
4140
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
3859
4141
|
col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
|
3860
4142
|
|
@@ -3862,8 +4144,9 @@ def subplot(rows:int=2,
|
|
3862
4144
|
nonlocal rows, grid_spec
|
3863
4145
|
rows += 1 # Expands by adding a row
|
3864
4146
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4147
|
+
|
3865
4148
|
def nexttile(rowspan=1, colspan=1, **kwargs):
|
3866
|
-
nonlocal rows, cols, occupied, grid_spec
|
4149
|
+
nonlocal rows, cols, occupied, grid_spec
|
3867
4150
|
for row in range(rows):
|
3868
4151
|
for col in range(cols):
|
3869
4152
|
if all(
|
@@ -3873,29 +4156,29 @@ def subplot(rows:int=2,
|
|
3873
4156
|
):
|
3874
4157
|
break
|
3875
4158
|
else:
|
3876
|
-
continue
|
4159
|
+
continue
|
3877
4160
|
break
|
3878
|
-
else:
|
4161
|
+
else:
|
3879
4162
|
expand_ax()
|
3880
4163
|
return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
|
3881
4164
|
|
3882
|
-
sharex_ax,sharey_ax = None, None
|
4165
|
+
sharex_ax, sharey_ax = None, None
|
3883
4166
|
|
3884
|
-
if sharex:
|
4167
|
+
if sharex:
|
3885
4168
|
sharex_ax = col_first_axes[col]
|
3886
4169
|
|
3887
|
-
if sharey:
|
3888
|
-
sharey_ax = row_first_axes[row]
|
4170
|
+
if sharey:
|
4171
|
+
sharey_ax = row_first_axes[row]
|
3889
4172
|
ax = fig.add_subplot(
|
3890
4173
|
grid_spec[row : row + rowspan, col : col + colspan],
|
3891
4174
|
sharex=sharex_ax,
|
3892
4175
|
sharey=sharey_ax,
|
3893
|
-
**kwargs
|
3894
|
-
)
|
4176
|
+
**kwargs,
|
4177
|
+
)
|
3895
4178
|
if row_first_axes[row] is None:
|
3896
4179
|
row_first_axes[row] = ax
|
3897
4180
|
if col_first_axes[col] is None:
|
3898
|
-
col_first_axes[col] = ax
|
4181
|
+
col_first_axes[col] = ax
|
3899
4182
|
for r in range(row, row + rowspan):
|
3900
4183
|
for c in range(col, col + colspan):
|
3901
4184
|
occupied.add((r, c))
|
@@ -3907,17 +4190,17 @@ def subplot(rows:int=2,
|
|
3907
4190
|
|
3908
4191
|
#! radar chart
|
3909
4192
|
def radar(
|
3910
|
-
data: pd.DataFrame,
|
3911
|
-
ylim=(0,100),
|
4193
|
+
data: pd.DataFrame,
|
4194
|
+
ylim=(0, 100),
|
3912
4195
|
color=get_color(5),
|
3913
4196
|
fontsize=10,
|
3914
|
-
fontcolor=
|
4197
|
+
fontcolor="k",
|
3915
4198
|
size=6,
|
3916
4199
|
linewidth=1,
|
3917
4200
|
linestyle="-",
|
3918
4201
|
alpha=0.5,
|
3919
4202
|
marker="o",
|
3920
|
-
edgecolor=
|
4203
|
+
edgecolor="none",
|
3921
4204
|
edge_linewidth=0,
|
3922
4205
|
bg_color="0.8",
|
3923
4206
|
bg_alpha=None,
|
@@ -3928,18 +4211,19 @@ def radar(
|
|
3928
4211
|
legend_fontsize=10,
|
3929
4212
|
grid_color="gray",
|
3930
4213
|
grid_alpha=0.5,
|
3931
|
-
grid_linestyle="--",
|
4214
|
+
grid_linestyle="--",
|
4215
|
+
grid_linewidth=0.5,
|
3932
4216
|
circular: bool = False,
|
3933
4217
|
tick_fontsize=None,
|
3934
4218
|
tick_fontcolor="0.65",
|
3935
|
-
tick_loc
|
3936
|
-
turning
|
4219
|
+
tick_loc=None, # label position
|
4220
|
+
turning=None,
|
3937
4221
|
ax=None,
|
3938
4222
|
sp=2,
|
3939
|
-
**kwargs
|
4223
|
+
**kwargs,
|
3940
4224
|
):
|
3941
4225
|
"""
|
3942
|
-
Example DATA:
|
4226
|
+
Example DATA:
|
3943
4227
|
df = pd.DataFrame(
|
3944
4228
|
data=[
|
3945
4229
|
[80, 80, 80, 80, 80, 80, 80],
|
@@ -3948,7 +4232,7 @@ def radar(
|
|
3948
4232
|
],
|
3949
4233
|
index=["Hero", "Warrior", "Wizard"],
|
3950
4234
|
columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
|
3951
|
-
|
4235
|
+
|
3952
4236
|
Parameters:
|
3953
4237
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
3954
4238
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4002,8 +4286,12 @@ def radar(
|
|
4002
4286
|
|
4003
4287
|
# bg_color
|
4004
4288
|
if bg_alpha is None:
|
4005
|
-
bg_alpha=alpha
|
4006
|
-
|
4289
|
+
bg_alpha = alpha
|
4290
|
+
(
|
4291
|
+
ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
|
4292
|
+
if circular
|
4293
|
+
else ax.set_facecolor("none")
|
4294
|
+
)
|
4007
4295
|
# Set up the radar chart with straight-line connections
|
4008
4296
|
ax.set_theta_offset(np.pi / 2)
|
4009
4297
|
ax.set_theta_direction(-1)
|
@@ -4012,53 +4300,68 @@ def radar(
|
|
4012
4300
|
ax.set_xticks(angles[:-1])
|
4013
4301
|
ax.set_xticklabels(categories)
|
4014
4302
|
|
4015
|
-
# Set y-axis limits and grid intervals
|
4303
|
+
# Set y-axis limits and grid intervals
|
4016
4304
|
vmin, vmax = ylim
|
4017
|
-
if circular:
|
4018
|
-
|
4019
|
-
ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
|
4020
|
-
ax.grid(
|
4021
|
-
|
4022
|
-
|
4023
|
-
|
4024
|
-
|
4025
|
-
|
4026
|
-
|
4027
|
-
|
4028
|
-
|
4305
|
+
if circular:
|
4306
|
+
# * cicular style
|
4307
|
+
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
4308
|
+
ax.grid(
|
4309
|
+
axis="both",
|
4310
|
+
color=grid_color,
|
4311
|
+
linestyle=grid_linestyle,
|
4312
|
+
alpha=grid_alpha,
|
4313
|
+
linewidth=grid_linewidth,
|
4314
|
+
dash_capstyle="round",
|
4315
|
+
dash_joinstyle="round",
|
4316
|
+
)
|
4317
|
+
ax.spines["polar"].set_color(grid_color)
|
4029
4318
|
ax.spines["polar"].set_linewidth(grid_linewidth)
|
4030
|
-
ax.spines["polar"].set_linestyle(
|
4319
|
+
ax.spines["polar"].set_linestyle("-")
|
4031
4320
|
ax.spines["polar"].set_alpha(grid_alpha)
|
4032
|
-
ax.spines["polar"].set_capstyle(
|
4033
|
-
ax.spines["polar"].set_joinstyle(
|
4321
|
+
ax.spines["polar"].set_capstyle("round")
|
4322
|
+
ax.spines["polar"].set_joinstyle("round")
|
4034
4323
|
|
4035
4324
|
else:
|
4036
|
-
|
4325
|
+
# * spider style: spider-style grid (straight lines, not circles)
|
4037
4326
|
# Create the spider-style grid (straight lines, not circles)
|
4038
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4327
|
+
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4039
4328
|
ax.plot(
|
4040
4329
|
angles + [angles[0]], # Closing the loop
|
4041
|
-
[i * vmax * grid_interval_ratio] * (num_vars+1)
|
4042
|
-
|
4330
|
+
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
4331
|
+
+ [i * vmax * grid_interval_ratio],
|
4332
|
+
color=grid_color,
|
4333
|
+
linestyle=grid_linestyle,
|
4334
|
+
alpha=grid_alpha,
|
4335
|
+
linewidth=grid_linewidth,
|
4043
4336
|
)
|
4044
|
-
# set bg_color
|
4045
|
-
ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
|
4337
|
+
# set bg_color
|
4338
|
+
ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
|
4046
4339
|
ax.yaxis.grid(False)
|
4047
4340
|
# Move radial labels away from plotted line
|
4048
4341
|
if tick_loc is None:
|
4049
|
-
tick_loc
|
4342
|
+
tick_loc = (
|
4343
|
+
np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
|
4344
|
+
)
|
4050
4345
|
|
4051
4346
|
ax.set_rlabel_position(tick_loc)
|
4052
4347
|
ax.set_theta_offset(turning) if turning is not None else None
|
4053
|
-
ax.tick_params(
|
4054
|
-
|
4055
|
-
|
4348
|
+
ax.tick_params(
|
4349
|
+
axis="x", labelsize=fontsize, colors=fontcolor
|
4350
|
+
) # Optional: for angular labels
|
4351
|
+
tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
|
4352
|
+
ax.tick_params(
|
4353
|
+
axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
|
4354
|
+
) # For radial labels
|
4056
4355
|
if not circular:
|
4057
|
-
ax.spines[
|
4058
|
-
ax.tick_params(axis=
|
4059
|
-
ax.tick_params(axis=
|
4356
|
+
ax.spines["polar"].set_visible(False)
|
4357
|
+
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4358
|
+
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4060
4359
|
# colors
|
4061
|
-
colors =
|
4360
|
+
colors = (
|
4361
|
+
get_color(data.shape[0])
|
4362
|
+
if cmap is None
|
4363
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
4364
|
+
)
|
4062
4365
|
# Plot each row with straight lines
|
4063
4366
|
for i, (index, row) in enumerate(data.iterrows()):
|
4064
4367
|
values = row.tolist()
|
@@ -4070,7 +4373,7 @@ def radar(
|
|
4070
4373
|
linewidth=linewidth,
|
4071
4374
|
linestyle=linestyle,
|
4072
4375
|
label=index,
|
4073
|
-
clip_on=False
|
4376
|
+
clip_on=False,
|
4074
4377
|
)
|
4075
4378
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
4076
4379
|
|
@@ -4084,13 +4387,23 @@ def radar(
|
|
4084
4387
|
marker=marker,
|
4085
4388
|
markersize=size,
|
4086
4389
|
markeredgecolor=edgecolor,
|
4087
|
-
markeredgewidth
|
4390
|
+
markeredgewidth=edge_linewidth,
|
4391
|
+
zorder=10,
|
4392
|
+
clip_on=False,
|
4088
4393
|
)
|
4089
4394
|
# ax.tick_params(axis='y', labelleft=False, left=False)
|
4090
|
-
if
|
4395
|
+
if "legend" in kws_figsets:
|
4091
4396
|
figsets(ax=ax, **kws_figsets)
|
4092
4397
|
else:
|
4093
|
-
|
4094
|
-
figsets(
|
4095
|
-
|
4096
|
-
|
4398
|
+
|
4399
|
+
figsets(
|
4400
|
+
ax=ax,
|
4401
|
+
legend=dict(
|
4402
|
+
loc=legend_loc,
|
4403
|
+
fontsize=legend_fontsize,
|
4404
|
+
bbox_to_anchor=[1.1, 1.4],
|
4405
|
+
ncols=2,
|
4406
|
+
),
|
4407
|
+
**kws_figsets,
|
4408
|
+
)
|
4409
|
+
return ax
|