py2ls 0.2.4.10.5__py3-none-any.whl → 0.2.4.10.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py
CHANGED
@@ -1,21 +1,37 @@
|
|
1
1
|
import numpy as np
|
2
|
-
import pandas as pd
|
2
|
+
import pandas as pd
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import matplotlib
|
5
5
|
import seaborn as sns
|
6
6
|
import warnings
|
7
7
|
import logging
|
8
8
|
from typing import Union
|
9
|
-
from .ips import
|
9
|
+
from .ips import (
|
10
|
+
isa,
|
11
|
+
fsave,
|
12
|
+
fload,
|
13
|
+
mkdir,
|
14
|
+
listdir,
|
15
|
+
figsave,
|
16
|
+
strcmp,
|
17
|
+
unique,
|
18
|
+
get_os,
|
19
|
+
ssplit,
|
20
|
+
flatten,
|
21
|
+
plt_font,
|
22
|
+
run_once_within,
|
23
|
+
)
|
10
24
|
from .stats import *
|
11
25
|
import os
|
26
|
+
|
12
27
|
# Suppress INFO messages from fontTools
|
13
28
|
logging.getLogger("fontTools").setLevel(logging.ERROR)
|
14
|
-
logging.getLogger(
|
29
|
+
logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
15
30
|
|
16
31
|
warnings.simplefilter("ignore", category=pd.errors.SettingWithCopyWarning)
|
17
32
|
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
|
18
33
|
|
34
|
+
|
19
35
|
def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
20
36
|
"""Adds text annotations for various types of Seaborn and Matplotlib plots.
|
21
37
|
Args:
|
@@ -449,7 +465,7 @@ def catplot(data, *args, **kwargs):
|
|
449
465
|
"""
|
450
466
|
from matplotlib.colors import to_rgba
|
451
467
|
import os
|
452
|
-
|
468
|
+
|
453
469
|
def plot_bars(data, data_m, opt_b, xloc, ax, label=None):
|
454
470
|
if "l" in opt_b["loc"]:
|
455
471
|
xloc_s = xloc - opt_b["x_dist"]
|
@@ -755,6 +771,7 @@ def catplot(data, *args, **kwargs):
|
|
755
771
|
)
|
756
772
|
else:
|
757
773
|
from scipy.stats import gaussian_kde
|
774
|
+
|
758
775
|
kde = gaussian_kde(ys, bw_method=opt_v["BandWidth"])
|
759
776
|
min_val, max_val = ys.min(), ys.max()
|
760
777
|
y_vals = np.linspace(min_val, max_val, opt_v["NumPoints"])
|
@@ -1416,7 +1433,9 @@ def catplot(data, *args, **kwargs):
|
|
1416
1433
|
style_load = fload(dir_style + style_use + ".json")
|
1417
1434
|
else:
|
1418
1435
|
style_load = fload(
|
1419
|
-
listdir(dir_style, "json").path.tolist()[
|
1436
|
+
listdir(dir_style, "json", verbose=False).path.tolist()[
|
1437
|
+
style_use
|
1438
|
+
]
|
1420
1439
|
)
|
1421
1440
|
style_load = remove_colors_in_dict(style_load)
|
1422
1441
|
opt.update(style_load)
|
@@ -1812,16 +1831,17 @@ def read_mplstyle(style_file):
|
|
1812
1831
|
return style_dict
|
1813
1832
|
|
1814
1833
|
|
1815
|
-
def figsets(*args, **kwargs):
|
1834
|
+
def figsets(*args, **kwargs):
|
1816
1835
|
import matplotlib
|
1817
1836
|
from cycler import cycler
|
1837
|
+
|
1818
1838
|
matplotlib.rc("text", usetex=False)
|
1819
1839
|
|
1820
1840
|
fig = plt.gcf()
|
1821
|
-
fontsize = kwargs.get("fontsize",11)
|
1822
|
-
plt.rcParams["font.size"]=fontsize
|
1823
|
-
fontname = kwargs.pop("fontname","Arial")
|
1824
|
-
fontname=plt_font(fontname)
|
1841
|
+
fontsize = kwargs.get("fontsize", 11)
|
1842
|
+
plt.rcParams["font.size"] = fontsize
|
1843
|
+
fontname = kwargs.pop("fontname", "Arial")
|
1844
|
+
fontname = plt_font(fontname) # 显示中文
|
1825
1845
|
|
1826
1846
|
sns_themes = ["white", "whitegrid", "dark", "darkgrid", "ticks"]
|
1827
1847
|
sns_contexts = ["notebook", "talk", "poster"] # now available "paper"
|
@@ -1851,18 +1871,21 @@ def figsets(*args, **kwargs):
|
|
1851
1871
|
nonlocal fontsize, fontname
|
1852
1872
|
if ("fo" in key) and (("size" in key) or ("sz" in key)):
|
1853
1873
|
fontsize = value
|
1854
|
-
plt.rcParams.update(
|
1855
|
-
|
1856
|
-
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1874
|
+
plt.rcParams.update(
|
1875
|
+
{
|
1876
|
+
"font.size": fontsize,
|
1877
|
+
"figure.titlesize": fontsize,
|
1878
|
+
"axes.titlesize": fontsize,
|
1879
|
+
"axes.labelsize": fontsize,
|
1880
|
+
"xtick.labelsize": fontsize,
|
1881
|
+
"ytick.labelsize": fontsize,
|
1882
|
+
"legend.fontsize": fontsize,
|
1883
|
+
"legend.title_fontsize": fontsize,
|
1884
|
+
}
|
1885
|
+
)
|
1863
1886
|
|
1864
1887
|
# Customize tick labels
|
1865
|
-
ax.tick_params(axis=
|
1888
|
+
ax.tick_params(axis="both", which="major", labelsize=fontsize)
|
1866
1889
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
1867
1890
|
label.set_fontname(fontname)
|
1868
1891
|
|
@@ -1910,15 +1933,15 @@ def figsets(*args, **kwargs):
|
|
1910
1933
|
if ("x" in key.lower()) and (
|
1911
1934
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1912
1935
|
):
|
1913
|
-
ax.set_xlabel(value, fontname=fontname,fontsize=fontsize)
|
1936
|
+
ax.set_xlabel(value, fontname=fontname, fontsize=fontsize)
|
1914
1937
|
if ("y" in key.lower()) and (
|
1915
1938
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1916
1939
|
):
|
1917
|
-
ax.set_ylabel(value, fontname=fontname,fontsize=fontsize)
|
1940
|
+
ax.set_ylabel(value, fontname=fontname, fontsize=fontsize)
|
1918
1941
|
if ("z" in key.lower()) and (
|
1919
1942
|
"tic" not in key.lower() and "tk" not in key.lower()
|
1920
1943
|
):
|
1921
|
-
ax.set_zlabel(value, fontname=fontname,fontsize=fontsize)
|
1944
|
+
ax.set_zlabel(value, fontname=fontname, fontsize=fontsize)
|
1922
1945
|
if key == "xlabel" and isinstance(value, dict):
|
1923
1946
|
ax.set_xlabel(**value)
|
1924
1947
|
if key == "ylabel" and isinstance(value, dict):
|
@@ -2110,6 +2133,7 @@ def figsets(*args, **kwargs):
|
|
2110
2133
|
|
2111
2134
|
if "mi" in key.lower() and "tic" in key.lower(): # minor_ticks
|
2112
2135
|
import matplotlib.ticker as tck
|
2136
|
+
|
2113
2137
|
if "x" in value.lower() or "x" in key.lower():
|
2114
2138
|
ax.xaxis.set_minor_locator(tck.AutoMinorLocator()) # ax.minorticks_on()
|
2115
2139
|
if "y" in value.lower() or "y" in key.lower():
|
@@ -2162,9 +2186,9 @@ def figsets(*args, **kwargs):
|
|
2162
2186
|
ax.grid(visible=False)
|
2163
2187
|
if "tit" in key.lower():
|
2164
2188
|
if "sup" in key.lower():
|
2165
|
-
plt.suptitle(value,fontname=fontname,fontsize=fontsize)
|
2189
|
+
plt.suptitle(value, fontname=fontname, fontsize=fontsize)
|
2166
2190
|
else:
|
2167
|
-
ax.set_title(value,fontname=fontname,fontsize=fontsize)
|
2191
|
+
ax.set_title(value, fontname=fontname, fontsize=fontsize)
|
2168
2192
|
if key.lower() in ["spine", "adjust", "ad", "sp", "spi", "adj", "spines"]:
|
2169
2193
|
if isinstance(value, bool) or (value in ["go", "do", "ja", "yes"]):
|
2170
2194
|
if value:
|
@@ -2185,9 +2209,14 @@ def figsets(*args, **kwargs):
|
|
2185
2209
|
legend = ax.get_legend()
|
2186
2210
|
if legend is not None:
|
2187
2211
|
legend.remove()
|
2188
|
-
if
|
2212
|
+
if (
|
2213
|
+
any(["colorbar" in key.lower(), "cbar" in key.lower()])
|
2214
|
+
and "loc" in key.lower()
|
2215
|
+
):
|
2189
2216
|
cbar = ax.collections[0].colorbar # Access the colorbar from the plot
|
2190
|
-
cbar.ax.set_position(
|
2217
|
+
cbar.ax.set_position(
|
2218
|
+
value
|
2219
|
+
) # [left, bottom, width, height] [0.475, 0.15, 0.04, 0.25]
|
2191
2220
|
|
2192
2221
|
for arg in args:
|
2193
2222
|
if isinstance(arg, matplotlib.axes._axes.Axes):
|
@@ -2225,7 +2254,8 @@ def figsets(*args, **kwargs):
|
|
2225
2254
|
if len(fig.get_axes()) > 1:
|
2226
2255
|
plt.tight_layout()
|
2227
2256
|
|
2228
|
-
|
2257
|
+
|
2258
|
+
def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
|
2229
2259
|
"""
|
2230
2260
|
split_legend(
|
2231
2261
|
ax,
|
@@ -2238,15 +2268,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2238
2268
|
# Retrieve all lines and labels from the axis
|
2239
2269
|
handles, labels = ax.get_legend_handles_labels()
|
2240
2270
|
num_labels = len(labels)
|
2241
|
-
|
2271
|
+
|
2242
2272
|
# Calculate the number of labels per legend part
|
2243
|
-
labels_per_part = (num_labels + n - 1) // n # Round up
|
2273
|
+
labels_per_part = (num_labels + n - 1) // n # Round up
|
2244
2274
|
# Create a list to hold each legend object
|
2245
2275
|
legends = []
|
2246
2276
|
|
2247
2277
|
# Default locations and titles if not specified
|
2248
2278
|
if loc is None:
|
2249
|
-
loc = [
|
2279
|
+
loc = ["best"] * n
|
2250
2280
|
if title is None:
|
2251
2281
|
title = [None] * n
|
2252
2282
|
if bbox is None:
|
@@ -2257,7 +2287,7 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2257
2287
|
# Calculate the range of labels for this part
|
2258
2288
|
start_idx = i * labels_per_part
|
2259
2289
|
end_idx = min(start_idx + labels_per_part, num_labels)
|
2260
|
-
|
2290
|
+
|
2261
2291
|
# Skip if no labels in this range
|
2262
2292
|
if start_idx >= end_idx:
|
2263
2293
|
break
|
@@ -2267,19 +2297,24 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None,ncol=1, **kwargs):
|
|
2267
2297
|
part_labels = labels[start_idx:end_idx]
|
2268
2298
|
|
2269
2299
|
# Create the legend for this part
|
2270
|
-
legend = ax.legend(
|
2271
|
-
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2300
|
+
legend = ax.legend(
|
2301
|
+
handles=part_handles,
|
2302
|
+
labels=part_labels,
|
2303
|
+
loc=loc[i],
|
2304
|
+
title=title[i],
|
2305
|
+
ncol=ncol,
|
2306
|
+
bbox_to_anchor=bbox[i],
|
2307
|
+
**kwargs,
|
2308
|
+
)
|
2309
|
+
|
2278
2310
|
# Add the legend to the axis and save it to the list
|
2279
|
-
|
2311
|
+
(
|
2312
|
+
ax.add_artist(legend) if i != (n - 1) else None
|
2313
|
+
) # the lastone will be added automaticaly
|
2280
2314
|
legends.append(legend)
|
2281
2315
|
return legends
|
2282
2316
|
|
2317
|
+
|
2283
2318
|
def get_colors(
|
2284
2319
|
n: int = 1,
|
2285
2320
|
cmap: str = "auto",
|
@@ -2289,7 +2324,9 @@ def get_colors(
|
|
2289
2324
|
*args,
|
2290
2325
|
**kwargs,
|
2291
2326
|
):
|
2292
|
-
return get_color(n,cmap,alpha,output
|
2327
|
+
return get_color(n, cmap, alpha, output, *args, **kwargs)
|
2328
|
+
|
2329
|
+
|
2293
2330
|
def get_color(
|
2294
2331
|
n: int = 1,
|
2295
2332
|
cmap: str = "auto",
|
@@ -2300,6 +2337,7 @@ def get_color(
|
|
2300
2337
|
**kwargs,
|
2301
2338
|
):
|
2302
2339
|
from cycler import cycler
|
2340
|
+
|
2303
2341
|
def cmap2hex(cmap_name):
|
2304
2342
|
cmap_ = matplotlib.pyplot.get_cmap(cmap_name)
|
2305
2343
|
colors = [cmap_(i) for i in range(cmap_.N)]
|
@@ -2347,49 +2385,137 @@ def get_color(
|
|
2347
2385
|
return "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, a)
|
2348
2386
|
else:
|
2349
2387
|
return "#{:02X}{:02X}{:02X}".format(r, g, b)
|
2350
|
-
|
2388
|
+
|
2351
2389
|
# sc.pl.palettes.default_20
|
2352
|
-
cmap_20= [
|
2353
|
-
|
2354
|
-
|
2390
|
+
cmap_20 = [
|
2391
|
+
"#1f77b4",
|
2392
|
+
"#ff7f0e",
|
2393
|
+
"#279e68",
|
2394
|
+
"#d62728",
|
2395
|
+
"#aa40fc",
|
2396
|
+
"#8c564b",
|
2397
|
+
"#e377c2",
|
2398
|
+
"#b5bd61",
|
2399
|
+
"#17becf",
|
2400
|
+
"#aec7e8",
|
2401
|
+
"#ffbb78",
|
2402
|
+
"#98df8a",
|
2403
|
+
"#ff9896",
|
2404
|
+
"#c5b0d5",
|
2405
|
+
"#c49c94",
|
2406
|
+
"#f7b6d2",
|
2407
|
+
"#dbdb8d",
|
2408
|
+
"#9edae5",
|
2409
|
+
"#ad494a",
|
2410
|
+
"#8c6d31",
|
2411
|
+
]
|
2355
2412
|
# sc.pl.palettes.zeileis_28
|
2356
|
-
cmap_28 = [
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2413
|
+
cmap_28 = [
|
2414
|
+
"#023fa5",
|
2415
|
+
"#7d87b9",
|
2416
|
+
"#bec1d4",
|
2417
|
+
"#d6bcc0",
|
2418
|
+
"#bb7784",
|
2419
|
+
"#8e063b",
|
2420
|
+
"#4a6fe3",
|
2421
|
+
"#8595e1",
|
2422
|
+
"#b5bbe3",
|
2423
|
+
"#e6afb9",
|
2424
|
+
"#e07b91",
|
2425
|
+
"#d33f6a",
|
2426
|
+
"#11c638",
|
2427
|
+
"#8dd593",
|
2428
|
+
"#c6dec7",
|
2429
|
+
"#ead3c6",
|
2430
|
+
"#f0b98d",
|
2431
|
+
"#ef9708",
|
2432
|
+
"#0fcfc0",
|
2433
|
+
"#9cded6",
|
2434
|
+
"#d5eae7",
|
2435
|
+
"#f3e1eb",
|
2436
|
+
"#f6c4e1",
|
2437
|
+
"#f79cd4",
|
2438
|
+
"#7f7f7f",
|
2439
|
+
"#c7c7c7",
|
2440
|
+
"#1CE6FF",
|
2441
|
+
"#336600",
|
2442
|
+
]
|
2360
2443
|
if cmap == "gray":
|
2361
2444
|
cmap = "grey"
|
2362
|
-
elif cmap=="20":
|
2363
|
-
cmap=cmap_20
|
2364
|
-
elif cmap=="28":
|
2365
|
-
cmap=cmap_28
|
2445
|
+
elif cmap == "20":
|
2446
|
+
cmap = cmap_20
|
2447
|
+
elif cmap == "28":
|
2448
|
+
cmap = cmap_28
|
2366
2449
|
# Determine color list based on cmap parameter
|
2367
|
-
if isinstance(cmap,str):
|
2450
|
+
if isinstance(cmap, str):
|
2368
2451
|
if "aut" in cmap:
|
2369
2452
|
if n == 1:
|
2370
2453
|
colorlist = ["#3A4453"]
|
2371
2454
|
elif n == 2:
|
2372
2455
|
colorlist = ["#3A4453", "#FF2C00"]
|
2373
2456
|
elif n == 3:
|
2374
|
-
colorlist = ["#66c2a5","#fc8d62","#8da0cb"]
|
2457
|
+
colorlist = ["#66c2a5", "#fc8d62", "#8da0cb"]
|
2375
2458
|
elif n == 4:
|
2376
|
-
colorlist = ["#FF2C00","#087cf7", "#FBAF63", "#3C898A"]
|
2459
|
+
colorlist = ["#FF2C00", "#087cf7", "#FBAF63", "#3C898A"]
|
2377
2460
|
elif n == 5:
|
2378
|
-
colorlist = ["#FF2C00","#459AA9", "#B25E9D", "#4B8C3B","#EF8632"]
|
2461
|
+
colorlist = ["#FF2C00", "#459AA9", "#B25E9D", "#4B8C3B", "#EF8632"]
|
2379
2462
|
elif n == 6:
|
2380
|
-
colorlist = [
|
2381
|
-
|
2382
|
-
|
2383
|
-
|
2463
|
+
colorlist = [
|
2464
|
+
"#FF2C00",
|
2465
|
+
"#91bfdb",
|
2466
|
+
"#B25E9D",
|
2467
|
+
"#4B8C3B",
|
2468
|
+
"#EF8632",
|
2469
|
+
"#24578E",
|
2470
|
+
]
|
2471
|
+
elif n == 7:
|
2472
|
+
colorlist = [
|
2473
|
+
"#7F7F7F",
|
2474
|
+
"#459AA9",
|
2475
|
+
"#B25E9D",
|
2476
|
+
"#4B8C3B",
|
2477
|
+
"#EF8632",
|
2478
|
+
"#24578E" "#FF2C00",
|
2479
|
+
]
|
2480
|
+
elif n == 8:
|
2384
2481
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
2385
2482
|
# colorlist = ["#367C7E","#51B34F","#881A11","#E9374C","#EF893C","#010072","#385DCB","#EA43E3"]
|
2386
|
-
colorlist = [
|
2387
|
-
|
2388
|
-
|
2389
|
-
|
2390
|
-
|
2391
|
-
|
2392
|
-
|
2483
|
+
colorlist = [
|
2484
|
+
"#78BFDA",
|
2485
|
+
"#D52E6F",
|
2486
|
+
"#F7D648",
|
2487
|
+
"#A52D28",
|
2488
|
+
"#6B9F41",
|
2489
|
+
"#E18330",
|
2490
|
+
"#E18B9D",
|
2491
|
+
"#3C88CC",
|
2492
|
+
]
|
2493
|
+
elif n == 9:
|
2494
|
+
colorlist = [
|
2495
|
+
"#1f77b4",
|
2496
|
+
"#ff7f0e",
|
2497
|
+
"#367B7F",
|
2498
|
+
"#ff9896",
|
2499
|
+
"#d62728",
|
2500
|
+
"#aa40fc",
|
2501
|
+
"#e377c2",
|
2502
|
+
"#51B34F",
|
2503
|
+
"#17becf",
|
2504
|
+
]
|
2505
|
+
elif n == 10:
|
2506
|
+
colorlist = [
|
2507
|
+
"#1f77b4",
|
2508
|
+
"#ff7f0e",
|
2509
|
+
"#367B7F",
|
2510
|
+
"#ff9896",
|
2511
|
+
"#51B34F",
|
2512
|
+
"#d62728" "#aa40fc",
|
2513
|
+
"#e377c2",
|
2514
|
+
"#375FD2",
|
2515
|
+
"#17becf",
|
2516
|
+
]
|
2517
|
+
elif 10 < n <= 20:
|
2518
|
+
colorlist = cmap_20
|
2393
2519
|
else:
|
2394
2520
|
colorlist = cmap_28
|
2395
2521
|
by = "start"
|
@@ -2426,7 +2552,7 @@ def get_color(
|
|
2426
2552
|
by = "linspace"
|
2427
2553
|
colorlist = cmap2hex(cmap)
|
2428
2554
|
elif isinstance(cmap, list):
|
2429
|
-
colorlist=cmap
|
2555
|
+
colorlist = cmap
|
2430
2556
|
|
2431
2557
|
# Determine method for generating color list
|
2432
2558
|
if "st" in by.lower() or "be" in by.lower():
|
@@ -2446,7 +2572,7 @@ def get_color(
|
|
2446
2572
|
return hue_list
|
2447
2573
|
else:
|
2448
2574
|
raise ValueError("Invalid output type. Choose 'rgb' or 'hue'.")
|
2449
|
-
|
2575
|
+
|
2450
2576
|
|
2451
2577
|
def stdshade(ax=None, *args, **kwargs):
|
2452
2578
|
"""
|
@@ -2454,7 +2580,7 @@ def stdshade(ax=None, *args, **kwargs):
|
|
2454
2580
|
plot.stdshade(data_array, c=clist[1], lw=2, ls="-.", alpha=0.2)
|
2455
2581
|
"""
|
2456
2582
|
from scipy.signal import savgol_filter
|
2457
|
-
|
2583
|
+
|
2458
2584
|
# Separate kws_line and kws_fill if necessary
|
2459
2585
|
kws_line = kwargs.pop("kws_line", {})
|
2460
2586
|
kws_fill = kwargs.pop("kws_fill", {})
|
@@ -2724,37 +2850,47 @@ def adjust_spines(ax=None, spines=["left", "bottom"], distance=2):
|
|
2724
2850
|
# cax = fig.add_axes([l + w + pad, b, width, h]) # define cbar Axes
|
2725
2851
|
# return fig.colorbar(im, cax=cax, **kwargs) # draw cbar
|
2726
2852
|
|
2727
|
-
|
2728
|
-
|
2729
|
-
|
2730
|
-
|
2731
|
-
|
2732
|
-
|
2733
|
-
|
2734
|
-
|
2735
|
-
|
2853
|
+
|
2854
|
+
def add_colorbar(
|
2855
|
+
im,
|
2856
|
+
cmap="viridis",
|
2857
|
+
vmin=-1,
|
2858
|
+
vmax=1,
|
2859
|
+
orientation="vertical",
|
2860
|
+
width_ratio=0.05,
|
2861
|
+
pad_ratio=0.02,
|
2862
|
+
shrink=1.0,
|
2863
|
+
**kwargs,
|
2864
|
+
):
|
2736
2865
|
import matplotlib as mpl
|
2866
|
+
|
2737
2867
|
if all([cmap, vmin, vmax]):
|
2738
2868
|
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
2739
2869
|
else:
|
2740
|
-
norm=False
|
2870
|
+
norm = False
|
2741
2871
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
2742
2872
|
sm.set_array([])
|
2743
2873
|
l, b, w, h = im.axes.get_position().bounds # position: left, bottom, width, height
|
2744
|
-
if orientation ==
|
2874
|
+
if orientation == "vertical":
|
2745
2875
|
width = width_ratio * w
|
2746
2876
|
pad = pad_ratio * w
|
2747
|
-
cax = im.figure.add_axes(
|
2877
|
+
cax = im.figure.add_axes(
|
2878
|
+
[l + w + pad, b, width, h * shrink]
|
2879
|
+
) # Right of the image
|
2748
2880
|
else:
|
2749
2881
|
height = width_ratio * h
|
2750
2882
|
pad = pad_ratio * h
|
2751
|
-
cax = im.figure.add_axes(
|
2752
|
-
|
2753
|
-
|
2883
|
+
cax = im.figure.add_axes(
|
2884
|
+
[l, b - height - pad, w * shrink, height]
|
2885
|
+
) # Below the image
|
2886
|
+
cbar = im.figure.colorbar(sm, cax=cax, orientation=orientation, **kwargs)
|
2887
|
+
return cbar
|
2888
|
+
|
2754
2889
|
|
2755
2890
|
# Usage:
|
2756
2891
|
# add_colorbar(im, width_ratio=0.03, pad_ratio=0.01, orientation='horizontal', label="PSD (dB)")
|
2757
2892
|
|
2893
|
+
|
2758
2894
|
def generate_xticks_with_gap(x_len, hue_len):
|
2759
2895
|
"""
|
2760
2896
|
Generate a concatenated array based on x_len and hue_len,
|
@@ -2884,6 +3020,7 @@ def style_examples(
|
|
2884
3020
|
f = listdir(
|
2885
3021
|
"/Users/macjianfeng/Dropbox/github/python/py2ls/.venv/lib/python3.12/site-packages/py2ls/data/styles/",
|
2886
3022
|
kind=".json",
|
3023
|
+
verbose=False,
|
2887
3024
|
)
|
2888
3025
|
display(f.sample(2))
|
2889
3026
|
# def style_example(dir_save,)
|
@@ -2961,6 +3098,7 @@ def thumbnail(dir_img_list: list, figsize=(10, 10), dpi=100, show=False, verbose
|
|
2961
3098
|
|
2962
3099
|
def get_params_from_func_usage(function_signature):
|
2963
3100
|
import re
|
3101
|
+
|
2964
3102
|
# Regular expression to match parameter names, ignoring '*' and '**kwargs'
|
2965
3103
|
keys_pattern = r"(?<!\*\*)\b(\w+)="
|
2966
3104
|
# Find all matches
|
@@ -2987,7 +3125,7 @@ def plotxy(
|
|
2987
3125
|
x=None,
|
2988
3126
|
y=None,
|
2989
3127
|
ax=None,
|
2990
|
-
kind: Union[str, list] =
|
3128
|
+
kind: Union[str, list] = "scatter", # Specify the kind of plot
|
2991
3129
|
verbose=False,
|
2992
3130
|
**kwargs,
|
2993
3131
|
):
|
@@ -3011,14 +3149,15 @@ def plotxy(
|
|
3011
3149
|
# Check for valid plot kind
|
3012
3150
|
# Default arguments for various plot types
|
3013
3151
|
from pathlib import Path
|
3152
|
+
|
3014
3153
|
# Get the current script's directory as a Path object
|
3015
|
-
current_directory = Path(__file__).resolve().parent
|
3154
|
+
current_directory = Path(__file__).resolve().parent
|
3155
|
+
|
3156
|
+
if not "default_settings" in locals():
|
3157
|
+
default_settings = fload(current_directory / "data" / "usages_sns.json")
|
3158
|
+
if not "sns_info" in locals():
|
3159
|
+
sns_info = pd.DataFrame(fload(current_directory / "data" / "sns_info.json"))
|
3016
3160
|
|
3017
|
-
if not 'default_settings' in locals():
|
3018
|
-
default_settings = fload(current_directory / 'data' / 'usages_sns.json')
|
3019
|
-
if not 'sns_info' in locals():
|
3020
|
-
sns_info = pd.DataFrame(fload(current_directory / 'data' / 'sns_info.json'))
|
3021
|
-
|
3022
3161
|
valid_kinds = list(default_settings.keys())
|
3023
3162
|
# print(valid_kinds)
|
3024
3163
|
if kind is not None:
|
@@ -3032,7 +3171,7 @@ def plotxy(
|
|
3032
3171
|
if kind is not None:
|
3033
3172
|
for k in kind:
|
3034
3173
|
if k in valid_kinds:
|
3035
|
-
print(f"{k}:\n\t{default_settings[k]}")
|
3174
|
+
print(f"{k}:\n\t{default_settings[k]}")
|
3036
3175
|
usage_str = """plotxy(data=ranked_genes,
|
3037
3176
|
x="log2(fold_change)",
|
3038
3177
|
y="-log10(p-value)",
|
@@ -3057,15 +3196,25 @@ def plotxy(
|
|
3057
3196
|
kws_add_text = v_arg
|
3058
3197
|
kwargs.pop(k_arg, None)
|
3059
3198
|
break
|
3060
|
-
zorder=0
|
3199
|
+
zorder = 0
|
3061
3200
|
for k in kind:
|
3062
|
-
zorder+=1
|
3201
|
+
zorder += 1
|
3063
3202
|
# indicate 'col' features
|
3064
3203
|
col = kwargs.get("col", None)
|
3065
|
-
sns_with_col = [
|
3204
|
+
sns_with_col = [
|
3205
|
+
"catplot",
|
3206
|
+
"histplot",
|
3207
|
+
"relplot",
|
3208
|
+
"lmplot",
|
3209
|
+
"pairplot",
|
3210
|
+
"displot",
|
3211
|
+
"kdeplot",
|
3212
|
+
]
|
3066
3213
|
if col is not None:
|
3067
3214
|
if not k in sns_with_col:
|
3068
|
-
print(
|
3215
|
+
print(
|
3216
|
+
f"tips:\n'{k}' has no 'col' param, you could try with {sns_with_col}"
|
3217
|
+
)
|
3069
3218
|
# (1) return FcetGrid
|
3070
3219
|
if k == "jointplot":
|
3071
3220
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
@@ -3092,77 +3241,98 @@ def plotxy(
|
|
3092
3241
|
ax = stdshade(ax=ax, **kwargs)
|
3093
3242
|
elif k == "scatterplot":
|
3094
3243
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3095
|
-
kws_scatter={
|
3244
|
+
kws_scatter = {
|
3245
|
+
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3246
|
+
}
|
3096
3247
|
hue = kwargs.pop("hue", None)
|
3097
3248
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3098
3249
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3099
|
-
palette = kws_scatter.pop(
|
3250
|
+
palette = kws_scatter.pop(
|
3251
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3252
|
+
)
|
3100
3253
|
s = kws_scatter.pop("s", 10)
|
3101
3254
|
alpha = kws_scatter.pop("alpha", 0.7)
|
3102
|
-
ax = sns.scatterplot(
|
3255
|
+
ax = sns.scatterplot(
|
3256
|
+
ax=ax,
|
3257
|
+
data=data,
|
3258
|
+
x=x,
|
3259
|
+
y=y,
|
3260
|
+
hue=hue,
|
3261
|
+
palette=palette,
|
3262
|
+
s=s,
|
3263
|
+
alpha=alpha,
|
3264
|
+
zorder=zorder,
|
3265
|
+
**kws_scatter,
|
3266
|
+
)
|
3103
3267
|
elif k == "histplot":
|
3104
3268
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3105
|
-
kws_hist={k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3106
|
-
ax = sns.histplot(data=data, x=x, ax=ax,zorder=zorder, **kws_hist)
|
3269
|
+
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3270
|
+
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3107
3271
|
elif k == "kdeplot":
|
3108
3272
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3109
|
-
kws_kde={k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3110
|
-
ax = sns.kdeplot(data=data, x=x, ax=ax,zorder=zorder, **kws_kde)
|
3273
|
+
kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
|
3274
|
+
ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
|
3111
3275
|
elif k == "ecdfplot":
|
3112
3276
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3113
|
-
kws_ecdf={k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3114
|
-
ax = sns.ecdfplot(data=data, x=x, ax=ax,zorder=zorder, **kws_ecdf)
|
3277
|
+
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3278
|
+
ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
|
3115
3279
|
elif k == "rugplot":
|
3116
3280
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3117
|
-
kws_rug={k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3118
|
-
ax = sns.rugplot(data=data, x=x, ax=ax,zorder=zorder, **kws_rug)
|
3281
|
+
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3282
|
+
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3119
3283
|
elif k == "stripplot":
|
3120
3284
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3121
|
-
kws_strip={k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3285
|
+
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
3122
3286
|
dodge = kws_strip.pop("dodge", True)
|
3123
|
-
ax = sns.stripplot(
|
3287
|
+
ax = sns.stripplot(
|
3288
|
+
data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
|
3289
|
+
)
|
3124
3290
|
elif k == "swarmplot":
|
3125
3291
|
kws_swarm = kwargs.pop("kws_swarm", kwargs)
|
3126
|
-
kws_swarm={k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3127
|
-
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_swarm)
|
3292
|
+
kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
|
3293
|
+
ax = sns.swarmplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_swarm)
|
3128
3294
|
elif k == "boxplot":
|
3129
3295
|
kws_box = kwargs.pop("kws_box", kwargs)
|
3130
|
-
kws_box={k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3131
|
-
ax = sns.boxplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_box)
|
3296
|
+
kws_box = {k: v for k, v in kws_box.items() if not k.startswith("kws_")}
|
3297
|
+
ax = sns.boxplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_box)
|
3132
3298
|
elif k == "violinplot":
|
3133
3299
|
kws_violin = kwargs.pop("kws_violin", kwargs)
|
3134
|
-
kws_violin={
|
3135
|
-
|
3300
|
+
kws_violin = {
|
3301
|
+
k: v for k, v in kws_violin.items() if not k.startswith("kws_")
|
3302
|
+
}
|
3303
|
+
ax = sns.violinplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_violin)
|
3136
3304
|
elif k == "boxenplot":
|
3137
3305
|
kws_boxen = kwargs.pop("kws_boxen", kwargs)
|
3138
|
-
kws_boxen={k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3139
|
-
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_boxen)
|
3306
|
+
kws_boxen = {k: v for k, v in kws_boxen.items() if not k.startswith("kws_")}
|
3307
|
+
ax = sns.boxenplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_boxen)
|
3140
3308
|
elif k == "pointplot":
|
3141
3309
|
kws_point = kwargs.pop("kws_point", kwargs)
|
3142
|
-
kws_point={k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3143
|
-
ax = sns.pointplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_point)
|
3310
|
+
kws_point = {k: v for k, v in kws_point.items() if not k.startswith("kws_")}
|
3311
|
+
ax = sns.pointplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_point)
|
3144
3312
|
elif k == "barplot":
|
3145
3313
|
kws_bar = kwargs.pop("kws_bar", kwargs)
|
3146
|
-
kws_bar={k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3147
|
-
ax = sns.barplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_bar)
|
3314
|
+
kws_bar = {k: v for k, v in kws_bar.items() if not k.startswith("kws_")}
|
3315
|
+
ax = sns.barplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_bar)
|
3148
3316
|
elif k == "countplot":
|
3149
3317
|
kws_count = kwargs.pop("kws_count", kwargs)
|
3150
|
-
kws_count={k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3151
|
-
if not kws_count.get("hue",None):
|
3152
|
-
kws_count.pop("palette",None)
|
3153
|
-
ax = sns.countplot(data=data, x=x,y=y, ax=ax,zorder=zorder, **kws_count)
|
3318
|
+
kws_count = {k: v for k, v in kws_count.items() if not k.startswith("kws_")}
|
3319
|
+
if not kws_count.get("hue", None):
|
3320
|
+
kws_count.pop("palette", None)
|
3321
|
+
ax = sns.countplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_count)
|
3154
3322
|
elif k == "regplot":
|
3155
3323
|
kws_reg = kwargs.pop("kws_reg", kwargs)
|
3156
|
-
kws_reg={k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3157
|
-
ax = sns.regplot(data=data, x=x, y=y, ax=ax,zorder=zorder, **kws_reg)
|
3324
|
+
kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
|
3325
|
+
ax = sns.regplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_reg)
|
3158
3326
|
elif k == "residplot":
|
3159
3327
|
kws_resid = kwargs.pop("kws_resid", kwargs)
|
3160
|
-
kws_resid={k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3161
|
-
ax = sns.residplot(
|
3328
|
+
kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
|
3329
|
+
ax = sns.residplot(
|
3330
|
+
data=data, x=x, y=y, lowess=True, zorder=zorder, ax=ax, **kws_resid
|
3331
|
+
)
|
3162
3332
|
elif k == "lineplot":
|
3163
3333
|
kws_line = kwargs.pop("kws_line", kwargs)
|
3164
|
-
kws_line={k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3165
|
-
ax = sns.lineplot(ax=ax, data=data, x=x, y=y,zorder=zorder, **kws_line)
|
3334
|
+
kws_line = {k: v for k, v in kws_line.items() if not k.startswith("kws_")}
|
3335
|
+
ax = sns.lineplot(ax=ax, data=data, x=x, y=y, zorder=zorder, **kws_line)
|
3166
3336
|
|
3167
3337
|
figsets(ax=ax, **kws_figsets)
|
3168
3338
|
if kws_add_text:
|
@@ -3176,20 +3346,22 @@ def plotxy(
|
|
3176
3346
|
return g, ax
|
3177
3347
|
return ax
|
3178
3348
|
|
3349
|
+
|
3179
3350
|
def norm_cmap(data, cmap="coolwarm", min_max=[0, 1]):
|
3180
3351
|
norm_ = plt.Normalize(min_max[0], min_max[1])
|
3181
3352
|
colormap = plt.get_cmap(cmap)
|
3182
3353
|
return colormap(norm_(data))
|
3183
3354
|
|
3355
|
+
|
3184
3356
|
def volcano(
|
3185
|
-
data:pd.DataFrame,
|
3186
|
-
x:str,
|
3187
|
-
y:str,
|
3188
|
-
gene_col:str=None,
|
3189
|
-
top_genes=[5, 5],
|
3190
|
-
thr_x=np.log2(1.5),
|
3357
|
+
data: pd.DataFrame,
|
3358
|
+
x: str,
|
3359
|
+
y: str,
|
3360
|
+
gene_col: str = None,
|
3361
|
+
top_genes=[5, 5], # [down-regulated, up-regulated]
|
3362
|
+
thr_x=np.log2(1.5), # default: 0.585
|
3191
3363
|
thr_y=-np.log10(0.05),
|
3192
|
-
sort_xy="x",
|
3364
|
+
sort_xy="x", #'y', 'xy'
|
3193
3365
|
colors=("#00BFFF", "#9d9a9a", "#FF3030"),
|
3194
3366
|
s=20,
|
3195
3367
|
fill=True, # plot filled scatter
|
@@ -3201,11 +3373,10 @@ def volcano(
|
|
3201
3373
|
ax=None,
|
3202
3374
|
verbose=False,
|
3203
3375
|
kws_text=dict(fontsize=10, color="k"),
|
3204
|
-
kws_bbox=dict(
|
3205
|
-
|
3206
|
-
|
3207
|
-
|
3208
|
-
kws_arrow=dict(color="k", lw=0.5),# '{}' to hide
|
3376
|
+
kws_bbox=dict(
|
3377
|
+
facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"
|
3378
|
+
), # '{}' to hide
|
3379
|
+
kws_arrow=dict(color="k", lw=0.5), # '{}' to hide
|
3209
3380
|
**kwargs,
|
3210
3381
|
):
|
3211
3382
|
"""
|
@@ -3275,39 +3446,35 @@ def volcano(
|
|
3275
3446
|
kws_figsets = v_arg
|
3276
3447
|
kwargs.pop(k_arg, None)
|
3277
3448
|
break
|
3278
|
-
|
3279
|
-
data=data.copy()
|
3449
|
+
|
3450
|
+
data = data.copy()
|
3280
3451
|
# filter nan
|
3281
3452
|
data = data.dropna(subset=[x, y]) # Drop rows with NaN in x or y
|
3282
|
-
data.loc[:,"color"] = np.where(
|
3453
|
+
data.loc[:, "color"] = np.where(
|
3283
3454
|
(data[x] > thr_x) & (data[y] > thr_y),
|
3284
3455
|
colors[2],
|
3285
|
-
np.where((data[x] < -thr_x) & (data[y] > thr_y),
|
3286
|
-
colors[0],
|
3287
|
-
colors[1]),
|
3456
|
+
np.where((data[x] < -thr_x) & (data[y] > thr_y), colors[0], colors[1]),
|
3288
3457
|
)
|
3289
|
-
top_genes=[top_genes, top_genes] if isinstance(top_genes,int) else top_genes
|
3290
|
-
|
3458
|
+
top_genes = [top_genes, top_genes] if isinstance(top_genes, int) else top_genes
|
3459
|
+
|
3291
3460
|
# could custom how to select the top genes, x: x has priority
|
3292
|
-
sort_by_x_y=[x,y] if sort_xy=="x" else [y,x]
|
3293
|
-
ascending_up=[True, True] if sort_xy=="x" else [False, True]
|
3294
|
-
ascending_down=[False, True] if sort_xy=="x" else [False, False]
|
3295
|
-
|
3296
|
-
down_reg_genes =
|
3297
|
-
(data["color"] == colors[0]) &
|
3298
|
-
|
3299
|
-
(
|
3300
|
-
|
3301
|
-
up_reg_genes =
|
3302
|
-
(data["color"] == colors[2]) &
|
3303
|
-
|
3304
|
-
(
|
3305
|
-
|
3461
|
+
sort_by_x_y = [x, y] if sort_xy == "x" else [y, x]
|
3462
|
+
ascending_up = [True, True] if sort_xy == "x" else [False, True]
|
3463
|
+
ascending_down = [False, True] if sort_xy == "x" else [False, False]
|
3464
|
+
|
3465
|
+
down_reg_genes = (
|
3466
|
+
data[(data["color"] == colors[0]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3467
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_up)
|
3468
|
+
.head(top_genes[0])
|
3469
|
+
)
|
3470
|
+
up_reg_genes = (
|
3471
|
+
data[(data["color"] == colors[2]) & (data[x].abs() > thr_x) & (data[y] > thr_y)]
|
3472
|
+
.sort_values(by=sort_by_x_y, ascending=ascending_down)
|
3473
|
+
.head(top_genes[1])
|
3474
|
+
)
|
3306
3475
|
sele_gene = pd.concat([down_reg_genes, up_reg_genes])
|
3307
|
-
|
3308
|
-
palette = {colors[0]: colors[0],
|
3309
|
-
colors[1]: colors[1],
|
3310
|
-
colors[2]: colors[2]}
|
3476
|
+
|
3477
|
+
palette = {colors[0]: colors[0], colors[1]: colors[1], colors[2]: colors[2]}
|
3311
3478
|
# Plot setup
|
3312
3479
|
if ax is None:
|
3313
3480
|
ax = plt.gca()
|
@@ -3337,9 +3504,9 @@ def volcano(
|
|
3337
3504
|
)
|
3338
3505
|
|
3339
3506
|
# Add threshold lines for x and y axes
|
3340
|
-
ax.axhline(y=thr_y, color="black", linestyle="--",lw=1)
|
3341
|
-
ax.axvline(x=-thr_x, color="black", linestyle="--",lw=1)
|
3342
|
-
ax.axvline(x=thr_x, color="black", linestyle="--",lw=1)
|
3507
|
+
ax.axhline(y=thr_y, color="black", linestyle="--", lw=1)
|
3508
|
+
ax.axvline(x=-thr_x, color="black", linestyle="--", lw=1)
|
3509
|
+
ax.axvline(x=thr_x, color="black", linestyle="--", lw=1)
|
3343
3510
|
|
3344
3511
|
# Add gene labels for selected significant points
|
3345
3512
|
if gene_col:
|
@@ -3349,19 +3516,28 @@ def volcano(
|
|
3349
3516
|
textcolor = kws_text.pop("color", "k")
|
3350
3517
|
fontsize = kws_text.pop("fontsize", 10)
|
3351
3518
|
arrowstyles = [
|
3352
|
-
"->",
|
3353
|
-
"
|
3354
|
-
"
|
3519
|
+
"->",
|
3520
|
+
"<-",
|
3521
|
+
"<->",
|
3522
|
+
"<|-",
|
3523
|
+
"-|>",
|
3524
|
+
"<|-|>",
|
3525
|
+
"-",
|
3526
|
+
"-[",
|
3527
|
+
"-[",
|
3528
|
+
"fancy",
|
3529
|
+
"simple",
|
3530
|
+
"wedge",
|
3355
3531
|
]
|
3356
3532
|
arrowstyle = kws_arrow.pop("style", "<|-")
|
3357
|
-
arrowstyle = strcmp(arrowstyle, arrowstyles,scorer=
|
3358
|
-
expand=kws_arrow.pop("expand",(1.05,1.1))
|
3533
|
+
arrowstyle = strcmp(arrowstyle, arrowstyles, scorer="strict")[0]
|
3534
|
+
expand = kws_arrow.pop("expand", (1.05, 1.1))
|
3359
3535
|
arrowcolor = kws_arrow.pop("color", "0.4")
|
3360
3536
|
arrowlinewidth = kws_arrow.pop("lw", 0.75)
|
3361
3537
|
shrinkA = kws_arrow.pop("shrinkA", 0)
|
3362
3538
|
shrinkB = kws_arrow.pop("shrinkB", 0)
|
3363
3539
|
mutation_scale = kws_arrow.pop("head", 10)
|
3364
|
-
arrow_fill=kws_arrow.pop("fill", False)
|
3540
|
+
arrow_fill = kws_arrow.pop("fill", False)
|
3365
3541
|
for i in range(sele_gene.shape[0]):
|
3366
3542
|
if isinstance(textcolor, list): # be consistant with dots's color
|
3367
3543
|
textcolor = colors[0] if sele_gene[x].iloc[i] > 0 else colors[1]
|
@@ -3382,7 +3558,7 @@ def volcano(
|
|
3382
3558
|
adjust_text(
|
3383
3559
|
texts,
|
3384
3560
|
expand=expand,
|
3385
|
-
min_arrow_len=5,
|
3561
|
+
min_arrow_len=5,
|
3386
3562
|
ax=ax,
|
3387
3563
|
arrowprops=dict(
|
3388
3564
|
arrowstyle=arrowstyle,
|
@@ -3393,7 +3569,7 @@ def volcano(
|
|
3393
3569
|
shrinkB=shrinkB,
|
3394
3570
|
mutation_scale=mutation_scale,
|
3395
3571
|
**kws_arrow,
|
3396
|
-
)
|
3572
|
+
),
|
3397
3573
|
)
|
3398
3574
|
|
3399
3575
|
figsets(**kws_figsets)
|
@@ -3499,6 +3675,7 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3499
3675
|
"""Reduce the saturation of a color by a given factor (between 0 and 1)."""
|
3500
3676
|
import matplotlib.colors as mcolors
|
3501
3677
|
import colorsys
|
3678
|
+
|
3502
3679
|
# Convert the color to RGB
|
3503
3680
|
rgb = mcolors.to_rgb(color)
|
3504
3681
|
# Convert RGB to HLS (Hue, Lightness, Saturation)
|
@@ -3508,8 +3685,19 @@ def desaturate_color(color, saturation_factor=0.5):
|
|
3508
3685
|
# Convert back to RGB
|
3509
3686
|
return colorsys.hls_to_rgb(h, l, s)
|
3510
3687
|
|
3511
|
-
|
3512
|
-
|
3688
|
+
|
3689
|
+
def textsets(
|
3690
|
+
text,
|
3691
|
+
fontname="Arial",
|
3692
|
+
fontsize=11,
|
3693
|
+
fontweight="normal",
|
3694
|
+
fontstyle="normal",
|
3695
|
+
fontcolor="k",
|
3696
|
+
backgroundcolor=None,
|
3697
|
+
shadow=False,
|
3698
|
+
ha="center",
|
3699
|
+
va="center",
|
3700
|
+
):
|
3513
3701
|
if text: # Ensure text exists
|
3514
3702
|
if fontname:
|
3515
3703
|
text.set_fontname(plt_font(fontname))
|
@@ -3522,26 +3710,28 @@ def textsets(text, fontname='Arial', fontsize=11, fontweight="normal", fontstyle
|
|
3522
3710
|
if fontcolor:
|
3523
3711
|
text.set_color(fontcolor)
|
3524
3712
|
if backgroundcolor:
|
3525
|
-
text.set_backgroundcolor(backgroundcolor)
|
3713
|
+
text.set_backgroundcolor(backgroundcolor)
|
3526
3714
|
text.set_horizontalalignment(ha)
|
3527
|
-
text.set_verticalalignment(va)
|
3715
|
+
text.set_verticalalignment(va)
|
3528
3716
|
if shadow:
|
3529
|
-
text.set_path_effects(
|
3530
|
-
matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")
|
3531
|
-
|
3717
|
+
text.set_path_effects(
|
3718
|
+
[matplotlib.patheffects.withStroke(linewidth=3, foreground="gray")]
|
3719
|
+
)
|
3720
|
+
|
3721
|
+
|
3532
3722
|
def venn(
|
3533
|
-
lists:list,
|
3534
|
-
labels:list=None,
|
3723
|
+
lists: list,
|
3724
|
+
labels: list = None,
|
3535
3725
|
ax=None,
|
3536
3726
|
colors=None,
|
3537
3727
|
edgecolor=None,
|
3538
3728
|
alpha=0.5,
|
3539
|
-
saturation
|
3540
|
-
linewidth=0,
|
3729
|
+
saturation=0.75,
|
3730
|
+
linewidth=0, # default no edge
|
3541
3731
|
linestyle="-",
|
3542
|
-
fontname=
|
3732
|
+
fontname="Arial",
|
3543
3733
|
fontsize=10,
|
3544
|
-
fontcolor=
|
3734
|
+
fontcolor="k",
|
3545
3735
|
fontweight="normal",
|
3546
3736
|
fontstyle="normal",
|
3547
3737
|
ha="center",
|
@@ -3550,18 +3740,18 @@ def venn(
|
|
3550
3740
|
subset_fontsize=10,
|
3551
3741
|
subset_fontweight="normal",
|
3552
3742
|
subset_fontstyle="normal",
|
3553
|
-
subset_fontcolor=
|
3743
|
+
subset_fontcolor="k",
|
3554
3744
|
backgroundcolor=None,
|
3555
3745
|
custom_texts=None,
|
3556
|
-
show_percentages=True,
|
3746
|
+
show_percentages=True, # display percentage
|
3557
3747
|
fmt="{:.1%}",
|
3558
3748
|
ellipse_shape=False, # 椭圆形
|
3559
|
-
ellipse_scale=[1.5, 1], #not perfect, 椭圆形的形状
|
3560
|
-
**kwargs
|
3749
|
+
ellipse_scale=[1.5, 1], # not perfect, 椭圆形的形状
|
3750
|
+
**kwargs,
|
3561
3751
|
):
|
3562
3752
|
"""
|
3563
3753
|
Advanced Venn diagram plotting function with extensive customization options.
|
3564
|
-
Usage:
|
3754
|
+
Usage:
|
3565
3755
|
# Define the two sets
|
3566
3756
|
set1 = [1, 2, 3, 4, 5]
|
3567
3757
|
set2 = [4, 5, 6, 7, 8]
|
@@ -3612,56 +3802,70 @@ def venn(
|
|
3612
3802
|
"""
|
3613
3803
|
if ax is None:
|
3614
3804
|
ax = plt.gca()
|
3615
|
-
lists=[set(flatten(i, verbose=False)) for i in lists]
|
3805
|
+
lists = [set(flatten(i, verbose=False)) for i in lists]
|
3616
3806
|
# Function to apply text styles to labels
|
3617
3807
|
if colors is None:
|
3618
|
-
colors=["r","b"] if len(lists)==2 else ["r","g","b"]
|
3808
|
+
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
3619
3809
|
if labels is None:
|
3620
|
-
labels=["set1","set2"] if len(lists)==2 else ["set1","set2","set3"]
|
3810
|
+
labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
|
3621
3811
|
if edgecolor is None:
|
3622
|
-
edgecolor=colors
|
3812
|
+
edgecolor = colors
|
3623
3813
|
colors = [desaturate_color(color, saturation) for color in colors]
|
3624
3814
|
# Check colors and auto-calculate overlaps
|
3625
3815
|
if len(lists) == 2:
|
3816
|
+
|
3626
3817
|
def get_count_and_percentage(set_count, subset_count):
|
3627
3818
|
percent = subset_count / set_count if set_count > 0 else 0
|
3628
|
-
return
|
3819
|
+
return (
|
3820
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3821
|
+
if show_percentages
|
3822
|
+
else f"{subset_count}"
|
3823
|
+
)
|
3629
3824
|
|
3630
3825
|
from matplotlib_venn import venn2, venn2_circles
|
3826
|
+
|
3631
3827
|
# Auto-calculate overlap color for 2-set Venn diagram
|
3632
3828
|
overlap_color = get_color_overlap(colors[0], colors[1]) if colors else None
|
3633
|
-
|
3829
|
+
|
3634
3830
|
# Draw the venn diagram
|
3635
3831
|
v = venn2(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3636
3832
|
venn_circles = venn2_circles(subsets=lists, ax=ax)
|
3637
|
-
set1,set2=lists[0],lists[1]
|
3638
|
-
v.get_patch_by_id(
|
3639
|
-
v.get_patch_by_id(
|
3640
|
-
v.get_patch_by_id(
|
3833
|
+
set1, set2 = lists[0], lists[1]
|
3834
|
+
v.get_patch_by_id("10").set_color(colors[0])
|
3835
|
+
v.get_patch_by_id("01").set_color(colors[1])
|
3836
|
+
v.get_patch_by_id("11").set_color(
|
3837
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
3838
|
+
)
|
3641
3839
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
3642
3840
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
3643
3841
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
3644
|
-
v.get_label_by_id(
|
3645
|
-
|
3646
|
-
|
3647
|
-
|
3842
|
+
v.get_label_by_id("10").set_text(
|
3843
|
+
get_count_and_percentage(len(set1), len(set1 - set2))
|
3844
|
+
)
|
3845
|
+
v.get_label_by_id("01").set_text(
|
3846
|
+
get_count_and_percentage(len(set2), len(set2 - set1))
|
3847
|
+
)
|
3848
|
+
v.get_label_by_id("11").set_text(
|
3849
|
+
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
3850
|
+
)
|
3648
3851
|
|
3649
|
-
if not isinstance(linewidth,list):
|
3650
|
-
linewidth=[linewidth]
|
3651
|
-
if isinstance(linestyle,str):
|
3652
|
-
linestyle=[linestyle]
|
3852
|
+
if not isinstance(linewidth, list):
|
3853
|
+
linewidth = [linewidth]
|
3854
|
+
if isinstance(linestyle, str):
|
3855
|
+
linestyle = [linestyle]
|
3653
3856
|
if not isinstance(edgecolor, list):
|
3654
|
-
edgecolor=[edgecolor]
|
3655
|
-
linewidth=linewidth*2 if len(linewidth)==1 else linewidth
|
3656
|
-
linestyle=linestyle*2 if len(linestyle)==1 else linestyle
|
3657
|
-
edgecolor=edgecolor*2 if len(edgecolor)==1 else edgecolor
|
3857
|
+
edgecolor = [edgecolor]
|
3858
|
+
linewidth = linewidth * 2 if len(linewidth) == 1 else linewidth
|
3859
|
+
linestyle = linestyle * 2 if len(linestyle) == 1 else linestyle
|
3860
|
+
edgecolor = edgecolor * 2 if len(edgecolor) == 1 else edgecolor
|
3658
3861
|
for i in range(2):
|
3659
3862
|
venn_circles[i].set_lw(linewidth[i])
|
3660
3863
|
venn_circles[i].set_ls(linestyle[i])
|
3661
3864
|
venn_circles[i].set_edgecolor(edgecolor[i])
|
3662
3865
|
# 椭圆
|
3663
3866
|
if ellipse_shape:
|
3664
|
-
import matplotlib.patches as patches
|
3867
|
+
import matplotlib.patches as patches
|
3868
|
+
|
3665
3869
|
for patch in v.patches:
|
3666
3870
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3667
3871
|
center1 = v.get_circle_center(0)
|
@@ -3672,9 +3876,13 @@ def venn(
|
|
3672
3876
|
height=ellipse_scale[1],
|
3673
3877
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3674
3878
|
facecolor=colors[0],
|
3675
|
-
lw=
|
3879
|
+
lw=(
|
3880
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3881
|
+
), # Ensure lw is a number
|
3676
3882
|
ls=linestyle[0],
|
3677
|
-
alpha=
|
3883
|
+
alpha=(
|
3884
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3885
|
+
), # Ensure alpha is a number
|
3678
3886
|
)
|
3679
3887
|
ellipse2 = patches.Ellipse(
|
3680
3888
|
(center2.x, center2.y),
|
@@ -3682,48 +3890,78 @@ def venn(
|
|
3682
3890
|
height=ellipse_scale[1],
|
3683
3891
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3684
3892
|
facecolor=colors[1],
|
3685
|
-
lw=
|
3893
|
+
lw=(
|
3894
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
3895
|
+
), # Ensure lw is a number
|
3686
3896
|
ls=linestyle[0],
|
3687
|
-
alpha=
|
3897
|
+
alpha=(
|
3898
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
3899
|
+
), # Ensure alpha is a number
|
3688
3900
|
)
|
3689
3901
|
ax.add_patch(ellipse1)
|
3690
3902
|
ax.add_patch(ellipse2)
|
3691
3903
|
# Apply styles to set labels
|
3692
3904
|
for i, text in enumerate(v.set_labels):
|
3693
|
-
textsets(
|
3694
|
-
|
3905
|
+
textsets(
|
3906
|
+
text,
|
3907
|
+
fontname=fontname,
|
3908
|
+
fontsize=fontsize,
|
3909
|
+
fontweight=fontweight,
|
3910
|
+
fontstyle=fontstyle,
|
3911
|
+
fontcolor=fontcolor,
|
3912
|
+
ha=ha,
|
3913
|
+
va=va,
|
3914
|
+
shadow=shadow,
|
3915
|
+
)
|
3695
3916
|
|
3696
3917
|
# Apply styles to subset labels
|
3697
3918
|
for i, text in enumerate(v.subset_labels):
|
3698
3919
|
if text: # Ensure text exists
|
3699
3920
|
if custom_texts: # Custom text handling
|
3700
|
-
text.set_text(custom_texts[i])
|
3701
|
-
textsets(
|
3702
|
-
|
3921
|
+
text.set_text(custom_texts[i])
|
3922
|
+
textsets(
|
3923
|
+
text,
|
3924
|
+
fontname=fontname,
|
3925
|
+
fontsize=subset_fontsize,
|
3926
|
+
fontweight=subset_fontweight,
|
3927
|
+
fontstyle=subset_fontstyle,
|
3928
|
+
fontcolor=subset_fontcolor,
|
3929
|
+
ha=ha,
|
3930
|
+
va=va,
|
3931
|
+
shadow=shadow,
|
3932
|
+
)
|
3703
3933
|
|
3704
3934
|
elif len(lists) == 3:
|
3935
|
+
|
3705
3936
|
def get_label(set_count, subset_count):
|
3706
3937
|
percent = subset_count / set_count if set_count > 0 else 0
|
3707
|
-
return
|
3708
|
-
|
3938
|
+
return (
|
3939
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
3940
|
+
if show_percentages
|
3941
|
+
else f"{subset_count}"
|
3942
|
+
)
|
3943
|
+
|
3709
3944
|
from matplotlib_venn import venn3, venn3_circles
|
3945
|
+
|
3710
3946
|
# Auto-calculate overlap colors for 3-set Venn diagram
|
3711
3947
|
colorAB = get_color_overlap(colors[0], colors[1]) if colors else None
|
3712
3948
|
colorAC = get_color_overlap(colors[0], colors[2]) if colors else None
|
3713
3949
|
colorBC = get_color_overlap(colors[1], colors[2]) if colors else None
|
3714
|
-
colorABC =
|
3715
|
-
|
3950
|
+
colorABC = (
|
3951
|
+
get_color_overlap(colors[0], colors[1], colors[2]) if colors else None
|
3952
|
+
)
|
3953
|
+
set1, set2, set3 = lists[0], lists[1], lists[2]
|
3716
3954
|
|
3717
3955
|
# Draw the venn diagram
|
3718
|
-
v = venn3(subsets=lists, set_labels=labels, ax=ax
|
3719
|
-
v.get_patch_by_id(
|
3720
|
-
v.get_patch_by_id(
|
3721
|
-
v.get_patch_by_id(
|
3722
|
-
v.get_patch_by_id(
|
3723
|
-
v.get_patch_by_id(
|
3724
|
-
v.get_patch_by_id(
|
3725
|
-
v.get_patch_by_id(
|
3726
|
-
|
3956
|
+
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
3957
|
+
v.get_patch_by_id("100").set_color(colors[0])
|
3958
|
+
v.get_patch_by_id("010").set_color(colors[1])
|
3959
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
3960
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
3961
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
3962
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
3963
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
3964
|
+
|
3727
3965
|
# Correctly labeling subset sizes
|
3728
3966
|
# v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
|
3729
3967
|
# v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
|
@@ -3732,63 +3970,93 @@ def venn(
|
|
3732
3970
|
# v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
|
3733
3971
|
# v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
|
3734
3972
|
# v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
|
3735
|
-
v.get_label_by_id(
|
3736
|
-
v.get_label_by_id(
|
3737
|
-
v.get_label_by_id(
|
3738
|
-
v.get_label_by_id(
|
3739
|
-
|
3740
|
-
|
3741
|
-
v.get_label_by_id(
|
3742
|
-
|
3973
|
+
v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
|
3974
|
+
v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
|
3975
|
+
v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
|
3976
|
+
v.get_label_by_id("110").set_text(
|
3977
|
+
get_label(len(set1 | set2), len(set1 & set2 - set3))
|
3978
|
+
)
|
3979
|
+
v.get_label_by_id("101").set_text(
|
3980
|
+
get_label(len(set1 | set3), len(set1 & set3 - set2))
|
3981
|
+
)
|
3982
|
+
v.get_label_by_id("011").set_text(
|
3983
|
+
get_label(len(set2 | set3), len(set2 & set3 - set1))
|
3984
|
+
)
|
3985
|
+
v.get_label_by_id("111").set_text(
|
3986
|
+
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
3987
|
+
)
|
3743
3988
|
|
3744
3989
|
# Apply styles to set labels
|
3745
3990
|
for i, text in enumerate(v.set_labels):
|
3746
|
-
textsets(
|
3747
|
-
|
3991
|
+
textsets(
|
3992
|
+
text,
|
3993
|
+
fontname=fontname,
|
3994
|
+
fontsize=fontsize,
|
3995
|
+
fontweight=fontweight,
|
3996
|
+
fontstyle=fontstyle,
|
3997
|
+
fontcolor=fontcolor,
|
3998
|
+
ha=ha,
|
3999
|
+
va=va,
|
4000
|
+
shadow=shadow,
|
4001
|
+
)
|
3748
4002
|
|
3749
4003
|
# Apply styles to subset labels
|
3750
4004
|
for i, text in enumerate(v.subset_labels):
|
3751
4005
|
if text: # Ensure text exists
|
3752
4006
|
if custom_texts: # Custom text handling
|
3753
4007
|
text.set_text(custom_texts[i])
|
3754
|
-
textsets(
|
3755
|
-
|
4008
|
+
textsets(
|
4009
|
+
text,
|
4010
|
+
fontname=fontname,
|
4011
|
+
fontsize=subset_fontsize,
|
4012
|
+
fontweight=subset_fontweight,
|
4013
|
+
fontstyle=subset_fontstyle,
|
4014
|
+
fontcolor=subset_fontcolor,
|
4015
|
+
ha=ha,
|
4016
|
+
va=va,
|
4017
|
+
shadow=shadow,
|
4018
|
+
)
|
3756
4019
|
|
3757
4020
|
venn_circles = venn3_circles(subsets=lists, ax=ax)
|
3758
|
-
if not isinstance(linewidth,list):
|
3759
|
-
linewidth=[linewidth]
|
3760
|
-
if isinstance(linestyle,str):
|
3761
|
-
linestyle=[linestyle]
|
4021
|
+
if not isinstance(linewidth, list):
|
4022
|
+
linewidth = [linewidth]
|
4023
|
+
if isinstance(linestyle, str):
|
4024
|
+
linestyle = [linestyle]
|
3762
4025
|
if not isinstance(edgecolor, list):
|
3763
|
-
edgecolor=[edgecolor]
|
3764
|
-
linewidth=linewidth*3 if len(linewidth)==1 else linewidth
|
3765
|
-
linestyle=linestyle*3 if len(linestyle)==1 else linestyle
|
3766
|
-
edgecolor=edgecolor*3 if len(edgecolor)==1 else edgecolor
|
4026
|
+
edgecolor = [edgecolor]
|
4027
|
+
linewidth = linewidth * 3 if len(linewidth) == 1 else linewidth
|
4028
|
+
linestyle = linestyle * 3 if len(linestyle) == 1 else linestyle
|
4029
|
+
edgecolor = edgecolor * 3 if len(edgecolor) == 1 else edgecolor
|
3767
4030
|
|
3768
4031
|
# edgecolor=[to_rgba(i) for i in edgecolor]
|
3769
4032
|
|
3770
|
-
for i in range(3):
|
4033
|
+
for i in range(3):
|
3771
4034
|
venn_circles[i].set_lw(linewidth[i])
|
3772
|
-
venn_circles[i].set_ls(linestyle[i])
|
3773
|
-
venn_circles[i].set_edgecolor(edgecolor[i])
|
4035
|
+
venn_circles[i].set_ls(linestyle[i])
|
4036
|
+
venn_circles[i].set_edgecolor(edgecolor[i])
|
3774
4037
|
|
3775
|
-
|
4038
|
+
# 椭圆形
|
3776
4039
|
if ellipse_shape:
|
3777
|
-
import matplotlib.patches as patches
|
4040
|
+
import matplotlib.patches as patches
|
4041
|
+
|
3778
4042
|
for patch in v.patches:
|
3779
4043
|
patch.set_visible(False) # Hide original patches if using ellipses
|
3780
4044
|
center1 = v.get_circle_center(0)
|
3781
4045
|
center2 = v.get_circle_center(1)
|
3782
|
-
center3 = v.get_circle_center(2)
|
4046
|
+
center3 = v.get_circle_center(2)
|
3783
4047
|
ellipse1 = patches.Ellipse(
|
3784
4048
|
(center1.x, center1.y),
|
3785
4049
|
width=ellipse_scale[0],
|
3786
4050
|
height=ellipse_scale[1],
|
3787
4051
|
edgecolor=edgecolor[0] if edgecolor else colors[0],
|
3788
4052
|
facecolor=colors[0],
|
3789
|
-
lw=
|
4053
|
+
lw=(
|
4054
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4055
|
+
), # Ensure lw is a number
|
3790
4056
|
ls=linestyle[0],
|
3791
|
-
alpha=
|
4057
|
+
alpha=(
|
4058
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4059
|
+
), # Ensure alpha is a number
|
3792
4060
|
)
|
3793
4061
|
ellipse2 = patches.Ellipse(
|
3794
4062
|
(center2.x, center2.y),
|
@@ -3796,9 +4064,13 @@ def venn(
|
|
3796
4064
|
height=ellipse_scale[1],
|
3797
4065
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3798
4066
|
facecolor=colors[1],
|
3799
|
-
lw=
|
4067
|
+
lw=(
|
4068
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4069
|
+
), # Ensure lw is a number
|
3800
4070
|
ls=linestyle[0],
|
3801
|
-
alpha=
|
4071
|
+
alpha=(
|
4072
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4073
|
+
), # Ensure alpha is a number
|
3802
4074
|
)
|
3803
4075
|
ellipse3 = patches.Ellipse(
|
3804
4076
|
(center3.x, center3.y),
|
@@ -3806,9 +4078,13 @@ def venn(
|
|
3806
4078
|
height=ellipse_scale[1],
|
3807
4079
|
edgecolor=edgecolor[1] if edgecolor else colors[1],
|
3808
4080
|
facecolor=colors[1],
|
3809
|
-
lw=
|
4081
|
+
lw=(
|
4082
|
+
linewidth if isinstance(linewidth, (int, float)) else 1.0
|
4083
|
+
), # Ensure lw is a number
|
3810
4084
|
ls=linestyle[0],
|
3811
|
-
alpha=
|
4085
|
+
alpha=(
|
4086
|
+
alpha if isinstance(alpha, (int, float)) else 0.5
|
4087
|
+
), # Ensure alpha is a number
|
3812
4088
|
)
|
3813
4089
|
ax.add_patch(ellipse1)
|
3814
4090
|
ax.add_patch(ellipse2)
|
@@ -3820,17 +4096,20 @@ def venn(
|
|
3820
4096
|
for patch in v.patches:
|
3821
4097
|
if patch:
|
3822
4098
|
patch.set_alpha(alpha)
|
3823
|
-
if
|
4099
|
+
if "none" in edgecolor or 0 in linewidth:
|
3824
4100
|
patch.set_edgecolor("none")
|
3825
4101
|
return ax
|
3826
4102
|
|
4103
|
+
|
3827
4104
|
#! subplots, support automatic extend new axis
|
3828
|
-
def subplot(
|
3829
|
-
|
3830
|
-
|
3831
|
-
|
3832
|
-
|
3833
|
-
|
4105
|
+
def subplot(
|
4106
|
+
rows: int = 2,
|
4107
|
+
cols: int = 2,
|
4108
|
+
figsize: Union[tuple, list] = [8, 8],
|
4109
|
+
sharex=False,
|
4110
|
+
sharey=False,
|
4111
|
+
**kwargs,
|
4112
|
+
):
|
3834
4113
|
"""
|
3835
4114
|
nexttile = subplot(
|
3836
4115
|
8,
|
@@ -3850,11 +4129,14 @@ def subplot(rows:int=2,
|
|
3850
4129
|
ax.set_xlabel(f"Tile {i + 1}")
|
3851
4130
|
"""
|
3852
4131
|
from matplotlib.gridspec import GridSpec
|
4132
|
+
|
3853
4133
|
if run_once_within():
|
3854
|
-
print(
|
4134
|
+
print(
|
4135
|
+
f"usage:\n\tnexttile = subplot(2, 2, figsize=(5, 5), sharex=True, sharey=True)\n\tax = nexttile()"
|
4136
|
+
)
|
3855
4137
|
fig = plt.figure(figsize=figsize)
|
3856
4138
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
3857
|
-
occupied = set()
|
4139
|
+
occupied = set()
|
3858
4140
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
3859
4141
|
col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
|
3860
4142
|
|
@@ -3862,8 +4144,9 @@ def subplot(rows:int=2,
|
|
3862
4144
|
nonlocal rows, grid_spec
|
3863
4145
|
rows += 1 # Expands by adding a row
|
3864
4146
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4147
|
+
|
3865
4148
|
def nexttile(rowspan=1, colspan=1, **kwargs):
|
3866
|
-
nonlocal rows, cols, occupied, grid_spec
|
4149
|
+
nonlocal rows, cols, occupied, grid_spec
|
3867
4150
|
for row in range(rows):
|
3868
4151
|
for col in range(cols):
|
3869
4152
|
if all(
|
@@ -3873,29 +4156,29 @@ def subplot(rows:int=2,
|
|
3873
4156
|
):
|
3874
4157
|
break
|
3875
4158
|
else:
|
3876
|
-
continue
|
4159
|
+
continue
|
3877
4160
|
break
|
3878
|
-
else:
|
4161
|
+
else:
|
3879
4162
|
expand_ax()
|
3880
4163
|
return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
|
3881
4164
|
|
3882
|
-
sharex_ax,sharey_ax = None, None
|
4165
|
+
sharex_ax, sharey_ax = None, None
|
3883
4166
|
|
3884
|
-
if sharex:
|
4167
|
+
if sharex:
|
3885
4168
|
sharex_ax = col_first_axes[col]
|
3886
4169
|
|
3887
|
-
if sharey:
|
3888
|
-
sharey_ax = row_first_axes[row]
|
4170
|
+
if sharey:
|
4171
|
+
sharey_ax = row_first_axes[row]
|
3889
4172
|
ax = fig.add_subplot(
|
3890
4173
|
grid_spec[row : row + rowspan, col : col + colspan],
|
3891
4174
|
sharex=sharex_ax,
|
3892
4175
|
sharey=sharey_ax,
|
3893
|
-
**kwargs
|
3894
|
-
)
|
4176
|
+
**kwargs,
|
4177
|
+
)
|
3895
4178
|
if row_first_axes[row] is None:
|
3896
4179
|
row_first_axes[row] = ax
|
3897
4180
|
if col_first_axes[col] is None:
|
3898
|
-
col_first_axes[col] = ax
|
4181
|
+
col_first_axes[col] = ax
|
3899
4182
|
for r in range(row, row + rowspan):
|
3900
4183
|
for c in range(col, col + colspan):
|
3901
4184
|
occupied.add((r, c))
|
@@ -3907,17 +4190,17 @@ def subplot(rows:int=2,
|
|
3907
4190
|
|
3908
4191
|
#! radar chart
|
3909
4192
|
def radar(
|
3910
|
-
data: pd.DataFrame,
|
3911
|
-
ylim=(0,100),
|
4193
|
+
data: pd.DataFrame,
|
4194
|
+
ylim=(0, 100),
|
3912
4195
|
color=get_color(5),
|
3913
4196
|
fontsize=10,
|
3914
|
-
fontcolor=
|
4197
|
+
fontcolor="k",
|
3915
4198
|
size=6,
|
3916
4199
|
linewidth=1,
|
3917
4200
|
linestyle="-",
|
3918
4201
|
alpha=0.5,
|
3919
4202
|
marker="o",
|
3920
|
-
edgecolor=
|
4203
|
+
edgecolor="none",
|
3921
4204
|
edge_linewidth=0,
|
3922
4205
|
bg_color="0.8",
|
3923
4206
|
bg_alpha=None,
|
@@ -3928,18 +4211,19 @@ def radar(
|
|
3928
4211
|
legend_fontsize=10,
|
3929
4212
|
grid_color="gray",
|
3930
4213
|
grid_alpha=0.5,
|
3931
|
-
grid_linestyle="--",
|
4214
|
+
grid_linestyle="--",
|
4215
|
+
grid_linewidth=0.5,
|
3932
4216
|
circular: bool = False,
|
3933
4217
|
tick_fontsize=None,
|
3934
4218
|
tick_fontcolor="0.65",
|
3935
|
-
tick_loc
|
3936
|
-
turning
|
4219
|
+
tick_loc=None, # label position
|
4220
|
+
turning=None,
|
3937
4221
|
ax=None,
|
3938
4222
|
sp=2,
|
3939
|
-
**kwargs
|
4223
|
+
**kwargs,
|
3940
4224
|
):
|
3941
4225
|
"""
|
3942
|
-
Example DATA:
|
4226
|
+
Example DATA:
|
3943
4227
|
df = pd.DataFrame(
|
3944
4228
|
data=[
|
3945
4229
|
[80, 80, 80, 80, 80, 80, 80],
|
@@ -3948,7 +4232,7 @@ def radar(
|
|
3948
4232
|
],
|
3949
4233
|
index=["Hero", "Warrior", "Wizard"],
|
3950
4234
|
columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
|
3951
|
-
|
4235
|
+
|
3952
4236
|
Parameters:
|
3953
4237
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
3954
4238
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4002,8 +4286,12 @@ def radar(
|
|
4002
4286
|
|
4003
4287
|
# bg_color
|
4004
4288
|
if bg_alpha is None:
|
4005
|
-
bg_alpha=alpha
|
4006
|
-
|
4289
|
+
bg_alpha = alpha
|
4290
|
+
(
|
4291
|
+
ax.set_facecolor(to_rgba(bg_color, alpha=bg_alpha))
|
4292
|
+
if circular
|
4293
|
+
else ax.set_facecolor("none")
|
4294
|
+
)
|
4007
4295
|
# Set up the radar chart with straight-line connections
|
4008
4296
|
ax.set_theta_offset(np.pi / 2)
|
4009
4297
|
ax.set_theta_direction(-1)
|
@@ -4012,53 +4300,68 @@ def radar(
|
|
4012
4300
|
ax.set_xticks(angles[:-1])
|
4013
4301
|
ax.set_xticklabels(categories)
|
4014
4302
|
|
4015
|
-
# Set y-axis limits and grid intervals
|
4303
|
+
# Set y-axis limits and grid intervals
|
4016
4304
|
vmin, vmax = ylim
|
4017
|
-
if circular:
|
4018
|
-
|
4019
|
-
ax.yaxis.set_ticks(np.arange(vmin, vmax+1, vmax * grid_interval_ratio))
|
4020
|
-
ax.grid(
|
4021
|
-
|
4022
|
-
|
4023
|
-
|
4024
|
-
|
4025
|
-
|
4026
|
-
|
4027
|
-
|
4028
|
-
|
4305
|
+
if circular:
|
4306
|
+
# * cicular style
|
4307
|
+
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
4308
|
+
ax.grid(
|
4309
|
+
axis="both",
|
4310
|
+
color=grid_color,
|
4311
|
+
linestyle=grid_linestyle,
|
4312
|
+
alpha=grid_alpha,
|
4313
|
+
linewidth=grid_linewidth,
|
4314
|
+
dash_capstyle="round",
|
4315
|
+
dash_joinstyle="round",
|
4316
|
+
)
|
4317
|
+
ax.spines["polar"].set_color(grid_color)
|
4029
4318
|
ax.spines["polar"].set_linewidth(grid_linewidth)
|
4030
|
-
ax.spines["polar"].set_linestyle(
|
4319
|
+
ax.spines["polar"].set_linestyle("-")
|
4031
4320
|
ax.spines["polar"].set_alpha(grid_alpha)
|
4032
|
-
ax.spines["polar"].set_capstyle(
|
4033
|
-
ax.spines["polar"].set_joinstyle(
|
4321
|
+
ax.spines["polar"].set_capstyle("round")
|
4322
|
+
ax.spines["polar"].set_joinstyle("round")
|
4034
4323
|
|
4035
4324
|
else:
|
4036
|
-
|
4325
|
+
# * spider style: spider-style grid (straight lines, not circles)
|
4037
4326
|
# Create the spider-style grid (straight lines, not circles)
|
4038
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4327
|
+
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4039
4328
|
ax.plot(
|
4040
4329
|
angles + [angles[0]], # Closing the loop
|
4041
|
-
[i * vmax * grid_interval_ratio] * (num_vars+1)
|
4042
|
-
|
4330
|
+
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
4331
|
+
+ [i * vmax * grid_interval_ratio],
|
4332
|
+
color=grid_color,
|
4333
|
+
linestyle=grid_linestyle,
|
4334
|
+
alpha=grid_alpha,
|
4335
|
+
linewidth=grid_linewidth,
|
4043
4336
|
)
|
4044
|
-
# set bg_color
|
4045
|
-
ax.fill(angles, [vmax]*(data.shape[1]+1), color=bg_color, alpha=bg_alpha)
|
4337
|
+
# set bg_color
|
4338
|
+
ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
|
4046
4339
|
ax.yaxis.grid(False)
|
4047
4340
|
# Move radial labels away from plotted line
|
4048
4341
|
if tick_loc is None:
|
4049
|
-
tick_loc
|
4342
|
+
tick_loc = (
|
4343
|
+
np.mean([angles[0], angles[1]]) / (2 * np.pi) * 360 if circular else 0
|
4344
|
+
)
|
4050
4345
|
|
4051
4346
|
ax.set_rlabel_position(tick_loc)
|
4052
4347
|
ax.set_theta_offset(turning) if turning is not None else None
|
4053
|
-
ax.tick_params(
|
4054
|
-
|
4055
|
-
|
4348
|
+
ax.tick_params(
|
4349
|
+
axis="x", labelsize=fontsize, colors=fontcolor
|
4350
|
+
) # Optional: for angular labels
|
4351
|
+
tick_fontsize = fontsize - 2 if fontsize is None else tick_fontsize
|
4352
|
+
ax.tick_params(
|
4353
|
+
axis="y", labelsize=tick_fontsize, colors=tick_fontcolor
|
4354
|
+
) # For radial labels
|
4056
4355
|
if not circular:
|
4057
|
-
ax.spines[
|
4058
|
-
ax.tick_params(axis=
|
4059
|
-
ax.tick_params(axis=
|
4356
|
+
ax.spines["polar"].set_visible(False)
|
4357
|
+
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4358
|
+
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4060
4359
|
# colors
|
4061
|
-
colors =
|
4360
|
+
colors = (
|
4361
|
+
get_color(data.shape[0])
|
4362
|
+
if cmap is None
|
4363
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
4364
|
+
)
|
4062
4365
|
# Plot each row with straight lines
|
4063
4366
|
for i, (index, row) in enumerate(data.iterrows()):
|
4064
4367
|
values = row.tolist()
|
@@ -4070,7 +4373,7 @@ def radar(
|
|
4070
4373
|
linewidth=linewidth,
|
4071
4374
|
linestyle=linestyle,
|
4072
4375
|
label=index,
|
4073
|
-
clip_on=False
|
4376
|
+
clip_on=False,
|
4074
4377
|
)
|
4075
4378
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
4076
4379
|
|
@@ -4084,13 +4387,23 @@ def radar(
|
|
4084
4387
|
marker=marker,
|
4085
4388
|
markersize=size,
|
4086
4389
|
markeredgecolor=edgecolor,
|
4087
|
-
markeredgewidth
|
4390
|
+
markeredgewidth=edge_linewidth,
|
4391
|
+
zorder=10,
|
4392
|
+
clip_on=False,
|
4088
4393
|
)
|
4089
4394
|
# ax.tick_params(axis='y', labelleft=False, left=False)
|
4090
|
-
if
|
4395
|
+
if "legend" in kws_figsets:
|
4091
4396
|
figsets(ax=ax, **kws_figsets)
|
4092
4397
|
else:
|
4093
|
-
|
4094
|
-
figsets(
|
4095
|
-
|
4096
|
-
|
4398
|
+
|
4399
|
+
figsets(
|
4400
|
+
ax=ax,
|
4401
|
+
legend=dict(
|
4402
|
+
loc=legend_loc,
|
4403
|
+
fontsize=legend_fontsize,
|
4404
|
+
bbox_to_anchor=[1.1, 1.4],
|
4405
|
+
ncols=2,
|
4406
|
+
),
|
4407
|
+
**kws_figsets,
|
4408
|
+
)
|
4409
|
+
return ax
|