py2ls 0.2.4.24__py3-none-any.whl → 0.2.4.26__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- py2ls/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/usages_sns.json +6 -1
- py2ls/ec2ls.py +61 -0
- py2ls/ips.py +496 -138
- py2ls/ml2ls.py +994 -288
- py2ls/netfinder.py +16 -20
- py2ls/nl2ls.py +283 -0
- py2ls/plot.py +1244 -158
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/METADATA +5 -1
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/RECORD +17 -14
- py2ls/data/usages_pd copy.json +0 -1105
- py2ls/ml2ls copy.py +0 -2906
- {py2ls-0.2.4.24.dist-info → py2ls-0.2.4.26.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -20,7 +20,9 @@ from .ips import (
|
|
20
20
|
flatten,
|
21
21
|
plt_font,
|
22
22
|
run_once_within,
|
23
|
-
df_format
|
23
|
+
df_format,
|
24
|
+
df_corr,
|
25
|
+
df_scaler
|
24
26
|
)
|
25
27
|
import scipy.stats as scipy_stats
|
26
28
|
from .stats import *
|
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
|
142
144
|
**kwargs,
|
143
145
|
)
|
144
146
|
|
147
|
+
def pval2str(p):
|
148
|
+
if p > 0.05:
|
149
|
+
txt = ""
|
150
|
+
elif 0.01 <= p <= 0.05:
|
151
|
+
txt = "*"
|
152
|
+
elif 0.001 <= p < 0.01:
|
153
|
+
txt = "**"
|
154
|
+
elif p < 0.001:
|
155
|
+
txt = "***"
|
156
|
+
return txt
|
145
157
|
|
146
158
|
def heatmap(
|
147
159
|
data,
|
148
160
|
ax=None,
|
149
161
|
kind="corr", #'corr','direct','pivot'
|
162
|
+
method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
|
150
163
|
columns="all", # pivot, default: coll numeric columns
|
164
|
+
style=0,# for correlation
|
151
165
|
index=None, # pivot
|
152
166
|
values=None, # pivot
|
167
|
+
fontsize=10,
|
153
168
|
tri="u",
|
154
169
|
mask=True,
|
155
170
|
k=1,
|
171
|
+
vmin=None,
|
172
|
+
vmax=None,
|
173
|
+
size_scale=500,
|
156
174
|
annot=True,
|
157
175
|
cmap="coolwarm",
|
158
176
|
fmt=".2f",
|
177
|
+
show_indicator = True,# only for style==1
|
159
178
|
cluster=False,
|
160
179
|
inplace=False,
|
161
180
|
figsize=(10, 8),
|
@@ -201,7 +220,7 @@ def heatmap(
|
|
201
220
|
ax = plt.gca()
|
202
221
|
# Select numeric columns or specific subset of columns
|
203
222
|
if columns == "all":
|
204
|
-
df_numeric = data.select_dtypes(include=[
|
223
|
+
df_numeric = data.select_dtypes(include=[np.number])
|
205
224
|
else:
|
206
225
|
df_numeric = data[columns]
|
207
226
|
|
@@ -209,8 +228,10 @@ def heatmap(
|
|
209
228
|
kind = strcmp(kind, kinds)[0]
|
210
229
|
print(kind)
|
211
230
|
if "corr" in kind: # correlation
|
231
|
+
methods = ["pearson", "spearman", "kendall"]
|
232
|
+
method = strcmp(method, methods)[0]
|
212
233
|
# Compute the correlation matrix
|
213
|
-
data4heatmap = df_numeric.corr()
|
234
|
+
data4heatmap = df_numeric.corr(method=method)
|
214
235
|
# Generate mask for the upper triangle if mask is True
|
215
236
|
if mask:
|
216
237
|
if "u" in tri.lower(): # upper => np.tril
|
@@ -288,18 +309,76 @@ def heatmap(
|
|
288
309
|
df_col_cluster,
|
289
310
|
)
|
290
311
|
else:
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
312
|
+
if style==0:
|
313
|
+
# Create a standard heatmap
|
314
|
+
ax = sns.heatmap(
|
315
|
+
data4heatmap,
|
316
|
+
ax=ax,
|
317
|
+
mask=mask_array,
|
318
|
+
annot=annot,
|
319
|
+
cmap=cmap,
|
320
|
+
fmt=fmt,
|
321
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
322
|
+
)
|
323
|
+
return ax
|
324
|
+
elif style==1:
|
325
|
+
if isinstance(cmap, str):
|
326
|
+
cmap = plt.get_cmap(cmap)
|
327
|
+
norm = plt.Normalize(vmin=-1, vmax=1)
|
328
|
+
r_, p_ = df_corr(data4heatmap, method=method)
|
329
|
+
# size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
|
330
|
+
# 初始化一个空的可绘制对象用于颜色条
|
331
|
+
scatter_handles = []
|
332
|
+
# 循环绘制气泡图和数值
|
333
|
+
for i in range(len(r_.columns)):
|
334
|
+
for j in range(len(r_.columns)):
|
335
|
+
if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
|
336
|
+
color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
|
337
|
+
scatter = ax.scatter(
|
338
|
+
i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
|
339
|
+
# alpha=1,edgecolor=edgecolor,linewidth=linewidth,
|
340
|
+
**kwargs
|
341
|
+
)
|
342
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
343
|
+
# add *** indicators
|
344
|
+
if show_indicator:
|
345
|
+
ax.text(
|
346
|
+
i,
|
347
|
+
j,
|
348
|
+
pval2str(p_.iloc[i, j]),
|
349
|
+
ha="center",
|
350
|
+
va="center",
|
351
|
+
color="k",
|
352
|
+
fontsize=fontsize * 1.3,
|
353
|
+
)
|
354
|
+
elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
|
355
|
+
color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
|
356
|
+
ax.text(
|
357
|
+
i,
|
358
|
+
j,
|
359
|
+
f"{r_.iloc[i, j]:{fmt}}",
|
360
|
+
ha="center",
|
361
|
+
va="center",
|
362
|
+
color=color,
|
363
|
+
fontsize=fontsize,
|
364
|
+
)
|
365
|
+
else: # 对角线部分,显示空白
|
366
|
+
ax.scatter(i, j, s=1, color="white")
|
367
|
+
# 设置坐标轴标签
|
368
|
+
figsets(xticks=range(len(r_.columns)),
|
369
|
+
xticklabels=r_.columns,
|
370
|
+
xangle=90,
|
371
|
+
fontsize=fontsize,
|
372
|
+
yticks=range(len(r_.columns)),
|
373
|
+
yticklabels=r_.columns,
|
374
|
+
xlim=[-0.5,len(r_.columns)-0.5],
|
375
|
+
ylim=[-0.5,len(r_.columns)-0.5]
|
376
|
+
)
|
377
|
+
# 添加颜色条
|
378
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
379
|
+
sm.set_array([]) # 仅用于显示颜色条
|
380
|
+
plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
|
381
|
+
return ax
|
303
382
|
elif "dir" in kind: # direct
|
304
383
|
data4heatmap = df_numeric
|
305
384
|
elif "pi" in kind: # pivot
|
@@ -389,16 +468,53 @@ def heatmap(
|
|
389
468
|
)
|
390
469
|
else:
|
391
470
|
# Create a standard heatmap
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
471
|
+
if style==0:
|
472
|
+
ax = sns.heatmap(
|
473
|
+
data4heatmap,
|
474
|
+
ax=ax,
|
475
|
+
annot=annot,
|
476
|
+
cmap=cmap,
|
477
|
+
fmt=fmt,
|
478
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
479
|
+
)
|
480
|
+
# Return the Axes object for further customization if needed
|
481
|
+
return ax
|
482
|
+
elif style==1:
|
483
|
+
if isinstance(cmap, str):
|
484
|
+
cmap = plt.get_cmap(cmap)
|
485
|
+
if vmin is None:
|
486
|
+
vmin=np.min(data4heatmap)
|
487
|
+
if vmax is None:
|
488
|
+
vmax=np.max(data4heatmap)
|
489
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
490
|
+
|
491
|
+
# 初始化一个空的可绘制对象用于颜色条
|
492
|
+
scatter_handles = []
|
493
|
+
# 循环绘制气泡图和数值
|
494
|
+
print(len(data4heatmap.index),len(data4heatmap.columns))
|
495
|
+
for i in range(len(data4heatmap.index)):
|
496
|
+
for j in range(len(data4heatmap.columns)):
|
497
|
+
color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
|
498
|
+
scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
|
499
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
500
|
+
|
501
|
+
# 设置坐标轴标签
|
502
|
+
figsets(xticks=range(len(data4heatmap.columns)),
|
503
|
+
xticklabels=data4heatmap.columns,
|
504
|
+
xangle=90,
|
505
|
+
fontsize=fontsize,
|
506
|
+
yticks=range(len(data4heatmap.index)),
|
507
|
+
yticklabels=data4heatmap.index,
|
508
|
+
xlim=[-0.5,len(data4heatmap.columns)-0.5],
|
509
|
+
ylim=[-0.5,len(data4heatmap.index)-0.5]
|
510
|
+
)
|
511
|
+
# 添加颜色条
|
512
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
513
|
+
sm.set_array([]) # 仅用于显示颜色条
|
514
|
+
plt.colorbar(sm, ax=ax,
|
515
|
+
# label="Correlation Coefficient"
|
516
|
+
)
|
517
|
+
return ax
|
402
518
|
|
403
519
|
|
404
520
|
# !usage: py2ls.plot.heatmap()
|
@@ -2475,7 +2591,8 @@ def get_color(
|
|
2475
2591
|
"#B25E9D",
|
2476
2592
|
"#4B8C3B",
|
2477
2593
|
"#EF8632",
|
2478
|
-
"#24578E"
|
2594
|
+
"#24578E",
|
2595
|
+
"#FF2C00",
|
2479
2596
|
]
|
2480
2597
|
elif n == 8:
|
2481
2598
|
# colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
|
@@ -3199,9 +3316,12 @@ def plotxy(
|
|
3199
3316
|
zorder = 0
|
3200
3317
|
for k in kind_:
|
3201
3318
|
# preprocess data
|
3202
|
-
|
3203
|
-
|
3204
|
-
|
3319
|
+
try:
|
3320
|
+
data=df_preprocessing_(data, kind=k)
|
3321
|
+
if 'variable' in data.columns and 'value' in data.columns:
|
3322
|
+
x,y='variable','value'
|
3323
|
+
except Exception as e:
|
3324
|
+
print(e)
|
3205
3325
|
zorder += 1
|
3206
3326
|
# indicate 'col' features
|
3207
3327
|
col = kwargs.get("col", None)
|
@@ -3222,18 +3342,37 @@ def plotxy(
|
|
3222
3342
|
# (1) return FcetGrid
|
3223
3343
|
if k == "jointplot":
|
3224
3344
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
3225
|
-
|
3345
|
+
kws_joint = {
|
3346
|
+
k: v for k, v in kws_joint.items() if not k.startswith("kws_")
|
3347
|
+
}
|
3348
|
+
hue = kwargs.get("hue", None)
|
3349
|
+
if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3350
|
+
kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
|
3351
|
+
|
3352
|
+
palette = kwargs.get("palette", None)
|
3353
|
+
if palette is None:
|
3354
|
+
palette = kws_joint.pop(
|
3355
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3356
|
+
)
|
3357
|
+
else:
|
3358
|
+
kws_joint.pop("palette", palette)
|
3359
|
+
stats=kwargs.pop("stats",None)
|
3360
|
+
if stats:
|
3361
|
+
stats=kws_joint.pop("stats",True)
|
3226
3362
|
if stats:
|
3227
3363
|
r, p_value = scipy_stats.pearsonr(data[x], data[y])
|
3228
|
-
|
3229
|
-
|
3230
|
-
|
3231
|
-
|
3232
|
-
|
3233
|
-
|
3234
|
-
|
3235
|
-
|
3236
|
-
|
3364
|
+
for key in ["palette", "alpha", "hue","stats"]:
|
3365
|
+
kws_joint.pop(key, None)
|
3366
|
+
g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
|
3367
|
+
if stats:
|
3368
|
+
g.ax_joint.annotate(
|
3369
|
+
f"pearsonr = {r:.2f} p = {p_value:.3f}",
|
3370
|
+
xy=(0.6, 0.98),
|
3371
|
+
xycoords="axes fraction",
|
3372
|
+
fontsize=12,
|
3373
|
+
color="black",
|
3374
|
+
ha="center",
|
3375
|
+
)
|
3237
3376
|
elif k == "lmplot":
|
3238
3377
|
kws_lm = kwargs.pop("kws_lm", kwargs)
|
3239
3378
|
stats = kwargs.pop("stats", True) # Flag to calculate stats
|
@@ -3379,8 +3518,7 @@ def plotxy(
|
|
3379
3518
|
xycoords="axes fraction",
|
3380
3519
|
fontsize=12,
|
3381
3520
|
color="black",
|
3382
|
-
ha="center"
|
3383
|
-
)
|
3521
|
+
ha="center")
|
3384
3522
|
|
3385
3523
|
elif k == "catplot_sns":
|
3386
3524
|
kws_cat = kwargs.pop("kws_cat", kwargs)
|
@@ -3388,7 +3526,7 @@ def plotxy(
|
|
3388
3526
|
elif k == "displot":
|
3389
3527
|
kws_dis = kwargs.pop("kws_dis", kwargs)
|
3390
3528
|
# displot creates a new figure and returns a FacetGrid
|
3391
|
-
g = sns.displot(data=data, x=x,
|
3529
|
+
g = sns.displot(data=data, x=x,y=y, **kws_dis)
|
3392
3530
|
|
3393
3531
|
# (2) return axis
|
3394
3532
|
if ax is None:
|
@@ -3399,19 +3537,58 @@ def plotxy(
|
|
3399
3537
|
elif k == "stdshade":
|
3400
3538
|
kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
|
3401
3539
|
ax = stdshade(ax=ax, **kwargs)
|
3540
|
+
elif k=="ellipse":
|
3541
|
+
kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
|
3542
|
+
kws_ellipse = {
|
3543
|
+
k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
|
3544
|
+
}
|
3545
|
+
hue = kwargs.get("hue", None)
|
3546
|
+
if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3547
|
+
kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
|
3548
|
+
|
3549
|
+
palette = kwargs.get("palette", None)
|
3550
|
+
if palette is None:
|
3551
|
+
palette = kws_ellipse.pop(
|
3552
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3553
|
+
)
|
3554
|
+
alpha = kws_ellipse.pop("alpha", 0.1)
|
3555
|
+
hue_order = kwargs.get("hue_order",None)
|
3556
|
+
if hue_order is None:
|
3557
|
+
hue_order = kws_ellipse.get("hue_order",None)
|
3558
|
+
if hue_order:
|
3559
|
+
data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
|
3560
|
+
data = data.sort_values(by="hue")
|
3561
|
+
for key in ["palette", "alpha", "hue","hue_order"]:
|
3562
|
+
kws_ellipse.pop(key, None)
|
3563
|
+
ax=ellipse(
|
3564
|
+
ax=ax,
|
3565
|
+
data=data,
|
3566
|
+
x=x,
|
3567
|
+
y=y,
|
3568
|
+
hue=hue,
|
3569
|
+
palette=palette,
|
3570
|
+
alpha=alpha,
|
3571
|
+
zorder=zorder,
|
3572
|
+
**kws_ellipse,
|
3573
|
+
)
|
3402
3574
|
elif k == "scatterplot":
|
3403
3575
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3404
3576
|
kws_scatter = {
|
3405
3577
|
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3406
3578
|
}
|
3407
|
-
hue = kwargs.
|
3579
|
+
hue = kwargs.get("hue", None)
|
3408
3580
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3409
3581
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3410
|
-
palette
|
3411
|
-
|
3412
|
-
|
3582
|
+
palette=kws_scatter.get("palette",None)
|
3583
|
+
if palette is None:
|
3584
|
+
palette = kws_scatter.pop(
|
3585
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3586
|
+
)
|
3413
3587
|
s = kws_scatter.pop("s", 10)
|
3414
|
-
alpha = kws_scatter.pop("alpha", 0.7)
|
3588
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
3589
|
+
for key in ["s", "palette", "alpha", "hue"]:
|
3590
|
+
kws_scatter.pop(key, None)
|
3591
|
+
|
3415
3592
|
ax = sns.scatterplot(
|
3416
3593
|
ax=ax,
|
3417
3594
|
data=data,
|
@@ -3427,19 +3604,33 @@ def plotxy(
|
|
3427
3604
|
elif k == "histplot":
|
3428
3605
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3429
3606
|
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3430
|
-
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3431
|
-
elif k == "kdeplot":
|
3607
|
+
ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
|
3608
|
+
elif k == "kdeplot":
|
3432
3609
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3433
|
-
kws_kde = {
|
3434
|
-
|
3610
|
+
kws_kde = {
|
3611
|
+
k: v for k, v in kws_kde.items() if not k.startswith("kws_")
|
3612
|
+
}
|
3613
|
+
hue = kwargs.get("hue", None)
|
3614
|
+
if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
|
3615
|
+
kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
|
3616
|
+
|
3617
|
+
palette = kwargs.get("palette", None)
|
3618
|
+
if palette is None:
|
3619
|
+
palette = kws_kde.pop(
|
3620
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3621
|
+
)
|
3622
|
+
alpha = kws_kde.pop("alpha", 0.05)
|
3623
|
+
for key in ["palette", "alpha", "hue"]:
|
3624
|
+
kws_kde.pop(key, None)
|
3625
|
+
ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
|
3435
3626
|
elif k == "ecdfplot":
|
3436
3627
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3437
3628
|
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3438
|
-
ax = sns.ecdfplot(data=data, x=x,
|
3629
|
+
ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
|
3439
3630
|
elif k == "rugplot":
|
3440
3631
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3441
3632
|
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3442
|
-
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3633
|
+
ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
|
3443
3634
|
elif k == "stripplot":
|
3444
3635
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3445
3636
|
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
@@ -3512,7 +3703,8 @@ def plotxy(
|
|
3512
3703
|
figsets(ax=ax, **kws_figsets)
|
3513
3704
|
if kws_add_text:
|
3514
3705
|
add_text(ax=ax, **kws_add_text) if kws_add_text else None
|
3515
|
-
|
3706
|
+
if run_once_within(10):
|
3707
|
+
for k in kind_:
|
3516
3708
|
print(f"\n{k}⤵ ")
|
3517
3709
|
print(default_settings[k])
|
3518
3710
|
# print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
|
@@ -3570,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3570
3762
|
"lineplot", # Can work with both wide and long formats
|
3571
3763
|
"area plot", # Can work with both formats, useful for stacked areas
|
3572
3764
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3765
|
+
"ellipse",# ellipse plot, default confidence=0.95
|
3573
3766
|
],
|
3574
3767
|
)[0]
|
3575
3768
|
|
@@ -3605,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3605
3798
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3606
3799
|
"relplot",
|
3607
3800
|
"pointplot", # Works well with wide format
|
3801
|
+
"ellipse",
|
3608
3802
|
]
|
3609
3803
|
|
3610
3804
|
# Wide format (e.g., for heatmap and pairplot)
|
@@ -4096,25 +4290,45 @@ def venn(
|
|
4096
4290
|
"""
|
4097
4291
|
if ax is None:
|
4098
4292
|
ax = plt.gca()
|
4293
|
+
if isinstance(lists, dict):
|
4294
|
+
labels = list(lists.keys())
|
4295
|
+
lists = list(lists.values())
|
4296
|
+
if isinstance(lists[0], set):
|
4297
|
+
lists = [list(i) for i in lists]
|
4099
4298
|
lists = [set(flatten(i, verbose=False)) for i in lists]
|
4100
4299
|
# Function to apply text styles to labels
|
4101
4300
|
if colors is None:
|
4102
4301
|
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
4103
4302
|
if labels is None:
|
4104
|
-
|
4303
|
+
if len(lists) == 2:
|
4304
|
+
labels = ["set1", "set2"]
|
4305
|
+
elif len(lists) == 3:
|
4306
|
+
labels = ["set1", "set2", "set3"]
|
4307
|
+
elif len(lists) == 4:
|
4308
|
+
labels = ["set1", "set2", "set3","set4"]
|
4309
|
+
elif len(lists) == 5:
|
4310
|
+
labels = ["set1", "set2", "set3","set4","set55"]
|
4311
|
+
elif len(lists) == 6:
|
4312
|
+
labels = ["set1", "set2", "set3","set4","set5","set6"]
|
4313
|
+
elif len(lists) == 7:
|
4314
|
+
labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
|
4105
4315
|
if edgecolor is None:
|
4106
4316
|
edgecolor = colors
|
4107
4317
|
colors = [desaturate_color(color, saturation) for color in colors]
|
4108
|
-
|
4109
|
-
if len(lists) == 2:
|
4318
|
+
universe = len(set.union(*lists))
|
4110
4319
|
|
4111
|
-
|
4112
|
-
|
4113
|
-
|
4114
|
-
|
4115
|
-
|
4116
|
-
|
4117
|
-
|
4320
|
+
# Check colors and auto-calculate overlaps
|
4321
|
+
def get_count_and_percentage(set_count, subset_count):
|
4322
|
+
percent = subset_count / set_count if set_count > 0 else 0
|
4323
|
+
return (
|
4324
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
4325
|
+
if show_percentages
|
4326
|
+
else f"{subset_count}"
|
4327
|
+
)
|
4328
|
+
if fmt is not None:
|
4329
|
+
if not fmt.startswith("{"):
|
4330
|
+
fmt="{:" + fmt + "}"
|
4331
|
+
if len(lists) == 2:
|
4118
4332
|
|
4119
4333
|
from matplotlib_venn import venn2, venn2_circles
|
4120
4334
|
|
@@ -4127,21 +4341,28 @@ def venn(
|
|
4127
4341
|
set1, set2 = lists[0], lists[1]
|
4128
4342
|
v.get_patch_by_id("10").set_color(colors[0])
|
4129
4343
|
v.get_patch_by_id("01").set_color(colors[1])
|
4130
|
-
|
4131
|
-
|
4132
|
-
|
4344
|
+
try:
|
4345
|
+
v.get_patch_by_id("11").set_color(
|
4346
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
4347
|
+
)
|
4348
|
+
except Exception as e:
|
4349
|
+
print(e)
|
4133
4350
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
4134
4351
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
4135
4352
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
4353
|
+
|
4136
4354
|
v.get_label_by_id("10").set_text(
|
4137
|
-
get_count_and_percentage(
|
4355
|
+
get_count_and_percentage(universe, len(set1 - set2))
|
4138
4356
|
)
|
4139
4357
|
v.get_label_by_id("01").set_text(
|
4140
|
-
get_count_and_percentage(
|
4141
|
-
)
|
4142
|
-
v.get_label_by_id("11").set_text(
|
4143
|
-
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
4358
|
+
get_count_and_percentage(universe, len(set2 - set1))
|
4144
4359
|
)
|
4360
|
+
try:
|
4361
|
+
v.get_label_by_id("11").set_text(
|
4362
|
+
get_count_and_percentage(universe, len(set1 & set2))
|
4363
|
+
)
|
4364
|
+
except Exception as e:
|
4365
|
+
print(e)
|
4145
4366
|
|
4146
4367
|
if not isinstance(linewidth, list):
|
4147
4368
|
linewidth = [linewidth]
|
@@ -4224,16 +4445,14 @@ def venn(
|
|
4224
4445
|
va=va,
|
4225
4446
|
shadow=shadow,
|
4226
4447
|
)
|
4227
|
-
|
4228
|
-
|
4229
|
-
|
4230
|
-
|
4231
|
-
|
4232
|
-
|
4233
|
-
|
4234
|
-
|
4235
|
-
else f"{subset_count}"
|
4236
|
-
)
|
4448
|
+
# Set transparency level
|
4449
|
+
for patch in v.patches:
|
4450
|
+
if patch:
|
4451
|
+
patch.set_alpha(alpha)
|
4452
|
+
if "none" in edgecolor or 0 in linewidth:
|
4453
|
+
patch.set_edgecolor("none")
|
4454
|
+
return ax
|
4455
|
+
elif len(lists) == 3:
|
4237
4456
|
|
4238
4457
|
from matplotlib_venn import venn3, venn3_circles
|
4239
4458
|
|
@@ -4249,36 +4468,34 @@ def venn(
|
|
4249
4468
|
# Draw the venn diagram
|
4250
4469
|
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
4251
4470
|
v.get_patch_by_id("100").set_color(colors[0])
|
4471
|
+
v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
|
4252
4472
|
v.get_patch_by_id("010").set_color(colors[1])
|
4253
|
-
v.
|
4254
|
-
|
4255
|
-
|
4256
|
-
|
4257
|
-
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
|
4271
|
-
|
4272
|
-
|
4273
|
-
|
4274
|
-
|
4275
|
-
|
4276
|
-
|
4277
|
-
|
4278
|
-
|
4279
|
-
v.get_label_by_id("111").set_text(
|
4280
|
-
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
4281
|
-
)
|
4473
|
+
v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
|
4474
|
+
try:
|
4475
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
4476
|
+
v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
|
4477
|
+
except Exception as e:
|
4478
|
+
print(e)
|
4479
|
+
try:
|
4480
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
4481
|
+
v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
|
4482
|
+
except Exception as e:
|
4483
|
+
print(e)
|
4484
|
+
try:
|
4485
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
4486
|
+
v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
|
4487
|
+
except Exception as e:
|
4488
|
+
print(e)
|
4489
|
+
try:
|
4490
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
4491
|
+
v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
|
4492
|
+
except Exception as e:
|
4493
|
+
print(e)
|
4494
|
+
try:
|
4495
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
4496
|
+
v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
|
4497
|
+
except Exception as e:
|
4498
|
+
print(e)
|
4282
4499
|
|
4283
4500
|
# Apply styles to set labels
|
4284
4501
|
for i, text in enumerate(v.set_labels):
|
@@ -4383,16 +4600,34 @@ def venn(
|
|
4383
4600
|
ax.add_patch(ellipse1)
|
4384
4601
|
ax.add_patch(ellipse2)
|
4385
4602
|
ax.add_patch(ellipse3)
|
4603
|
+
# Set transparency level
|
4604
|
+
for patch in v.patches:
|
4605
|
+
if patch:
|
4606
|
+
patch.set_alpha(alpha)
|
4607
|
+
if "none" in edgecolor or 0 in linewidth:
|
4608
|
+
patch.set_edgecolor("none")
|
4609
|
+
return ax
|
4610
|
+
|
4611
|
+
|
4612
|
+
dict_data = {}
|
4613
|
+
for i_list, list_ in enumerate(lists):
|
4614
|
+
dict_data[labels[i_list]]={*list_}
|
4615
|
+
|
4616
|
+
if 3<len(lists)<6:
|
4617
|
+
from venn import venn as vn
|
4618
|
+
|
4619
|
+
legend_loc=kwargs.pop("legend_loc", "upper right")
|
4620
|
+
ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
|
4621
|
+
|
4622
|
+
return ax
|
4386
4623
|
else:
|
4387
|
-
|
4388
|
-
|
4389
|
-
|
4390
|
-
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
patch.set_edgecolor("none")
|
4395
|
-
return ax
|
4624
|
+
from venn import pseudovenn
|
4625
|
+
cmap=kwargs.pop("cmap","plasma")
|
4626
|
+
ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
|
4627
|
+
|
4628
|
+
return ax
|
4629
|
+
|
4630
|
+
|
4396
4631
|
|
4397
4632
|
|
4398
4633
|
#! subplots, support automatic extend new axis
|
@@ -4403,6 +4638,7 @@ def subplot(
|
|
4403
4638
|
sharex=False,
|
4404
4639
|
sharey=False,
|
4405
4640
|
verbose=False,
|
4641
|
+
fig=None,
|
4406
4642
|
**kwargs,
|
4407
4643
|
):
|
4408
4644
|
"""
|
@@ -4431,8 +4667,8 @@ def subplot(
|
|
4431
4667
|
)
|
4432
4668
|
|
4433
4669
|
figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
|
4434
|
-
|
4435
|
-
|
4670
|
+
if fig is None:
|
4671
|
+
fig = plt.figure(figsize=figsize, constrained_layout=True)
|
4436
4672
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4437
4673
|
occupied = set()
|
4438
4674
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
@@ -4496,18 +4732,19 @@ def subplot(
|
|
4496
4732
|
#! radar chart
|
4497
4733
|
def radar(
|
4498
4734
|
data: pd.DataFrame,
|
4499
|
-
|
4735
|
+
columns=None,
|
4500
4736
|
ylim=(0, 100),
|
4501
|
-
|
4737
|
+
facecolor=None,
|
4738
|
+
edgecolor="none",
|
4739
|
+
edge_linewidth=0.5,
|
4502
4740
|
fontsize=10,
|
4503
4741
|
fontcolor="k",
|
4504
4742
|
size=6,
|
4505
4743
|
linewidth=1,
|
4506
4744
|
linestyle="-",
|
4507
|
-
alpha=0.
|
4745
|
+
alpha=0.3,
|
4746
|
+
fmt=".1f",
|
4508
4747
|
marker="o",
|
4509
|
-
edgecolor="none",
|
4510
|
-
edge_linewidth=0.5,
|
4511
4748
|
bg_color="0.8",
|
4512
4749
|
bg_alpha=None,
|
4513
4750
|
grid_interval_ratio=0.2,
|
@@ -4527,19 +4764,34 @@ def radar(
|
|
4527
4764
|
ax=None,
|
4528
4765
|
sp=2,
|
4529
4766
|
verbose=True,
|
4767
|
+
axis=0,
|
4530
4768
|
**kwargs,
|
4531
4769
|
):
|
4532
4770
|
"""
|
4533
4771
|
Example DATA:
|
4534
4772
|
df = pd.DataFrame(
|
4535
|
-
|
4536
|
-
|
4537
|
-
|
4538
|
-
|
4539
|
-
|
4540
|
-
|
4541
|
-
|
4542
|
-
|
4773
|
+
data=[
|
4774
|
+
[80, 90, 60],
|
4775
|
+
[80, 20, 90],
|
4776
|
+
[80, 95, 20],
|
4777
|
+
[80, 95, 20],
|
4778
|
+
[80, 30, 100],
|
4779
|
+
[80, 30, 90],
|
4780
|
+
[80, 80, 50],
|
4781
|
+
],
|
4782
|
+
index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
|
4783
|
+
columns=["Hero", "Warrior", "Wizard"],
|
4784
|
+
)
|
4785
|
+
usage 1:
|
4786
|
+
radar(data=df)
|
4787
|
+
usage 2:
|
4788
|
+
radar(data=df["Wizard"])
|
4789
|
+
usage 3:
|
4790
|
+
radar(data=df, columns="Wizard")
|
4791
|
+
usage 4:
|
4792
|
+
nexttile = subplot(1, 2)
|
4793
|
+
radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
|
4794
|
+
pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
|
4543
4795
|
Parameters:
|
4544
4796
|
- data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
4545
4797
|
- ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
@@ -4556,7 +4808,6 @@ def radar(
|
|
4556
4808
|
- edge_linewidth (int): Line width for the marker edges.
|
4557
4809
|
- bg_color (str): Background color for the radar chart.
|
4558
4810
|
- grid_interval_ratio (float): Determines the intervals for the grid lines as a fraction of the y-limit.
|
4559
|
-
- title (str): The title of the radar chart.
|
4560
4811
|
- cmap (str): The colormap to use if `color` is a list.
|
4561
4812
|
- legend_loc (str): The location of the legend.
|
4562
4813
|
- legend_fontsize (int): Font size for the legend.
|
@@ -4573,22 +4824,22 @@ def radar(
|
|
4573
4824
|
- sp (int): Padding for the ticks from the plot area.
|
4574
4825
|
- **kwargs: Additional arguments for customization.
|
4575
4826
|
"""
|
4576
|
-
if run_once_within() and verbose:
|
4827
|
+
if run_once_within(20,reverse=True) and verbose:
|
4577
4828
|
usage_="""usage:
|
4578
4829
|
radar(
|
4579
4830
|
data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
|
4580
|
-
title="Radar Chart",
|
4581
4831
|
ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
|
4582
|
-
|
4832
|
+
facecolor=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
|
4833
|
+
edgecolor="none",#for the marker edges.
|
4834
|
+
edge_linewidth=0.5,#for the marker edges.
|
4583
4835
|
fontsize=10,# Font size for the angular labels (x-axis).
|
4584
4836
|
fontcolor="k",# Color for the angular labels.
|
4585
4837
|
size=6,#The size of the markers for each data point.
|
4586
4838
|
linewidth=1,
|
4587
4839
|
linestyle="-",
|
4588
4840
|
alpha=0.5,#for the filled area.
|
4841
|
+
fmt=".1f",
|
4589
4842
|
marker="o",# for the data points.
|
4590
|
-
edgecolor="none",#for the marker edges.
|
4591
|
-
edge_linewidth=0.5,#for the marker edges.
|
4592
4843
|
bg_color="0.8",
|
4593
4844
|
bg_alpha=None,
|
4594
4845
|
grid_interval_ratio=0.2,#Determines the intervals for the grid lines as a fraction of the y-limit.
|
@@ -4618,9 +4869,25 @@ def radar(
|
|
4618
4869
|
kws_figsets = v_arg
|
4619
4870
|
kwargs.pop(k_arg, None)
|
4620
4871
|
break
|
4621
|
-
|
4872
|
+
if axis==1:
|
4873
|
+
data=data.T
|
4874
|
+
if isinstance(data, dict):
|
4875
|
+
data = pd.DataFrame(pd.Series(data))
|
4876
|
+
if ~isinstance(data, pd.DataFrame):
|
4877
|
+
data=pd.DataFrame(data)
|
4878
|
+
if isinstance(data, pd.DataFrame):
|
4879
|
+
data=data.select_dtypes(include=np.number)
|
4880
|
+
if isinstance(columns,str):
|
4881
|
+
columns=[columns]
|
4882
|
+
if columns is None:
|
4883
|
+
columns = list(data.columns)
|
4884
|
+
data=data[columns]
|
4885
|
+
categories = list(data.index)
|
4622
4886
|
num_vars = len(categories)
|
4623
4887
|
|
4888
|
+
# Set y-axis limits and grid intervals
|
4889
|
+
vmin, vmax = ylim
|
4890
|
+
|
4624
4891
|
# Set up angle for each category on radar chart
|
4625
4892
|
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
4626
4893
|
angles += angles[:1] # Complete the loop to ensure straight-line connections
|
@@ -4644,9 +4911,6 @@ def radar(
|
|
4644
4911
|
# Draw one axis per variable and add labels
|
4645
4912
|
ax.set_xticks(angles[:-1])
|
4646
4913
|
ax.set_xticklabels(categories)
|
4647
|
-
|
4648
|
-
# Set y-axis limits and grid intervals
|
4649
|
-
vmin, vmax = ylim
|
4650
4914
|
if circular:
|
4651
4915
|
# * cicular style
|
4652
4916
|
ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
|
@@ -4669,7 +4933,7 @@ def radar(
|
|
4669
4933
|
else:
|
4670
4934
|
# * spider style: spider-style grid (straight lines, not circles)
|
4671
4935
|
# Create the spider-style grid (straight lines, not circles)
|
4672
|
-
for i in range(1, int(vmax * grid_interval_ratio) + 1):
|
4936
|
+
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
|
4673
4937
|
ax.plot(
|
4674
4938
|
angles + [angles[0]], # Closing the loop
|
4675
4939
|
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
@@ -4680,7 +4944,7 @@ def radar(
|
|
4680
4944
|
linewidth=grid_linewidth,
|
4681
4945
|
)
|
4682
4946
|
# set bg_color
|
4683
|
-
ax.fill(angles, [vmax] * (data.shape[
|
4947
|
+
ax.fill(angles, [vmax] * (data.shape[0] + 1), color=bg_color, alpha=bg_alpha)
|
4684
4948
|
ax.yaxis.grid(False)
|
4685
4949
|
# Move radial labels away from plotted line
|
4686
4950
|
if tick_loc is None:
|
@@ -4702,14 +4966,20 @@ def radar(
|
|
4702
4966
|
ax.tick_params(axis="x", pad=sp) # move spines outward
|
4703
4967
|
ax.tick_params(axis="y", pad=sp) # move spines outward
|
4704
4968
|
# colors
|
4705
|
-
|
4706
|
-
|
4707
|
-
|
4708
|
-
|
4709
|
-
|
4969
|
+
if facecolor is not None:
|
4970
|
+
if not isinstance(facecolor,list):
|
4971
|
+
facecolor=[facecolor]
|
4972
|
+
colors = facecolor
|
4973
|
+
else:
|
4974
|
+
colors = (
|
4975
|
+
get_color(data.shape[1])
|
4976
|
+
if cmap is None
|
4977
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
|
4978
|
+
)
|
4979
|
+
|
4710
4980
|
# Plot each row with straight lines
|
4711
|
-
for i, (
|
4712
|
-
values =
|
4981
|
+
for i, (col, val) in enumerate(data.items()):
|
4982
|
+
values = val.tolist()
|
4713
4983
|
values += values[:1] # Close the loop
|
4714
4984
|
ax.plot(
|
4715
4985
|
angles,
|
@@ -4717,7 +4987,7 @@ def radar(
|
|
4717
4987
|
color=colors[i],
|
4718
4988
|
linewidth=linewidth,
|
4719
4989
|
linestyle=linestyle,
|
4720
|
-
label=
|
4990
|
+
label=col,
|
4721
4991
|
clip_on=False,
|
4722
4992
|
)
|
4723
4993
|
ax.fill(angles, values, color=colors[i], alpha=alpha)
|
@@ -4748,7 +5018,7 @@ def radar(
|
|
4748
5018
|
ax.text(
|
4749
5019
|
angle,
|
4750
5020
|
offset_radius,
|
4751
|
-
|
5021
|
+
f"{value:{fmt}}",
|
4752
5022
|
ha="center",
|
4753
5023
|
va="center",
|
4754
5024
|
fontsize=fontsize,
|
@@ -4759,10 +5029,10 @@ def radar(
|
|
4759
5029
|
|
4760
5030
|
ax.set_ylim(ylim)
|
4761
5031
|
# Add markers for each data point
|
4762
|
-
for i,
|
5032
|
+
for i, (col, val) in enumerate(data.items()):
|
4763
5033
|
ax.plot(
|
4764
5034
|
angles,
|
4765
|
-
list(
|
5035
|
+
list(val) + [val[0]], # Close the loop for markers
|
4766
5036
|
color=colors[i],
|
4767
5037
|
marker=marker,
|
4768
5038
|
markersize=size,
|
@@ -4787,3 +5057,819 @@ def radar(
|
|
4787
5057
|
**kws_figsets,
|
4788
5058
|
)
|
4789
5059
|
return ax
|
5060
|
+
|
5061
|
+
|
5062
|
+
def pie(
|
5063
|
+
data:pd.Series,
|
5064
|
+
columns:list = None,
|
5065
|
+
facecolor=None,
|
5066
|
+
explode=[0.1],
|
5067
|
+
startangle=90,
|
5068
|
+
shadow=True,
|
5069
|
+
fontcolor="k",
|
5070
|
+
fmt=".2f",
|
5071
|
+
width=None,# the center blank
|
5072
|
+
pctdistance=0.85,
|
5073
|
+
labeldistance=1.1,
|
5074
|
+
kws_wedge={},
|
5075
|
+
kws_text={},
|
5076
|
+
kws_arrow={},
|
5077
|
+
center=(0, 0),
|
5078
|
+
radius=1,
|
5079
|
+
frame=False,
|
5080
|
+
fontsize=10,
|
5081
|
+
edgecolor="white",
|
5082
|
+
edgewidth=1,
|
5083
|
+
cmap=None,
|
5084
|
+
show_value=False,
|
5085
|
+
show_label=True,# False: only show the outer layer, if it is None, not show
|
5086
|
+
expand_label=(1.2,1.2),
|
5087
|
+
kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
|
5088
|
+
show_legend=True,
|
5089
|
+
legend_loc="upper right",
|
5090
|
+
bbox_to_anchor=[1.4, 1.1],
|
5091
|
+
legend_fontsize=10,
|
5092
|
+
rotation_correction=0,
|
5093
|
+
verbose=True,
|
5094
|
+
ax=None,
|
5095
|
+
**kwargs
|
5096
|
+
):
|
5097
|
+
from adjustText import adjust_text
|
5098
|
+
if run_once_within(20,reverse=True) and verbose:
|
5099
|
+
usage_="""usage:
|
5100
|
+
pie(
|
5101
|
+
data:pd.Series,
|
5102
|
+
columns:list = None,
|
5103
|
+
facecolor=None,
|
5104
|
+
explode=[0.1],
|
5105
|
+
startangle=90,
|
5106
|
+
shadow=True,
|
5107
|
+
fontcolor="k",
|
5108
|
+
fmt=".2f",
|
5109
|
+
width=None,# the center blank
|
5110
|
+
pctdistance=0.85,
|
5111
|
+
labeldistance=1.1,
|
5112
|
+
kws_wedge={},
|
5113
|
+
kws_text={},
|
5114
|
+
center=(0, 0),
|
5115
|
+
radius=1,
|
5116
|
+
frame=False,
|
5117
|
+
fontsize=10,
|
5118
|
+
edgecolor="white",
|
5119
|
+
edgewidth=1,
|
5120
|
+
cmap=None,
|
5121
|
+
show_value=False,
|
5122
|
+
show_label=True,# False: only show the outer layer, if it is None, not show
|
5123
|
+
show_legend=True,
|
5124
|
+
legend_loc="upper right",
|
5125
|
+
bbox_to_anchor=[1.4, 1.1],
|
5126
|
+
legend_fontsize=10,
|
5127
|
+
rotation_correction=0,
|
5128
|
+
verbose=True,
|
5129
|
+
ax=None,
|
5130
|
+
**kwargs
|
5131
|
+
)
|
5132
|
+
|
5133
|
+
usage 1:
|
5134
|
+
data = {"Segment A": 30, "Segment B": 50, "Segment C": 20}
|
5135
|
+
|
5136
|
+
ax = pie(
|
5137
|
+
data=data,
|
5138
|
+
# columns="Segment A",
|
5139
|
+
explode=[0, 0.2, 0],
|
5140
|
+
# width=0.4,
|
5141
|
+
show_label=False,
|
5142
|
+
fontsize=10,
|
5143
|
+
# show_value=1,
|
5144
|
+
fmt=".3f",
|
5145
|
+
)
|
5146
|
+
|
5147
|
+
# prepare dataset
|
5148
|
+
df = pd.DataFrame(
|
5149
|
+
data=[
|
5150
|
+
[80, 90, 60],
|
5151
|
+
[80, 20, 90],
|
5152
|
+
[80, 95, 20],
|
5153
|
+
[80, 95, 20],
|
5154
|
+
[80, 30, 100],
|
5155
|
+
[80, 30, 90],
|
5156
|
+
[80, 80, 50],
|
5157
|
+
],
|
5158
|
+
index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
|
5159
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5160
|
+
)
|
5161
|
+
usage 1: only plot one column
|
5162
|
+
pie(
|
5163
|
+
df,
|
5164
|
+
columns="Wizard",
|
5165
|
+
width=0.6,
|
5166
|
+
show_label=False,
|
5167
|
+
fmt=".0f",
|
5168
|
+
)
|
5169
|
+
usage 2:
|
5170
|
+
pie(df,columns=["Hero", "Warrior"],show_label=False)
|
5171
|
+
usage 3: set different width
|
5172
|
+
pie(df,
|
5173
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5174
|
+
width=[0.3, 0.2, 0.2],
|
5175
|
+
show_label=False,
|
5176
|
+
fmt=".0f",
|
5177
|
+
)
|
5178
|
+
usage 4: set width the same for all columns
|
5179
|
+
pie(df,
|
5180
|
+
columns=["Hero", "Warrior", "Wizard"],
|
5181
|
+
width=0.2,
|
5182
|
+
show_label=False,
|
5183
|
+
fmt=".0f",
|
5184
|
+
)
|
5185
|
+
usage 5: adjust the labels' offset
|
5186
|
+
pie(df, columns="Wizard", width=0.6, show_label=False, fmt=".6f", labeldistance=1.2)
|
5187
|
+
|
5188
|
+
usage 6:
|
5189
|
+
nexttile = subplot(1, 2)
|
5190
|
+
radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
|
5191
|
+
pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
|
5192
|
+
"""
|
5193
|
+
print(usage_)
|
5194
|
+
# Convert data to a Pandas Series if needed
|
5195
|
+
if isinstance(data, dict):
|
5196
|
+
data = pd.DataFrame(pd.Series(data))
|
5197
|
+
if ~isinstance(data, pd.DataFrame):
|
5198
|
+
data=pd.DataFrame(data)
|
5199
|
+
|
5200
|
+
if isinstance(data, pd.DataFrame):
|
5201
|
+
data=data.select_dtypes(include=np.number)
|
5202
|
+
if isinstance(columns,str):
|
5203
|
+
columns=[columns]
|
5204
|
+
if columns is None:
|
5205
|
+
columns = list(data.columns)
|
5206
|
+
# data=data[columns]
|
5207
|
+
# columns = list(data.columns)
|
5208
|
+
# print(columns)
|
5209
|
+
# 选择部分数据
|
5210
|
+
df=data[columns]
|
5211
|
+
|
5212
|
+
if not isinstance(explode, list):
|
5213
|
+
explode=[explode]
|
5214
|
+
if explode == [None]:
|
5215
|
+
explode=[0]
|
5216
|
+
|
5217
|
+
if width is None:
|
5218
|
+
if df.shape[1]>1:
|
5219
|
+
width=1/(df.shape[1]+2)
|
5220
|
+
else:
|
5221
|
+
width=1
|
5222
|
+
if isinstance(width,(float,int)):
|
5223
|
+
width=[width]
|
5224
|
+
if len(width)<df.shape[1]:
|
5225
|
+
width=width*df.shape[1]
|
5226
|
+
if isinstance(radius,(float,int)):
|
5227
|
+
radius=[radius]
|
5228
|
+
radius_tile=[1]*df.shape[1]
|
5229
|
+
radius=radius_tile.copy()
|
5230
|
+
for i in range(1,df.shape[1]):
|
5231
|
+
radius[i]=radius_tile[i]-np.sum(width[:i])
|
5232
|
+
|
5233
|
+
# colors
|
5234
|
+
if facecolor is not None:
|
5235
|
+
if not isinstance(facecolor,list):
|
5236
|
+
facecolor=[facecolor]
|
5237
|
+
colors = facecolor
|
5238
|
+
else:
|
5239
|
+
colors = (
|
5240
|
+
get_color(data.shape[0])
|
5241
|
+
if cmap is None
|
5242
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
|
5243
|
+
)
|
5244
|
+
# to check if facecolor is nested list or not
|
5245
|
+
is_nested = True if any(isinstance(i, list) for i in colors) else False
|
5246
|
+
inested = 0
|
5247
|
+
for column_,width_,radius_ in zip(columns, width,radius):
|
5248
|
+
if column_!=columns[0]:
|
5249
|
+
labels = data.index if show_label else None
|
5250
|
+
else:
|
5251
|
+
labels = data.index if show_label is not None else None
|
5252
|
+
data = df[column_]
|
5253
|
+
labels_legend=data.index
|
5254
|
+
sizes = data.values
|
5255
|
+
|
5256
|
+
# Set wedge and text properties if none are provided
|
5257
|
+
kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
|
5258
|
+
kws_wedge.update({"width":width_})
|
5259
|
+
fontcolor=kws_text.get("color",fontcolor)
|
5260
|
+
fontsize=kws_text.get("fontsize",fontsize)
|
5261
|
+
kws_text.update({"color": fontcolor, "fontsize": fontsize})
|
5262
|
+
|
5263
|
+
if ax is None:
|
5264
|
+
ax=plt.gca()
|
5265
|
+
if len(explode)<len(labels_legend):
|
5266
|
+
explode.extend([0]*(len(labels_legend)-len(explode)))
|
5267
|
+
print(explode)
|
5268
|
+
if fmt:
|
5269
|
+
if not fmt.startswith("%"):
|
5270
|
+
autopct =f"%{fmt}%%"
|
5271
|
+
else:
|
5272
|
+
autopct=None
|
5273
|
+
|
5274
|
+
if show_value is None:
|
5275
|
+
result = ax.pie(
|
5276
|
+
sizes,
|
5277
|
+
labels=labels,
|
5278
|
+
autopct= None,
|
5279
|
+
startangle=startangle + rotation_correction,
|
5280
|
+
explode=explode,
|
5281
|
+
colors=colors[inested] if is_nested else colors,
|
5282
|
+
shadow=shadow,
|
5283
|
+
pctdistance=pctdistance,
|
5284
|
+
labeldistance=labeldistance,
|
5285
|
+
wedgeprops=kws_wedge,
|
5286
|
+
textprops=kws_text,
|
5287
|
+
center=center,
|
5288
|
+
radius=radius_,
|
5289
|
+
frame=frame,
|
5290
|
+
**kwargs
|
5291
|
+
)
|
5292
|
+
else:
|
5293
|
+
result = ax.pie(
|
5294
|
+
sizes,
|
5295
|
+
labels=labels,
|
5296
|
+
autopct=autopct if autopct else None,
|
5297
|
+
startangle=startangle + rotation_correction,
|
5298
|
+
explode=explode,
|
5299
|
+
colors=colors[inested] if is_nested else colors,
|
5300
|
+
shadow=shadow,#shadow,
|
5301
|
+
pctdistance=pctdistance,
|
5302
|
+
labeldistance=labeldistance,
|
5303
|
+
wedgeprops=kws_wedge,
|
5304
|
+
textprops=kws_text,
|
5305
|
+
center=center,
|
5306
|
+
radius=radius_,
|
5307
|
+
frame=frame,
|
5308
|
+
**kwargs
|
5309
|
+
)
|
5310
|
+
if len(result) == 3:
|
5311
|
+
wedges, texts, autotexts = result
|
5312
|
+
elif len(result) == 2:
|
5313
|
+
wedges, texts = result
|
5314
|
+
autotexts = None
|
5315
|
+
#! adjust_text
|
5316
|
+
if autotexts or texts:
|
5317
|
+
all_texts = []
|
5318
|
+
if autotexts and show_value:
|
5319
|
+
all_texts.extend(autotexts)
|
5320
|
+
if texts and show_label:
|
5321
|
+
all_texts.extend(texts)
|
5322
|
+
|
5323
|
+
adjust_text(
|
5324
|
+
all_texts,
|
5325
|
+
ax=ax,
|
5326
|
+
arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
|
5327
|
+
bbox=kws_bbox if kws_bbox else None,
|
5328
|
+
expand=expand_label,
|
5329
|
+
fontdict={
|
5330
|
+
"fontsize": fontsize,
|
5331
|
+
"color": fontcolor,
|
5332
|
+
},
|
5333
|
+
)
|
5334
|
+
# Show exact values on wedges if show_value is True
|
5335
|
+
if show_value:
|
5336
|
+
for i, (wedge, txt) in enumerate(zip(wedges, texts)):
|
5337
|
+
angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
|
5338
|
+
x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
|
5339
|
+
y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
|
5340
|
+
if not fmt.startswith("{"):
|
5341
|
+
value_text = f"{sizes[i]:{fmt}}"
|
5342
|
+
else:
|
5343
|
+
value_text = fmt.format(sizes[i])
|
5344
|
+
ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
|
5345
|
+
inested+=1
|
5346
|
+
# Customize the legend
|
5347
|
+
if show_legend:
|
5348
|
+
ax.legend(
|
5349
|
+
wedges,
|
5350
|
+
labels_legend,
|
5351
|
+
loc=legend_loc,
|
5352
|
+
bbox_to_anchor=bbox_to_anchor,
|
5353
|
+
fontsize=legend_fontsize,
|
5354
|
+
title_fontsize=legend_fontsize,
|
5355
|
+
)
|
5356
|
+
ax.set(aspect="equal")
|
5357
|
+
return ax
|
5358
|
+
|
5359
|
+
def ellipse(
|
5360
|
+
data,
|
5361
|
+
x=None,
|
5362
|
+
y=None,
|
5363
|
+
hue=None,
|
5364
|
+
n_std=1.5,
|
5365
|
+
ax=None,
|
5366
|
+
confidence=0.95,
|
5367
|
+
annotate_center=False,
|
5368
|
+
palette=None,
|
5369
|
+
facecolor=None,
|
5370
|
+
edgecolor=None,
|
5371
|
+
label:bool=True,
|
5372
|
+
**kwargs,
|
5373
|
+
):
|
5374
|
+
"""
|
5375
|
+
Plot advanced ellipses representing covariance for different groups
|
5376
|
+
# simulate data:
|
5377
|
+
control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
|
5378
|
+
patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
|
5379
|
+
df = pd.DataFrame(
|
5380
|
+
{
|
5381
|
+
"Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
|
5382
|
+
"Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
|
5383
|
+
"Group": ["Control"] * 50 + ["Patient"] * 50,
|
5384
|
+
}
|
5385
|
+
)
|
5386
|
+
plotxy(
|
5387
|
+
data=df,
|
5388
|
+
x="Dim1",
|
5389
|
+
y="Dim2",
|
5390
|
+
hue="Group",
|
5391
|
+
kind_="scatter",
|
5392
|
+
palette=get_color(8),
|
5393
|
+
)
|
5394
|
+
ellipse(
|
5395
|
+
data=df,
|
5396
|
+
x="Dim1",
|
5397
|
+
y="Dim2",
|
5398
|
+
hue="Group",
|
5399
|
+
palette=get_color(8),
|
5400
|
+
alpha=0.1,
|
5401
|
+
lw=2,
|
5402
|
+
)
|
5403
|
+
Parameters:
|
5404
|
+
data (DataFrame): Input DataFrame with columns for x, y, and hue.
|
5405
|
+
x (str): Column name for x-axis values.
|
5406
|
+
y (str): Column name for y-axis values.
|
5407
|
+
hue (str, optional): Column name for group labels.
|
5408
|
+
n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
|
5409
|
+
ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
|
5410
|
+
confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
|
5411
|
+
annotate_center (bool): Whether to annotate the ellipse center (mean).
|
5412
|
+
palette (dict or list, optional): A mapping of hues to colors or a list of colors.
|
5413
|
+
**kwargs: Additional keyword arguments for the Ellipse patch.
|
5414
|
+
|
5415
|
+
Returns:
|
5416
|
+
list: List of Ellipse objects added to the Axes.
|
5417
|
+
"""
|
5418
|
+
from matplotlib.patches import Ellipse
|
5419
|
+
import numpy as np
|
5420
|
+
import matplotlib.pyplot as plt
|
5421
|
+
import seaborn as sns
|
5422
|
+
import pandas as pd
|
5423
|
+
from scipy.stats import chi2
|
5424
|
+
|
5425
|
+
if ax is None:
|
5426
|
+
ax = plt.gca()
|
5427
|
+
|
5428
|
+
# Validate inputs
|
5429
|
+
if x is None or y is None:
|
5430
|
+
raise ValueError(
|
5431
|
+
"Both `x` and `y` must be specified as column names in the DataFrame."
|
5432
|
+
)
|
5433
|
+
if not isinstance(data, pd.DataFrame):
|
5434
|
+
raise ValueError("`data` must be a pandas DataFrame.")
|
5435
|
+
|
5436
|
+
# Prepare data for hue-based grouping
|
5437
|
+
ellipses = []
|
5438
|
+
if hue is not None:
|
5439
|
+
groups = data[hue].unique()
|
5440
|
+
colors = sns.color_palette(palette or "husl", len(groups))
|
5441
|
+
color_map = dict(zip(groups, colors))
|
5442
|
+
else:
|
5443
|
+
groups = [None]
|
5444
|
+
color_map = {None: kwargs.get("edgecolor", "blue")}
|
5445
|
+
alpha = kwargs.pop("alpha", 0.2)
|
5446
|
+
edgecolor=kwargs.pop("edgecolor", None)
|
5447
|
+
facecolor=kwargs.pop("facecolor", None)
|
5448
|
+
for group in groups:
|
5449
|
+
group_data = data[data[hue] == group] if hue else data
|
5450
|
+
|
5451
|
+
# Extract x and y columns for the group
|
5452
|
+
group_points = group_data[[x, y]].values
|
5453
|
+
|
5454
|
+
# Compute mean and covariance matrix
|
5455
|
+
# # 标准化处理
|
5456
|
+
# group_points = group_data[[x, y]].values
|
5457
|
+
# group_points -= group_points.mean(axis=0)
|
5458
|
+
# group_points /= group_points.std(axis=0)
|
5459
|
+
|
5460
|
+
cov = np.cov(group_points.T)
|
5461
|
+
mean = np.mean(group_points, axis=0)
|
5462
|
+
|
5463
|
+
# Eigenvalues and eigenvectors
|
5464
|
+
eigvals, eigvecs = np.linalg.eigh(cov)
|
5465
|
+
order = eigvals.argsort()[::-1]
|
5466
|
+
eigvals, eigvecs = eigvals[order], eigvecs[:, order]
|
5467
|
+
|
5468
|
+
# Rotation angle and ellipse dimensions
|
5469
|
+
angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
|
5470
|
+
if confidence:
|
5471
|
+
n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
|
5472
|
+
width, height = 2 * n_std * np.sqrt(eigvals)
|
5473
|
+
|
5474
|
+
# Create and style the ellipse
|
5475
|
+
if facecolor is None:
|
5476
|
+
facecolor_ = color_map[group]
|
5477
|
+
if edgecolor is None:
|
5478
|
+
edgecolor_ = color_map[group]
|
5479
|
+
ellipse = Ellipse(
|
5480
|
+
xy=mean,
|
5481
|
+
width=width,
|
5482
|
+
height=height,
|
5483
|
+
angle=angle,
|
5484
|
+
edgecolor=edgecolor_,
|
5485
|
+
facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
|
5486
|
+
# alpha=alpha,
|
5487
|
+
label=group if (hue and label) else None,
|
5488
|
+
**kwargs,
|
5489
|
+
)
|
5490
|
+
ax.add_patch(ellipse)
|
5491
|
+
ellipses.append(ellipse)
|
5492
|
+
|
5493
|
+
# Annotate center
|
5494
|
+
if annotate_center:
|
5495
|
+
ax.annotate(
|
5496
|
+
f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
|
5497
|
+
xy=mean,
|
5498
|
+
xycoords="data",
|
5499
|
+
fontsize=10,
|
5500
|
+
ha="center",
|
5501
|
+
color=ellipse_color,
|
5502
|
+
bbox=dict(
|
5503
|
+
boxstyle="round,pad=0.3",
|
5504
|
+
edgecolor="gray",
|
5505
|
+
facecolor="white",
|
5506
|
+
alpha=0.8,
|
5507
|
+
),
|
5508
|
+
)
|
5509
|
+
|
5510
|
+
return ax
|
5511
|
+
|
5512
|
+
def ppi(
|
5513
|
+
interactions,
|
5514
|
+
player1="preferredName_A",
|
5515
|
+
player2="preferredName_B",
|
5516
|
+
weight="score",
|
5517
|
+
n_layers=None, # Number of concentric layers
|
5518
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5519
|
+
dist_node = 10, # Distance between each rank of circles
|
5520
|
+
layout="degree",
|
5521
|
+
size=None,#700,
|
5522
|
+
sizes=(50,500),# min and max of size
|
5523
|
+
facecolor="skyblue",
|
5524
|
+
cmap='coolwarm',
|
5525
|
+
edgecolor="k",
|
5526
|
+
edgelinewidth=1.5,
|
5527
|
+
alpha=.5,
|
5528
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5529
|
+
marker="o",
|
5530
|
+
node_hideticks=True,
|
5531
|
+
linecolor="gray",
|
5532
|
+
line_cmap='coolwarm',
|
5533
|
+
linewidth=1.5,
|
5534
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5535
|
+
linealpha=1.0,
|
5536
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5537
|
+
linestyle="-",
|
5538
|
+
line_arrowstyle='-',
|
5539
|
+
fontsize=10,
|
5540
|
+
fontcolor="k",
|
5541
|
+
ha:str="center",
|
5542
|
+
va:str="center",
|
5543
|
+
figsize=(12, 10),
|
5544
|
+
k_value=0.3,
|
5545
|
+
bgcolor="w",
|
5546
|
+
dir_save="./ppi_network.html",
|
5547
|
+
physics=True,
|
5548
|
+
notebook=False,
|
5549
|
+
scale=1,
|
5550
|
+
ax=None,
|
5551
|
+
**kwargs
|
5552
|
+
):
|
5553
|
+
"""
|
5554
|
+
Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
|
5555
|
+
|
5556
|
+
ppi(
|
5557
|
+
interactions_sort.iloc[:1000, :],
|
5558
|
+
player1="player1",
|
5559
|
+
player2="player2",
|
5560
|
+
weight="count",
|
5561
|
+
layout="spring",
|
5562
|
+
n_layers=13,
|
5563
|
+
fontsize=1,
|
5564
|
+
n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
|
5565
|
+
)
|
5566
|
+
"""
|
5567
|
+
from pyvis.network import Network
|
5568
|
+
import networkx as nx
|
5569
|
+
from IPython.display import IFrame
|
5570
|
+
from matplotlib.colors import Normalize
|
5571
|
+
from matplotlib import cm
|
5572
|
+
from . import ips
|
5573
|
+
|
5574
|
+
if run_once_within():
|
5575
|
+
usage_str="""
|
5576
|
+
ppi(
|
5577
|
+
interactions,
|
5578
|
+
player1="preferredName_A",
|
5579
|
+
player2="preferredName_B",
|
5580
|
+
weight="score",
|
5581
|
+
n_layers=None, # Number of concentric layers
|
5582
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5583
|
+
dist_node = 10, # Distance between each rank of circles
|
5584
|
+
layout="degree",
|
5585
|
+
size=None,#700,
|
5586
|
+
sizes=(50,500),# min and max of size
|
5587
|
+
facecolor="skyblue",
|
5588
|
+
cmap='coolwarm',
|
5589
|
+
edgecolor="k",
|
5590
|
+
edgelinewidth=1.5,
|
5591
|
+
alpha=.5,
|
5592
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5593
|
+
marker="o",
|
5594
|
+
node_hideticks=True,
|
5595
|
+
linecolor="gray",
|
5596
|
+
line_cmap='coolwarm',
|
5597
|
+
linewidth=1.5,
|
5598
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5599
|
+
linealpha=1.0,
|
5600
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5601
|
+
linestyle="-",
|
5602
|
+
line_arrowstyle='-',
|
5603
|
+
fontsize=10,
|
5604
|
+
fontcolor="k",
|
5605
|
+
ha:str="center",
|
5606
|
+
va:str="center",
|
5607
|
+
figsize=(12, 10),
|
5608
|
+
k_value=0.3,
|
5609
|
+
bgcolor="w",
|
5610
|
+
dir_save="./ppi_network.html",
|
5611
|
+
physics=True,
|
5612
|
+
notebook=False,
|
5613
|
+
scale=1,
|
5614
|
+
ax=None,
|
5615
|
+
**kwargs
|
5616
|
+
):
|
5617
|
+
"""
|
5618
|
+
print(usage_str)
|
5619
|
+
|
5620
|
+
# Check for required columns in the DataFrame
|
5621
|
+
for col in [player1, player2, weight]:
|
5622
|
+
if col not in interactions.columns:
|
5623
|
+
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
5624
|
+
interactions.sort_values(by=[weight], inplace=True)
|
5625
|
+
# Initialize Pyvis network
|
5626
|
+
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
5627
|
+
net.force_atlas_2based(
|
5628
|
+
gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
|
5629
|
+
)
|
5630
|
+
net.toggle_physics(physics)
|
5631
|
+
|
5632
|
+
kws_figsets = {}
|
5633
|
+
for k_arg, v_arg in kwargs.items():
|
5634
|
+
if "figset" in k_arg:
|
5635
|
+
kws_figsets = v_arg
|
5636
|
+
kwargs.pop(k_arg, None)
|
5637
|
+
break
|
5638
|
+
|
5639
|
+
# Create a NetworkX graph from the interaction data
|
5640
|
+
G = nx.Graph()
|
5641
|
+
for _, row in interactions.iterrows():
|
5642
|
+
G.add_edge(row[player1], row[player2], weight=row[weight])
|
5643
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
5644
|
+
|
5645
|
+
|
5646
|
+
# Calculate node degrees
|
5647
|
+
degrees = dict(G.degree())
|
5648
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5649
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
5650
|
+
|
5651
|
+
if not ips.isa(facecolor, 'color'):
|
5652
|
+
print("facecolor: based on degrees")
|
5653
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
5654
|
+
num_nodes = G.number_of_nodes()
|
5655
|
+
#* size
|
5656
|
+
# Set properties based on degrees
|
5657
|
+
if not isinstance(size, (int,float,list)):
|
5658
|
+
print("size: based on degrees")
|
5659
|
+
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
5660
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
5661
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
5662
|
+
# Normalize sizes
|
5663
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
5664
|
+
min_degree, max_degree = min(size), max(size)
|
5665
|
+
if max_degree > min_degree: # Avoid division by zero
|
5666
|
+
size = [
|
5667
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
5668
|
+
for sz in size
|
5669
|
+
]
|
5670
|
+
else:
|
5671
|
+
# If all values are the same, set them to a default of the midpoint
|
5672
|
+
size = [(min_size + max_size) / 2] * len(size)
|
5673
|
+
|
5674
|
+
#* facecolor
|
5675
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
5676
|
+
# * facealpha
|
5677
|
+
if isinstance(alpha, list):
|
5678
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
5679
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
5680
|
+
if len(alpha) > 0:
|
5681
|
+
# Normalize alpha based on the specified min and max
|
5682
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
5683
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
5684
|
+
alpha = [
|
5685
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
5686
|
+
for ea in alpha
|
5687
|
+
]
|
5688
|
+
else:
|
5689
|
+
# If all alpha values are the same, set them to the average of min and max
|
5690
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
5691
|
+
else:
|
5692
|
+
# Default to a full opacity if no edges are provided
|
5693
|
+
alpha = [1.0] * num_nodes
|
5694
|
+
else:
|
5695
|
+
# If alpha is a single value, convert it to a list and normalize it
|
5696
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
5697
|
+
|
5698
|
+
for i, node in enumerate(G.nodes()):
|
5699
|
+
net.add_node(
|
5700
|
+
node,
|
5701
|
+
label=node,
|
5702
|
+
size=size[i],
|
5703
|
+
color=facecolor[i],
|
5704
|
+
alpha=alpha[i],
|
5705
|
+
font={"size": fontsize, "color": fontcolor},
|
5706
|
+
)
|
5707
|
+
print(f'nodes number: {i+1}')
|
5708
|
+
|
5709
|
+
for edge in G.edges(data=True):
|
5710
|
+
net.add_edge(
|
5711
|
+
edge[0],
|
5712
|
+
edge[1],
|
5713
|
+
weight=edge[2]["weight"],
|
5714
|
+
color=edgecolor,
|
5715
|
+
width=edgelinewidth * edge[2]["weight"],
|
5716
|
+
)
|
5717
|
+
|
5718
|
+
layouts = [
|
5719
|
+
"spring",
|
5720
|
+
"circular",
|
5721
|
+
"kamada_kawai",
|
5722
|
+
"random",
|
5723
|
+
"shell",
|
5724
|
+
"planar",
|
5725
|
+
"spiral",
|
5726
|
+
"degree"
|
5727
|
+
]
|
5728
|
+
layout = ips.strcmp(layout, layouts)[0]
|
5729
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
5730
|
+
|
5731
|
+
# Choose layout
|
5732
|
+
if layout == "spring":
|
5733
|
+
pos = nx.spring_layout(G, k=k_value)
|
5734
|
+
elif layout == "circular":
|
5735
|
+
pos = nx.circular_layout(G)
|
5736
|
+
elif layout == "kamada_kawai":
|
5737
|
+
pos = nx.kamada_kawai_layout(G)
|
5738
|
+
elif layout == "spectral":
|
5739
|
+
pos = nx.spectral_layout(G)
|
5740
|
+
elif layout == "random":
|
5741
|
+
pos = nx.random_layout(G)
|
5742
|
+
elif layout == "shell":
|
5743
|
+
pos = nx.shell_layout(G)
|
5744
|
+
elif layout == "planar":
|
5745
|
+
if nx.check_planarity(G)[0]:
|
5746
|
+
pos = nx.planar_layout(G)
|
5747
|
+
else:
|
5748
|
+
print("Graph is not planar; switching to spring layout.")
|
5749
|
+
pos = nx.spring_layout(G, k=k_value)
|
5750
|
+
elif layout == "spiral":
|
5751
|
+
pos = nx.spiral_layout(G)
|
5752
|
+
elif layout=='degree':
|
5753
|
+
# Calculate node degrees and sort nodes by degree
|
5754
|
+
degrees = dict(G.degree())
|
5755
|
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
5756
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5757
|
+
colormap = cm.get_cmap(cmap)
|
5758
|
+
|
5759
|
+
# Create positions for concentric circles based on n_layers and n_rank
|
5760
|
+
pos = {}
|
5761
|
+
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
5762
|
+
for rank_index in range(n_layers):
|
5763
|
+
if rank_index < len(n_rank):
|
5764
|
+
nodes_per_rank = n_rank[rank_index]
|
5765
|
+
rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
|
5766
|
+
else:
|
5767
|
+
# 随机打乱剩余节点的顺序
|
5768
|
+
remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
|
5769
|
+
random_indices = np.random.permutation(len(remaining_nodes))
|
5770
|
+
rank_nodes = [remaining_nodes[i] for i in random_indices]
|
5771
|
+
|
5772
|
+
radius = (rank_index + 1) * dist_node # Radius for this rank
|
5773
|
+
|
5774
|
+
# Arrange nodes in a circle for the current rank
|
5775
|
+
for i, (node, degree) in enumerate(rank_nodes):
|
5776
|
+
angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
|
5777
|
+
pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
|
5778
|
+
|
5779
|
+
else:
|
5780
|
+
print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
|
5781
|
+
pos = nx.spring_layout(G, k=k_value)
|
5782
|
+
|
5783
|
+
for node, (x, y) in pos.items():
|
5784
|
+
net.get_node(node)["x"] = x * scale
|
5785
|
+
net.get_node(node)["y"] = y * scale
|
5786
|
+
|
5787
|
+
# If ax is None, use plt.gca()
|
5788
|
+
if ax is None:
|
5789
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
5790
|
+
|
5791
|
+
# Draw nodes, edges, and labels with customization options
|
5792
|
+
nx.draw_networkx_nodes(
|
5793
|
+
G,
|
5794
|
+
pos,
|
5795
|
+
ax=ax,
|
5796
|
+
node_size=size,
|
5797
|
+
node_color=facecolor,
|
5798
|
+
linewidths=edgelinewidth,
|
5799
|
+
edgecolors=edgecolor,
|
5800
|
+
alpha=alpha,
|
5801
|
+
hide_ticks=node_hideticks,
|
5802
|
+
node_shape=marker
|
5803
|
+
)
|
5804
|
+
|
5805
|
+
#* linewidth
|
5806
|
+
if not isinstance(linewidth, list):
|
5807
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5808
|
+
else:
|
5809
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
5810
|
+
# Normalize linewidth if it is a list
|
5811
|
+
if isinstance(linewidth, list):
|
5812
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
5813
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
5814
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
5815
|
+
# Scale between vmin and vmax
|
5816
|
+
linewidth = [
|
5817
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
5818
|
+
for lw in linewidth
|
5819
|
+
]
|
5820
|
+
else:
|
5821
|
+
# If all values are the same, set them to a default of the midpoint
|
5822
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
5823
|
+
else:
|
5824
|
+
# If linewidth is a single value, convert it to a list of that value
|
5825
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5826
|
+
#* linecolor
|
5827
|
+
if not isinstance(linecolor, str):
|
5828
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
5829
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
5830
|
+
colormap = cm.get_cmap(line_cmap)
|
5831
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
5832
|
+
else:
|
5833
|
+
linecolor = [linecolor] * G.number_of_edges()
|
5834
|
+
|
5835
|
+
# * linealpha
|
5836
|
+
if isinstance(linealpha, list):
|
5837
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
5838
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
5839
|
+
if len(linealpha) > 0:
|
5840
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
5841
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
5842
|
+
linealpha = [
|
5843
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
5844
|
+
for ea in linealpha
|
5845
|
+
]
|
5846
|
+
else:
|
5847
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
5848
|
+
else:
|
5849
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
5850
|
+
else:
|
5851
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
5852
|
+
nx.draw_networkx_edges(
|
5853
|
+
G,
|
5854
|
+
pos,
|
5855
|
+
ax=ax,
|
5856
|
+
edge_color=linecolor,
|
5857
|
+
width=linewidth,
|
5858
|
+
style=linestyle,
|
5859
|
+
arrowstyle=line_arrowstyle,
|
5860
|
+
alpha=linealpha
|
5861
|
+
)
|
5862
|
+
|
5863
|
+
nx.draw_networkx_labels(
|
5864
|
+
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
5865
|
+
)
|
5866
|
+
figsets(ax=ax,**kws_figsets)
|
5867
|
+
ax.axis("off")
|
5868
|
+
if dir_save:
|
5869
|
+
if not os.path.basename(dir_save):
|
5870
|
+
dir_save="_.html"
|
5871
|
+
net.write_html(dir_save)
|
5872
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
5873
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
5874
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
5875
|
+
return G,ax
|