py2ls 0.2.4.24__py3-none-any.whl → 0.2.4.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/usages_sns.json +6 -1
- py2ls/ec2ls.py +61 -0
- py2ls/ips.py +496 -138
- py2ls/ml2ls.py +994 -288
- py2ls/netfinder.py +16 -20
- py2ls/nl2ls.py +283 -0
- py2ls/plot.py +1244 -158
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/METADATA +5 -1
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/RECORD +17 -14
- py2ls/data/usages_pd copy.json +0 -1105
- py2ls/ml2ls copy.py +0 -2906
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -20,7 +20,9 @@ from .ips import (
|
|
20
20
|
flatten,
|
21
21
|
plt_font,
|
22
22
|
run_once_within,
|
23
|
-
df_format
|
23
|
+
df_format,
|
24
|
+
df_corr,
|
25
|
+
df_scaler
|
24
26
|
)
|
25
27
|
import scipy.stats as scipy_stats
|
26
28
|
from .stats import *
|
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
|
142
144
|
**kwargs,
|
143
145
|
)
|
144
146
|
|
147
|
+
def pval2str(p):
|
148
|
+
if p > 0.05:
|
149
|
+
txt = ""
|
150
|
+
elif 0.01 <= p <= 0.05:
|
151
|
+
txt = "*"
|
152
|
+
elif 0.001 <= p < 0.01:
|
153
|
+
txt = "**"
|
154
|
+
elif p < 0.001:
|
155
|
+
txt = "***"
|
156
|
+
return txt
|
145
157
|
|
146
158
|
def heatmap(
|
147
159
|
data,
|
148
160
|
ax=None,
|
149
161
|
kind="corr", #'corr','direct','pivot'
|
162
|
+
method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
|
150
163
|
columns="all", # pivot, default: coll numeric columns
|
164
|
+
style=0,# for correlation
|
151
165
|
index=None, # pivot
|
152
166
|
values=None, # pivot
|
167
|
+
fontsize=10,
|
153
168
|
tri="u",
|
154
169
|
mask=True,
|
155
170
|
k=1,
|
171
|
+
vmin=None,
|
172
|
+
vmax=None,
|
173
|
+
size_scale=500,
|
156
174
|
annot=True,
|
157
175
|
cmap="coolwarm",
|
158
176
|
fmt=".2f",
|
177
|
+
show_indicator = True,# only for style==1
|
159
178
|
cluster=False,
|
160
179
|
inplace=False,
|
161
180
|
figsize=(10, 8),
|
@@ -201,7 +220,7 @@ def heatmap(
|
|
201
220
|
ax = plt.gca()
|
202
221
|
# Select numeric columns or specific subset of columns
|
203
222
|
if columns == "all":
|
204
|
-
df_numeric = data.select_dtypes(include=[
|
223
|
+
df_numeric = data.select_dtypes(include=[np.number])
|
205
224
|
else:
|
206
225
|
df_numeric = data[columns]
|
207
226
|
|
@@ -209,8 +228,10 @@ def heatmap(
|
|
209
228
|
kind = strcmp(kind, kinds)[0]
|
210
229
|
print(kind)
|
211
230
|
if "corr" in kind: # correlation
|
231
|
+
methods = ["pearson", "spearman", "kendall"]
|
232
|
+
method = strcmp(method, methods)[0]
|
212
233
|
# Compute the correlation matrix
|
213
|
-
data4heatmap = df_numeric.corr()
|
234
|
+
data4heatmap = df_numeric.corr(method=method)
|
214
235
|
# Generate mask for the upper triangle if mask is True
|
215
236
|
if mask:
|
216
237
|
if "u" in tri.lower(): # upper => np.tril
|
@@ -288,18 +309,76 @@ def heatmap(
|
|
288
309
|
df_col_cluster,
|
289
310
|
)
|
290
311
|
else:
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
312
|
+
if style==0:
|
313
|
+
# Create a standard heatmap
|
314
|
+
ax = sns.heatmap(
|
315
|
+
data4heatmap,
|
316
|
+
ax=ax,
|
317
|
+
mask=mask_array,
|
318
|
+
annot=annot,
|
319
|
+
cmap=cmap,
|
320
|
+
fmt=fmt,
|
321
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
322
|
+
)
|
323
|
+
return ax
|
324
|
+
elif style==1:
|
325
|
+
if isinstance(cmap, str):
|
326
|
+
cmap = plt.get_cmap(cmap)
|
327
|
+
norm = plt.Normalize(vmin=-1, vmax=1)
|
328
|
+
r_, p_ = df_corr(data4heatmap, method=method)
|
329
|
+
# size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
|
330
|
+
# 初始化一个空的可绘制对象用于颜色条
|
331
|
+
scatter_handles = []
|
332
|
+
# 循环绘制气泡图和数值
|
333
|
+
for i in range(len(r_.columns)):
|
334
|
+
for j in range(len(r_.columns)):
|
335
|
+
if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
|
336
|
+
color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
|
337
|
+
scatter = ax.scatter(
|
338
|
+
i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
|
339
|
+
# alpha=1,edgecolor=edgecolor,linewidth=linewidth,
|
340
|
+
**kwargs
|
341
|
+
)
|
342
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
343
|
+
# add *** indicators
|
344
|
+
if show_indicator:
|
345
|
+
ax.text(
|
346
|
+
i,
|
347
|
+
j,
|
348
|
+
pval2str(p_.iloc[i, j]),
|
349
|
+
ha="center",
|
350
|
+
va="center",
|
351
|
+
color="k",
|
352
|
+
fontsize=fontsize * 1.3,
|
353
|
+
)
|
354
|
+
elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
|
355
|
+
color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
|
356
|
+
ax.text(
|
357
|
+
i,
|
358
|
+
j,
|
359
|
+
f"{r_.iloc[i, j]:{fmt}}",
|
360
|
+
ha="center",
|
361
|
+
va="center",
|
362
|
+
color=color,
|
363
|
+
fontsize=fontsize,
|
364
|
+
)
|
365
|
+
else: # 对角线部分,显示空白
|
366
|
+
ax.scatter(i, j, s=1, color="white")
|
367
|
+
# 设置坐标轴标签
|
368
|
+
figsets(xticks=range(len(r_.columns)),
|
369
|
+
xticklabels=r_.columns,
|
370
|
+
xangle=90,
|
371
|
+
fontsize=fontsize,
|
372
|
+
yticks=range(len(r_.columns)),
|
373
|
+
yticklabels=r_.columns,
|
374
|
+
xlim=[-0.5,len(r_.columns)-0.5],
|
375
|
+
ylim=[-0.5,len(r_.columns)-0.5]
|
376
|
+
)
|
377
|
+
# 添加颜色条
|
378
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
379
|
+
sm.set_array([]) # 仅用于显示颜色条
|
380
|
+
plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
|
381
|
+
return ax
|
303
382
|
elif "dir" in kind: # direct
|
304
383
|
data4heatmap = df_numeric
|
305
384
|
elif "pi" in kind: # pivot
|
@@ -389,16 +468,53 @@ def heatmap(
|
|
389
468
|
)
|
390
469
|
else:
|
391
470
|
# Create a standard heatmap
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
471
|
+
if style==0:
|
472
|
+
ax = sns.heatmap(
|
473
|
+
data4heatmap,
|
474
|
+
ax=ax,
|
475
|
+
annot=annot,
|
476
|
+
cmap=cmap,
|
477
|
+
fmt=fmt,
|
478
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
479
|
+
)
|
480
|
+
# Return the Axes object for further customization if needed
|
481
|
+
return ax
|
482
|
+
elif style==1:
|
483
|
+
if isinstance(cmap, str):
|
484
|
+
cmap = plt.get_cmap(cmap)
|
485
|
+
if vmin is None:
|
486
|
+
vmin=np.min(data4heatmap)
|
487
|
+
if vmax is None:
|
488
|
+
vmax=np.max(data4heatmap)
|
489
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
490
|
+
|
491
|
+
# 初始化一个空的可绘制对象用于颜色条
|
492
|
+
scatter_handles = []
|
493
|
+
# 循环绘制气泡图和数值
|
494
|
+
print(len(data4heatmap.index),len(data4heatmap.columns))
|
495
|
+
for i in range(len(data4heatmap.index)):
|
496
|
+
for j in range(len(data4heatmap.columns)):
|
497
|
+
color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
|
498
|
+
scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
|
499
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
500
|
+
|
501
|
+
# 设置坐标轴标签
|
502
|
+
figsets(xticks=range(len(data4heatmap.columns)),
|
503
|
+
xticklabels=data4heatmap.columns,
|
504
|
+
xangle=90,
|
505
|
+
fontsize=fontsize,
|
506
|
+
yticks=range(len(data4heatmap.index)),
|
507
|
+
yticklabels=data4heatmap.index,
|
508
|
+
xlim=[-0.5,len(data4heatmap.columns)-0.5],
|
509
|
+
ylim=[-0.5,len(data4heatmap.index)-0.5]
|
510
|
+
)
|
511
|
+
# 添加颜色条
|
512
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
513
|
+
sm.set_array([]) # 仅用于显示颜色条
|
514
|
+
plt.colorbar(sm, ax=ax,
|
515
|
+
# label="Correlation Coefficient"
|
516
|
+
)
|
517
|
+
return ax
|
402
518
|
|
403
519
|
|
404
520
|
# !usage: py2ls.plot.heatmap()
|
@@ -2475,7 +2591,8 @@ def get_color(
|
|
2475
2591
|
"#B25E9D",
|
2476
2592
|
"#4B8C3B",
|
2477
2593
|
"#EF8632",
|
2478
|
-
"#24578E"
|
2594
|
+
"#24578E",
|
2595
|
+
"#FF2C00",
|
2479
2596
|
]
|
2480
2597
|
elif n == 8:
|
2481
2598
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
@@ -3199,9 +3316,12 @@ def plotxy(
|
|
3199
3316
|
zorder = 0
|
3200
3317
|
for k in kind_:
|
3201
3318
|
# preprocess data
|
3202
|
-
|
3203
|
-
|
3204
|
-
|
3319
|
+
try:
|
3320
|
+
data=df_preprocessing_(data, kind=k)
|
3321
|
+
if 'variable' in data.columns and 'value' in data.columns:
|
3322
|
+
x,y='variable','value'
|
3323
|
+
except Exception as e:
|
3324
|
+
print(e)
|
3205
3325
|
zorder += 1
|
3206
3326
|
# indicate 'col' features
|
3207
3327
|
col = kwargs.get("col", None)
|
@@ -3222,18 +3342,37 @@ def plotxy(
|
|
3222
3342
|
# (1) return FcetGrid
|
3223
3343
|
if k == "jointplot":
|
3224
3344
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
3225
|
-
|
3345
|
+
kws_joint = {
|
3346
|
+
k: v for k, v in kws_joint.items() if not k.startswith("kws_")
|
3347
|
+
}
|
3348
|
+
hue = kwargs.get("hue", None)
|
3349
|
+
if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3350
|
+
kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
|
3351
|
+
|
3352
|
+
palette = kwargs.get("palette", None)
|
3353
|
+
if palette is None:
|
3354
|
+
palette = kws_joint.pop(
|
3355
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3356
|
+
)
|
3357
|
+
else:
|
3358
|
+
kws_joint.pop("palette", palette)
|
3359
|
+
stats=kwargs.pop("stats",None)
|
3360
|
+
if stats:
|
3361
|
+
stats=kws_joint.pop("stats",True)
|
3226
3362
|
if stats:
|
3227
3363
|
r, p_value = scipy_stats.pearsonr(data[x], data[y])
|
3228
|
-
|
3229
|
-
|
3230
|
-
|
3231
|
-
|
3232
|
-
|
3233
|
-
|
3234
|
-
|
3235
|
-
|
3236
|
-
|
3364
|
+
for key in ["palette", "alpha", "hue","stats"]:
|
3365
|
+
kws_joint.pop(key, None)
|
3366
|
+
g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
|
3367
|
+
if stats:
|
3368
|
+
g.ax_joint.annotate(
|
3369
|
+
f"pearsonr = {r:.2f} p = {p_value:.3f}",
|
3370
|
+
xy=(0.6, 0.98),
|
3371
|
+
xycoords="axes fraction",
|
3372
|
+
fontsize=12,
|
3373
|
+
color="black",
|
3374
|
+
ha="center",
|
3375
|
+
)
|
3237
3376
|
elif k == "lmplot":
|
3238
3377
|
kws_lm = kwargs.pop("kws_lm", kwargs)
|
3239
3378
|
stats = kwargs.pop("stats", True) # Flag to calculate stats
|
@@ -3379,8 +3518,7 @@ def plotxy(
|
|
3379
3518
|
xycoords="axes fraction",
|
3380
3519
|
fontsize=12,
|
3381
3520
|
color="black",
|
3382
|
-
ha="center"
|
3383
|
-
)
|
3521
|
+
ha="center")
|
3384
3522
|
|
3385
3523
|
elif k == "catplot_sns":
|
3386
3524
|
kws_cat = kwargs.pop("kws_cat", kwargs)
|
@@ -3388,7 +3526,7 @@ def plotxy(
|
|
3388
3526
|
elif k == "displot":
|
3389
3527
|
kws_dis = kwargs.pop("kws_dis", kwargs)
|
3390
3528
|
# displot creates a new figure and returns a FacetGrid
|
3391
|
-
g = sns.displot(data=data, x=x,
|
3529
|
+
g = sns.displot(data=data, x=x,y=y, **kws_dis)
|
3392
3530
|
|
3393
3531
|
# (2) return axis
|
3394
3532
|
if ax is None:
|
@@ -3399,19 +3537,58 @@ def plotxy(
|
|
3399
3537
|
elif k == "stdshade":
|
3400
3538
|
kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
|
3401
3539
|
ax = stdshade(ax=ax, **kwargs)
|
3540
|
+
elif k=="ellipse":
|
3541
|
+
kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
|
3542
|
+
kws_ellipse = {
|
3543
|
+
k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
|
3544
|
+
}
|
3545
|
+
hue = kwargs.get("hue", None)
|
3546
|
+
if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3547
|
+
kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
|
3548
|
+
|
3549
|
+
palette = kwargs.get("palette", None)
|
3550
|
+
if palette is None:
|
3551
|
+
palette = kws_ellipse.pop(
|
3552
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3553
|
+
)
|
3554
|
+
alpha = kws_ellipse.pop("alpha", 0.1)
|
3555
|
+
hue_order = kwargs.get("hue_order",None)
|
3556
|
+
if hue_order is None:
|
3557
|
+
hue_order = kws_ellipse.get("hue_order",None)
|
3558
|
+
if hue_order:
|
3559
|
+
data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
|
3560
|
+
data = data.sort_values(by="hue")
|
3561
|
+
for key in ["palette", "alpha", "hue","hue_order"]:
|
3562
|
+
kws_ellipse.pop(key, None)
|
3563
|
+
ax=ellipse(
|
3564
|
+
ax=ax,
|
3565
|
+
data=data,
|
3566
|
+
x=x,
|
3567
|
+
y=y,
|
3568
|
+
hue=hue,
|
3569
|
+
palette=palette,
|
3570
|
+
alpha=alpha,
|
3571
|
+
zorder=zorder,
|
3572
|
+
**kws_ellipse,
|
3573
|
+
)
|
3402
3574
|
elif k == "scatterplot":
|
3403
3575
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3404
3576
|
kws_scatter = {
|
3405
3577
|
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3406
3578
|
}
|
3407
|
-
hue = kwargs.
|
3579
|
+
hue = kwargs.get("hue", None)
|
3408
3580
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3409
3581
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3410
|
-
palette
|
3411
|
-
|
3412
|
-
|
3582
|
+
palette=kws_scatter.get("palette",None)
|
3583
|
+
if palette is None:
|
3584
|
+
palette = kws_scatter.pop(
|
3585
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3586
|
+
)
|
3413
3587
|
s = kws_scatter.pop("s", 10)
|
3414
|
-
alpha = kws_scatter.pop("alpha", 0.7)
|
3588
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
3589
|
+
for key in ["s", "palette", "alpha", "hue"]:
|
3590
|
+
kws_scatter.pop(key, None)
|
3591
|
+
|
3415
3592
|
ax = sns.scatterplot(
|
3416
3593
|
ax=ax,
|
3417
3594
|
data=data,
|
@@ -3427,19 +3604,33 @@ def plotxy(
|
|
3427
3604
|
elif k == "histplot":
|
3428
3605
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3429
3606
|
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3430
|
-
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3431
|
-
elif k == "kdeplot":
|
3607
|
+
ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
|
3608
|
+
elif k == "kdeplot":
|
3432
3609
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3433
|
-
kws_kde = {
|
3434
|
-
|
3610
|
+
kws_kde = {
|
3611
|
+
k: v for k, v in kws_kde.items() if not k.startswith("kws_")
|
3612
|
+
}
|
3613
|
+
hue = kwargs.get("hue", None)
|
3614
|
+
if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
|
3615
|
+
kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
|
3616
|
+
|
3617
|
+
palette = kwargs.get("palette", None)
|
3618
|
+
if palette is None:
|
3619
|
+
palette = kws_kde.pop(
|
3620
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3621
|
+
)
|
3622
|
+
alpha = kws_kde.pop("alpha", 0.05)
|
3623
|
+
for key in ["palette", "alpha", "hue"]:
|
3624
|
+
kws_kde.pop(key, None)
|
3625
|
+
ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
|
3435
3626
|
elif k == "ecdfplot":
|
3436
3627
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3437
3628
|
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3438
|
-
ax = sns.ecdfplot(data=data, x=x,
|
3629
|
+
ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
|
3439
3630
|
elif k == "rugplot":
|
3440
3631
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3441
3632
|
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3442
|
-
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3633
|
+
ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
|
3443
3634
|
elif k == "stripplot":
|
3444
3635
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3445
3636
|
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
@@ -3512,7 +3703,8 @@ def plotxy(
|
|
3512
3703
|
figsets(ax=ax, **kws_figsets)
|
3513
3704
|
if kws_add_text:
|
3514
3705
|
add_text(ax=ax, **kws_add_text) if kws_add_text else None
|
3515
|
-
|
3706
|
+
if run_once_within(10):
|
3707
|
+
for k in kind_:
|
3516
3708
|
print(f"\n{k}⤵ ")
|
3517
3709
|
print(default_settings[k])
|
3518
3710
|
# print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
|
@@ -3570,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3570
3762
|
"lineplot", # Can work with both wide and long formats
|
3571
3763
|
"area plot", # Can work with both formats, useful for stacked areas
|
3572
3764
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3765
|
+
"ellipse",# ellipse plot, default confidence=0.95
|
3573
3766
|
],
|
3574
3767
|
)[0]
|
3575
3768
|
|
@@ -3605,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3605
3798
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3606
3799
|
"relplot",
|
3607
3800
|
"pointplot", # Works well with wide format
|
3801
|
+
"ellipse",
|
3608
3802
|
]
|
3609
3803
|
|
3610
3804
|
# Wide format (e.g., for heatmap and pairplot)
|
@@ -4096,25 +4290,45 @@ def venn(
|
|
4096
4290
|
"""
|
4097
4291
|
if ax is None:
|
4098
4292
|
ax = plt.gca()
|
4293
|
+
if isinstance(lists, dict):
|
4294
|
+
labels = list(lists.keys())
|
4295
|
+
lists = list(lists.values())
|
4296
|
+
if isinstance(lists[0], set):
|
4297
|
+
lists = [list(i) for i in lists]
|
4099
4298
|
lists = [set(flatten(i, verbose=False)) for i in lists]
|
4100
4299
|
# Function to apply text styles to labels
|
4101
4300
|
if colors is None:
|
4102
4301
|
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
4103
4302
|
if labels is None:
|
4104
|
-
|
4303
|
+
if len(lists) == 2:
|
4304
|
+
labels = ["set1", "set2"]
|
4305
|
+
elif len(lists) == 3:
|
4306
|
+
labels = ["set1", "set2", "set3"]
|
4307
|
+
elif len(lists) == 4:
|
4308
|
+
labels = ["set1", "set2", "set3","set4"]
|
4309
|
+
elif len(lists) == 5:
|
4310
|
+
labels = ["set1", "set2", "set3","set4","set55"]
|
4311
|
+
elif len(lists) == 6:
|
4312
|
+
labels = ["set1", "set2", "set3","set4","set5","set6"]
|
4313
|
+
elif len(lists) == 7:
|
4314
|
+
labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
|
4105
4315
|
if edgecolor is None:
|
4106
4316
|
edgecolor = colors
|
4107
4317
|
colors = [desaturate_color(color, saturation) for color in colors]
|
4108
|
-
|
4109
|
-
if len(lists) == 2:
|
4318
|
+
universe = len(set.union(*lists))
|
4110
4319
|
|
4111
|
-
|
4112
|
-
|
4113
|
-
|
4114
|
-
|
4115
|
-
|
4116
|
-
|
4117
|
-
|
4320
|
+
# Check colors and auto-calculate overlaps
|
4321
|
+
def get_count_and_percentage(set_count, subset_count):
|
4322
|
+
percent = subset_count / set_count if set_count > 0 else 0
|
4323
|
+
return (
|
4324
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
4325
|
+
if show_percentages
|
4326
|
+
else f"{subset_count}"
|
4327
|
+
)
|
4328
|
+
if fmt is not None:
|
4329
|
+
if not fmt.startswith("{"):
|
4330
|
+
fmt="{:" + fmt + "}"
|
4331
|
+
if len(lists) == 2:
|
4118
4332
|
|
4119
4333
|
from matplotlib_venn import venn2, venn2_circles
|
4120
4334
|
|
@@ -4127,21 +4341,28 @@ def venn(
|
|
4127
4341
|
set1, set2 = lists[0], lists[1]
|
4128
4342
|
v.get_patch_by_id("10").set_color(colors[0])
|
4129
4343
|
v.get_patch_by_id("01").set_color(colors[1])
|
4130
|
-
|
4131
|
-
|
4132
|
-
|
4344
|
+
try:
|
4345
|
+
v.get_patch_by_id("11").set_color(
|
4346
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
4347
|
+
)
|
4348
|
+
except Exception as e:
|
4349
|
+
print(e)
|
4133
4350
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
4134
4351
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
4135
4352
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
4353
|
+
|
4136
4354
|
v.get_label_by_id("10").set_text(
|
4137
|
-
get_count_and_percentage(
|
4355
|
+
get_count_and_percentage(universe, len(set1 - set2))
|
4138
4356
|
)
|
4139
4357
|
v.get_label_by_id("01").set_text(
|
4140
|
-
get_count_and_percentage(
|
4141
|
-
)
|
4142
|
-
v.get_label_by_id("11").set_text(
|
4143
|
-
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
4358
|
+
get_count_and_percentage(universe, len(set2 - set1))
|
4144
4359
|
)
|
4360
|
+
try:
|
4361
|
+
v.get_label_by_id("11").set_text(
|
4362
|
+
get_count_and_percentage(universe, len(set1 & set2))
|
4363
|
+
)
|
4364
|
+
except Exception as e:
|
4365
|
+
print(e)
|
4145
4366
|
|
4146
4367
|
if not isinstance(linewidth, list):
|
4147
4368
|
linewidth = [linewidth]
|
@@ -4224,16 +4445,14 @@ def venn(
|
|
4224
4445
|
va=va,
|
4225
4446
|
shadow=shadow,
|
4226
4447
|
)
|
4227
|
-
|
4228
|
-
|
4229
|
-
|
4230
|
-
|
4231
|
-
|
4232
|
-
|
4233
|
-
|
4234
|
-
|
4235
|
-
else f"{subset_count}"
|
4236
|
-
)
|
4448
|
+
# Set transparency level
|
4449
|
+
for patch in v.patches:
|
4450
|
+
if patch:
|
4451
|
+
patch.set_alpha(alpha)
|
4452
|
+
if "none" in edgecolor or 0 in linewidth:
|
4453
|
+
patch.set_edgecolor("none")
|
4454
|
+
return ax
|
4455
|
+
elif len(lists) == 3:
|
4237
4456
|
|
4238
4457
|
from matplotlib_venn import venn3, venn3_circles
|
4239
4458
|
|
@@ -4249,36 +4468,34 @@ def venn(
|
|
4249
4468
|
# Draw the venn diagram
|
4250
4469
|
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
4251
4470
|
v.get_patch_by_id("100").set_color(colors[0])
|
4471
|
+
v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
|
4252
4472
|
v.get_patch_by_id("010").set_color(colors[1])
|
4253
|
-
v.
|
4254
|
-
|
4255
|
-
|
4256
|
-
|
4257
|
-
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
|
4271
|
-
|
4272
|
-
|
4273
|
-
|
4274
|
-
|
4275
|
-
|
4276
|
-
|
4277
|
-
|
4278
|
-
|
4279
|
-
v.get_label_by_id("111").set_text(
|
4280
|
-
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
4281
|
-
)
|
4473
|
+
v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
|
4474
|
+
try:
|
4475
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
4476
|
+
v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
|
4477
|
+
except Exception as e:
|
4478
|
+
print(e)
|
4479
|
+
try:
|
4480
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
4481
|
+
v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
|
4482
|
+
except Exception as e:
|
4483
|
+
print(e)
|
4484
|
+
try:
|
4485
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
4486
|
+
v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
|
4487
|
+
except Exception as e:
|
4488
|
+
print(e)
|
4489
|
+
try:
|
4490
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
4491
|
+
v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
|
4492
|
+
except Exception as e:
|
4493
|
+
print(e)
|
4494
|
+
try:
|
4495
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
4496
|
+
v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
|
4497
|
+
except Exception as e:
|
4498
|
+
print(e)
|
4282
4499
|
|
4283
4500
|
# Apply styles to set labels
|
4284
4501
|
for i, text in enumerate(v.set_labels):
|
@@ -4383,16 +4600,34 @@ def venn(
|
|
4383
4600
|
ax.add_patch(ellipse1)
|
4384
4601
|
ax.add_patch(ellipse2)
|
4385
4602
|
ax.add_patch(ellipse3)
|
4603
|
+
# Set transparency level
|
4604
|
+
for patch in v.patches:
|
4605
|
+
if patch:
|
4606
|
+
patch.set_alpha(alpha)
|
4607
|
+
if "none" in edgecolor or 0 in linewidth:
|
4608
|
+
patch.set_edgecolor("none")
|
4609
|
+
return ax
|
4610
|
+
|
4611
|
+
|
4612
|
+
dict_data = {}
|
4613
|
+
for i_list, list_ in enumerate(lists):
|
4614
|
+
dict_data[labels[i_list]]={*list_}
|
4615
|
+
|
4616
|
+
if 3<len(lists)<6:
|
4617
|
+
from venn import venn as vn
|
4618
|
+
|
4619
|
+
legend_loc=kwargs.pop("legend_loc", "upper right")
|
4620
|
+
ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
|
4621
|
+
|
4622
|
+
return ax
|
4386
4623
|
else:
|
4387
|
-
|
4388
|
-
|
4389
|
-
|
4390
|
-
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
patch.set_edgecolor("none")
|
4395
|
-
return ax
|
4624
|
+
from venn import pseudovenn
|
4625
|
+
cmap=kwargs.pop("cmap","plasma")
|
4626
|
+
ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
|
4627
|
+
|
4628
|
+
return ax
|
4629
|
+
|
4630
|
+
|
4396
4631
|
|
4397
4632
|
|
4398
4633
|
#! subplots, support automatic extend new axis
|
@@ -4403,6 +4638,7 @@ def subplot(
|
|
4403
4638
|
sharex=False,
|
4404
4639
|
sharey=False,
|
4405
4640
|
verbose=False,
|
4641
|
+
fig=None,
|
4406
4642
|
**kwargs,
|
4407
4643
|
):
|
4408
4644
|
"""
|
@@ -4431,8 +4667,8 @@ def subplot(
|
|
4431
4667
|
)
|
4432
4668
|
|
4433
4669
|
figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
|
4434
|
-
|
4435
|
-
|
4670
|
+
if fig is None:
|
4671
|
+
fig = plt.figure(figsize=figsize, constrained_layout=True)
|
4436
4672
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4437
4673
|
occupied = set()
|
4438
4674
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
@@ -4496,18 +4732,19 @@ def subplot(
|
|
4496
4732
|
#! radar chart
|
4497
4733
|
def radar(
|
4498
4734
|
data: pd.DataFrame,
|
4499
|
-
|
4735
|
+
columns=None,
|
4500
4736
|
ylim=(0, 100),
|
4501
|
-
|
4737
|
+
facecolor=None,
|
4738
|
+
edgecolor="none",
|
4739
|
+
edge_linewidth=0.5,
|
4502
4740
|
fontsize=10,
|
4503
4741
|
fontcolor="k",
|
4504
4742
|
size=6,
|
4505
4743
|
linewidth=1,
|
4506
4744
|
linestyle="-",
|
4507
|
-
alpha=0.
|
4745
|
+
alpha=0.3,
|
4746
|
+
fmt=".1f",
|
4508
4747
|
marker="o",
|
4509
|
-
edgecolor="none",
|
4510
|
-
edge_linewidth=0.5,
|
4511
4748
|
bg_color="0.8",
|
4512
4749
|
bg_alpha=None,
|
4513
4750
|
grid_interval_ratio=0.2,
|
@@ -4527,19 +4764,34 @@ def radar(
|
|
4527
4764
|
ax=None,
|
4528
4765
|
sp=2,
|
4529
4766
|
verbose=True,
|
4767
|
+
axis=0,
|
4530
4768
|
**kwargs,
|
4531
4769
|
):
|
4532
4770
|
"""
|
4533
4771
|
Example DATA:
|
4534
4772
|
df = pd.DataFrame(
|
4535
|
-
|
4536
|
-
|
4537
|
-
|
4538
|
-
|
4539
|
-
|
4540
|
-
|
4541
|
-
|
4542
|
-
|
4773
|
+
data=[
|
4774
|
+
[80, 90, 60],
|
4775
|
+
[80, 20, 90],
|
4776
|
+
[80, 95, 20],
|
4777
|
+
[80, 95, 20],
|
4778
|
+
[80, 30, 100],
|
4779
|
+
[80, 30, 90],
|
4780
|
+
[80, 80, 50],
|
4781
|
+
],
|
4782
|
+
index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
|
4783
|
+
columns=["Hero", "Warrior", "Wizard"],
|
4784
|
+
)
|
4785
|
+
usage 1:
|
4786
|
+
radar(data=df)
|
4787
|
+
usage 2:
|
4788
|
+
radar(data=df["Wizard"])
|
4789
|
+
usage 3:
|
4790
|
+
radar(data=df, columns="Wizard")
|
4791
|
+
usage 4:
|
4792
|
+
nexttile = subplot(1, 2)
|
4793
|
+
radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
|
4794
|
+
pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
|
4543
4795
|
Parameters:
|
4544
4796
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
4545
4797
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4556,7 +4808,6 @@ def radar(
|
|
4556
4808
|
- edge_linewidth (int): Line width for the marker edges.
|
4557
4809
|
- bg_color (str): Background color for the radar chart.
|
4558
4810
|
- grid_interval_ratio (float): Determines the intervals for the grid lines as a fraction of the y-limit.
|
4559
|
-
- title (str): The title of the radar chart.
|
4560
4811
|
- cmap (str): The colormap to use if `color` is a list.
|
4561
4812
|
- legend_loc (str): The location of the legend.
|
4562
4813
|
- legend_fontsize (int): Font size for the legend.
|
@@ -4573,22 +4824,22 @@ def radar(
|
|
4573
4824
|
- sp (int): Padding for the ticks from the plot area.
|
4574
4825
|
- **kwargs: Additional arguments for customization.
|
4575
4826
|
"""
|
4576
|
-
if run_once_within() and verbose:
|
4827
|
+
if run_once_within(20,reverse=True) and verbose:
|
4577
4828
|
usage_="""usage:
|
4578
4829
|
radar(
|
4579
4830
|
data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
4580
|
-
title="Radar Chart",
|
4581
4831
|
ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
4582
|
-
|
4832
|
+
facecolor=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
|
4833
|
+
edgecolor="none",#for the marker edges.
|
4834
|
+
edge_linewidth=0.5,#for the marker edges.
|
4583
4835
|
fontsize=10,# Font size for the angular labels (x-axis).
|
4584
4836
|
fontcolor="k",# Color for the angular labels.
|
4585
4837
|
size=6,#The size of the markers for each data point.
|
4586
4838
|
linewidth=1,
|
4587
4839
|
linestyle="-",
|
4588
4840
|
alpha=0.5,#for the filled area.
|
4841
|
+
fmt=".1f",
|
4589
4842
|
marker="o",# for the data points.
|
4590
|
-
edgecolor="none",#for the marker edges.
|
4591
|
-
edge_linewidth=0.5,#for the marker edges.
|
4592
4843
|
bg_color="0.8",
|
4593
4844
|
bg_alpha=None,
|
4594
4845
|
grid_interval_ratio=0.2,#Determines the intervals for the grid lines as a fraction of the y-limit.
|
@@ -4618,9 +4869,25 @@ def radar(
|
|
4618
4869
|
kws_figsets = v_arg
|
4619
4870
|
kwargs.pop(k_arg, None)
|
4620
4871
|
break
|
4621
|
-
|
4872
|
+
if axis==1:
|
4873
|
+
data=data.T
|
4874
|
+
if isinstance(data, dict):
|
4875
|
+
data = pd.DataFrame(pd.Series(data))
|
4876
|
+
if ~isinstance(data, pd.DataFrame):
|
4877
|
+
data=pd.DataFrame(data)
|
4878
|
+
if isinstance(data, pd.DataFrame):
|
4879
|
+
data=data.select_dtypes(include=np.number)
|
4880
|
+
if isinstance(columns,str):
|
4881
|
+
columns=[columns]
|
4882
|
+
if columns is None:
|
4883
|
+
columns = list(data.columns)
|
4884
|
+
data=data[columns]
|
4885
|
+
categories = list(data.index)
|
4622
4886
|
num_vars = len(categories)
|
4623
4887
|
|
4888
|
+
# Set y-axis limits and grid intervals
|
4889
|
+
vmin, vmax = ylim
|
4890
|
+
|
4624
4891
|
# Set up angle for each category on radar chart
|
4625
4892
|
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
4626
4893
|
angles += angles[:1] # Complete the loop to ensure straight-line connections
|
@@ -4644,9 +4911,6 @@ def radar(
|
|
4644
4911
|
# Draw one axis per variable and add labels
|
4645
4912
|
ax.set_xticks(angles[:-1])
|
4646
4913
|
ax.set_xticklabels(categories)
|
4647
|
-
|
4648
|
-
# Set y-axis limits and grid intervals
|
4649
|
-
vmin, vmax = ylim
|
4650
4914
|
if circular:
|
4651
4915
|
# * cicular style
|
4652
4916
|
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
@@ -4669,7 +4933,7 @@ def radar(
|
|
4669
4933
|
else:
|
4670
4934
|
# * spider style: spider-style grid (straight lines, not circles)
|
4671
4935
|
# Create the spider-style grid (straight lines, not circles)
|
4672
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4936
|
+
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
|
4673
4937
|
ax.plot(
|
4674
4938
|
angles + [angles[0]], # Closing the loop
|
4675
4939
|
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
@@ -4680,7 +4944,7 @@ def radar(
|
|
4680
4944
|
linewidth=grid_linewidth,
|
4681
4945
|
)
|
4682
4946
|
# set bg_color
|
4683
|
-
ax.fill(angles, [vmax] * (data.shape[
|
4947
|
+
ax.fill(angles, [vmax] * (data.shape[0] + 1), color=bg_color, alpha=bg_alpha)
|
4684
4948
|
ax.yaxis.grid(False)
|
4685
4949
|
# Move radial labels away from plotted line
|
4686
4950
|
if tick_loc is None:
|
@@ -4702,14 +4966,20 @@ def radar(
|
|
4702
4966
|
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4703
4967
|
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4704
4968
|
# colors
|
4705
|
-
|
4706
|
-
|
4707
|
-
|
4708
|
-
|
4709
|
-
|
4969
|
+
if facecolor is not None:
|
4970
|
+
if not isinstance(facecolor,list):
|
4971
|
+
facecolor=[facecolor]
|
4972
|
+
colors = facecolor
|
4973
|
+
else:
|
4974
|
+
colors = (
|
4975
|
+
get_color(data.shape[1])
|
4976
|
+
if cmap is None
|
4977
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
|
4978
|
+
)
|
4979
|
+
|
4710
4980
|
# Plot each row with straight lines
|
4711
|
-
for i, (
|
4712
|
-
values =
|
4981
|
+
for i, (col, val) in enumerate(data.items()):
|
4982
|
+
values = val.tolist()
|
4713
4983
|
values += values[:1] # Close the loop
|
4714
4984
|
ax.plot(
|
4715
4985
|
angles,
|
@@ -4717,7 +4987,7 @@ def radar(
|
|
4717
4987
|
color=colors[i],
|
4718
4988
|
linewidth=linewidth,
|
4719
4989
|
linestyle=linestyle,
|
4720
|
-
label=
|
4990
|
+
label=col,
|
4721
4991
|
clip_on=False,
|
4722
4992
|
)
|
4723
4993
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
@@ -4748,7 +5018,7 @@ def radar(
|
|
4748
5018
|
ax.text(
|
4749
5019
|
angle,
|
4750
5020
|
offset_radius,
|
4751
|
-
|
5021
|
+
f"{value:{fmt}}",
|
4752
5022
|
ha="center",
|
4753
5023
|
va="center",
|
4754
5024
|
fontsize=fontsize,
|
@@ -4759,10 +5029,10 @@ def radar(
|
|
4759
5029
|
|
4760
5030
|
ax.set_ylim(ylim)
|
4761
5031
|
# Add markers for each data point
|
4762
|
-
for i,
|
5032
|
+
for i, (col, val) in enumerate(data.items()):
|
4763
5033
|
ax.plot(
|
4764
5034
|
angles,
|
4765
|
-
list(
|
5035
|
+
list(val) + [val[0]], # Close the loop for markers
|
4766
5036
|
color=colors[i],
|
4767
5037
|
marker=marker,
|
4768
5038
|
markersize=size,
|
@@ -4787,3 +5057,819 @@ def radar(
|
|
4787
5057
|
**kws_figsets,
|
4788
5058
|
)
|
4789
5059
|
return ax
|
5060
|
+
|
5061
|
+
|
5062
|
+
def pie(
|
5063
|
+
data:pd.Series,
|
5064
|
+
columns:list = None,
|
5065
|
+
facecolor=None,
|
5066
|
+
explode=[0.1],
|
5067
|
+
startangle=90,
|
5068
|
+
shadow=True,
|
5069
|
+
fontcolor="k",
|
5070
|
+
fmt=".2f",
|
5071
|
+
width=None,# the center blank
|
5072
|
+
pctdistance=0.85,
|
5073
|
+
labeldistance=1.1,
|
5074
|
+
kws_wedge={},
|
5075
|
+
kws_text={},
|
5076
|
+
kws_arrow={},
|
5077
|
+
center=(0, 0),
|
5078
|
+
radius=1,
|
5079
|
+
frame=False,
|
5080
|
+
fontsize=10,
|
5081
|
+
edgecolor="white",
|
5082
|
+
edgewidth=1,
|
5083
|
+
cmap=None,
|
5084
|
+
show_value=False,
|
5085
|
+
show_label=True,# False: only show the outer layer, if it is None, not show
|
5086
|
+
expand_label=(1.2,1.2),
|
5087
|
+
kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
|
5088
|
+
show_legend=True,
|
5089
|
+
legend_loc="upper right",
|
5090
|
+
bbox_to_anchor=[1.4, 1.1],
|
5091
|
+
legend_fontsize=10,
|
5092
|
+
rotation_correction=0,
|
5093
|
+
verbose=True,
|
5094
|
+
ax=None,
|
5095
|
+
**kwargs
|
5096
|
+
):
|
5097
|
+
from adjustText import adjust_text
|
5098
|
+
if run_once_within(20,reverse=True) and verbose:
|
5099
|
+
usage_="""usage:
|
5100
|
+
pie(
|
5101
|
+
data:pd.Series,
|
5102
|
+
columns:list = None,
|
5103
|
+
facecolor=None,
|
5104
|
+
explode=[0.1],
|
5105
|
+
startangle=90,
|
5106
|
+
shadow=True,
|
5107
|
+
fontcolor="k",
|
5108
|
+
fmt=".2f",
|
5109
|
+
width=None,# the center blank
|
5110
|
+
pctdistance=0.85,
|
5111
|
+
labeldistance=1.1,
|
5112
|
+
kws_wedge={},
|
5113
|
+
kws_text={},
|
5114
|
+
center=(0, 0),
|
5115
|
+
radius=1,
|
5116
|
+
frame=False,
|
5117
|
+
fontsize=10,
|
5118
|
+
edgecolor="white",
|
5119
|
+
edgewidth=1,
|
5120
|
+
cmap=None,
|
5121
|
+
show_value=False,
|
5122
|
+
show_label=True,# False: only show the outer layer, if it is None, not show
|
5123
|
+
show_legend=True,
|
5124
|
+
legend_loc="upper right",
|
5125
|
+
bbox_to_anchor=[1.4, 1.1],
|
5126
|
+
legend_fontsize=10,
|
5127
|
+
rotation_correction=0,
|
5128
|
+
verbose=True,
|
5129
|
+
ax=None,
|
5130
|
+
**kwargs
|
5131
|
+
)
|
5132
|
+
|
5133
|
+
usage 1:
|
5134
|
+
data = {"Segment A": 30, "Segment B": 50, "Segment C": 20}
|
5135
|
+
|
5136
|
+
ax = pie(
|
5137
|
+
data=data,
|
5138
|
+
# columns="Segment A",
|
5139
|
+
explode=[0, 0.2, 0],
|
5140
|
+
# width=0.4,
|
5141
|
+
show_label=False,
|
5142
|
+
fontsize=10,
|
5143
|
+
# show_value=1,
|
5144
|
+
fmt=".3f",
|
5145
|
+
)
|
5146
|
+
|
5147
|
+
# prepare dataset
|
5148
|
+
df = pd.DataFrame(
|
5149
|
+
data=[
|
5150
|
+
[80, 90, 60],
|
5151
|
+
[80, 20, 90],
|
5152
|
+
[80, 95, 20],
|
5153
|
+
[80, 95, 20],
|
5154
|
+
[80, 30, 100],
|
5155
|
+
[80, 30, 90],
|
5156
|
+
[80, 80, 50],
|
5157
|
+
],
|
5158
|
+
index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
|
5159
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5160
|
+
)
|
5161
|
+
usage 1: only plot one column
|
5162
|
+
pie(
|
5163
|
+
df,
|
5164
|
+
columns="Wizard",
|
5165
|
+
width=0.6,
|
5166
|
+
show_label=False,
|
5167
|
+
fmt=".0f",
|
5168
|
+
)
|
5169
|
+
usage 2:
|
5170
|
+
pie(df,columns=["Hero", "Warrior"],show_label=False)
|
5171
|
+
usage 3: set different width
|
5172
|
+
pie(df,
|
5173
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5174
|
+
width=[0.3, 0.2, 0.2],
|
5175
|
+
show_label=False,
|
5176
|
+
fmt=".0f",
|
5177
|
+
)
|
5178
|
+
usage 4: set width the same for all columns
|
5179
|
+
pie(df,
|
5180
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5181
|
+
width=0.2,
|
5182
|
+
show_label=False,
|
5183
|
+
fmt=".0f",
|
5184
|
+
)
|
5185
|
+
usage 5: adjust the labels' offset
|
5186
|
+
pie(df, columns="Wizard", width=0.6, show_label=False, fmt=".6f", labeldistance=1.2)
|
5187
|
+
|
5188
|
+
usage 6:
|
5189
|
+
nexttile = subplot(1, 2)
|
5190
|
+
radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
|
5191
|
+
pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
|
5192
|
+
"""
|
5193
|
+
print(usage_)
|
5194
|
+
# Convert data to a Pandas Series if needed
|
5195
|
+
if isinstance(data, dict):
|
5196
|
+
data = pd.DataFrame(pd.Series(data))
|
5197
|
+
if ~isinstance(data, pd.DataFrame):
|
5198
|
+
data=pd.DataFrame(data)
|
5199
|
+
|
5200
|
+
if isinstance(data, pd.DataFrame):
|
5201
|
+
data=data.select_dtypes(include=np.number)
|
5202
|
+
if isinstance(columns,str):
|
5203
|
+
columns=[columns]
|
5204
|
+
if columns is None:
|
5205
|
+
columns = list(data.columns)
|
5206
|
+
# data=data[columns]
|
5207
|
+
# columns = list(data.columns)
|
5208
|
+
# print(columns)
|
5209
|
+
# 选择部分数据
|
5210
|
+
df=data[columns]
|
5211
|
+
|
5212
|
+
if not isinstance(explode, list):
|
5213
|
+
explode=[explode]
|
5214
|
+
if explode == [None]:
|
5215
|
+
explode=[0]
|
5216
|
+
|
5217
|
+
if width is None:
|
5218
|
+
if df.shape[1]>1:
|
5219
|
+
width=1/(df.shape[1]+2)
|
5220
|
+
else:
|
5221
|
+
width=1
|
5222
|
+
if isinstance(width,(float,int)):
|
5223
|
+
width=[width]
|
5224
|
+
if len(width)<df.shape[1]:
|
5225
|
+
width=width*df.shape[1]
|
5226
|
+
if isinstance(radius,(float,int)):
|
5227
|
+
radius=[radius]
|
5228
|
+
radius_tile=[1]*df.shape[1]
|
5229
|
+
radius=radius_tile.copy()
|
5230
|
+
for i in range(1,df.shape[1]):
|
5231
|
+
radius[i]=radius_tile[i]-np.sum(width[:i])
|
5232
|
+
|
5233
|
+
# colors
|
5234
|
+
if facecolor is not None:
|
5235
|
+
if not isinstance(facecolor,list):
|
5236
|
+
facecolor=[facecolor]
|
5237
|
+
colors = facecolor
|
5238
|
+
else:
|
5239
|
+
colors = (
|
5240
|
+
get_color(data.shape[0])
|
5241
|
+
if cmap is None
|
5242
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
5243
|
+
)
|
5244
|
+
# to check if facecolor is nested list or not
|
5245
|
+
is_nested = True if any(isinstance(i, list) for i in colors) else False
|
5246
|
+
inested = 0
|
5247
|
+
for column_,width_,radius_ in zip(columns, width,radius):
|
5248
|
+
if column_!=columns[0]:
|
5249
|
+
labels = data.index if show_label else None
|
5250
|
+
else:
|
5251
|
+
labels = data.index if show_label is not None else None
|
5252
|
+
data = df[column_]
|
5253
|
+
labels_legend=data.index
|
5254
|
+
sizes = data.values
|
5255
|
+
|
5256
|
+
# Set wedge and text properties if none are provided
|
5257
|
+
kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
|
5258
|
+
kws_wedge.update({"width":width_})
|
5259
|
+
fontcolor=kws_text.get("color",fontcolor)
|
5260
|
+
fontsize=kws_text.get("fontsize",fontsize)
|
5261
|
+
kws_text.update({"color": fontcolor, "fontsize": fontsize})
|
5262
|
+
|
5263
|
+
if ax is None:
|
5264
|
+
ax=plt.gca()
|
5265
|
+
if len(explode)<len(labels_legend):
|
5266
|
+
explode.extend([0]*(len(labels_legend)-len(explode)))
|
5267
|
+
print(explode)
|
5268
|
+
if fmt:
|
5269
|
+
if not fmt.startswith("%"):
|
5270
|
+
autopct =f"%{fmt}%%"
|
5271
|
+
else:
|
5272
|
+
autopct=None
|
5273
|
+
|
5274
|
+
if show_value is None:
|
5275
|
+
result = ax.pie(
|
5276
|
+
sizes,
|
5277
|
+
labels=labels,
|
5278
|
+
autopct= None,
|
5279
|
+
startangle=startangle + rotation_correction,
|
5280
|
+
explode=explode,
|
5281
|
+
colors=colors[inested] if is_nested else colors,
|
5282
|
+
shadow=shadow,
|
5283
|
+
pctdistance=pctdistance,
|
5284
|
+
labeldistance=labeldistance,
|
5285
|
+
wedgeprops=kws_wedge,
|
5286
|
+
textprops=kws_text,
|
5287
|
+
center=center,
|
5288
|
+
radius=radius_,
|
5289
|
+
frame=frame,
|
5290
|
+
**kwargs
|
5291
|
+
)
|
5292
|
+
else:
|
5293
|
+
result = ax.pie(
|
5294
|
+
sizes,
|
5295
|
+
labels=labels,
|
5296
|
+
autopct=autopct if autopct else None,
|
5297
|
+
startangle=startangle + rotation_correction,
|
5298
|
+
explode=explode,
|
5299
|
+
colors=colors[inested] if is_nested else colors,
|
5300
|
+
shadow=shadow,#shadow,
|
5301
|
+
pctdistance=pctdistance,
|
5302
|
+
labeldistance=labeldistance,
|
5303
|
+
wedgeprops=kws_wedge,
|
5304
|
+
textprops=kws_text,
|
5305
|
+
center=center,
|
5306
|
+
radius=radius_,
|
5307
|
+
frame=frame,
|
5308
|
+
**kwargs
|
5309
|
+
)
|
5310
|
+
if len(result) == 3:
|
5311
|
+
wedges, texts, autotexts = result
|
5312
|
+
elif len(result) == 2:
|
5313
|
+
wedges, texts = result
|
5314
|
+
autotexts = None
|
5315
|
+
#! adjust_text
|
5316
|
+
if autotexts or texts:
|
5317
|
+
all_texts = []
|
5318
|
+
if autotexts and show_value:
|
5319
|
+
all_texts.extend(autotexts)
|
5320
|
+
if texts and show_label:
|
5321
|
+
all_texts.extend(texts)
|
5322
|
+
|
5323
|
+
adjust_text(
|
5324
|
+
all_texts,
|
5325
|
+
ax=ax,
|
5326
|
+
arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
|
5327
|
+
bbox=kws_bbox if kws_bbox else None,
|
5328
|
+
expand=expand_label,
|
5329
|
+
fontdict={
|
5330
|
+
"fontsize": fontsize,
|
5331
|
+
"color": fontcolor,
|
5332
|
+
},
|
5333
|
+
)
|
5334
|
+
# Show exact values on wedges if show_value is True
|
5335
|
+
if show_value:
|
5336
|
+
for i, (wedge, txt) in enumerate(zip(wedges, texts)):
|
5337
|
+
angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
|
5338
|
+
x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
|
5339
|
+
y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
|
5340
|
+
if not fmt.startswith("{"):
|
5341
|
+
value_text = f"{sizes[i]:{fmt}}"
|
5342
|
+
else:
|
5343
|
+
value_text = fmt.format(sizes[i])
|
5344
|
+
ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
|
5345
|
+
inested+=1
|
5346
|
+
# Customize the legend
|
5347
|
+
if show_legend:
|
5348
|
+
ax.legend(
|
5349
|
+
wedges,
|
5350
|
+
labels_legend,
|
5351
|
+
loc=legend_loc,
|
5352
|
+
bbox_to_anchor=bbox_to_anchor,
|
5353
|
+
fontsize=legend_fontsize,
|
5354
|
+
title_fontsize=legend_fontsize,
|
5355
|
+
)
|
5356
|
+
ax.set(aspect="equal")
|
5357
|
+
return ax
|
5358
|
+
|
5359
|
+
def ellipse(
|
5360
|
+
data,
|
5361
|
+
x=None,
|
5362
|
+
y=None,
|
5363
|
+
hue=None,
|
5364
|
+
n_std=1.5,
|
5365
|
+
ax=None,
|
5366
|
+
confidence=0.95,
|
5367
|
+
annotate_center=False,
|
5368
|
+
palette=None,
|
5369
|
+
facecolor=None,
|
5370
|
+
edgecolor=None,
|
5371
|
+
label:bool=True,
|
5372
|
+
**kwargs,
|
5373
|
+
):
|
5374
|
+
"""
|
5375
|
+
Plot advanced ellipses representing covariance for different groups
|
5376
|
+
# simulate data:
|
5377
|
+
control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
|
5378
|
+
patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
|
5379
|
+
df = pd.DataFrame(
|
5380
|
+
{
|
5381
|
+
"Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
|
5382
|
+
"Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
|
5383
|
+
"Group": ["Control"] * 50 + ["Patient"] * 50,
|
5384
|
+
}
|
5385
|
+
)
|
5386
|
+
plotxy(
|
5387
|
+
data=df,
|
5388
|
+
x="Dim1",
|
5389
|
+
y="Dim2",
|
5390
|
+
hue="Group",
|
5391
|
+
kind_="scatter",
|
5392
|
+
palette=get_color(8),
|
5393
|
+
)
|
5394
|
+
ellipse(
|
5395
|
+
data=df,
|
5396
|
+
x="Dim1",
|
5397
|
+
y="Dim2",
|
5398
|
+
hue="Group",
|
5399
|
+
palette=get_color(8),
|
5400
|
+
alpha=0.1,
|
5401
|
+
lw=2,
|
5402
|
+
)
|
5403
|
+
Parameters:
|
5404
|
+
data (DataFrame): Input DataFrame with columns for x, y, and hue.
|
5405
|
+
x (str): Column name for x-axis values.
|
5406
|
+
y (str): Column name for y-axis values.
|
5407
|
+
hue (str, optional): Column name for group labels.
|
5408
|
+
n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
|
5409
|
+
ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
|
5410
|
+
confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
|
5411
|
+
annotate_center (bool): Whether to annotate the ellipse center (mean).
|
5412
|
+
palette (dict or list, optional): A mapping of hues to colors or a list of colors.
|
5413
|
+
**kwargs: Additional keyword arguments for the Ellipse patch.
|
5414
|
+
|
5415
|
+
Returns:
|
5416
|
+
list: List of Ellipse objects added to the Axes.
|
5417
|
+
"""
|
5418
|
+
from matplotlib.patches import Ellipse
|
5419
|
+
import numpy as np
|
5420
|
+
import matplotlib.pyplot as plt
|
5421
|
+
import seaborn as sns
|
5422
|
+
import pandas as pd
|
5423
|
+
from scipy.stats import chi2
|
5424
|
+
|
5425
|
+
if ax is None:
|
5426
|
+
ax = plt.gca()
|
5427
|
+
|
5428
|
+
# Validate inputs
|
5429
|
+
if x is None or y is None:
|
5430
|
+
raise ValueError(
|
5431
|
+
"Both `x` and `y` must be specified as column names in the DataFrame."
|
5432
|
+
)
|
5433
|
+
if not isinstance(data, pd.DataFrame):
|
5434
|
+
raise ValueError("`data` must be a pandas DataFrame.")
|
5435
|
+
|
5436
|
+
# Prepare data for hue-based grouping
|
5437
|
+
ellipses = []
|
5438
|
+
if hue is not None:
|
5439
|
+
groups = data[hue].unique()
|
5440
|
+
colors = sns.color_palette(palette or "husl", len(groups))
|
5441
|
+
color_map = dict(zip(groups, colors))
|
5442
|
+
else:
|
5443
|
+
groups = [None]
|
5444
|
+
color_map = {None: kwargs.get("edgecolor", "blue")}
|
5445
|
+
alpha = kwargs.pop("alpha", 0.2)
|
5446
|
+
edgecolor=kwargs.pop("edgecolor", None)
|
5447
|
+
facecolor=kwargs.pop("facecolor", None)
|
5448
|
+
for group in groups:
|
5449
|
+
group_data = data[data[hue] == group] if hue else data
|
5450
|
+
|
5451
|
+
# Extract x and y columns for the group
|
5452
|
+
group_points = group_data[[x, y]].values
|
5453
|
+
|
5454
|
+
# Compute mean and covariance matrix
|
5455
|
+
# # 标准化处理
|
5456
|
+
# group_points = group_data[[x, y]].values
|
5457
|
+
# group_points -= group_points.mean(axis=0)
|
5458
|
+
# group_points /= group_points.std(axis=0)
|
5459
|
+
|
5460
|
+
cov = np.cov(group_points.T)
|
5461
|
+
mean = np.mean(group_points, axis=0)
|
5462
|
+
|
5463
|
+
# Eigenvalues and eigenvectors
|
5464
|
+
eigvals, eigvecs = np.linalg.eigh(cov)
|
5465
|
+
order = eigvals.argsort()[::-1]
|
5466
|
+
eigvals, eigvecs = eigvals[order], eigvecs[:, order]
|
5467
|
+
|
5468
|
+
# Rotation angle and ellipse dimensions
|
5469
|
+
angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
|
5470
|
+
if confidence:
|
5471
|
+
n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
|
5472
|
+
width, height = 2 * n_std * np.sqrt(eigvals)
|
5473
|
+
|
5474
|
+
# Create and style the ellipse
|
5475
|
+
if facecolor is None:
|
5476
|
+
facecolor_ = color_map[group]
|
5477
|
+
if edgecolor is None:
|
5478
|
+
edgecolor_ = color_map[group]
|
5479
|
+
ellipse = Ellipse(
|
5480
|
+
xy=mean,
|
5481
|
+
width=width,
|
5482
|
+
height=height,
|
5483
|
+
angle=angle,
|
5484
|
+
edgecolor=edgecolor_,
|
5485
|
+
facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
|
5486
|
+
# alpha=alpha,
|
5487
|
+
label=group if (hue and label) else None,
|
5488
|
+
**kwargs,
|
5489
|
+
)
|
5490
|
+
ax.add_patch(ellipse)
|
5491
|
+
ellipses.append(ellipse)
|
5492
|
+
|
5493
|
+
# Annotate center
|
5494
|
+
if annotate_center:
|
5495
|
+
ax.annotate(
|
5496
|
+
f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
|
5497
|
+
xy=mean,
|
5498
|
+
xycoords="data",
|
5499
|
+
fontsize=10,
|
5500
|
+
ha="center",
|
5501
|
+
color=ellipse_color,
|
5502
|
+
bbox=dict(
|
5503
|
+
boxstyle="round,pad=0.3",
|
5504
|
+
edgecolor="gray",
|
5505
|
+
facecolor="white",
|
5506
|
+
alpha=0.8,
|
5507
|
+
),
|
5508
|
+
)
|
5509
|
+
|
5510
|
+
return ax
|
5511
|
+
|
5512
|
+
def ppi(
|
5513
|
+
interactions,
|
5514
|
+
player1="preferredName_A",
|
5515
|
+
player2="preferredName_B",
|
5516
|
+
weight="score",
|
5517
|
+
n_layers=None, # Number of concentric layers
|
5518
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5519
|
+
dist_node = 10, # Distance between each rank of circles
|
5520
|
+
layout="degree",
|
5521
|
+
size=None,#700,
|
5522
|
+
sizes=(50,500),# min and max of size
|
5523
|
+
facecolor="skyblue",
|
5524
|
+
cmap='coolwarm',
|
5525
|
+
edgecolor="k",
|
5526
|
+
edgelinewidth=1.5,
|
5527
|
+
alpha=.5,
|
5528
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5529
|
+
marker="o",
|
5530
|
+
node_hideticks=True,
|
5531
|
+
linecolor="gray",
|
5532
|
+
line_cmap='coolwarm',
|
5533
|
+
linewidth=1.5,
|
5534
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5535
|
+
linealpha=1.0,
|
5536
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5537
|
+
linestyle="-",
|
5538
|
+
line_arrowstyle='-',
|
5539
|
+
fontsize=10,
|
5540
|
+
fontcolor="k",
|
5541
|
+
ha:str="center",
|
5542
|
+
va:str="center",
|
5543
|
+
figsize=(12, 10),
|
5544
|
+
k_value=0.3,
|
5545
|
+
bgcolor="w",
|
5546
|
+
dir_save="./ppi_network.html",
|
5547
|
+
physics=True,
|
5548
|
+
notebook=False,
|
5549
|
+
scale=1,
|
5550
|
+
ax=None,
|
5551
|
+
**kwargs
|
5552
|
+
):
|
5553
|
+
"""
|
5554
|
+
Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
|
5555
|
+
|
5556
|
+
ppi(
|
5557
|
+
interactions_sort.iloc[:1000, :],
|
5558
|
+
player1="player1",
|
5559
|
+
player2="player2",
|
5560
|
+
weight="count",
|
5561
|
+
layout="spring",
|
5562
|
+
n_layers=13,
|
5563
|
+
fontsize=1,
|
5564
|
+
n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
|
5565
|
+
)
|
5566
|
+
"""
|
5567
|
+
from pyvis.network import Network
|
5568
|
+
import networkx as nx
|
5569
|
+
from IPython.display import IFrame
|
5570
|
+
from matplotlib.colors import Normalize
|
5571
|
+
from matplotlib import cm
|
5572
|
+
from . import ips
|
5573
|
+
|
5574
|
+
if run_once_within():
|
5575
|
+
usage_str="""
|
5576
|
+
ppi(
|
5577
|
+
interactions,
|
5578
|
+
player1="preferredName_A",
|
5579
|
+
player2="preferredName_B",
|
5580
|
+
weight="score",
|
5581
|
+
n_layers=None, # Number of concentric layers
|
5582
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5583
|
+
dist_node = 10, # Distance between each rank of circles
|
5584
|
+
layout="degree",
|
5585
|
+
size=None,#700,
|
5586
|
+
sizes=(50,500),# min and max of size
|
5587
|
+
facecolor="skyblue",
|
5588
|
+
cmap='coolwarm',
|
5589
|
+
edgecolor="k",
|
5590
|
+
edgelinewidth=1.5,
|
5591
|
+
alpha=.5,
|
5592
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5593
|
+
marker="o",
|
5594
|
+
node_hideticks=True,
|
5595
|
+
linecolor="gray",
|
5596
|
+
line_cmap='coolwarm',
|
5597
|
+
linewidth=1.5,
|
5598
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5599
|
+
linealpha=1.0,
|
5600
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5601
|
+
linestyle="-",
|
5602
|
+
line_arrowstyle='-',
|
5603
|
+
fontsize=10,
|
5604
|
+
fontcolor="k",
|
5605
|
+
ha:str="center",
|
5606
|
+
va:str="center",
|
5607
|
+
figsize=(12, 10),
|
5608
|
+
k_value=0.3,
|
5609
|
+
bgcolor="w",
|
5610
|
+
dir_save="./ppi_network.html",
|
5611
|
+
physics=True,
|
5612
|
+
notebook=False,
|
5613
|
+
scale=1,
|
5614
|
+
ax=None,
|
5615
|
+
**kwargs
|
5616
|
+
):
|
5617
|
+
"""
|
5618
|
+
print(usage_str)
|
5619
|
+
|
5620
|
+
# Check for required columns in the DataFrame
|
5621
|
+
for col in [player1, player2, weight]:
|
5622
|
+
if col not in interactions.columns:
|
5623
|
+
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
5624
|
+
interactions.sort_values(by=[weight], inplace=True)
|
5625
|
+
# Initialize Pyvis network
|
5626
|
+
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
5627
|
+
net.force_atlas_2based(
|
5628
|
+
gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
|
5629
|
+
)
|
5630
|
+
net.toggle_physics(physics)
|
5631
|
+
|
5632
|
+
kws_figsets = {}
|
5633
|
+
for k_arg, v_arg in kwargs.items():
|
5634
|
+
if "figset" in k_arg:
|
5635
|
+
kws_figsets = v_arg
|
5636
|
+
kwargs.pop(k_arg, None)
|
5637
|
+
break
|
5638
|
+
|
5639
|
+
# Create a NetworkX graph from the interaction data
|
5640
|
+
G = nx.Graph()
|
5641
|
+
for _, row in interactions.iterrows():
|
5642
|
+
G.add_edge(row[player1], row[player2], weight=row[weight])
|
5643
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
5644
|
+
|
5645
|
+
|
5646
|
+
# Calculate node degrees
|
5647
|
+
degrees = dict(G.degree())
|
5648
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5649
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
5650
|
+
|
5651
|
+
if not ips.isa(facecolor, 'color'):
|
5652
|
+
print("facecolor: based on degrees")
|
5653
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
5654
|
+
num_nodes = G.number_of_nodes()
|
5655
|
+
#* size
|
5656
|
+
# Set properties based on degrees
|
5657
|
+
if not isinstance(size, (int,float,list)):
|
5658
|
+
print("size: based on degrees")
|
5659
|
+
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
5660
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
5661
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
5662
|
+
# Normalize sizes
|
5663
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
5664
|
+
min_degree, max_degree = min(size), max(size)
|
5665
|
+
if max_degree > min_degree: # Avoid division by zero
|
5666
|
+
size = [
|
5667
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
5668
|
+
for sz in size
|
5669
|
+
]
|
5670
|
+
else:
|
5671
|
+
# If all values are the same, set them to a default of the midpoint
|
5672
|
+
size = [(min_size + max_size) / 2] * len(size)
|
5673
|
+
|
5674
|
+
#* facecolor
|
5675
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
5676
|
+
# * facealpha
|
5677
|
+
if isinstance(alpha, list):
|
5678
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
5679
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
5680
|
+
if len(alpha) > 0:
|
5681
|
+
# Normalize alpha based on the specified min and max
|
5682
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
5683
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
5684
|
+
alpha = [
|
5685
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
5686
|
+
for ea in alpha
|
5687
|
+
]
|
5688
|
+
else:
|
5689
|
+
# If all alpha values are the same, set them to the average of min and max
|
5690
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
5691
|
+
else:
|
5692
|
+
# Default to a full opacity if no edges are provided
|
5693
|
+
alpha = [1.0] * num_nodes
|
5694
|
+
else:
|
5695
|
+
# If alpha is a single value, convert it to a list and normalize it
|
5696
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
5697
|
+
|
5698
|
+
for i, node in enumerate(G.nodes()):
|
5699
|
+
net.add_node(
|
5700
|
+
node,
|
5701
|
+
label=node,
|
5702
|
+
size=size[i],
|
5703
|
+
color=facecolor[i],
|
5704
|
+
alpha=alpha[i],
|
5705
|
+
font={"size": fontsize, "color": fontcolor},
|
5706
|
+
)
|
5707
|
+
print(f'nodes number: {i+1}')
|
5708
|
+
|
5709
|
+
for edge in G.edges(data=True):
|
5710
|
+
net.add_edge(
|
5711
|
+
edge[0],
|
5712
|
+
edge[1],
|
5713
|
+
weight=edge[2]["weight"],
|
5714
|
+
color=edgecolor,
|
5715
|
+
width=edgelinewidth * edge[2]["weight"],
|
5716
|
+
)
|
5717
|
+
|
5718
|
+
layouts = [
|
5719
|
+
"spring",
|
5720
|
+
"circular",
|
5721
|
+
"kamada_kawai",
|
5722
|
+
"random",
|
5723
|
+
"shell",
|
5724
|
+
"planar",
|
5725
|
+
"spiral",
|
5726
|
+
"degree"
|
5727
|
+
]
|
5728
|
+
layout = ips.strcmp(layout, layouts)[0]
|
5729
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
5730
|
+
|
5731
|
+
# Choose layout
|
5732
|
+
if layout == "spring":
|
5733
|
+
pos = nx.spring_layout(G, k=k_value)
|
5734
|
+
elif layout == "circular":
|
5735
|
+
pos = nx.circular_layout(G)
|
5736
|
+
elif layout == "kamada_kawai":
|
5737
|
+
pos = nx.kamada_kawai_layout(G)
|
5738
|
+
elif layout == "spectral":
|
5739
|
+
pos = nx.spectral_layout(G)
|
5740
|
+
elif layout == "random":
|
5741
|
+
pos = nx.random_layout(G)
|
5742
|
+
elif layout == "shell":
|
5743
|
+
pos = nx.shell_layout(G)
|
5744
|
+
elif layout == "planar":
|
5745
|
+
if nx.check_planarity(G)[0]:
|
5746
|
+
pos = nx.planar_layout(G)
|
5747
|
+
else:
|
5748
|
+
print("Graph is not planar; switching to spring layout.")
|
5749
|
+
pos = nx.spring_layout(G, k=k_value)
|
5750
|
+
elif layout == "spiral":
|
5751
|
+
pos = nx.spiral_layout(G)
|
5752
|
+
elif layout=='degree':
|
5753
|
+
# Calculate node degrees and sort nodes by degree
|
5754
|
+
degrees = dict(G.degree())
|
5755
|
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
5756
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5757
|
+
colormap = cm.get_cmap(cmap)
|
5758
|
+
|
5759
|
+
# Create positions for concentric circles based on n_layers and n_rank
|
5760
|
+
pos = {}
|
5761
|
+
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
5762
|
+
for rank_index in range(n_layers):
|
5763
|
+
if rank_index < len(n_rank):
|
5764
|
+
nodes_per_rank = n_rank[rank_index]
|
5765
|
+
rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
|
5766
|
+
else:
|
5767
|
+
# 随机打乱剩余节点的顺序
|
5768
|
+
remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
|
5769
|
+
random_indices = np.random.permutation(len(remaining_nodes))
|
5770
|
+
rank_nodes = [remaining_nodes[i] for i in random_indices]
|
5771
|
+
|
5772
|
+
radius = (rank_index + 1) * dist_node # Radius for this rank
|
5773
|
+
|
5774
|
+
# Arrange nodes in a circle for the current rank
|
5775
|
+
for i, (node, degree) in enumerate(rank_nodes):
|
5776
|
+
angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
|
5777
|
+
pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
|
5778
|
+
|
5779
|
+
else:
|
5780
|
+
print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
|
5781
|
+
pos = nx.spring_layout(G, k=k_value)
|
5782
|
+
|
5783
|
+
for node, (x, y) in pos.items():
|
5784
|
+
net.get_node(node)["x"] = x * scale
|
5785
|
+
net.get_node(node)["y"] = y * scale
|
5786
|
+
|
5787
|
+
# If ax is None, use plt.gca()
|
5788
|
+
if ax is None:
|
5789
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
5790
|
+
|
5791
|
+
# Draw nodes, edges, and labels with customization options
|
5792
|
+
nx.draw_networkx_nodes(
|
5793
|
+
G,
|
5794
|
+
pos,
|
5795
|
+
ax=ax,
|
5796
|
+
node_size=size,
|
5797
|
+
node_color=facecolor,
|
5798
|
+
linewidths=edgelinewidth,
|
5799
|
+
edgecolors=edgecolor,
|
5800
|
+
alpha=alpha,
|
5801
|
+
hide_ticks=node_hideticks,
|
5802
|
+
node_shape=marker
|
5803
|
+
)
|
5804
|
+
|
5805
|
+
#* linewidth
|
5806
|
+
if not isinstance(linewidth, list):
|
5807
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5808
|
+
else:
|
5809
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
5810
|
+
# Normalize linewidth if it is a list
|
5811
|
+
if isinstance(linewidth, list):
|
5812
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
5813
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
5814
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
5815
|
+
# Scale between vmin and vmax
|
5816
|
+
linewidth = [
|
5817
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
5818
|
+
for lw in linewidth
|
5819
|
+
]
|
5820
|
+
else:
|
5821
|
+
# If all values are the same, set them to a default of the midpoint
|
5822
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
5823
|
+
else:
|
5824
|
+
# If linewidth is a single value, convert it to a list of that value
|
5825
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5826
|
+
#* linecolor
|
5827
|
+
if not isinstance(linecolor, str):
|
5828
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
5829
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
5830
|
+
colormap = cm.get_cmap(line_cmap)
|
5831
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
5832
|
+
else:
|
5833
|
+
linecolor = [linecolor] * G.number_of_edges()
|
5834
|
+
|
5835
|
+
# * linealpha
|
5836
|
+
if isinstance(linealpha, list):
|
5837
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
5838
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
5839
|
+
if len(linealpha) > 0:
|
5840
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
5841
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
5842
|
+
linealpha = [
|
5843
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
5844
|
+
for ea in linealpha
|
5845
|
+
]
|
5846
|
+
else:
|
5847
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
5848
|
+
else:
|
5849
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
5850
|
+
else:
|
5851
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
5852
|
+
nx.draw_networkx_edges(
|
5853
|
+
G,
|
5854
|
+
pos,
|
5855
|
+
ax=ax,
|
5856
|
+
edge_color=linecolor,
|
5857
|
+
width=linewidth,
|
5858
|
+
style=linestyle,
|
5859
|
+
arrowstyle=line_arrowstyle,
|
5860
|
+
alpha=linealpha
|
5861
|
+
)
|
5862
|
+
|
5863
|
+
nx.draw_networkx_labels(
|
5864
|
+
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
5865
|
+
)
|
5866
|
+
figsets(ax=ax,**kws_figsets)
|
5867
|
+
ax.axis("off")
|
5868
|
+
if dir_save:
|
5869
|
+
if not os.path.basename(dir_save):
|
5870
|
+
dir_save="_.html"
|
5871
|
+
net.write_html(dir_save)
|
5872
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
5873
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
5874
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
5875
|
+
return G,ax
|