py2ls 0.2.4.25__py3-none-any.whl → 0.2.4.26__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- py2ls/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/usages_sns.json +6 -1
- py2ls/ips.py +399 -91
- py2ls/ml2ls.py +758 -186
- py2ls/netfinder.py +16 -20
- py2ls/plot.py +916 -141
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.26.dist-info}/METADATA +5 -1
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.26.dist-info}/RECORD +15 -13
- py2ls/data/usages_pd copy.json +0 -1105
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.26.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -20,7 +20,9 @@ from .ips import (
|
|
20
20
|
flatten,
|
21
21
|
plt_font,
|
22
22
|
run_once_within,
|
23
|
-
df_format
|
23
|
+
df_format,
|
24
|
+
df_corr,
|
25
|
+
df_scaler
|
24
26
|
)
|
25
27
|
import scipy.stats as scipy_stats
|
26
28
|
from .stats import *
|
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
|
142
144
|
**kwargs,
|
143
145
|
)
|
144
146
|
|
147
|
+
def pval2str(p):
|
148
|
+
if p > 0.05:
|
149
|
+
txt = ""
|
150
|
+
elif 0.01 <= p <= 0.05:
|
151
|
+
txt = "*"
|
152
|
+
elif 0.001 <= p < 0.01:
|
153
|
+
txt = "**"
|
154
|
+
elif p < 0.001:
|
155
|
+
txt = "***"
|
156
|
+
return txt
|
145
157
|
|
146
158
|
def heatmap(
|
147
159
|
data,
|
148
160
|
ax=None,
|
149
161
|
kind="corr", #'corr','direct','pivot'
|
162
|
+
method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
|
150
163
|
columns="all", # pivot, default: coll numeric columns
|
164
|
+
style=0,# for correlation
|
151
165
|
index=None, # pivot
|
152
166
|
values=None, # pivot
|
167
|
+
fontsize=10,
|
153
168
|
tri="u",
|
154
169
|
mask=True,
|
155
170
|
k=1,
|
171
|
+
vmin=None,
|
172
|
+
vmax=None,
|
173
|
+
size_scale=500,
|
156
174
|
annot=True,
|
157
175
|
cmap="coolwarm",
|
158
176
|
fmt=".2f",
|
177
|
+
show_indicator = True,# only for style==1
|
159
178
|
cluster=False,
|
160
179
|
inplace=False,
|
161
180
|
figsize=(10, 8),
|
@@ -201,7 +220,7 @@ def heatmap(
|
|
201
220
|
ax = plt.gca()
|
202
221
|
# Select numeric columns or specific subset of columns
|
203
222
|
if columns == "all":
|
204
|
-
df_numeric = data.select_dtypes(include=[
|
223
|
+
df_numeric = data.select_dtypes(include=[np.number])
|
205
224
|
else:
|
206
225
|
df_numeric = data[columns]
|
207
226
|
|
@@ -209,8 +228,10 @@ def heatmap(
|
|
209
228
|
kind = strcmp(kind, kinds)[0]
|
210
229
|
print(kind)
|
211
230
|
if "corr" in kind: # correlation
|
231
|
+
methods = ["pearson", "spearman", "kendall"]
|
232
|
+
method = strcmp(method, methods)[0]
|
212
233
|
# Compute the correlation matrix
|
213
|
-
data4heatmap = df_numeric.corr()
|
234
|
+
data4heatmap = df_numeric.corr(method=method)
|
214
235
|
# Generate mask for the upper triangle if mask is True
|
215
236
|
if mask:
|
216
237
|
if "u" in tri.lower(): # upper => np.tril
|
@@ -288,18 +309,76 @@ def heatmap(
|
|
288
309
|
df_col_cluster,
|
289
310
|
)
|
290
311
|
else:
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
312
|
+
if style==0:
|
313
|
+
# Create a standard heatmap
|
314
|
+
ax = sns.heatmap(
|
315
|
+
data4heatmap,
|
316
|
+
ax=ax,
|
317
|
+
mask=mask_array,
|
318
|
+
annot=annot,
|
319
|
+
cmap=cmap,
|
320
|
+
fmt=fmt,
|
321
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
322
|
+
)
|
323
|
+
return ax
|
324
|
+
elif style==1:
|
325
|
+
if isinstance(cmap, str):
|
326
|
+
cmap = plt.get_cmap(cmap)
|
327
|
+
norm = plt.Normalize(vmin=-1, vmax=1)
|
328
|
+
r_, p_ = df_corr(data4heatmap, method=method)
|
329
|
+
# size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
|
330
|
+
# 初始化一个空的可绘制对象用于颜色条
|
331
|
+
scatter_handles = []
|
332
|
+
# 循环绘制气泡图和数值
|
333
|
+
for i in range(len(r_.columns)):
|
334
|
+
for j in range(len(r_.columns)):
|
335
|
+
if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
|
336
|
+
color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
|
337
|
+
scatter = ax.scatter(
|
338
|
+
i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
|
339
|
+
# alpha=1,edgecolor=edgecolor,linewidth=linewidth,
|
340
|
+
**kwargs
|
341
|
+
)
|
342
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
343
|
+
# add *** indicators
|
344
|
+
if show_indicator:
|
345
|
+
ax.text(
|
346
|
+
i,
|
347
|
+
j,
|
348
|
+
pval2str(p_.iloc[i, j]),
|
349
|
+
ha="center",
|
350
|
+
va="center",
|
351
|
+
color="k",
|
352
|
+
fontsize=fontsize * 1.3,
|
353
|
+
)
|
354
|
+
elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
|
355
|
+
color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
|
356
|
+
ax.text(
|
357
|
+
i,
|
358
|
+
j,
|
359
|
+
f"{r_.iloc[i, j]:{fmt}}",
|
360
|
+
ha="center",
|
361
|
+
va="center",
|
362
|
+
color=color,
|
363
|
+
fontsize=fontsize,
|
364
|
+
)
|
365
|
+
else: # 对角线部分,显示空白
|
366
|
+
ax.scatter(i, j, s=1, color="white")
|
367
|
+
# 设置坐标轴标签
|
368
|
+
figsets(xticks=range(len(r_.columns)),
|
369
|
+
xticklabels=r_.columns,
|
370
|
+
xangle=90,
|
371
|
+
fontsize=fontsize,
|
372
|
+
yticks=range(len(r_.columns)),
|
373
|
+
yticklabels=r_.columns,
|
374
|
+
xlim=[-0.5,len(r_.columns)-0.5],
|
375
|
+
ylim=[-0.5,len(r_.columns)-0.5]
|
376
|
+
)
|
377
|
+
# 添加颜色条
|
378
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
379
|
+
sm.set_array([]) # 仅用于显示颜色条
|
380
|
+
plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
|
381
|
+
return ax
|
303
382
|
elif "dir" in kind: # direct
|
304
383
|
data4heatmap = df_numeric
|
305
384
|
elif "pi" in kind: # pivot
|
@@ -389,16 +468,53 @@ def heatmap(
|
|
389
468
|
)
|
390
469
|
else:
|
391
470
|
# Create a standard heatmap
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
471
|
+
if style==0:
|
472
|
+
ax = sns.heatmap(
|
473
|
+
data4heatmap,
|
474
|
+
ax=ax,
|
475
|
+
annot=annot,
|
476
|
+
cmap=cmap,
|
477
|
+
fmt=fmt,
|
478
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
479
|
+
)
|
480
|
+
# Return the Axes object for further customization if needed
|
481
|
+
return ax
|
482
|
+
elif style==1:
|
483
|
+
if isinstance(cmap, str):
|
484
|
+
cmap = plt.get_cmap(cmap)
|
485
|
+
if vmin is None:
|
486
|
+
vmin=np.min(data4heatmap)
|
487
|
+
if vmax is None:
|
488
|
+
vmax=np.max(data4heatmap)
|
489
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
490
|
+
|
491
|
+
# 初始化一个空的可绘制对象用于颜色条
|
492
|
+
scatter_handles = []
|
493
|
+
# 循环绘制气泡图和数值
|
494
|
+
print(len(data4heatmap.index),len(data4heatmap.columns))
|
495
|
+
for i in range(len(data4heatmap.index)):
|
496
|
+
for j in range(len(data4heatmap.columns)):
|
497
|
+
color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
|
498
|
+
scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
|
499
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
500
|
+
|
501
|
+
# 设置坐标轴标签
|
502
|
+
figsets(xticks=range(len(data4heatmap.columns)),
|
503
|
+
xticklabels=data4heatmap.columns,
|
504
|
+
xangle=90,
|
505
|
+
fontsize=fontsize,
|
506
|
+
yticks=range(len(data4heatmap.index)),
|
507
|
+
yticklabels=data4heatmap.index,
|
508
|
+
xlim=[-0.5,len(data4heatmap.columns)-0.5],
|
509
|
+
ylim=[-0.5,len(data4heatmap.index)-0.5]
|
510
|
+
)
|
511
|
+
# 添加颜色条
|
512
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
513
|
+
sm.set_array([]) # 仅用于显示颜色条
|
514
|
+
plt.colorbar(sm, ax=ax,
|
515
|
+
# label="Correlation Coefficient"
|
516
|
+
)
|
517
|
+
return ax
|
402
518
|
|
403
519
|
|
404
520
|
# !usage: py2ls.plot.heatmap()
|
@@ -3226,18 +3342,37 @@ def plotxy(
|
|
3226
3342
|
# (1) return FcetGrid
|
3227
3343
|
if k == "jointplot":
|
3228
3344
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
3229
|
-
|
3345
|
+
kws_joint = {
|
3346
|
+
k: v for k, v in kws_joint.items() if not k.startswith("kws_")
|
3347
|
+
}
|
3348
|
+
hue = kwargs.get("hue", None)
|
3349
|
+
if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3350
|
+
kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
|
3351
|
+
|
3352
|
+
palette = kwargs.get("palette", None)
|
3353
|
+
if palette is None:
|
3354
|
+
palette = kws_joint.pop(
|
3355
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3356
|
+
)
|
3357
|
+
else:
|
3358
|
+
kws_joint.pop("palette", palette)
|
3359
|
+
stats=kwargs.pop("stats",None)
|
3360
|
+
if stats:
|
3361
|
+
stats=kws_joint.pop("stats",True)
|
3230
3362
|
if stats:
|
3231
3363
|
r, p_value = scipy_stats.pearsonr(data[x], data[y])
|
3232
|
-
|
3233
|
-
|
3234
|
-
|
3235
|
-
|
3236
|
-
|
3237
|
-
|
3238
|
-
|
3239
|
-
|
3240
|
-
|
3364
|
+
for key in ["palette", "alpha", "hue","stats"]:
|
3365
|
+
kws_joint.pop(key, None)
|
3366
|
+
g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
|
3367
|
+
if stats:
|
3368
|
+
g.ax_joint.annotate(
|
3369
|
+
f"pearsonr = {r:.2f} p = {p_value:.3f}",
|
3370
|
+
xy=(0.6, 0.98),
|
3371
|
+
xycoords="axes fraction",
|
3372
|
+
fontsize=12,
|
3373
|
+
color="black",
|
3374
|
+
ha="center",
|
3375
|
+
)
|
3241
3376
|
elif k == "lmplot":
|
3242
3377
|
kws_lm = kwargs.pop("kws_lm", kwargs)
|
3243
3378
|
stats = kwargs.pop("stats", True) # Flag to calculate stats
|
@@ -3383,8 +3518,7 @@ def plotxy(
|
|
3383
3518
|
xycoords="axes fraction",
|
3384
3519
|
fontsize=12,
|
3385
3520
|
color="black",
|
3386
|
-
ha="center"
|
3387
|
-
)
|
3521
|
+
ha="center")
|
3388
3522
|
|
3389
3523
|
elif k == "catplot_sns":
|
3390
3524
|
kws_cat = kwargs.pop("kws_cat", kwargs)
|
@@ -3392,7 +3526,7 @@ def plotxy(
|
|
3392
3526
|
elif k == "displot":
|
3393
3527
|
kws_dis = kwargs.pop("kws_dis", kwargs)
|
3394
3528
|
# displot creates a new figure and returns a FacetGrid
|
3395
|
-
g = sns.displot(data=data, x=x,
|
3529
|
+
g = sns.displot(data=data, x=x,y=y, **kws_dis)
|
3396
3530
|
|
3397
3531
|
# (2) return axis
|
3398
3532
|
if ax is None:
|
@@ -3403,19 +3537,58 @@ def plotxy(
|
|
3403
3537
|
elif k == "stdshade":
|
3404
3538
|
kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
|
3405
3539
|
ax = stdshade(ax=ax, **kwargs)
|
3540
|
+
elif k=="ellipse":
|
3541
|
+
kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
|
3542
|
+
kws_ellipse = {
|
3543
|
+
k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
|
3544
|
+
}
|
3545
|
+
hue = kwargs.get("hue", None)
|
3546
|
+
if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3547
|
+
kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
|
3548
|
+
|
3549
|
+
palette = kwargs.get("palette", None)
|
3550
|
+
if palette is None:
|
3551
|
+
palette = kws_ellipse.pop(
|
3552
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3553
|
+
)
|
3554
|
+
alpha = kws_ellipse.pop("alpha", 0.1)
|
3555
|
+
hue_order = kwargs.get("hue_order",None)
|
3556
|
+
if hue_order is None:
|
3557
|
+
hue_order = kws_ellipse.get("hue_order",None)
|
3558
|
+
if hue_order:
|
3559
|
+
data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
|
3560
|
+
data = data.sort_values(by="hue")
|
3561
|
+
for key in ["palette", "alpha", "hue","hue_order"]:
|
3562
|
+
kws_ellipse.pop(key, None)
|
3563
|
+
ax=ellipse(
|
3564
|
+
ax=ax,
|
3565
|
+
data=data,
|
3566
|
+
x=x,
|
3567
|
+
y=y,
|
3568
|
+
hue=hue,
|
3569
|
+
palette=palette,
|
3570
|
+
alpha=alpha,
|
3571
|
+
zorder=zorder,
|
3572
|
+
**kws_ellipse,
|
3573
|
+
)
|
3406
3574
|
elif k == "scatterplot":
|
3407
3575
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3408
3576
|
kws_scatter = {
|
3409
3577
|
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3410
3578
|
}
|
3411
|
-
hue = kwargs.
|
3579
|
+
hue = kwargs.get("hue", None)
|
3412
3580
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3413
3581
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3414
|
-
palette
|
3415
|
-
|
3416
|
-
|
3582
|
+
palette=kws_scatter.get("palette",None)
|
3583
|
+
if palette is None:
|
3584
|
+
palette = kws_scatter.pop(
|
3585
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3586
|
+
)
|
3417
3587
|
s = kws_scatter.pop("s", 10)
|
3418
|
-
alpha = kws_scatter.pop("alpha", 0.7)
|
3588
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
3589
|
+
for key in ["s", "palette", "alpha", "hue"]:
|
3590
|
+
kws_scatter.pop(key, None)
|
3591
|
+
|
3419
3592
|
ax = sns.scatterplot(
|
3420
3593
|
ax=ax,
|
3421
3594
|
data=data,
|
@@ -3431,19 +3604,33 @@ def plotxy(
|
|
3431
3604
|
elif k == "histplot":
|
3432
3605
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3433
3606
|
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3434
|
-
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3435
|
-
elif k == "kdeplot":
|
3607
|
+
ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
|
3608
|
+
elif k == "kdeplot":
|
3436
3609
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3437
|
-
kws_kde = {
|
3438
|
-
|
3610
|
+
kws_kde = {
|
3611
|
+
k: v for k, v in kws_kde.items() if not k.startswith("kws_")
|
3612
|
+
}
|
3613
|
+
hue = kwargs.get("hue", None)
|
3614
|
+
if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
|
3615
|
+
kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
|
3616
|
+
|
3617
|
+
palette = kwargs.get("palette", None)
|
3618
|
+
if palette is None:
|
3619
|
+
palette = kws_kde.pop(
|
3620
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3621
|
+
)
|
3622
|
+
alpha = kws_kde.pop("alpha", 0.05)
|
3623
|
+
for key in ["palette", "alpha", "hue"]:
|
3624
|
+
kws_kde.pop(key, None)
|
3625
|
+
ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
|
3439
3626
|
elif k == "ecdfplot":
|
3440
3627
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3441
3628
|
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3442
|
-
ax = sns.ecdfplot(data=data, x=x,
|
3629
|
+
ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
|
3443
3630
|
elif k == "rugplot":
|
3444
3631
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3445
3632
|
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3446
|
-
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3633
|
+
ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
|
3447
3634
|
elif k == "stripplot":
|
3448
3635
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3449
3636
|
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
@@ -3516,7 +3703,8 @@ def plotxy(
|
|
3516
3703
|
figsets(ax=ax, **kws_figsets)
|
3517
3704
|
if kws_add_text:
|
3518
3705
|
add_text(ax=ax, **kws_add_text) if kws_add_text else None
|
3519
|
-
|
3706
|
+
if run_once_within(10):
|
3707
|
+
for k in kind_:
|
3520
3708
|
print(f"\n{k}⤵ ")
|
3521
3709
|
print(default_settings[k])
|
3522
3710
|
# print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
|
@@ -3574,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3574
3762
|
"lineplot", # Can work with both wide and long formats
|
3575
3763
|
"area plot", # Can work with both formats, useful for stacked areas
|
3576
3764
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3765
|
+
"ellipse",# ellipse plot, default confidence=0.95
|
3577
3766
|
],
|
3578
3767
|
)[0]
|
3579
3768
|
|
@@ -3609,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3609
3798
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3610
3799
|
"relplot",
|
3611
3800
|
"pointplot", # Works well with wide format
|
3801
|
+
"ellipse",
|
3612
3802
|
]
|
3613
3803
|
|
3614
3804
|
# Wide format (e.g., for heatmap and pairplot)
|
@@ -4100,25 +4290,45 @@ def venn(
|
|
4100
4290
|
"""
|
4101
4291
|
if ax is None:
|
4102
4292
|
ax = plt.gca()
|
4293
|
+
if isinstance(lists, dict):
|
4294
|
+
labels = list(lists.keys())
|
4295
|
+
lists = list(lists.values())
|
4296
|
+
if isinstance(lists[0], set):
|
4297
|
+
lists = [list(i) for i in lists]
|
4103
4298
|
lists = [set(flatten(i, verbose=False)) for i in lists]
|
4104
4299
|
# Function to apply text styles to labels
|
4105
4300
|
if colors is None:
|
4106
4301
|
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
4107
4302
|
if labels is None:
|
4108
|
-
|
4303
|
+
if len(lists) == 2:
|
4304
|
+
labels = ["set1", "set2"]
|
4305
|
+
elif len(lists) == 3:
|
4306
|
+
labels = ["set1", "set2", "set3"]
|
4307
|
+
elif len(lists) == 4:
|
4308
|
+
labels = ["set1", "set2", "set3","set4"]
|
4309
|
+
elif len(lists) == 5:
|
4310
|
+
labels = ["set1", "set2", "set3","set4","set55"]
|
4311
|
+
elif len(lists) == 6:
|
4312
|
+
labels = ["set1", "set2", "set3","set4","set5","set6"]
|
4313
|
+
elif len(lists) == 7:
|
4314
|
+
labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
|
4109
4315
|
if edgecolor is None:
|
4110
4316
|
edgecolor = colors
|
4111
4317
|
colors = [desaturate_color(color, saturation) for color in colors]
|
4112
|
-
|
4113
|
-
if len(lists) == 2:
|
4318
|
+
universe = len(set.union(*lists))
|
4114
4319
|
|
4115
|
-
|
4116
|
-
|
4117
|
-
|
4118
|
-
|
4119
|
-
|
4120
|
-
|
4121
|
-
|
4320
|
+
# Check colors and auto-calculate overlaps
|
4321
|
+
def get_count_and_percentage(set_count, subset_count):
|
4322
|
+
percent = subset_count / set_count if set_count > 0 else 0
|
4323
|
+
return (
|
4324
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
4325
|
+
if show_percentages
|
4326
|
+
else f"{subset_count}"
|
4327
|
+
)
|
4328
|
+
if fmt is not None:
|
4329
|
+
if not fmt.startswith("{"):
|
4330
|
+
fmt="{:" + fmt + "}"
|
4331
|
+
if len(lists) == 2:
|
4122
4332
|
|
4123
4333
|
from matplotlib_venn import venn2, venn2_circles
|
4124
4334
|
|
@@ -4131,21 +4341,28 @@ def venn(
|
|
4131
4341
|
set1, set2 = lists[0], lists[1]
|
4132
4342
|
v.get_patch_by_id("10").set_color(colors[0])
|
4133
4343
|
v.get_patch_by_id("01").set_color(colors[1])
|
4134
|
-
|
4135
|
-
|
4136
|
-
|
4344
|
+
try:
|
4345
|
+
v.get_patch_by_id("11").set_color(
|
4346
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
4347
|
+
)
|
4348
|
+
except Exception as e:
|
4349
|
+
print(e)
|
4137
4350
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
4138
4351
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
4139
4352
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
4353
|
+
|
4140
4354
|
v.get_label_by_id("10").set_text(
|
4141
|
-
get_count_and_percentage(
|
4355
|
+
get_count_and_percentage(universe, len(set1 - set2))
|
4142
4356
|
)
|
4143
4357
|
v.get_label_by_id("01").set_text(
|
4144
|
-
get_count_and_percentage(
|
4145
|
-
)
|
4146
|
-
v.get_label_by_id("11").set_text(
|
4147
|
-
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
4358
|
+
get_count_and_percentage(universe, len(set2 - set1))
|
4148
4359
|
)
|
4360
|
+
try:
|
4361
|
+
v.get_label_by_id("11").set_text(
|
4362
|
+
get_count_and_percentage(universe, len(set1 & set2))
|
4363
|
+
)
|
4364
|
+
except Exception as e:
|
4365
|
+
print(e)
|
4149
4366
|
|
4150
4367
|
if not isinstance(linewidth, list):
|
4151
4368
|
linewidth = [linewidth]
|
@@ -4228,16 +4445,14 @@ def venn(
|
|
4228
4445
|
va=va,
|
4229
4446
|
shadow=shadow,
|
4230
4447
|
)
|
4231
|
-
|
4232
|
-
|
4233
|
-
|
4234
|
-
|
4235
|
-
|
4236
|
-
|
4237
|
-
|
4238
|
-
|
4239
|
-
else f"{subset_count}"
|
4240
|
-
)
|
4448
|
+
# Set transparency level
|
4449
|
+
for patch in v.patches:
|
4450
|
+
if patch:
|
4451
|
+
patch.set_alpha(alpha)
|
4452
|
+
if "none" in edgecolor or 0 in linewidth:
|
4453
|
+
patch.set_edgecolor("none")
|
4454
|
+
return ax
|
4455
|
+
elif len(lists) == 3:
|
4241
4456
|
|
4242
4457
|
from matplotlib_venn import venn3, venn3_circles
|
4243
4458
|
|
@@ -4253,36 +4468,34 @@ def venn(
|
|
4253
4468
|
# Draw the venn diagram
|
4254
4469
|
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
4255
4470
|
v.get_patch_by_id("100").set_color(colors[0])
|
4471
|
+
v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
|
4256
4472
|
v.get_patch_by_id("010").set_color(colors[1])
|
4257
|
-
v.
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
|
4271
|
-
|
4272
|
-
|
4273
|
-
|
4274
|
-
|
4275
|
-
|
4276
|
-
|
4277
|
-
|
4278
|
-
|
4279
|
-
|
4280
|
-
|
4281
|
-
|
4282
|
-
|
4283
|
-
v.get_label_by_id("111").set_text(
|
4284
|
-
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
4285
|
-
)
|
4473
|
+
v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
|
4474
|
+
try:
|
4475
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
4476
|
+
v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
|
4477
|
+
except Exception as e:
|
4478
|
+
print(e)
|
4479
|
+
try:
|
4480
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
4481
|
+
v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
|
4482
|
+
except Exception as e:
|
4483
|
+
print(e)
|
4484
|
+
try:
|
4485
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
4486
|
+
v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
|
4487
|
+
except Exception as e:
|
4488
|
+
print(e)
|
4489
|
+
try:
|
4490
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
4491
|
+
v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
|
4492
|
+
except Exception as e:
|
4493
|
+
print(e)
|
4494
|
+
try:
|
4495
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
4496
|
+
v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
|
4497
|
+
except Exception as e:
|
4498
|
+
print(e)
|
4286
4499
|
|
4287
4500
|
# Apply styles to set labels
|
4288
4501
|
for i, text in enumerate(v.set_labels):
|
@@ -4387,16 +4600,34 @@ def venn(
|
|
4387
4600
|
ax.add_patch(ellipse1)
|
4388
4601
|
ax.add_patch(ellipse2)
|
4389
4602
|
ax.add_patch(ellipse3)
|
4603
|
+
# Set transparency level
|
4604
|
+
for patch in v.patches:
|
4605
|
+
if patch:
|
4606
|
+
patch.set_alpha(alpha)
|
4607
|
+
if "none" in edgecolor or 0 in linewidth:
|
4608
|
+
patch.set_edgecolor("none")
|
4609
|
+
return ax
|
4610
|
+
|
4611
|
+
|
4612
|
+
dict_data = {}
|
4613
|
+
for i_list, list_ in enumerate(lists):
|
4614
|
+
dict_data[labels[i_list]]={*list_}
|
4615
|
+
|
4616
|
+
if 3<len(lists)<6:
|
4617
|
+
from venn import venn as vn
|
4618
|
+
|
4619
|
+
legend_loc=kwargs.pop("legend_loc", "upper right")
|
4620
|
+
ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
|
4621
|
+
|
4622
|
+
return ax
|
4390
4623
|
else:
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
|
4395
|
-
|
4396
|
-
|
4397
|
-
|
4398
|
-
patch.set_edgecolor("none")
|
4399
|
-
return ax
|
4624
|
+
from venn import pseudovenn
|
4625
|
+
cmap=kwargs.pop("cmap","plasma")
|
4626
|
+
ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
|
4627
|
+
|
4628
|
+
return ax
|
4629
|
+
|
4630
|
+
|
4400
4631
|
|
4401
4632
|
|
4402
4633
|
#! subplots, support automatic extend new axis
|
@@ -4407,6 +4638,7 @@ def subplot(
|
|
4407
4638
|
sharex=False,
|
4408
4639
|
sharey=False,
|
4409
4640
|
verbose=False,
|
4641
|
+
fig=None,
|
4410
4642
|
**kwargs,
|
4411
4643
|
):
|
4412
4644
|
"""
|
@@ -4435,8 +4667,8 @@ def subplot(
|
|
4435
4667
|
)
|
4436
4668
|
|
4437
4669
|
figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
|
4438
|
-
|
4439
|
-
|
4670
|
+
if fig is None:
|
4671
|
+
fig = plt.figure(figsize=figsize, constrained_layout=True)
|
4440
4672
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4441
4673
|
occupied = set()
|
4442
4674
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
@@ -4502,7 +4734,7 @@ def radar(
|
|
4502
4734
|
data: pd.DataFrame,
|
4503
4735
|
columns=None,
|
4504
4736
|
ylim=(0, 100),
|
4505
|
-
facecolor=
|
4737
|
+
facecolor=None,
|
4506
4738
|
edgecolor="none",
|
4507
4739
|
edge_linewidth=0.5,
|
4508
4740
|
fontsize=10,
|
@@ -4701,7 +4933,7 @@ def radar(
|
|
4701
4933
|
else:
|
4702
4934
|
# * spider style: spider-style grid (straight lines, not circles)
|
4703
4935
|
# Create the spider-style grid (straight lines, not circles)
|
4704
|
-
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))):#int(vmax * grid_interval_ratio) + 1):
|
4936
|
+
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
|
4705
4937
|
ax.plot(
|
4706
4938
|
angles + [angles[0]], # Closing the loop
|
4707
4939
|
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
@@ -4740,9 +4972,9 @@ def radar(
|
|
4740
4972
|
colors = facecolor
|
4741
4973
|
else:
|
4742
4974
|
colors = (
|
4743
|
-
get_color(data.shape[
|
4975
|
+
get_color(data.shape[1])
|
4744
4976
|
if cmap is None
|
4745
|
-
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[
|
4977
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
|
4746
4978
|
)
|
4747
4979
|
|
4748
4980
|
# Plot each row with straight lines
|
@@ -4840,7 +5072,8 @@ def pie(
|
|
4840
5072
|
pctdistance=0.85,
|
4841
5073
|
labeldistance=1.1,
|
4842
5074
|
kws_wedge={},
|
4843
|
-
kws_text={},
|
5075
|
+
kws_text={},
|
5076
|
+
kws_arrow={},
|
4844
5077
|
center=(0, 0),
|
4845
5078
|
radius=1,
|
4846
5079
|
frame=False,
|
@@ -4850,6 +5083,8 @@ def pie(
|
|
4850
5083
|
cmap=None,
|
4851
5084
|
show_value=False,
|
4852
5085
|
show_label=True,# False: only show the outer layer, if it is None, not show
|
5086
|
+
expand_label=(1.2,1.2),
|
5087
|
+
kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
|
4853
5088
|
show_legend=True,
|
4854
5089
|
legend_loc="upper right",
|
4855
5090
|
bbox_to_anchor=[1.4, 1.1],
|
@@ -4859,7 +5094,7 @@ def pie(
|
|
4859
5094
|
ax=None,
|
4860
5095
|
**kwargs
|
4861
5096
|
):
|
4862
|
-
|
5097
|
+
from adjustText import adjust_text
|
4863
5098
|
if run_once_within(20,reverse=True) and verbose:
|
4864
5099
|
usage_="""usage:
|
4865
5100
|
pie(
|
@@ -4974,6 +5209,11 @@ def pie(
|
|
4974
5209
|
# 选择部分数据
|
4975
5210
|
df=data[columns]
|
4976
5211
|
|
5212
|
+
if not isinstance(explode, list):
|
5213
|
+
explode=[explode]
|
5214
|
+
if explode == [None]:
|
5215
|
+
explode=[0]
|
5216
|
+
|
4977
5217
|
if width is None:
|
4978
5218
|
if df.shape[1]>1:
|
4979
5219
|
width=1/(df.shape[1]+2)
|
@@ -5022,10 +5262,9 @@ def pie(
|
|
5022
5262
|
|
5023
5263
|
if ax is None:
|
5024
5264
|
ax=plt.gca()
|
5025
|
-
if explode
|
5026
|
-
|
5027
|
-
|
5028
|
-
|
5265
|
+
if len(explode)<len(labels_legend):
|
5266
|
+
explode.extend([0]*(len(labels_legend)-len(explode)))
|
5267
|
+
print(explode)
|
5029
5268
|
if fmt:
|
5030
5269
|
if not fmt.startswith("%"):
|
5031
5270
|
autopct =f"%{fmt}%%"
|
@@ -5073,19 +5312,37 @@ def pie(
|
|
5073
5312
|
elif len(result) == 2:
|
5074
5313
|
wedges, texts = result
|
5075
5314
|
autotexts = None
|
5076
|
-
|
5077
|
-
|
5078
|
-
|
5079
|
-
|
5080
|
-
|
5081
|
-
|
5082
|
-
|
5083
|
-
|
5084
|
-
|
5085
|
-
|
5086
|
-
|
5087
|
-
|
5088
|
-
|
5315
|
+
#! adjust_text
|
5316
|
+
if autotexts or texts:
|
5317
|
+
all_texts = []
|
5318
|
+
if autotexts and show_value:
|
5319
|
+
all_texts.extend(autotexts)
|
5320
|
+
if texts and show_label:
|
5321
|
+
all_texts.extend(texts)
|
5322
|
+
|
5323
|
+
adjust_text(
|
5324
|
+
all_texts,
|
5325
|
+
ax=ax,
|
5326
|
+
arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
|
5327
|
+
bbox=kws_bbox if kws_bbox else None,
|
5328
|
+
expand=expand_label,
|
5329
|
+
fontdict={
|
5330
|
+
"fontsize": fontsize,
|
5331
|
+
"color": fontcolor,
|
5332
|
+
},
|
5333
|
+
)
|
5334
|
+
# Show exact values on wedges if show_value is True
|
5335
|
+
if show_value:
|
5336
|
+
for i, (wedge, txt) in enumerate(zip(wedges, texts)):
|
5337
|
+
angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
|
5338
|
+
x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
|
5339
|
+
y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
|
5340
|
+
if not fmt.startswith("{"):
|
5341
|
+
value_text = f"{sizes[i]:{fmt}}"
|
5342
|
+
else:
|
5343
|
+
value_text = fmt.format(sizes[i])
|
5344
|
+
ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
|
5345
|
+
inested+=1
|
5089
5346
|
# Customize the legend
|
5090
5347
|
if show_legend:
|
5091
5348
|
ax.legend(
|
@@ -5098,3 +5355,521 @@ def pie(
|
|
5098
5355
|
)
|
5099
5356
|
ax.set(aspect="equal")
|
5100
5357
|
return ax
|
5358
|
+
|
5359
|
+
def ellipse(
|
5360
|
+
data,
|
5361
|
+
x=None,
|
5362
|
+
y=None,
|
5363
|
+
hue=None,
|
5364
|
+
n_std=1.5,
|
5365
|
+
ax=None,
|
5366
|
+
confidence=0.95,
|
5367
|
+
annotate_center=False,
|
5368
|
+
palette=None,
|
5369
|
+
facecolor=None,
|
5370
|
+
edgecolor=None,
|
5371
|
+
label:bool=True,
|
5372
|
+
**kwargs,
|
5373
|
+
):
|
5374
|
+
"""
|
5375
|
+
Plot advanced ellipses representing covariance for different groups
|
5376
|
+
# simulate data:
|
5377
|
+
control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
|
5378
|
+
patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
|
5379
|
+
df = pd.DataFrame(
|
5380
|
+
{
|
5381
|
+
"Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
|
5382
|
+
"Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
|
5383
|
+
"Group": ["Control"] * 50 + ["Patient"] * 50,
|
5384
|
+
}
|
5385
|
+
)
|
5386
|
+
plotxy(
|
5387
|
+
data=df,
|
5388
|
+
x="Dim1",
|
5389
|
+
y="Dim2",
|
5390
|
+
hue="Group",
|
5391
|
+
kind_="scatter",
|
5392
|
+
palette=get_color(8),
|
5393
|
+
)
|
5394
|
+
ellipse(
|
5395
|
+
data=df,
|
5396
|
+
x="Dim1",
|
5397
|
+
y="Dim2",
|
5398
|
+
hue="Group",
|
5399
|
+
palette=get_color(8),
|
5400
|
+
alpha=0.1,
|
5401
|
+
lw=2,
|
5402
|
+
)
|
5403
|
+
Parameters:
|
5404
|
+
data (DataFrame): Input DataFrame with columns for x, y, and hue.
|
5405
|
+
x (str): Column name for x-axis values.
|
5406
|
+
y (str): Column name for y-axis values.
|
5407
|
+
hue (str, optional): Column name for group labels.
|
5408
|
+
n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
|
5409
|
+
ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
|
5410
|
+
confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
|
5411
|
+
annotate_center (bool): Whether to annotate the ellipse center (mean).
|
5412
|
+
palette (dict or list, optional): A mapping of hues to colors or a list of colors.
|
5413
|
+
**kwargs: Additional keyword arguments for the Ellipse patch.
|
5414
|
+
|
5415
|
+
Returns:
|
5416
|
+
list: List of Ellipse objects added to the Axes.
|
5417
|
+
"""
|
5418
|
+
from matplotlib.patches import Ellipse
|
5419
|
+
import numpy as np
|
5420
|
+
import matplotlib.pyplot as plt
|
5421
|
+
import seaborn as sns
|
5422
|
+
import pandas as pd
|
5423
|
+
from scipy.stats import chi2
|
5424
|
+
|
5425
|
+
if ax is None:
|
5426
|
+
ax = plt.gca()
|
5427
|
+
|
5428
|
+
# Validate inputs
|
5429
|
+
if x is None or y is None:
|
5430
|
+
raise ValueError(
|
5431
|
+
"Both `x` and `y` must be specified as column names in the DataFrame."
|
5432
|
+
)
|
5433
|
+
if not isinstance(data, pd.DataFrame):
|
5434
|
+
raise ValueError("`data` must be a pandas DataFrame.")
|
5435
|
+
|
5436
|
+
# Prepare data for hue-based grouping
|
5437
|
+
ellipses = []
|
5438
|
+
if hue is not None:
|
5439
|
+
groups = data[hue].unique()
|
5440
|
+
colors = sns.color_palette(palette or "husl", len(groups))
|
5441
|
+
color_map = dict(zip(groups, colors))
|
5442
|
+
else:
|
5443
|
+
groups = [None]
|
5444
|
+
color_map = {None: kwargs.get("edgecolor", "blue")}
|
5445
|
+
alpha = kwargs.pop("alpha", 0.2)
|
5446
|
+
edgecolor=kwargs.pop("edgecolor", None)
|
5447
|
+
facecolor=kwargs.pop("facecolor", None)
|
5448
|
+
for group in groups:
|
5449
|
+
group_data = data[data[hue] == group] if hue else data
|
5450
|
+
|
5451
|
+
# Extract x and y columns for the group
|
5452
|
+
group_points = group_data[[x, y]].values
|
5453
|
+
|
5454
|
+
# Compute mean and covariance matrix
|
5455
|
+
# # 标准化处理
|
5456
|
+
# group_points = group_data[[x, y]].values
|
5457
|
+
# group_points -= group_points.mean(axis=0)
|
5458
|
+
# group_points /= group_points.std(axis=0)
|
5459
|
+
|
5460
|
+
cov = np.cov(group_points.T)
|
5461
|
+
mean = np.mean(group_points, axis=0)
|
5462
|
+
|
5463
|
+
# Eigenvalues and eigenvectors
|
5464
|
+
eigvals, eigvecs = np.linalg.eigh(cov)
|
5465
|
+
order = eigvals.argsort()[::-1]
|
5466
|
+
eigvals, eigvecs = eigvals[order], eigvecs[:, order]
|
5467
|
+
|
5468
|
+
# Rotation angle and ellipse dimensions
|
5469
|
+
angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
|
5470
|
+
if confidence:
|
5471
|
+
n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
|
5472
|
+
width, height = 2 * n_std * np.sqrt(eigvals)
|
5473
|
+
|
5474
|
+
# Create and style the ellipse
|
5475
|
+
if facecolor is None:
|
5476
|
+
facecolor_ = color_map[group]
|
5477
|
+
if edgecolor is None:
|
5478
|
+
edgecolor_ = color_map[group]
|
5479
|
+
ellipse = Ellipse(
|
5480
|
+
xy=mean,
|
5481
|
+
width=width,
|
5482
|
+
height=height,
|
5483
|
+
angle=angle,
|
5484
|
+
edgecolor=edgecolor_,
|
5485
|
+
facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
|
5486
|
+
# alpha=alpha,
|
5487
|
+
label=group if (hue and label) else None,
|
5488
|
+
**kwargs,
|
5489
|
+
)
|
5490
|
+
ax.add_patch(ellipse)
|
5491
|
+
ellipses.append(ellipse)
|
5492
|
+
|
5493
|
+
# Annotate center
|
5494
|
+
if annotate_center:
|
5495
|
+
ax.annotate(
|
5496
|
+
f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
|
5497
|
+
xy=mean,
|
5498
|
+
xycoords="data",
|
5499
|
+
fontsize=10,
|
5500
|
+
ha="center",
|
5501
|
+
color=ellipse_color,
|
5502
|
+
bbox=dict(
|
5503
|
+
boxstyle="round,pad=0.3",
|
5504
|
+
edgecolor="gray",
|
5505
|
+
facecolor="white",
|
5506
|
+
alpha=0.8,
|
5507
|
+
),
|
5508
|
+
)
|
5509
|
+
|
5510
|
+
return ax
|
5511
|
+
|
5512
|
+
def ppi(
|
5513
|
+
interactions,
|
5514
|
+
player1="preferredName_A",
|
5515
|
+
player2="preferredName_B",
|
5516
|
+
weight="score",
|
5517
|
+
n_layers=None, # Number of concentric layers
|
5518
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5519
|
+
dist_node = 10, # Distance between each rank of circles
|
5520
|
+
layout="degree",
|
5521
|
+
size=None,#700,
|
5522
|
+
sizes=(50,500),# min and max of size
|
5523
|
+
facecolor="skyblue",
|
5524
|
+
cmap='coolwarm',
|
5525
|
+
edgecolor="k",
|
5526
|
+
edgelinewidth=1.5,
|
5527
|
+
alpha=.5,
|
5528
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5529
|
+
marker="o",
|
5530
|
+
node_hideticks=True,
|
5531
|
+
linecolor="gray",
|
5532
|
+
line_cmap='coolwarm',
|
5533
|
+
linewidth=1.5,
|
5534
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5535
|
+
linealpha=1.0,
|
5536
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5537
|
+
linestyle="-",
|
5538
|
+
line_arrowstyle='-',
|
5539
|
+
fontsize=10,
|
5540
|
+
fontcolor="k",
|
5541
|
+
ha:str="center",
|
5542
|
+
va:str="center",
|
5543
|
+
figsize=(12, 10),
|
5544
|
+
k_value=0.3,
|
5545
|
+
bgcolor="w",
|
5546
|
+
dir_save="./ppi_network.html",
|
5547
|
+
physics=True,
|
5548
|
+
notebook=False,
|
5549
|
+
scale=1,
|
5550
|
+
ax=None,
|
5551
|
+
**kwargs
|
5552
|
+
):
|
5553
|
+
"""
|
5554
|
+
Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
|
5555
|
+
|
5556
|
+
ppi(
|
5557
|
+
interactions_sort.iloc[:1000, :],
|
5558
|
+
player1="player1",
|
5559
|
+
player2="player2",
|
5560
|
+
weight="count",
|
5561
|
+
layout="spring",
|
5562
|
+
n_layers=13,
|
5563
|
+
fontsize=1,
|
5564
|
+
n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
|
5565
|
+
)
|
5566
|
+
"""
|
5567
|
+
from pyvis.network import Network
|
5568
|
+
import networkx as nx
|
5569
|
+
from IPython.display import IFrame
|
5570
|
+
from matplotlib.colors import Normalize
|
5571
|
+
from matplotlib import cm
|
5572
|
+
from . import ips
|
5573
|
+
|
5574
|
+
if run_once_within():
|
5575
|
+
usage_str="""
|
5576
|
+
ppi(
|
5577
|
+
interactions,
|
5578
|
+
player1="preferredName_A",
|
5579
|
+
player2="preferredName_B",
|
5580
|
+
weight="score",
|
5581
|
+
n_layers=None, # Number of concentric layers
|
5582
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5583
|
+
dist_node = 10, # Distance between each rank of circles
|
5584
|
+
layout="degree",
|
5585
|
+
size=None,#700,
|
5586
|
+
sizes=(50,500),# min and max of size
|
5587
|
+
facecolor="skyblue",
|
5588
|
+
cmap='coolwarm',
|
5589
|
+
edgecolor="k",
|
5590
|
+
edgelinewidth=1.5,
|
5591
|
+
alpha=.5,
|
5592
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5593
|
+
marker="o",
|
5594
|
+
node_hideticks=True,
|
5595
|
+
linecolor="gray",
|
5596
|
+
line_cmap='coolwarm',
|
5597
|
+
linewidth=1.5,
|
5598
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5599
|
+
linealpha=1.0,
|
5600
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5601
|
+
linestyle="-",
|
5602
|
+
line_arrowstyle='-',
|
5603
|
+
fontsize=10,
|
5604
|
+
fontcolor="k",
|
5605
|
+
ha:str="center",
|
5606
|
+
va:str="center",
|
5607
|
+
figsize=(12, 10),
|
5608
|
+
k_value=0.3,
|
5609
|
+
bgcolor="w",
|
5610
|
+
dir_save="./ppi_network.html",
|
5611
|
+
physics=True,
|
5612
|
+
notebook=False,
|
5613
|
+
scale=1,
|
5614
|
+
ax=None,
|
5615
|
+
**kwargs
|
5616
|
+
):
|
5617
|
+
"""
|
5618
|
+
print(usage_str)
|
5619
|
+
|
5620
|
+
# Check for required columns in the DataFrame
|
5621
|
+
for col in [player1, player2, weight]:
|
5622
|
+
if col not in interactions.columns:
|
5623
|
+
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
5624
|
+
interactions.sort_values(by=[weight], inplace=True)
|
5625
|
+
# Initialize Pyvis network
|
5626
|
+
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
5627
|
+
net.force_atlas_2based(
|
5628
|
+
gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
|
5629
|
+
)
|
5630
|
+
net.toggle_physics(physics)
|
5631
|
+
|
5632
|
+
kws_figsets = {}
|
5633
|
+
for k_arg, v_arg in kwargs.items():
|
5634
|
+
if "figset" in k_arg:
|
5635
|
+
kws_figsets = v_arg
|
5636
|
+
kwargs.pop(k_arg, None)
|
5637
|
+
break
|
5638
|
+
|
5639
|
+
# Create a NetworkX graph from the interaction data
|
5640
|
+
G = nx.Graph()
|
5641
|
+
for _, row in interactions.iterrows():
|
5642
|
+
G.add_edge(row[player1], row[player2], weight=row[weight])
|
5643
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
5644
|
+
|
5645
|
+
|
5646
|
+
# Calculate node degrees
|
5647
|
+
degrees = dict(G.degree())
|
5648
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5649
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
5650
|
+
|
5651
|
+
if not ips.isa(facecolor, 'color'):
|
5652
|
+
print("facecolor: based on degrees")
|
5653
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
5654
|
+
num_nodes = G.number_of_nodes()
|
5655
|
+
#* size
|
5656
|
+
# Set properties based on degrees
|
5657
|
+
if not isinstance(size, (int,float,list)):
|
5658
|
+
print("size: based on degrees")
|
5659
|
+
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
5660
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
5661
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
5662
|
+
# Normalize sizes
|
5663
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
5664
|
+
min_degree, max_degree = min(size), max(size)
|
5665
|
+
if max_degree > min_degree: # Avoid division by zero
|
5666
|
+
size = [
|
5667
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
5668
|
+
for sz in size
|
5669
|
+
]
|
5670
|
+
else:
|
5671
|
+
# If all values are the same, set them to a default of the midpoint
|
5672
|
+
size = [(min_size + max_size) / 2] * len(size)
|
5673
|
+
|
5674
|
+
#* facecolor
|
5675
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
5676
|
+
# * facealpha
|
5677
|
+
if isinstance(alpha, list):
|
5678
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
5679
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
5680
|
+
if len(alpha) > 0:
|
5681
|
+
# Normalize alpha based on the specified min and max
|
5682
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
5683
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
5684
|
+
alpha = [
|
5685
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
5686
|
+
for ea in alpha
|
5687
|
+
]
|
5688
|
+
else:
|
5689
|
+
# If all alpha values are the same, set them to the average of min and max
|
5690
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
5691
|
+
else:
|
5692
|
+
# Default to a full opacity if no edges are provided
|
5693
|
+
alpha = [1.0] * num_nodes
|
5694
|
+
else:
|
5695
|
+
# If alpha is a single value, convert it to a list and normalize it
|
5696
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
5697
|
+
|
5698
|
+
for i, node in enumerate(G.nodes()):
|
5699
|
+
net.add_node(
|
5700
|
+
node,
|
5701
|
+
label=node,
|
5702
|
+
size=size[i],
|
5703
|
+
color=facecolor[i],
|
5704
|
+
alpha=alpha[i],
|
5705
|
+
font={"size": fontsize, "color": fontcolor},
|
5706
|
+
)
|
5707
|
+
print(f'nodes number: {i+1}')
|
5708
|
+
|
5709
|
+
for edge in G.edges(data=True):
|
5710
|
+
net.add_edge(
|
5711
|
+
edge[0],
|
5712
|
+
edge[1],
|
5713
|
+
weight=edge[2]["weight"],
|
5714
|
+
color=edgecolor,
|
5715
|
+
width=edgelinewidth * edge[2]["weight"],
|
5716
|
+
)
|
5717
|
+
|
5718
|
+
layouts = [
|
5719
|
+
"spring",
|
5720
|
+
"circular",
|
5721
|
+
"kamada_kawai",
|
5722
|
+
"random",
|
5723
|
+
"shell",
|
5724
|
+
"planar",
|
5725
|
+
"spiral",
|
5726
|
+
"degree"
|
5727
|
+
]
|
5728
|
+
layout = ips.strcmp(layout, layouts)[0]
|
5729
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
5730
|
+
|
5731
|
+
# Choose layout
|
5732
|
+
if layout == "spring":
|
5733
|
+
pos = nx.spring_layout(G, k=k_value)
|
5734
|
+
elif layout == "circular":
|
5735
|
+
pos = nx.circular_layout(G)
|
5736
|
+
elif layout == "kamada_kawai":
|
5737
|
+
pos = nx.kamada_kawai_layout(G)
|
5738
|
+
elif layout == "spectral":
|
5739
|
+
pos = nx.spectral_layout(G)
|
5740
|
+
elif layout == "random":
|
5741
|
+
pos = nx.random_layout(G)
|
5742
|
+
elif layout == "shell":
|
5743
|
+
pos = nx.shell_layout(G)
|
5744
|
+
elif layout == "planar":
|
5745
|
+
if nx.check_planarity(G)[0]:
|
5746
|
+
pos = nx.planar_layout(G)
|
5747
|
+
else:
|
5748
|
+
print("Graph is not planar; switching to spring layout.")
|
5749
|
+
pos = nx.spring_layout(G, k=k_value)
|
5750
|
+
elif layout == "spiral":
|
5751
|
+
pos = nx.spiral_layout(G)
|
5752
|
+
elif layout=='degree':
|
5753
|
+
# Calculate node degrees and sort nodes by degree
|
5754
|
+
degrees = dict(G.degree())
|
5755
|
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
5756
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5757
|
+
colormap = cm.get_cmap(cmap)
|
5758
|
+
|
5759
|
+
# Create positions for concentric circles based on n_layers and n_rank
|
5760
|
+
pos = {}
|
5761
|
+
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
5762
|
+
for rank_index in range(n_layers):
|
5763
|
+
if rank_index < len(n_rank):
|
5764
|
+
nodes_per_rank = n_rank[rank_index]
|
5765
|
+
rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
|
5766
|
+
else:
|
5767
|
+
# 随机打乱剩余节点的顺序
|
5768
|
+
remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
|
5769
|
+
random_indices = np.random.permutation(len(remaining_nodes))
|
5770
|
+
rank_nodes = [remaining_nodes[i] for i in random_indices]
|
5771
|
+
|
5772
|
+
radius = (rank_index + 1) * dist_node # Radius for this rank
|
5773
|
+
|
5774
|
+
# Arrange nodes in a circle for the current rank
|
5775
|
+
for i, (node, degree) in enumerate(rank_nodes):
|
5776
|
+
angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
|
5777
|
+
pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
|
5778
|
+
|
5779
|
+
else:
|
5780
|
+
print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
|
5781
|
+
pos = nx.spring_layout(G, k=k_value)
|
5782
|
+
|
5783
|
+
for node, (x, y) in pos.items():
|
5784
|
+
net.get_node(node)["x"] = x * scale
|
5785
|
+
net.get_node(node)["y"] = y * scale
|
5786
|
+
|
5787
|
+
# If ax is None, use plt.gca()
|
5788
|
+
if ax is None:
|
5789
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
5790
|
+
|
5791
|
+
# Draw nodes, edges, and labels with customization options
|
5792
|
+
nx.draw_networkx_nodes(
|
5793
|
+
G,
|
5794
|
+
pos,
|
5795
|
+
ax=ax,
|
5796
|
+
node_size=size,
|
5797
|
+
node_color=facecolor,
|
5798
|
+
linewidths=edgelinewidth,
|
5799
|
+
edgecolors=edgecolor,
|
5800
|
+
alpha=alpha,
|
5801
|
+
hide_ticks=node_hideticks,
|
5802
|
+
node_shape=marker
|
5803
|
+
)
|
5804
|
+
|
5805
|
+
#* linewidth
|
5806
|
+
if not isinstance(linewidth, list):
|
5807
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5808
|
+
else:
|
5809
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
5810
|
+
# Normalize linewidth if it is a list
|
5811
|
+
if isinstance(linewidth, list):
|
5812
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
5813
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
5814
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
5815
|
+
# Scale between vmin and vmax
|
5816
|
+
linewidth = [
|
5817
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
5818
|
+
for lw in linewidth
|
5819
|
+
]
|
5820
|
+
else:
|
5821
|
+
# If all values are the same, set them to a default of the midpoint
|
5822
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
5823
|
+
else:
|
5824
|
+
# If linewidth is a single value, convert it to a list of that value
|
5825
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5826
|
+
#* linecolor
|
5827
|
+
if not isinstance(linecolor, str):
|
5828
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
5829
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
5830
|
+
colormap = cm.get_cmap(line_cmap)
|
5831
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
5832
|
+
else:
|
5833
|
+
linecolor = [linecolor] * G.number_of_edges()
|
5834
|
+
|
5835
|
+
# * linealpha
|
5836
|
+
if isinstance(linealpha, list):
|
5837
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
5838
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
5839
|
+
if len(linealpha) > 0:
|
5840
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
5841
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
5842
|
+
linealpha = [
|
5843
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
5844
|
+
for ea in linealpha
|
5845
|
+
]
|
5846
|
+
else:
|
5847
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
5848
|
+
else:
|
5849
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
5850
|
+
else:
|
5851
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
5852
|
+
nx.draw_networkx_edges(
|
5853
|
+
G,
|
5854
|
+
pos,
|
5855
|
+
ax=ax,
|
5856
|
+
edge_color=linecolor,
|
5857
|
+
width=linewidth,
|
5858
|
+
style=linestyle,
|
5859
|
+
arrowstyle=line_arrowstyle,
|
5860
|
+
alpha=linealpha
|
5861
|
+
)
|
5862
|
+
|
5863
|
+
nx.draw_networkx_labels(
|
5864
|
+
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
5865
|
+
)
|
5866
|
+
figsets(ax=ax,**kws_figsets)
|
5867
|
+
ax.axis("off")
|
5868
|
+
if dir_save:
|
5869
|
+
if not os.path.basename(dir_save):
|
5870
|
+
dir_save="_.html"
|
5871
|
+
net.write_html(dir_save)
|
5872
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
5873
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
5874
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
5875
|
+
return G,ax
|