py2ls 0.2.4.25__py3-none-any.whl → 0.2.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/usages_sns.json +6 -1
- py2ls/ips.py +1059 -114
- py2ls/ml2ls.py +758 -186
- py2ls/netfinder.py +204 -20
- py2ls/ocr.py +60 -4
- py2ls/plot.py +916 -141
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.27.dist-info}/METADATA +6 -1
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.27.dist-info}/RECORD +16 -14
- py2ls/data/usages_pd copy.json +0 -1105
- {py2ls-0.2.4.25.dist-info → py2ls-0.2.4.27.dist-info}/WHEEL +0 -0
py2ls/plot.py
CHANGED
@@ -20,7 +20,9 @@ from .ips import (
|
|
20
20
|
flatten,
|
21
21
|
plt_font,
|
22
22
|
run_once_within,
|
23
|
-
df_format
|
23
|
+
df_format,
|
24
|
+
df_corr,
|
25
|
+
df_scaler
|
24
26
|
)
|
25
27
|
import scipy.stats as scipy_stats
|
26
28
|
from .stats import *
|
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
|
|
142
144
|
**kwargs,
|
143
145
|
)
|
144
146
|
|
147
|
+
def pval2str(p):
|
148
|
+
if p > 0.05:
|
149
|
+
txt = ""
|
150
|
+
elif 0.01 <= p <= 0.05:
|
151
|
+
txt = "*"
|
152
|
+
elif 0.001 <= p < 0.01:
|
153
|
+
txt = "**"
|
154
|
+
elif p < 0.001:
|
155
|
+
txt = "***"
|
156
|
+
return txt
|
145
157
|
|
146
158
|
def heatmap(
|
147
159
|
data,
|
148
160
|
ax=None,
|
149
161
|
kind="corr", #'corr','direct','pivot'
|
162
|
+
method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
|
150
163
|
columns="all", # pivot, default: coll numeric columns
|
164
|
+
style=0,# for correlation
|
151
165
|
index=None, # pivot
|
152
166
|
values=None, # pivot
|
167
|
+
fontsize=10,
|
153
168
|
tri="u",
|
154
169
|
mask=True,
|
155
170
|
k=1,
|
171
|
+
vmin=None,
|
172
|
+
vmax=None,
|
173
|
+
size_scale=500,
|
156
174
|
annot=True,
|
157
175
|
cmap="coolwarm",
|
158
176
|
fmt=".2f",
|
177
|
+
show_indicator = True,# only for style==1
|
159
178
|
cluster=False,
|
160
179
|
inplace=False,
|
161
180
|
figsize=(10, 8),
|
@@ -201,7 +220,7 @@ def heatmap(
|
|
201
220
|
ax = plt.gca()
|
202
221
|
# Select numeric columns or specific subset of columns
|
203
222
|
if columns == "all":
|
204
|
-
df_numeric = data.select_dtypes(include=[
|
223
|
+
df_numeric = data.select_dtypes(include=[np.number])
|
205
224
|
else:
|
206
225
|
df_numeric = data[columns]
|
207
226
|
|
@@ -209,8 +228,10 @@ def heatmap(
|
|
209
228
|
kind = strcmp(kind, kinds)[0]
|
210
229
|
print(kind)
|
211
230
|
if "corr" in kind: # correlation
|
231
|
+
methods = ["pearson", "spearman", "kendall"]
|
232
|
+
method = strcmp(method, methods)[0]
|
212
233
|
# Compute the correlation matrix
|
213
|
-
data4heatmap = df_numeric.corr()
|
234
|
+
data4heatmap = df_numeric.corr(method=method)
|
214
235
|
# Generate mask for the upper triangle if mask is True
|
215
236
|
if mask:
|
216
237
|
if "u" in tri.lower(): # upper => np.tril
|
@@ -288,18 +309,76 @@ def heatmap(
|
|
288
309
|
df_col_cluster,
|
289
310
|
)
|
290
311
|
else:
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
312
|
+
if style==0:
|
313
|
+
# Create a standard heatmap
|
314
|
+
ax = sns.heatmap(
|
315
|
+
data4heatmap,
|
316
|
+
ax=ax,
|
317
|
+
mask=mask_array,
|
318
|
+
annot=annot,
|
319
|
+
cmap=cmap,
|
320
|
+
fmt=fmt,
|
321
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
322
|
+
)
|
323
|
+
return ax
|
324
|
+
elif style==1:
|
325
|
+
if isinstance(cmap, str):
|
326
|
+
cmap = plt.get_cmap(cmap)
|
327
|
+
norm = plt.Normalize(vmin=-1, vmax=1)
|
328
|
+
r_, p_ = df_corr(data4heatmap, method=method)
|
329
|
+
# size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
|
330
|
+
# 初始化一个空的可绘制对象用于颜色条
|
331
|
+
scatter_handles = []
|
332
|
+
# 循环绘制气泡图和数值
|
333
|
+
for i in range(len(r_.columns)):
|
334
|
+
for j in range(len(r_.columns)):
|
335
|
+
if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
|
336
|
+
color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
|
337
|
+
scatter = ax.scatter(
|
338
|
+
i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
|
339
|
+
# alpha=1,edgecolor=edgecolor,linewidth=linewidth,
|
340
|
+
**kwargs
|
341
|
+
)
|
342
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
343
|
+
# add *** indicators
|
344
|
+
if show_indicator:
|
345
|
+
ax.text(
|
346
|
+
i,
|
347
|
+
j,
|
348
|
+
pval2str(p_.iloc[i, j]),
|
349
|
+
ha="center",
|
350
|
+
va="center",
|
351
|
+
color="k",
|
352
|
+
fontsize=fontsize * 1.3,
|
353
|
+
)
|
354
|
+
elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
|
355
|
+
color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
|
356
|
+
ax.text(
|
357
|
+
i,
|
358
|
+
j,
|
359
|
+
f"{r_.iloc[i, j]:{fmt}}",
|
360
|
+
ha="center",
|
361
|
+
va="center",
|
362
|
+
color=color,
|
363
|
+
fontsize=fontsize,
|
364
|
+
)
|
365
|
+
else: # 对角线部分,显示空白
|
366
|
+
ax.scatter(i, j, s=1, color="white")
|
367
|
+
# 设置坐标轴标签
|
368
|
+
figsets(xticks=range(len(r_.columns)),
|
369
|
+
xticklabels=r_.columns,
|
370
|
+
xangle=90,
|
371
|
+
fontsize=fontsize,
|
372
|
+
yticks=range(len(r_.columns)),
|
373
|
+
yticklabels=r_.columns,
|
374
|
+
xlim=[-0.5,len(r_.columns)-0.5],
|
375
|
+
ylim=[-0.5,len(r_.columns)-0.5]
|
376
|
+
)
|
377
|
+
# 添加颜色条
|
378
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
379
|
+
sm.set_array([]) # 仅用于显示颜色条
|
380
|
+
plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
|
381
|
+
return ax
|
303
382
|
elif "dir" in kind: # direct
|
304
383
|
data4heatmap = df_numeric
|
305
384
|
elif "pi" in kind: # pivot
|
@@ -389,16 +468,53 @@ def heatmap(
|
|
389
468
|
)
|
390
469
|
else:
|
391
470
|
# Create a standard heatmap
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
471
|
+
if style==0:
|
472
|
+
ax = sns.heatmap(
|
473
|
+
data4heatmap,
|
474
|
+
ax=ax,
|
475
|
+
annot=annot,
|
476
|
+
cmap=cmap,
|
477
|
+
fmt=fmt,
|
478
|
+
**kwargs, # Pass any additional arguments to sns.heatmap
|
479
|
+
)
|
480
|
+
# Return the Axes object for further customization if needed
|
481
|
+
return ax
|
482
|
+
elif style==1:
|
483
|
+
if isinstance(cmap, str):
|
484
|
+
cmap = plt.get_cmap(cmap)
|
485
|
+
if vmin is None:
|
486
|
+
vmin=np.min(data4heatmap)
|
487
|
+
if vmax is None:
|
488
|
+
vmax=np.max(data4heatmap)
|
489
|
+
norm = plt.Normalize(vmin=vmin, vmax=vmax)
|
490
|
+
|
491
|
+
# 初始化一个空的可绘制对象用于颜色条
|
492
|
+
scatter_handles = []
|
493
|
+
# 循环绘制气泡图和数值
|
494
|
+
print(len(data4heatmap.index),len(data4heatmap.columns))
|
495
|
+
for i in range(len(data4heatmap.index)):
|
496
|
+
for j in range(len(data4heatmap.columns)):
|
497
|
+
color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
|
498
|
+
scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
|
499
|
+
scatter_handles.append(scatter) # 保存scatter对象用于颜色条
|
500
|
+
|
501
|
+
# 设置坐标轴标签
|
502
|
+
figsets(xticks=range(len(data4heatmap.columns)),
|
503
|
+
xticklabels=data4heatmap.columns,
|
504
|
+
xangle=90,
|
505
|
+
fontsize=fontsize,
|
506
|
+
yticks=range(len(data4heatmap.index)),
|
507
|
+
yticklabels=data4heatmap.index,
|
508
|
+
xlim=[-0.5,len(data4heatmap.columns)-0.5],
|
509
|
+
ylim=[-0.5,len(data4heatmap.index)-0.5]
|
510
|
+
)
|
511
|
+
# 添加颜色条
|
512
|
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
513
|
+
sm.set_array([]) # 仅用于显示颜色条
|
514
|
+
plt.colorbar(sm, ax=ax,
|
515
|
+
# label="Correlation Coefficient"
|
516
|
+
)
|
517
|
+
return ax
|
402
518
|
|
403
519
|
|
404
520
|
# !usage: py2ls.plot.heatmap()
|
@@ -3226,18 +3342,37 @@ def plotxy(
|
|
3226
3342
|
# (1) return FcetGrid
|
3227
3343
|
if k == "jointplot":
|
3228
3344
|
kws_joint = kwargs.pop("kws_joint", kwargs)
|
3229
|
-
|
3345
|
+
kws_joint = {
|
3346
|
+
k: v for k, v in kws_joint.items() if not k.startswith("kws_")
|
3347
|
+
}
|
3348
|
+
hue = kwargs.get("hue", None)
|
3349
|
+
if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3350
|
+
kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
|
3351
|
+
|
3352
|
+
palette = kwargs.get("palette", None)
|
3353
|
+
if palette is None:
|
3354
|
+
palette = kws_joint.pop(
|
3355
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3356
|
+
)
|
3357
|
+
else:
|
3358
|
+
kws_joint.pop("palette", palette)
|
3359
|
+
stats=kwargs.pop("stats",None)
|
3360
|
+
if stats:
|
3361
|
+
stats=kws_joint.pop("stats",True)
|
3230
3362
|
if stats:
|
3231
3363
|
r, p_value = scipy_stats.pearsonr(data[x], data[y])
|
3232
|
-
|
3233
|
-
|
3234
|
-
|
3235
|
-
|
3236
|
-
|
3237
|
-
|
3238
|
-
|
3239
|
-
|
3240
|
-
|
3364
|
+
for key in ["palette", "alpha", "hue","stats"]:
|
3365
|
+
kws_joint.pop(key, None)
|
3366
|
+
g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
|
3367
|
+
if stats:
|
3368
|
+
g.ax_joint.annotate(
|
3369
|
+
f"pearsonr = {r:.2f} p = {p_value:.3f}",
|
3370
|
+
xy=(0.6, 0.98),
|
3371
|
+
xycoords="axes fraction",
|
3372
|
+
fontsize=12,
|
3373
|
+
color="black",
|
3374
|
+
ha="center",
|
3375
|
+
)
|
3241
3376
|
elif k == "lmplot":
|
3242
3377
|
kws_lm = kwargs.pop("kws_lm", kwargs)
|
3243
3378
|
stats = kwargs.pop("stats", True) # Flag to calculate stats
|
@@ -3383,8 +3518,7 @@ def plotxy(
|
|
3383
3518
|
xycoords="axes fraction",
|
3384
3519
|
fontsize=12,
|
3385
3520
|
color="black",
|
3386
|
-
ha="center"
|
3387
|
-
)
|
3521
|
+
ha="center")
|
3388
3522
|
|
3389
3523
|
elif k == "catplot_sns":
|
3390
3524
|
kws_cat = kwargs.pop("kws_cat", kwargs)
|
@@ -3392,7 +3526,7 @@ def plotxy(
|
|
3392
3526
|
elif k == "displot":
|
3393
3527
|
kws_dis = kwargs.pop("kws_dis", kwargs)
|
3394
3528
|
# displot creates a new figure and returns a FacetGrid
|
3395
|
-
g = sns.displot(data=data, x=x,
|
3529
|
+
g = sns.displot(data=data, x=x,y=y, **kws_dis)
|
3396
3530
|
|
3397
3531
|
# (2) return axis
|
3398
3532
|
if ax is None:
|
@@ -3403,19 +3537,58 @@ def plotxy(
|
|
3403
3537
|
elif k == "stdshade":
|
3404
3538
|
kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
|
3405
3539
|
ax = stdshade(ax=ax, **kwargs)
|
3540
|
+
elif k=="ellipse":
|
3541
|
+
kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
|
3542
|
+
kws_ellipse = {
|
3543
|
+
k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
|
3544
|
+
}
|
3545
|
+
hue = kwargs.get("hue", None)
|
3546
|
+
if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
|
3547
|
+
kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
|
3548
|
+
|
3549
|
+
palette = kwargs.get("palette", None)
|
3550
|
+
if palette is None:
|
3551
|
+
palette = kws_ellipse.pop(
|
3552
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3553
|
+
)
|
3554
|
+
alpha = kws_ellipse.pop("alpha", 0.1)
|
3555
|
+
hue_order = kwargs.get("hue_order",None)
|
3556
|
+
if hue_order is None:
|
3557
|
+
hue_order = kws_ellipse.get("hue_order",None)
|
3558
|
+
if hue_order:
|
3559
|
+
data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
|
3560
|
+
data = data.sort_values(by="hue")
|
3561
|
+
for key in ["palette", "alpha", "hue","hue_order"]:
|
3562
|
+
kws_ellipse.pop(key, None)
|
3563
|
+
ax=ellipse(
|
3564
|
+
ax=ax,
|
3565
|
+
data=data,
|
3566
|
+
x=x,
|
3567
|
+
y=y,
|
3568
|
+
hue=hue,
|
3569
|
+
palette=palette,
|
3570
|
+
alpha=alpha,
|
3571
|
+
zorder=zorder,
|
3572
|
+
**kws_ellipse,
|
3573
|
+
)
|
3406
3574
|
elif k == "scatterplot":
|
3407
3575
|
kws_scatter = kwargs.pop("kws_scatter", kwargs)
|
3408
3576
|
kws_scatter = {
|
3409
3577
|
k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
|
3410
3578
|
}
|
3411
|
-
hue = kwargs.
|
3579
|
+
hue = kwargs.get("hue", None)
|
3412
3580
|
if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
|
3413
3581
|
kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
|
3414
|
-
palette
|
3415
|
-
|
3416
|
-
|
3582
|
+
palette=kws_scatter.get("palette",None)
|
3583
|
+
if palette is None:
|
3584
|
+
palette = kws_scatter.pop(
|
3585
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3586
|
+
)
|
3417
3587
|
s = kws_scatter.pop("s", 10)
|
3418
|
-
alpha = kws_scatter.pop("alpha", 0.7)
|
3588
|
+
alpha = kws_scatter.pop("alpha", 0.7)
|
3589
|
+
for key in ["s", "palette", "alpha", "hue"]:
|
3590
|
+
kws_scatter.pop(key, None)
|
3591
|
+
|
3419
3592
|
ax = sns.scatterplot(
|
3420
3593
|
ax=ax,
|
3421
3594
|
data=data,
|
@@ -3431,19 +3604,33 @@ def plotxy(
|
|
3431
3604
|
elif k == "histplot":
|
3432
3605
|
kws_hist = kwargs.pop("kws_hist", kwargs)
|
3433
3606
|
kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
|
3434
|
-
ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
|
3435
|
-
elif k == "kdeplot":
|
3607
|
+
ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
|
3608
|
+
elif k == "kdeplot":
|
3436
3609
|
kws_kde = kwargs.pop("kws_kde", kwargs)
|
3437
|
-
kws_kde = {
|
3438
|
-
|
3610
|
+
kws_kde = {
|
3611
|
+
k: v for k, v in kws_kde.items() if not k.startswith("kws_")
|
3612
|
+
}
|
3613
|
+
hue = kwargs.get("hue", None)
|
3614
|
+
if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
|
3615
|
+
kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
|
3616
|
+
|
3617
|
+
palette = kwargs.get("palette", None)
|
3618
|
+
if palette is None:
|
3619
|
+
palette = kws_kde.pop(
|
3620
|
+
"palette", get_color(data[hue].nunique()) if hue is not None else None
|
3621
|
+
)
|
3622
|
+
alpha = kws_kde.pop("alpha", 0.05)
|
3623
|
+
for key in ["palette", "alpha", "hue"]:
|
3624
|
+
kws_kde.pop(key, None)
|
3625
|
+
ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
|
3439
3626
|
elif k == "ecdfplot":
|
3440
3627
|
kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
|
3441
3628
|
kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
|
3442
|
-
ax = sns.ecdfplot(data=data, x=x,
|
3629
|
+
ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
|
3443
3630
|
elif k == "rugplot":
|
3444
3631
|
kws_rug = kwargs.pop("kws_rug", kwargs)
|
3445
3632
|
kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
|
3446
|
-
ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
|
3633
|
+
ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
|
3447
3634
|
elif k == "stripplot":
|
3448
3635
|
kws_strip = kwargs.pop("kws_strip", kwargs)
|
3449
3636
|
kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
|
@@ -3516,7 +3703,8 @@ def plotxy(
|
|
3516
3703
|
figsets(ax=ax, **kws_figsets)
|
3517
3704
|
if kws_add_text:
|
3518
3705
|
add_text(ax=ax, **kws_add_text) if kws_add_text else None
|
3519
|
-
|
3706
|
+
if run_once_within(10):
|
3707
|
+
for k in kind_:
|
3520
3708
|
print(f"\n{k}⤵ ")
|
3521
3709
|
print(default_settings[k])
|
3522
3710
|
# print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
|
@@ -3574,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3574
3762
|
"lineplot", # Can work with both wide and long formats
|
3575
3763
|
"area plot", # Can work with both formats, useful for stacked areas
|
3576
3764
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3765
|
+
"ellipse",# ellipse plot, default confidence=0.95
|
3577
3766
|
],
|
3578
3767
|
)[0]
|
3579
3768
|
|
@@ -3609,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
|
|
3609
3798
|
"violinplot", # Can work with both formats depending on categorical vs continuous data
|
3610
3799
|
"relplot",
|
3611
3800
|
"pointplot", # Works well with wide format
|
3801
|
+
"ellipse",
|
3612
3802
|
]
|
3613
3803
|
|
3614
3804
|
# Wide format (e.g., for heatmap and pairplot)
|
@@ -4100,25 +4290,45 @@ def venn(
|
|
4100
4290
|
"""
|
4101
4291
|
if ax is None:
|
4102
4292
|
ax = plt.gca()
|
4293
|
+
if isinstance(lists, dict):
|
4294
|
+
labels = list(lists.keys())
|
4295
|
+
lists = list(lists.values())
|
4296
|
+
if isinstance(lists[0], set):
|
4297
|
+
lists = [list(i) for i in lists]
|
4103
4298
|
lists = [set(flatten(i, verbose=False)) for i in lists]
|
4104
4299
|
# Function to apply text styles to labels
|
4105
4300
|
if colors is None:
|
4106
4301
|
colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
|
4107
4302
|
if labels is None:
|
4108
|
-
|
4303
|
+
if len(lists) == 2:
|
4304
|
+
labels = ["set1", "set2"]
|
4305
|
+
elif len(lists) == 3:
|
4306
|
+
labels = ["set1", "set2", "set3"]
|
4307
|
+
elif len(lists) == 4:
|
4308
|
+
labels = ["set1", "set2", "set3","set4"]
|
4309
|
+
elif len(lists) == 5:
|
4310
|
+
labels = ["set1", "set2", "set3","set4","set55"]
|
4311
|
+
elif len(lists) == 6:
|
4312
|
+
labels = ["set1", "set2", "set3","set4","set5","set6"]
|
4313
|
+
elif len(lists) == 7:
|
4314
|
+
labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
|
4109
4315
|
if edgecolor is None:
|
4110
4316
|
edgecolor = colors
|
4111
4317
|
colors = [desaturate_color(color, saturation) for color in colors]
|
4112
|
-
|
4113
|
-
if len(lists) == 2:
|
4318
|
+
universe = len(set.union(*lists))
|
4114
4319
|
|
4115
|
-
|
4116
|
-
|
4117
|
-
|
4118
|
-
|
4119
|
-
|
4120
|
-
|
4121
|
-
|
4320
|
+
# Check colors and auto-calculate overlaps
|
4321
|
+
def get_count_and_percentage(set_count, subset_count):
|
4322
|
+
percent = subset_count / set_count if set_count > 0 else 0
|
4323
|
+
return (
|
4324
|
+
f"{subset_count}\n({fmt.format(percent)})"
|
4325
|
+
if show_percentages
|
4326
|
+
else f"{subset_count}"
|
4327
|
+
)
|
4328
|
+
if fmt is not None:
|
4329
|
+
if not fmt.startswith("{"):
|
4330
|
+
fmt="{:" + fmt + "}"
|
4331
|
+
if len(lists) == 2:
|
4122
4332
|
|
4123
4333
|
from matplotlib_venn import venn2, venn2_circles
|
4124
4334
|
|
@@ -4131,21 +4341,28 @@ def venn(
|
|
4131
4341
|
set1, set2 = lists[0], lists[1]
|
4132
4342
|
v.get_patch_by_id("10").set_color(colors[0])
|
4133
4343
|
v.get_patch_by_id("01").set_color(colors[1])
|
4134
|
-
|
4135
|
-
|
4136
|
-
|
4344
|
+
try:
|
4345
|
+
v.get_patch_by_id("11").set_color(
|
4346
|
+
get_color_overlap(colors[0], colors[1]) if colors else None
|
4347
|
+
)
|
4348
|
+
except Exception as e:
|
4349
|
+
print(e)
|
4137
4350
|
# v.get_label_by_id('10').set_text(len(set1 - set2))
|
4138
4351
|
# v.get_label_by_id('01').set_text(len(set2 - set1))
|
4139
4352
|
# v.get_label_by_id('11').set_text(len(set1 & set2))
|
4353
|
+
|
4140
4354
|
v.get_label_by_id("10").set_text(
|
4141
|
-
get_count_and_percentage(
|
4355
|
+
get_count_and_percentage(universe, len(set1 - set2))
|
4142
4356
|
)
|
4143
4357
|
v.get_label_by_id("01").set_text(
|
4144
|
-
get_count_and_percentage(
|
4145
|
-
)
|
4146
|
-
v.get_label_by_id("11").set_text(
|
4147
|
-
get_count_and_percentage(len(set1 | set2), len(set1 & set2))
|
4358
|
+
get_count_and_percentage(universe, len(set2 - set1))
|
4148
4359
|
)
|
4360
|
+
try:
|
4361
|
+
v.get_label_by_id("11").set_text(
|
4362
|
+
get_count_and_percentage(universe, len(set1 & set2))
|
4363
|
+
)
|
4364
|
+
except Exception as e:
|
4365
|
+
print(e)
|
4149
4366
|
|
4150
4367
|
if not isinstance(linewidth, list):
|
4151
4368
|
linewidth = [linewidth]
|
@@ -4228,16 +4445,14 @@ def venn(
|
|
4228
4445
|
va=va,
|
4229
4446
|
shadow=shadow,
|
4230
4447
|
)
|
4231
|
-
|
4232
|
-
|
4233
|
-
|
4234
|
-
|
4235
|
-
|
4236
|
-
|
4237
|
-
|
4238
|
-
|
4239
|
-
else f"{subset_count}"
|
4240
|
-
)
|
4448
|
+
# Set transparency level
|
4449
|
+
for patch in v.patches:
|
4450
|
+
if patch:
|
4451
|
+
patch.set_alpha(alpha)
|
4452
|
+
if "none" in edgecolor or 0 in linewidth:
|
4453
|
+
patch.set_edgecolor("none")
|
4454
|
+
return ax
|
4455
|
+
elif len(lists) == 3:
|
4241
4456
|
|
4242
4457
|
from matplotlib_venn import venn3, venn3_circles
|
4243
4458
|
|
@@ -4253,36 +4468,34 @@ def venn(
|
|
4253
4468
|
# Draw the venn diagram
|
4254
4469
|
v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
|
4255
4470
|
v.get_patch_by_id("100").set_color(colors[0])
|
4471
|
+
v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
|
4256
4472
|
v.get_patch_by_id("010").set_color(colors[1])
|
4257
|
-
v.
|
4258
|
-
|
4259
|
-
|
4260
|
-
|
4261
|
-
|
4262
|
-
|
4263
|
-
|
4264
|
-
|
4265
|
-
|
4266
|
-
|
4267
|
-
|
4268
|
-
|
4269
|
-
|
4270
|
-
|
4271
|
-
|
4272
|
-
|
4273
|
-
|
4274
|
-
|
4275
|
-
|
4276
|
-
|
4277
|
-
|
4278
|
-
|
4279
|
-
|
4280
|
-
|
4281
|
-
|
4282
|
-
|
4283
|
-
v.get_label_by_id("111").set_text(
|
4284
|
-
get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
|
4285
|
-
)
|
4473
|
+
v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
|
4474
|
+
try:
|
4475
|
+
v.get_patch_by_id("001").set_color(colors[2])
|
4476
|
+
v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
|
4477
|
+
except Exception as e:
|
4478
|
+
print(e)
|
4479
|
+
try:
|
4480
|
+
v.get_patch_by_id("110").set_color(colorAB)
|
4481
|
+
v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
|
4482
|
+
except Exception as e:
|
4483
|
+
print(e)
|
4484
|
+
try:
|
4485
|
+
v.get_patch_by_id("101").set_color(colorAC)
|
4486
|
+
v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
|
4487
|
+
except Exception as e:
|
4488
|
+
print(e)
|
4489
|
+
try:
|
4490
|
+
v.get_patch_by_id("011").set_color(colorBC)
|
4491
|
+
v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
|
4492
|
+
except Exception as e:
|
4493
|
+
print(e)
|
4494
|
+
try:
|
4495
|
+
v.get_patch_by_id("111").set_color(colorABC)
|
4496
|
+
v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
|
4497
|
+
except Exception as e:
|
4498
|
+
print(e)
|
4286
4499
|
|
4287
4500
|
# Apply styles to set labels
|
4288
4501
|
for i, text in enumerate(v.set_labels):
|
@@ -4387,16 +4600,34 @@ def venn(
|
|
4387
4600
|
ax.add_patch(ellipse1)
|
4388
4601
|
ax.add_patch(ellipse2)
|
4389
4602
|
ax.add_patch(ellipse3)
|
4603
|
+
# Set transparency level
|
4604
|
+
for patch in v.patches:
|
4605
|
+
if patch:
|
4606
|
+
patch.set_alpha(alpha)
|
4607
|
+
if "none" in edgecolor or 0 in linewidth:
|
4608
|
+
patch.set_edgecolor("none")
|
4609
|
+
return ax
|
4610
|
+
|
4611
|
+
|
4612
|
+
dict_data = {}
|
4613
|
+
for i_list, list_ in enumerate(lists):
|
4614
|
+
dict_data[labels[i_list]]={*list_}
|
4615
|
+
|
4616
|
+
if 3<len(lists)<6:
|
4617
|
+
from venn import venn as vn
|
4618
|
+
|
4619
|
+
legend_loc=kwargs.pop("legend_loc", "upper right")
|
4620
|
+
ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
|
4621
|
+
|
4622
|
+
return ax
|
4390
4623
|
else:
|
4391
|
-
|
4392
|
-
|
4393
|
-
|
4394
|
-
|
4395
|
-
|
4396
|
-
|
4397
|
-
|
4398
|
-
patch.set_edgecolor("none")
|
4399
|
-
return ax
|
4624
|
+
from venn import pseudovenn
|
4625
|
+
cmap=kwargs.pop("cmap","plasma")
|
4626
|
+
ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
|
4627
|
+
|
4628
|
+
return ax
|
4629
|
+
|
4630
|
+
|
4400
4631
|
|
4401
4632
|
|
4402
4633
|
#! subplots, support automatic extend new axis
|
@@ -4407,6 +4638,7 @@ def subplot(
|
|
4407
4638
|
sharex=False,
|
4408
4639
|
sharey=False,
|
4409
4640
|
verbose=False,
|
4641
|
+
fig=None,
|
4410
4642
|
**kwargs,
|
4411
4643
|
):
|
4412
4644
|
"""
|
@@ -4435,8 +4667,8 @@ def subplot(
|
|
4435
4667
|
)
|
4436
4668
|
|
4437
4669
|
figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
|
4438
|
-
|
4439
|
-
|
4670
|
+
if fig is None:
|
4671
|
+
fig = plt.figure(figsize=figsize, constrained_layout=True)
|
4440
4672
|
grid_spec = GridSpec(rows, cols, figure=fig)
|
4441
4673
|
occupied = set()
|
4442
4674
|
row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
|
@@ -4502,7 +4734,7 @@ def radar(
|
|
4502
4734
|
data: pd.DataFrame,
|
4503
4735
|
columns=None,
|
4504
4736
|
ylim=(0, 100),
|
4505
|
-
facecolor=
|
4737
|
+
facecolor=None,
|
4506
4738
|
edgecolor="none",
|
4507
4739
|
edge_linewidth=0.5,
|
4508
4740
|
fontsize=10,
|
@@ -4701,7 +4933,7 @@ def radar(
|
|
4701
4933
|
else:
|
4702
4934
|
# * spider style: spider-style grid (straight lines, not circles)
|
4703
4935
|
# Create the spider-style grid (straight lines, not circles)
|
4704
|
-
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))):#int(vmax * grid_interval_ratio) + 1):
|
4936
|
+
for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
|
4705
4937
|
ax.plot(
|
4706
4938
|
angles + [angles[0]], # Closing the loop
|
4707
4939
|
[i * vmax * grid_interval_ratio] * (num_vars + 1)
|
@@ -4740,9 +4972,9 @@ def radar(
|
|
4740
4972
|
colors = facecolor
|
4741
4973
|
else:
|
4742
4974
|
colors = (
|
4743
|
-
get_color(data.shape[
|
4975
|
+
get_color(data.shape[1])
|
4744
4976
|
if cmap is None
|
4745
|
-
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[
|
4977
|
+
else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
|
4746
4978
|
)
|
4747
4979
|
|
4748
4980
|
# Plot each row with straight lines
|
@@ -4840,7 +5072,8 @@ def pie(
|
|
4840
5072
|
pctdistance=0.85,
|
4841
5073
|
labeldistance=1.1,
|
4842
5074
|
kws_wedge={},
|
4843
|
-
kws_text={},
|
5075
|
+
kws_text={},
|
5076
|
+
kws_arrow={},
|
4844
5077
|
center=(0, 0),
|
4845
5078
|
radius=1,
|
4846
5079
|
frame=False,
|
@@ -4850,6 +5083,8 @@ def pie(
|
|
4850
5083
|
cmap=None,
|
4851
5084
|
show_value=False,
|
4852
5085
|
show_label=True,# False: only show the outer layer, if it is None, not show
|
5086
|
+
expand_label=(1.2,1.2),
|
5087
|
+
kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
|
4853
5088
|
show_legend=True,
|
4854
5089
|
legend_loc="upper right",
|
4855
5090
|
bbox_to_anchor=[1.4, 1.1],
|
@@ -4859,7 +5094,7 @@ def pie(
|
|
4859
5094
|
ax=None,
|
4860
5095
|
**kwargs
|
4861
5096
|
):
|
4862
|
-
|
5097
|
+
from adjustText import adjust_text
|
4863
5098
|
if run_once_within(20,reverse=True) and verbose:
|
4864
5099
|
usage_="""usage:
|
4865
5100
|
pie(
|
@@ -4974,6 +5209,11 @@ def pie(
|
|
4974
5209
|
# 选择部分数据
|
4975
5210
|
df=data[columns]
|
4976
5211
|
|
5212
|
+
if not isinstance(explode, list):
|
5213
|
+
explode=[explode]
|
5214
|
+
if explode == [None]:
|
5215
|
+
explode=[0]
|
5216
|
+
|
4977
5217
|
if width is None:
|
4978
5218
|
if df.shape[1]>1:
|
4979
5219
|
width=1/(df.shape[1]+2)
|
@@ -5022,10 +5262,9 @@ def pie(
|
|
5022
5262
|
|
5023
5263
|
if ax is None:
|
5024
5264
|
ax=plt.gca()
|
5025
|
-
if explode
|
5026
|
-
|
5027
|
-
|
5028
|
-
|
5265
|
+
if len(explode)<len(labels_legend):
|
5266
|
+
explode.extend([0]*(len(labels_legend)-len(explode)))
|
5267
|
+
print(explode)
|
5029
5268
|
if fmt:
|
5030
5269
|
if not fmt.startswith("%"):
|
5031
5270
|
autopct =f"%{fmt}%%"
|
@@ -5073,19 +5312,37 @@ def pie(
|
|
5073
5312
|
elif len(result) == 2:
|
5074
5313
|
wedges, texts = result
|
5075
5314
|
autotexts = None
|
5076
|
-
|
5077
|
-
|
5078
|
-
|
5079
|
-
|
5080
|
-
|
5081
|
-
|
5082
|
-
|
5083
|
-
|
5084
|
-
|
5085
|
-
|
5086
|
-
|
5087
|
-
|
5088
|
-
|
5315
|
+
#! adjust_text
|
5316
|
+
if autotexts or texts:
|
5317
|
+
all_texts = []
|
5318
|
+
if autotexts and show_value:
|
5319
|
+
all_texts.extend(autotexts)
|
5320
|
+
if texts and show_label:
|
5321
|
+
all_texts.extend(texts)
|
5322
|
+
|
5323
|
+
adjust_text(
|
5324
|
+
all_texts,
|
5325
|
+
ax=ax,
|
5326
|
+
arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
|
5327
|
+
bbox=kws_bbox if kws_bbox else None,
|
5328
|
+
expand=expand_label,
|
5329
|
+
fontdict={
|
5330
|
+
"fontsize": fontsize,
|
5331
|
+
"color": fontcolor,
|
5332
|
+
},
|
5333
|
+
)
|
5334
|
+
# Show exact values on wedges if show_value is True
|
5335
|
+
if show_value:
|
5336
|
+
for i, (wedge, txt) in enumerate(zip(wedges, texts)):
|
5337
|
+
angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
|
5338
|
+
x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
|
5339
|
+
y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
|
5340
|
+
if not fmt.startswith("{"):
|
5341
|
+
value_text = f"{sizes[i]:{fmt}}"
|
5342
|
+
else:
|
5343
|
+
value_text = fmt.format(sizes[i])
|
5344
|
+
ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
|
5345
|
+
inested+=1
|
5089
5346
|
# Customize the legend
|
5090
5347
|
if show_legend:
|
5091
5348
|
ax.legend(
|
@@ -5098,3 +5355,521 @@ def pie(
|
|
5098
5355
|
)
|
5099
5356
|
ax.set(aspect="equal")
|
5100
5357
|
return ax
|
5358
|
+
|
5359
|
+
def ellipse(
|
5360
|
+
data,
|
5361
|
+
x=None,
|
5362
|
+
y=None,
|
5363
|
+
hue=None,
|
5364
|
+
n_std=1.5,
|
5365
|
+
ax=None,
|
5366
|
+
confidence=0.95,
|
5367
|
+
annotate_center=False,
|
5368
|
+
palette=None,
|
5369
|
+
facecolor=None,
|
5370
|
+
edgecolor=None,
|
5371
|
+
label:bool=True,
|
5372
|
+
**kwargs,
|
5373
|
+
):
|
5374
|
+
"""
|
5375
|
+
Plot advanced ellipses representing covariance for different groups
|
5376
|
+
# simulate data:
|
5377
|
+
control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
|
5378
|
+
patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
|
5379
|
+
df = pd.DataFrame(
|
5380
|
+
{
|
5381
|
+
"Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
|
5382
|
+
"Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
|
5383
|
+
"Group": ["Control"] * 50 + ["Patient"] * 50,
|
5384
|
+
}
|
5385
|
+
)
|
5386
|
+
plotxy(
|
5387
|
+
data=df,
|
5388
|
+
x="Dim1",
|
5389
|
+
y="Dim2",
|
5390
|
+
hue="Group",
|
5391
|
+
kind_="scatter",
|
5392
|
+
palette=get_color(8),
|
5393
|
+
)
|
5394
|
+
ellipse(
|
5395
|
+
data=df,
|
5396
|
+
x="Dim1",
|
5397
|
+
y="Dim2",
|
5398
|
+
hue="Group",
|
5399
|
+
palette=get_color(8),
|
5400
|
+
alpha=0.1,
|
5401
|
+
lw=2,
|
5402
|
+
)
|
5403
|
+
Parameters:
|
5404
|
+
data (DataFrame): Input DataFrame with columns for x, y, and hue.
|
5405
|
+
x (str): Column name for x-axis values.
|
5406
|
+
y (str): Column name for y-axis values.
|
5407
|
+
hue (str, optional): Column name for group labels.
|
5408
|
+
n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
|
5409
|
+
ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
|
5410
|
+
confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
|
5411
|
+
annotate_center (bool): Whether to annotate the ellipse center (mean).
|
5412
|
+
palette (dict or list, optional): A mapping of hues to colors or a list of colors.
|
5413
|
+
**kwargs: Additional keyword arguments for the Ellipse patch.
|
5414
|
+
|
5415
|
+
Returns:
|
5416
|
+
list: List of Ellipse objects added to the Axes.
|
5417
|
+
"""
|
5418
|
+
from matplotlib.patches import Ellipse
|
5419
|
+
import numpy as np
|
5420
|
+
import matplotlib.pyplot as plt
|
5421
|
+
import seaborn as sns
|
5422
|
+
import pandas as pd
|
5423
|
+
from scipy.stats import chi2
|
5424
|
+
|
5425
|
+
if ax is None:
|
5426
|
+
ax = plt.gca()
|
5427
|
+
|
5428
|
+
# Validate inputs
|
5429
|
+
if x is None or y is None:
|
5430
|
+
raise ValueError(
|
5431
|
+
"Both `x` and `y` must be specified as column names in the DataFrame."
|
5432
|
+
)
|
5433
|
+
if not isinstance(data, pd.DataFrame):
|
5434
|
+
raise ValueError("`data` must be a pandas DataFrame.")
|
5435
|
+
|
5436
|
+
# Prepare data for hue-based grouping
|
5437
|
+
ellipses = []
|
5438
|
+
if hue is not None:
|
5439
|
+
groups = data[hue].unique()
|
5440
|
+
colors = sns.color_palette(palette or "husl", len(groups))
|
5441
|
+
color_map = dict(zip(groups, colors))
|
5442
|
+
else:
|
5443
|
+
groups = [None]
|
5444
|
+
color_map = {None: kwargs.get("edgecolor", "blue")}
|
5445
|
+
alpha = kwargs.pop("alpha", 0.2)
|
5446
|
+
edgecolor=kwargs.pop("edgecolor", None)
|
5447
|
+
facecolor=kwargs.pop("facecolor", None)
|
5448
|
+
for group in groups:
|
5449
|
+
group_data = data[data[hue] == group] if hue else data
|
5450
|
+
|
5451
|
+
# Extract x and y columns for the group
|
5452
|
+
group_points = group_data[[x, y]].values
|
5453
|
+
|
5454
|
+
# Compute mean and covariance matrix
|
5455
|
+
# # 标准化处理
|
5456
|
+
# group_points = group_data[[x, y]].values
|
5457
|
+
# group_points -= group_points.mean(axis=0)
|
5458
|
+
# group_points /= group_points.std(axis=0)
|
5459
|
+
|
5460
|
+
cov = np.cov(group_points.T)
|
5461
|
+
mean = np.mean(group_points, axis=0)
|
5462
|
+
|
5463
|
+
# Eigenvalues and eigenvectors
|
5464
|
+
eigvals, eigvecs = np.linalg.eigh(cov)
|
5465
|
+
order = eigvals.argsort()[::-1]
|
5466
|
+
eigvals, eigvecs = eigvals[order], eigvecs[:, order]
|
5467
|
+
|
5468
|
+
# Rotation angle and ellipse dimensions
|
5469
|
+
angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
|
5470
|
+
if confidence:
|
5471
|
+
n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
|
5472
|
+
width, height = 2 * n_std * np.sqrt(eigvals)
|
5473
|
+
|
5474
|
+
# Create and style the ellipse
|
5475
|
+
if facecolor is None:
|
5476
|
+
facecolor_ = color_map[group]
|
5477
|
+
if edgecolor is None:
|
5478
|
+
edgecolor_ = color_map[group]
|
5479
|
+
ellipse = Ellipse(
|
5480
|
+
xy=mean,
|
5481
|
+
width=width,
|
5482
|
+
height=height,
|
5483
|
+
angle=angle,
|
5484
|
+
edgecolor=edgecolor_,
|
5485
|
+
facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
|
5486
|
+
# alpha=alpha,
|
5487
|
+
label=group if (hue and label) else None,
|
5488
|
+
**kwargs,
|
5489
|
+
)
|
5490
|
+
ax.add_patch(ellipse)
|
5491
|
+
ellipses.append(ellipse)
|
5492
|
+
|
5493
|
+
# Annotate center
|
5494
|
+
if annotate_center:
|
5495
|
+
ax.annotate(
|
5496
|
+
f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
|
5497
|
+
xy=mean,
|
5498
|
+
xycoords="data",
|
5499
|
+
fontsize=10,
|
5500
|
+
ha="center",
|
5501
|
+
color=ellipse_color,
|
5502
|
+
bbox=dict(
|
5503
|
+
boxstyle="round,pad=0.3",
|
5504
|
+
edgecolor="gray",
|
5505
|
+
facecolor="white",
|
5506
|
+
alpha=0.8,
|
5507
|
+
),
|
5508
|
+
)
|
5509
|
+
|
5510
|
+
return ax
|
5511
|
+
|
5512
|
+
def ppi(
|
5513
|
+
interactions,
|
5514
|
+
player1="preferredName_A",
|
5515
|
+
player2="preferredName_B",
|
5516
|
+
weight="score",
|
5517
|
+
n_layers=None, # Number of concentric layers
|
5518
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5519
|
+
dist_node = 10, # Distance between each rank of circles
|
5520
|
+
layout="degree",
|
5521
|
+
size=None,#700,
|
5522
|
+
sizes=(50,500),# min and max of size
|
5523
|
+
facecolor="skyblue",
|
5524
|
+
cmap='coolwarm',
|
5525
|
+
edgecolor="k",
|
5526
|
+
edgelinewidth=1.5,
|
5527
|
+
alpha=.5,
|
5528
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5529
|
+
marker="o",
|
5530
|
+
node_hideticks=True,
|
5531
|
+
linecolor="gray",
|
5532
|
+
line_cmap='coolwarm',
|
5533
|
+
linewidth=1.5,
|
5534
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5535
|
+
linealpha=1.0,
|
5536
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5537
|
+
linestyle="-",
|
5538
|
+
line_arrowstyle='-',
|
5539
|
+
fontsize=10,
|
5540
|
+
fontcolor="k",
|
5541
|
+
ha:str="center",
|
5542
|
+
va:str="center",
|
5543
|
+
figsize=(12, 10),
|
5544
|
+
k_value=0.3,
|
5545
|
+
bgcolor="w",
|
5546
|
+
dir_save="./ppi_network.html",
|
5547
|
+
physics=True,
|
5548
|
+
notebook=False,
|
5549
|
+
scale=1,
|
5550
|
+
ax=None,
|
5551
|
+
**kwargs
|
5552
|
+
):
|
5553
|
+
"""
|
5554
|
+
Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
|
5555
|
+
|
5556
|
+
ppi(
|
5557
|
+
interactions_sort.iloc[:1000, :],
|
5558
|
+
player1="player1",
|
5559
|
+
player2="player2",
|
5560
|
+
weight="count",
|
5561
|
+
layout="spring",
|
5562
|
+
n_layers=13,
|
5563
|
+
fontsize=1,
|
5564
|
+
n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
|
5565
|
+
)
|
5566
|
+
"""
|
5567
|
+
from pyvis.network import Network
|
5568
|
+
import networkx as nx
|
5569
|
+
from IPython.display import IFrame
|
5570
|
+
from matplotlib.colors import Normalize
|
5571
|
+
from matplotlib import cm
|
5572
|
+
from . import ips
|
5573
|
+
|
5574
|
+
if run_once_within():
|
5575
|
+
usage_str="""
|
5576
|
+
ppi(
|
5577
|
+
interactions,
|
5578
|
+
player1="preferredName_A",
|
5579
|
+
player2="preferredName_B",
|
5580
|
+
weight="score",
|
5581
|
+
n_layers=None, # Number of concentric layers
|
5582
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
5583
|
+
dist_node = 10, # Distance between each rank of circles
|
5584
|
+
layout="degree",
|
5585
|
+
size=None,#700,
|
5586
|
+
sizes=(50,500),# min and max of size
|
5587
|
+
facecolor="skyblue",
|
5588
|
+
cmap='coolwarm',
|
5589
|
+
edgecolor="k",
|
5590
|
+
edgelinewidth=1.5,
|
5591
|
+
alpha=.5,
|
5592
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
5593
|
+
marker="o",
|
5594
|
+
node_hideticks=True,
|
5595
|
+
linecolor="gray",
|
5596
|
+
line_cmap='coolwarm',
|
5597
|
+
linewidth=1.5,
|
5598
|
+
linewidths=(0.5,5),# min and max of linewidth
|
5599
|
+
linealpha=1.0,
|
5600
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
5601
|
+
linestyle="-",
|
5602
|
+
line_arrowstyle='-',
|
5603
|
+
fontsize=10,
|
5604
|
+
fontcolor="k",
|
5605
|
+
ha:str="center",
|
5606
|
+
va:str="center",
|
5607
|
+
figsize=(12, 10),
|
5608
|
+
k_value=0.3,
|
5609
|
+
bgcolor="w",
|
5610
|
+
dir_save="./ppi_network.html",
|
5611
|
+
physics=True,
|
5612
|
+
notebook=False,
|
5613
|
+
scale=1,
|
5614
|
+
ax=None,
|
5615
|
+
**kwargs
|
5616
|
+
):
|
5617
|
+
"""
|
5618
|
+
print(usage_str)
|
5619
|
+
|
5620
|
+
# Check for required columns in the DataFrame
|
5621
|
+
for col in [player1, player2, weight]:
|
5622
|
+
if col not in interactions.columns:
|
5623
|
+
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
5624
|
+
interactions.sort_values(by=[weight], inplace=True)
|
5625
|
+
# Initialize Pyvis network
|
5626
|
+
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
5627
|
+
net.force_atlas_2based(
|
5628
|
+
gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
|
5629
|
+
)
|
5630
|
+
net.toggle_physics(physics)
|
5631
|
+
|
5632
|
+
kws_figsets = {}
|
5633
|
+
for k_arg, v_arg in kwargs.items():
|
5634
|
+
if "figset" in k_arg:
|
5635
|
+
kws_figsets = v_arg
|
5636
|
+
kwargs.pop(k_arg, None)
|
5637
|
+
break
|
5638
|
+
|
5639
|
+
# Create a NetworkX graph from the interaction data
|
5640
|
+
G = nx.Graph()
|
5641
|
+
for _, row in interactions.iterrows():
|
5642
|
+
G.add_edge(row[player1], row[player2], weight=row[weight])
|
5643
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
5644
|
+
|
5645
|
+
|
5646
|
+
# Calculate node degrees
|
5647
|
+
degrees = dict(G.degree())
|
5648
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5649
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
5650
|
+
|
5651
|
+
if not ips.isa(facecolor, 'color'):
|
5652
|
+
print("facecolor: based on degrees")
|
5653
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
5654
|
+
num_nodes = G.number_of_nodes()
|
5655
|
+
#* size
|
5656
|
+
# Set properties based on degrees
|
5657
|
+
if not isinstance(size, (int,float,list)):
|
5658
|
+
print("size: based on degrees")
|
5659
|
+
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
5660
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
5661
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
5662
|
+
# Normalize sizes
|
5663
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
5664
|
+
min_degree, max_degree = min(size), max(size)
|
5665
|
+
if max_degree > min_degree: # Avoid division by zero
|
5666
|
+
size = [
|
5667
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
5668
|
+
for sz in size
|
5669
|
+
]
|
5670
|
+
else:
|
5671
|
+
# If all values are the same, set them to a default of the midpoint
|
5672
|
+
size = [(min_size + max_size) / 2] * len(size)
|
5673
|
+
|
5674
|
+
#* facecolor
|
5675
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
5676
|
+
# * facealpha
|
5677
|
+
if isinstance(alpha, list):
|
5678
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
5679
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
5680
|
+
if len(alpha) > 0:
|
5681
|
+
# Normalize alpha based on the specified min and max
|
5682
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
5683
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
5684
|
+
alpha = [
|
5685
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
5686
|
+
for ea in alpha
|
5687
|
+
]
|
5688
|
+
else:
|
5689
|
+
# If all alpha values are the same, set them to the average of min and max
|
5690
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
5691
|
+
else:
|
5692
|
+
# Default to a full opacity if no edges are provided
|
5693
|
+
alpha = [1.0] * num_nodes
|
5694
|
+
else:
|
5695
|
+
# If alpha is a single value, convert it to a list and normalize it
|
5696
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
5697
|
+
|
5698
|
+
for i, node in enumerate(G.nodes()):
|
5699
|
+
net.add_node(
|
5700
|
+
node,
|
5701
|
+
label=node,
|
5702
|
+
size=size[i],
|
5703
|
+
color=facecolor[i],
|
5704
|
+
alpha=alpha[i],
|
5705
|
+
font={"size": fontsize, "color": fontcolor},
|
5706
|
+
)
|
5707
|
+
print(f'nodes number: {i+1}')
|
5708
|
+
|
5709
|
+
for edge in G.edges(data=True):
|
5710
|
+
net.add_edge(
|
5711
|
+
edge[0],
|
5712
|
+
edge[1],
|
5713
|
+
weight=edge[2]["weight"],
|
5714
|
+
color=edgecolor,
|
5715
|
+
width=edgelinewidth * edge[2]["weight"],
|
5716
|
+
)
|
5717
|
+
|
5718
|
+
layouts = [
|
5719
|
+
"spring",
|
5720
|
+
"circular",
|
5721
|
+
"kamada_kawai",
|
5722
|
+
"random",
|
5723
|
+
"shell",
|
5724
|
+
"planar",
|
5725
|
+
"spiral",
|
5726
|
+
"degree"
|
5727
|
+
]
|
5728
|
+
layout = ips.strcmp(layout, layouts)[0]
|
5729
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
5730
|
+
|
5731
|
+
# Choose layout
|
5732
|
+
if layout == "spring":
|
5733
|
+
pos = nx.spring_layout(G, k=k_value)
|
5734
|
+
elif layout == "circular":
|
5735
|
+
pos = nx.circular_layout(G)
|
5736
|
+
elif layout == "kamada_kawai":
|
5737
|
+
pos = nx.kamada_kawai_layout(G)
|
5738
|
+
elif layout == "spectral":
|
5739
|
+
pos = nx.spectral_layout(G)
|
5740
|
+
elif layout == "random":
|
5741
|
+
pos = nx.random_layout(G)
|
5742
|
+
elif layout == "shell":
|
5743
|
+
pos = nx.shell_layout(G)
|
5744
|
+
elif layout == "planar":
|
5745
|
+
if nx.check_planarity(G)[0]:
|
5746
|
+
pos = nx.planar_layout(G)
|
5747
|
+
else:
|
5748
|
+
print("Graph is not planar; switching to spring layout.")
|
5749
|
+
pos = nx.spring_layout(G, k=k_value)
|
5750
|
+
elif layout == "spiral":
|
5751
|
+
pos = nx.spiral_layout(G)
|
5752
|
+
elif layout=='degree':
|
5753
|
+
# Calculate node degrees and sort nodes by degree
|
5754
|
+
degrees = dict(G.degree())
|
5755
|
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
5756
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
5757
|
+
colormap = cm.get_cmap(cmap)
|
5758
|
+
|
5759
|
+
# Create positions for concentric circles based on n_layers and n_rank
|
5760
|
+
pos = {}
|
5761
|
+
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
5762
|
+
for rank_index in range(n_layers):
|
5763
|
+
if rank_index < len(n_rank):
|
5764
|
+
nodes_per_rank = n_rank[rank_index]
|
5765
|
+
rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
|
5766
|
+
else:
|
5767
|
+
# 随机打乱剩余节点的顺序
|
5768
|
+
remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
|
5769
|
+
random_indices = np.random.permutation(len(remaining_nodes))
|
5770
|
+
rank_nodes = [remaining_nodes[i] for i in random_indices]
|
5771
|
+
|
5772
|
+
radius = (rank_index + 1) * dist_node # Radius for this rank
|
5773
|
+
|
5774
|
+
# Arrange nodes in a circle for the current rank
|
5775
|
+
for i, (node, degree) in enumerate(rank_nodes):
|
5776
|
+
angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
|
5777
|
+
pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
|
5778
|
+
|
5779
|
+
else:
|
5780
|
+
print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
|
5781
|
+
pos = nx.spring_layout(G, k=k_value)
|
5782
|
+
|
5783
|
+
for node, (x, y) in pos.items():
|
5784
|
+
net.get_node(node)["x"] = x * scale
|
5785
|
+
net.get_node(node)["y"] = y * scale
|
5786
|
+
|
5787
|
+
# If ax is None, use plt.gca()
|
5788
|
+
if ax is None:
|
5789
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
5790
|
+
|
5791
|
+
# Draw nodes, edges, and labels with customization options
|
5792
|
+
nx.draw_networkx_nodes(
|
5793
|
+
G,
|
5794
|
+
pos,
|
5795
|
+
ax=ax,
|
5796
|
+
node_size=size,
|
5797
|
+
node_color=facecolor,
|
5798
|
+
linewidths=edgelinewidth,
|
5799
|
+
edgecolors=edgecolor,
|
5800
|
+
alpha=alpha,
|
5801
|
+
hide_ticks=node_hideticks,
|
5802
|
+
node_shape=marker
|
5803
|
+
)
|
5804
|
+
|
5805
|
+
#* linewidth
|
5806
|
+
if not isinstance(linewidth, list):
|
5807
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5808
|
+
else:
|
5809
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
5810
|
+
# Normalize linewidth if it is a list
|
5811
|
+
if isinstance(linewidth, list):
|
5812
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
5813
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
5814
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
5815
|
+
# Scale between vmin and vmax
|
5816
|
+
linewidth = [
|
5817
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
5818
|
+
for lw in linewidth
|
5819
|
+
]
|
5820
|
+
else:
|
5821
|
+
# If all values are the same, set them to a default of the midpoint
|
5822
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
5823
|
+
else:
|
5824
|
+
# If linewidth is a single value, convert it to a list of that value
|
5825
|
+
linewidth = [linewidth] * G.number_of_edges()
|
5826
|
+
#* linecolor
|
5827
|
+
if not isinstance(linecolor, str):
|
5828
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
5829
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
5830
|
+
colormap = cm.get_cmap(line_cmap)
|
5831
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
5832
|
+
else:
|
5833
|
+
linecolor = [linecolor] * G.number_of_edges()
|
5834
|
+
|
5835
|
+
# * linealpha
|
5836
|
+
if isinstance(linealpha, list):
|
5837
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
5838
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
5839
|
+
if len(linealpha) > 0:
|
5840
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
5841
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
5842
|
+
linealpha = [
|
5843
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
5844
|
+
for ea in linealpha
|
5845
|
+
]
|
5846
|
+
else:
|
5847
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
5848
|
+
else:
|
5849
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
5850
|
+
else:
|
5851
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
5852
|
+
nx.draw_networkx_edges(
|
5853
|
+
G,
|
5854
|
+
pos,
|
5855
|
+
ax=ax,
|
5856
|
+
edge_color=linecolor,
|
5857
|
+
width=linewidth,
|
5858
|
+
style=linestyle,
|
5859
|
+
arrowstyle=line_arrowstyle,
|
5860
|
+
alpha=linealpha
|
5861
|
+
)
|
5862
|
+
|
5863
|
+
nx.draw_networkx_labels(
|
5864
|
+
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
5865
|
+
)
|
5866
|
+
figsets(ax=ax,**kws_figsets)
|
5867
|
+
ax.axis("off")
|
5868
|
+
if dir_save:
|
5869
|
+
if not os.path.basename(dir_save):
|
5870
|
+
dir_save="_.html"
|
5871
|
+
net.write_html(dir_save)
|
5872
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
5873
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
5874
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
5875
|
+
return G,ax
|