py2ls 0.2.4.24__py3-none-any.whl → 0.2.4.26__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py CHANGED
@@ -20,7 +20,9 @@ from .ips import (
20
20
  flatten,
21
21
  plt_font,
22
22
  run_once_within,
23
- df_format
23
+ df_format,
24
+ df_corr,
25
+ df_scaler
24
26
  )
25
27
  import scipy.stats as scipy_stats
26
28
  from .stats import *
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
142
144
  **kwargs,
143
145
  )
144
146
 
147
+ def pval2str(p):
148
+ if p > 0.05:
149
+ txt = ""
150
+ elif 0.01 <= p <= 0.05:
151
+ txt = "*"
152
+ elif 0.001 <= p < 0.01:
153
+ txt = "**"
154
+ elif p < 0.001:
155
+ txt = "***"
156
+ return txt
145
157
 
146
158
  def heatmap(
147
159
  data,
148
160
  ax=None,
149
161
  kind="corr", #'corr','direct','pivot'
162
+ method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
150
163
  columns="all", # pivot, default: coll numeric columns
164
+ style=0,# for correlation
151
165
  index=None, # pivot
152
166
  values=None, # pivot
167
+ fontsize=10,
153
168
  tri="u",
154
169
  mask=True,
155
170
  k=1,
171
+ vmin=None,
172
+ vmax=None,
173
+ size_scale=500,
156
174
  annot=True,
157
175
  cmap="coolwarm",
158
176
  fmt=".2f",
177
+ show_indicator = True,# only for style==1
159
178
  cluster=False,
160
179
  inplace=False,
161
180
  figsize=(10, 8),
@@ -201,7 +220,7 @@ def heatmap(
201
220
  ax = plt.gca()
202
221
  # Select numeric columns or specific subset of columns
203
222
  if columns == "all":
204
- df_numeric = data.select_dtypes(include=[float, int])
223
+ df_numeric = data.select_dtypes(include=[np.number])
205
224
  else:
206
225
  df_numeric = data[columns]
207
226
 
@@ -209,8 +228,10 @@ def heatmap(
209
228
  kind = strcmp(kind, kinds)[0]
210
229
  print(kind)
211
230
  if "corr" in kind: # correlation
231
+ methods = ["pearson", "spearman", "kendall"]
232
+ method = strcmp(method, methods)[0]
212
233
  # Compute the correlation matrix
213
- data4heatmap = df_numeric.corr()
234
+ data4heatmap = df_numeric.corr(method=method)
214
235
  # Generate mask for the upper triangle if mask is True
215
236
  if mask:
216
237
  if "u" in tri.lower(): # upper => np.tril
@@ -288,18 +309,76 @@ def heatmap(
288
309
  df_col_cluster,
289
310
  )
290
311
  else:
291
- # Create a standard heatmap
292
- ax = sns.heatmap(
293
- data4heatmap,
294
- ax=ax,
295
- mask=mask_array,
296
- annot=annot,
297
- cmap=cmap,
298
- fmt=fmt,
299
- **kwargs, # Pass any additional arguments to sns.heatmap
300
- )
301
- # Return the Axes object for further customization if needed
302
- return ax
312
+ if style==0:
313
+ # Create a standard heatmap
314
+ ax = sns.heatmap(
315
+ data4heatmap,
316
+ ax=ax,
317
+ mask=mask_array,
318
+ annot=annot,
319
+ cmap=cmap,
320
+ fmt=fmt,
321
+ **kwargs, # Pass any additional arguments to sns.heatmap
322
+ )
323
+ return ax
324
+ elif style==1:
325
+ if isinstance(cmap, str):
326
+ cmap = plt.get_cmap(cmap)
327
+ norm = plt.Normalize(vmin=-1, vmax=1)
328
+ r_, p_ = df_corr(data4heatmap, method=method)
329
+ # size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
330
+ # 初始化一个空的可绘制对象用于颜色条
331
+ scatter_handles = []
332
+ # 循环绘制气泡图和数值
333
+ for i in range(len(r_.columns)):
334
+ for j in range(len(r_.columns)):
335
+ if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
336
+ color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
337
+ scatter = ax.scatter(
338
+ i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
339
+ # alpha=1,edgecolor=edgecolor,linewidth=linewidth,
340
+ **kwargs
341
+ )
342
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
343
+ # add *** indicators
344
+ if show_indicator:
345
+ ax.text(
346
+ i,
347
+ j,
348
+ pval2str(p_.iloc[i, j]),
349
+ ha="center",
350
+ va="center",
351
+ color="k",
352
+ fontsize=fontsize * 1.3,
353
+ )
354
+ elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
355
+ color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
356
+ ax.text(
357
+ i,
358
+ j,
359
+ f"{r_.iloc[i, j]:{fmt}}",
360
+ ha="center",
361
+ va="center",
362
+ color=color,
363
+ fontsize=fontsize,
364
+ )
365
+ else: # 对角线部分,显示空白
366
+ ax.scatter(i, j, s=1, color="white")
367
+ # 设置坐标轴标签
368
+ figsets(xticks=range(len(r_.columns)),
369
+ xticklabels=r_.columns,
370
+ xangle=90,
371
+ fontsize=fontsize,
372
+ yticks=range(len(r_.columns)),
373
+ yticklabels=r_.columns,
374
+ xlim=[-0.5,len(r_.columns)-0.5],
375
+ ylim=[-0.5,len(r_.columns)-0.5]
376
+ )
377
+ # 添加颜色条
378
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
379
+ sm.set_array([]) # 仅用于显示颜色条
380
+ plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
381
+ return ax
303
382
  elif "dir" in kind: # direct
304
383
  data4heatmap = df_numeric
305
384
  elif "pi" in kind: # pivot
@@ -389,16 +468,53 @@ def heatmap(
389
468
  )
390
469
  else:
391
470
  # Create a standard heatmap
392
- ax = sns.heatmap(
393
- data4heatmap,
394
- ax=ax,
395
- annot=annot,
396
- cmap=cmap,
397
- fmt=fmt,
398
- **kwargs, # Pass any additional arguments to sns.heatmap
399
- )
400
- # Return the Axes object for further customization if needed
401
- return ax
471
+ if style==0:
472
+ ax = sns.heatmap(
473
+ data4heatmap,
474
+ ax=ax,
475
+ annot=annot,
476
+ cmap=cmap,
477
+ fmt=fmt,
478
+ **kwargs, # Pass any additional arguments to sns.heatmap
479
+ )
480
+ # Return the Axes object for further customization if needed
481
+ return ax
482
+ elif style==1:
483
+ if isinstance(cmap, str):
484
+ cmap = plt.get_cmap(cmap)
485
+ if vmin is None:
486
+ vmin=np.min(data4heatmap)
487
+ if vmax is None:
488
+ vmax=np.max(data4heatmap)
489
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
490
+
491
+ # 初始化一个空的可绘制对象用于颜色条
492
+ scatter_handles = []
493
+ # 循环绘制气泡图和数值
494
+ print(len(data4heatmap.index),len(data4heatmap.columns))
495
+ for i in range(len(data4heatmap.index)):
496
+ for j in range(len(data4heatmap.columns)):
497
+ color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
498
+ scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
499
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
500
+
501
+ # 设置坐标轴标签
502
+ figsets(xticks=range(len(data4heatmap.columns)),
503
+ xticklabels=data4heatmap.columns,
504
+ xangle=90,
505
+ fontsize=fontsize,
506
+ yticks=range(len(data4heatmap.index)),
507
+ yticklabels=data4heatmap.index,
508
+ xlim=[-0.5,len(data4heatmap.columns)-0.5],
509
+ ylim=[-0.5,len(data4heatmap.index)-0.5]
510
+ )
511
+ # 添加颜色条
512
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
513
+ sm.set_array([]) # 仅用于显示颜色条
514
+ plt.colorbar(sm, ax=ax,
515
+ # label="Correlation Coefficient"
516
+ )
517
+ return ax
402
518
 
403
519
 
404
520
  # !usage: py2ls.plot.heatmap()
@@ -2475,7 +2591,8 @@ def get_color(
2475
2591
  "#B25E9D",
2476
2592
  "#4B8C3B",
2477
2593
  "#EF8632",
2478
- "#24578E" "#FF2C00",
2594
+ "#24578E",
2595
+ "#FF2C00",
2479
2596
  ]
2480
2597
  elif n == 8:
2481
2598
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
@@ -3199,9 +3316,12 @@ def plotxy(
3199
3316
  zorder = 0
3200
3317
  for k in kind_:
3201
3318
  # preprocess data
3202
- data=df_preprocessing_(data, kind=k)
3203
- if 'variable' in data.columns and 'value' in data.columns:
3204
- x,y='variable','value'
3319
+ try:
3320
+ data=df_preprocessing_(data, kind=k)
3321
+ if 'variable' in data.columns and 'value' in data.columns:
3322
+ x,y='variable','value'
3323
+ except Exception as e:
3324
+ print(e)
3205
3325
  zorder += 1
3206
3326
  # indicate 'col' features
3207
3327
  col = kwargs.get("col", None)
@@ -3222,18 +3342,37 @@ def plotxy(
3222
3342
  # (1) return FcetGrid
3223
3343
  if k == "jointplot":
3224
3344
  kws_joint = kwargs.pop("kws_joint", kwargs)
3225
- stats=kwargs.pop("stats",True)
3345
+ kws_joint = {
3346
+ k: v for k, v in kws_joint.items() if not k.startswith("kws_")
3347
+ }
3348
+ hue = kwargs.get("hue", None)
3349
+ if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
3350
+ kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
3351
+
3352
+ palette = kwargs.get("palette", None)
3353
+ if palette is None:
3354
+ palette = kws_joint.pop(
3355
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3356
+ )
3357
+ else:
3358
+ kws_joint.pop("palette", palette)
3359
+ stats=kwargs.pop("stats",None)
3360
+ if stats:
3361
+ stats=kws_joint.pop("stats",True)
3226
3362
  if stats:
3227
3363
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3228
- g = sns.jointplot(data=data, x=x, y=y, **kws_joint)
3229
- g.ax_joint.annotate(
3230
- f"pearsonr = {r:.2f} p = {p_value:.3f}",
3231
- xy=(0.6, 0.98),
3232
- xycoords="axes fraction",
3233
- fontsize=12,
3234
- color="black",
3235
- ha="center",
3236
- )
3364
+ for key in ["palette", "alpha", "hue","stats"]:
3365
+ kws_joint.pop(key, None)
3366
+ g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
3367
+ if stats:
3368
+ g.ax_joint.annotate(
3369
+ f"pearsonr = {r:.2f} p = {p_value:.3f}",
3370
+ xy=(0.6, 0.98),
3371
+ xycoords="axes fraction",
3372
+ fontsize=12,
3373
+ color="black",
3374
+ ha="center",
3375
+ )
3237
3376
  elif k == "lmplot":
3238
3377
  kws_lm = kwargs.pop("kws_lm", kwargs)
3239
3378
  stats = kwargs.pop("stats", True) # Flag to calculate stats
@@ -3379,8 +3518,7 @@ def plotxy(
3379
3518
  xycoords="axes fraction",
3380
3519
  fontsize=12,
3381
3520
  color="black",
3382
- ha="center",
3383
- )
3521
+ ha="center")
3384
3522
 
3385
3523
  elif k == "catplot_sns":
3386
3524
  kws_cat = kwargs.pop("kws_cat", kwargs)
@@ -3388,7 +3526,7 @@ def plotxy(
3388
3526
  elif k == "displot":
3389
3527
  kws_dis = kwargs.pop("kws_dis", kwargs)
3390
3528
  # displot creates a new figure and returns a FacetGrid
3391
- g = sns.displot(data=data, x=x, **kws_dis)
3529
+ g = sns.displot(data=data, x=x,y=y, **kws_dis)
3392
3530
 
3393
3531
  # (2) return axis
3394
3532
  if ax is None:
@@ -3399,19 +3537,58 @@ def plotxy(
3399
3537
  elif k == "stdshade":
3400
3538
  kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
3401
3539
  ax = stdshade(ax=ax, **kwargs)
3540
+ elif k=="ellipse":
3541
+ kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
3542
+ kws_ellipse = {
3543
+ k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
3544
+ }
3545
+ hue = kwargs.get("hue", None)
3546
+ if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
3547
+ kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
3548
+
3549
+ palette = kwargs.get("palette", None)
3550
+ if palette is None:
3551
+ palette = kws_ellipse.pop(
3552
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3553
+ )
3554
+ alpha = kws_ellipse.pop("alpha", 0.1)
3555
+ hue_order = kwargs.get("hue_order",None)
3556
+ if hue_order is None:
3557
+ hue_order = kws_ellipse.get("hue_order",None)
3558
+ if hue_order:
3559
+ data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
3560
+ data = data.sort_values(by="hue")
3561
+ for key in ["palette", "alpha", "hue","hue_order"]:
3562
+ kws_ellipse.pop(key, None)
3563
+ ax=ellipse(
3564
+ ax=ax,
3565
+ data=data,
3566
+ x=x,
3567
+ y=y,
3568
+ hue=hue,
3569
+ palette=palette,
3570
+ alpha=alpha,
3571
+ zorder=zorder,
3572
+ **kws_ellipse,
3573
+ )
3402
3574
  elif k == "scatterplot":
3403
3575
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3404
3576
  kws_scatter = {
3405
3577
  k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3406
3578
  }
3407
- hue = kwargs.pop("hue", None)
3579
+ hue = kwargs.get("hue", None)
3408
3580
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3409
3581
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3410
- palette = kws_scatter.pop(
3411
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3412
- )
3582
+ palette=kws_scatter.get("palette",None)
3583
+ if palette is None:
3584
+ palette = kws_scatter.pop(
3585
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3586
+ )
3413
3587
  s = kws_scatter.pop("s", 10)
3414
- alpha = kws_scatter.pop("alpha", 0.7)
3588
+ alpha = kws_scatter.pop("alpha", 0.7)
3589
+ for key in ["s", "palette", "alpha", "hue"]:
3590
+ kws_scatter.pop(key, None)
3591
+
3415
3592
  ax = sns.scatterplot(
3416
3593
  ax=ax,
3417
3594
  data=data,
@@ -3427,19 +3604,33 @@ def plotxy(
3427
3604
  elif k == "histplot":
3428
3605
  kws_hist = kwargs.pop("kws_hist", kwargs)
3429
3606
  kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3430
- ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3431
- elif k == "kdeplot":
3607
+ ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
3608
+ elif k == "kdeplot":
3432
3609
  kws_kde = kwargs.pop("kws_kde", kwargs)
3433
- kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3434
- ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3610
+ kws_kde = {
3611
+ k: v for k, v in kws_kde.items() if not k.startswith("kws_")
3612
+ }
3613
+ hue = kwargs.get("hue", None)
3614
+ if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
3615
+ kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
3616
+
3617
+ palette = kwargs.get("palette", None)
3618
+ if palette is None:
3619
+ palette = kws_kde.pop(
3620
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3621
+ )
3622
+ alpha = kws_kde.pop("alpha", 0.05)
3623
+ for key in ["palette", "alpha", "hue"]:
3624
+ kws_kde.pop(key, None)
3625
+ ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
3435
3626
  elif k == "ecdfplot":
3436
3627
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3437
3628
  kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3438
- ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3629
+ ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
3439
3630
  elif k == "rugplot":
3440
3631
  kws_rug = kwargs.pop("kws_rug", kwargs)
3441
3632
  kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3442
- ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3633
+ ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
3443
3634
  elif k == "stripplot":
3444
3635
  kws_strip = kwargs.pop("kws_strip", kwargs)
3445
3636
  kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
@@ -3512,7 +3703,8 @@ def plotxy(
3512
3703
  figsets(ax=ax, **kws_figsets)
3513
3704
  if kws_add_text:
3514
3705
  add_text(ax=ax, **kws_add_text) if kws_add_text else None
3515
- if run_once_within(10):
3706
+ if run_once_within(10):
3707
+ for k in kind_:
3516
3708
  print(f"\n{k}⤵ ")
3517
3709
  print(default_settings[k])
3518
3710
  # print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
@@ -3570,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
3570
3762
  "lineplot", # Can work with both wide and long formats
3571
3763
  "area plot", # Can work with both formats, useful for stacked areas
3572
3764
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3765
+ "ellipse",# ellipse plot, default confidence=0.95
3573
3766
  ],
3574
3767
  )[0]
3575
3768
 
@@ -3605,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
3605
3798
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3606
3799
  "relplot",
3607
3800
  "pointplot", # Works well with wide format
3801
+ "ellipse",
3608
3802
  ]
3609
3803
 
3610
3804
  # Wide format (e.g., for heatmap and pairplot)
@@ -4096,25 +4290,45 @@ def venn(
4096
4290
  """
4097
4291
  if ax is None:
4098
4292
  ax = plt.gca()
4293
+ if isinstance(lists, dict):
4294
+ labels = list(lists.keys())
4295
+ lists = list(lists.values())
4296
+ if isinstance(lists[0], set):
4297
+ lists = [list(i) for i in lists]
4099
4298
  lists = [set(flatten(i, verbose=False)) for i in lists]
4100
4299
  # Function to apply text styles to labels
4101
4300
  if colors is None:
4102
4301
  colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
4103
4302
  if labels is None:
4104
- labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
4303
+ if len(lists) == 2:
4304
+ labels = ["set1", "set2"]
4305
+ elif len(lists) == 3:
4306
+ labels = ["set1", "set2", "set3"]
4307
+ elif len(lists) == 4:
4308
+ labels = ["set1", "set2", "set3","set4"]
4309
+ elif len(lists) == 5:
4310
+ labels = ["set1", "set2", "set3","set4","set55"]
4311
+ elif len(lists) == 6:
4312
+ labels = ["set1", "set2", "set3","set4","set5","set6"]
4313
+ elif len(lists) == 7:
4314
+ labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
4105
4315
  if edgecolor is None:
4106
4316
  edgecolor = colors
4107
4317
  colors = [desaturate_color(color, saturation) for color in colors]
4108
- # Check colors and auto-calculate overlaps
4109
- if len(lists) == 2:
4318
+ universe = len(set.union(*lists))
4110
4319
 
4111
- def get_count_and_percentage(set_count, subset_count):
4112
- percent = subset_count / set_count if set_count > 0 else 0
4113
- return (
4114
- f"{subset_count}\n({fmt.format(percent)})"
4115
- if show_percentages
4116
- else f"{subset_count}"
4117
- )
4320
+ # Check colors and auto-calculate overlaps
4321
+ def get_count_and_percentage(set_count, subset_count):
4322
+ percent = subset_count / set_count if set_count > 0 else 0
4323
+ return (
4324
+ f"{subset_count}\n({fmt.format(percent)})"
4325
+ if show_percentages
4326
+ else f"{subset_count}"
4327
+ )
4328
+ if fmt is not None:
4329
+ if not fmt.startswith("{"):
4330
+ fmt="{:" + fmt + "}"
4331
+ if len(lists) == 2:
4118
4332
 
4119
4333
  from matplotlib_venn import venn2, venn2_circles
4120
4334
 
@@ -4127,21 +4341,28 @@ def venn(
4127
4341
  set1, set2 = lists[0], lists[1]
4128
4342
  v.get_patch_by_id("10").set_color(colors[0])
4129
4343
  v.get_patch_by_id("01").set_color(colors[1])
4130
- v.get_patch_by_id("11").set_color(
4131
- get_color_overlap(colors[0], colors[1]) if colors else None
4132
- )
4344
+ try:
4345
+ v.get_patch_by_id("11").set_color(
4346
+ get_color_overlap(colors[0], colors[1]) if colors else None
4347
+ )
4348
+ except Exception as e:
4349
+ print(e)
4133
4350
  # v.get_label_by_id('10').set_text(len(set1 - set2))
4134
4351
  # v.get_label_by_id('01').set_text(len(set2 - set1))
4135
4352
  # v.get_label_by_id('11').set_text(len(set1 & set2))
4353
+
4136
4354
  v.get_label_by_id("10").set_text(
4137
- get_count_and_percentage(len(set1), len(set1 - set2))
4355
+ get_count_and_percentage(universe, len(set1 - set2))
4138
4356
  )
4139
4357
  v.get_label_by_id("01").set_text(
4140
- get_count_and_percentage(len(set2), len(set2 - set1))
4141
- )
4142
- v.get_label_by_id("11").set_text(
4143
- get_count_and_percentage(len(set1 | set2), len(set1 & set2))
4358
+ get_count_and_percentage(universe, len(set2 - set1))
4144
4359
  )
4360
+ try:
4361
+ v.get_label_by_id("11").set_text(
4362
+ get_count_and_percentage(universe, len(set1 & set2))
4363
+ )
4364
+ except Exception as e:
4365
+ print(e)
4145
4366
 
4146
4367
  if not isinstance(linewidth, list):
4147
4368
  linewidth = [linewidth]
@@ -4224,16 +4445,14 @@ def venn(
4224
4445
  va=va,
4225
4446
  shadow=shadow,
4226
4447
  )
4227
-
4228
- elif len(lists) == 3:
4229
-
4230
- def get_label(set_count, subset_count):
4231
- percent = subset_count / set_count if set_count > 0 else 0
4232
- return (
4233
- f"{subset_count}\n({fmt.format(percent)})"
4234
- if show_percentages
4235
- else f"{subset_count}"
4236
- )
4448
+ # Set transparency level
4449
+ for patch in v.patches:
4450
+ if patch:
4451
+ patch.set_alpha(alpha)
4452
+ if "none" in edgecolor or 0 in linewidth:
4453
+ patch.set_edgecolor("none")
4454
+ return ax
4455
+ elif len(lists) == 3:
4237
4456
 
4238
4457
  from matplotlib_venn import venn3, venn3_circles
4239
4458
 
@@ -4249,36 +4468,34 @@ def venn(
4249
4468
  # Draw the venn diagram
4250
4469
  v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
4251
4470
  v.get_patch_by_id("100").set_color(colors[0])
4471
+ v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
4252
4472
  v.get_patch_by_id("010").set_color(colors[1])
4253
- v.get_patch_by_id("001").set_color(colors[2])
4254
- v.get_patch_by_id("110").set_color(colorAB)
4255
- v.get_patch_by_id("101").set_color(colorAC)
4256
- v.get_patch_by_id("011").set_color(colorBC)
4257
- v.get_patch_by_id("111").set_color(colorABC)
4258
-
4259
- # Correctly labeling subset sizes
4260
- # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
4261
- # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
4262
- # v.get_label_by_id('001').set_text(len(set3 - set1 - set2))
4263
- # v.get_label_by_id('110').set_text(len(set1 & set2 - set3))
4264
- # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
4265
- # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
4266
- # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
4267
- v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
4268
- v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
4269
- v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
4270
- v.get_label_by_id("110").set_text(
4271
- get_label(len(set1 | set2), len(set1 & set2 - set3))
4272
- )
4273
- v.get_label_by_id("101").set_text(
4274
- get_label(len(set1 | set3), len(set1 & set3 - set2))
4275
- )
4276
- v.get_label_by_id("011").set_text(
4277
- get_label(len(set2 | set3), len(set2 & set3 - set1))
4278
- )
4279
- v.get_label_by_id("111").set_text(
4280
- get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
4281
- )
4473
+ v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
4474
+ try:
4475
+ v.get_patch_by_id("001").set_color(colors[2])
4476
+ v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
4477
+ except Exception as e:
4478
+ print(e)
4479
+ try:
4480
+ v.get_patch_by_id("110").set_color(colorAB)
4481
+ v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
4482
+ except Exception as e:
4483
+ print(e)
4484
+ try:
4485
+ v.get_patch_by_id("101").set_color(colorAC)
4486
+ v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
4487
+ except Exception as e:
4488
+ print(e)
4489
+ try:
4490
+ v.get_patch_by_id("011").set_color(colorBC)
4491
+ v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
4492
+ except Exception as e:
4493
+ print(e)
4494
+ try:
4495
+ v.get_patch_by_id("111").set_color(colorABC)
4496
+ v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
4497
+ except Exception as e:
4498
+ print(e)
4282
4499
 
4283
4500
  # Apply styles to set labels
4284
4501
  for i, text in enumerate(v.set_labels):
@@ -4383,16 +4600,34 @@ def venn(
4383
4600
  ax.add_patch(ellipse1)
4384
4601
  ax.add_patch(ellipse2)
4385
4602
  ax.add_patch(ellipse3)
4603
+ # Set transparency level
4604
+ for patch in v.patches:
4605
+ if patch:
4606
+ patch.set_alpha(alpha)
4607
+ if "none" in edgecolor or 0 in linewidth:
4608
+ patch.set_edgecolor("none")
4609
+ return ax
4610
+
4611
+
4612
+ dict_data = {}
4613
+ for i_list, list_ in enumerate(lists):
4614
+ dict_data[labels[i_list]]={*list_}
4615
+
4616
+ if 3<len(lists)<6:
4617
+ from venn import venn as vn
4618
+
4619
+ legend_loc=kwargs.pop("legend_loc", "upper right")
4620
+ ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
4621
+
4622
+ return ax
4386
4623
  else:
4387
- raise ValueError("只支持2或者3个list")
4388
-
4389
- # Set transparency level
4390
- for patch in v.patches:
4391
- if patch:
4392
- patch.set_alpha(alpha)
4393
- if "none" in edgecolor or 0 in linewidth:
4394
- patch.set_edgecolor("none")
4395
- return ax
4624
+ from venn import pseudovenn
4625
+ cmap=kwargs.pop("cmap","plasma")
4626
+ ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
4627
+
4628
+ return ax
4629
+
4630
+
4396
4631
 
4397
4632
 
4398
4633
  #! subplots, support automatic extend new axis
@@ -4403,6 +4638,7 @@ def subplot(
4403
4638
  sharex=False,
4404
4639
  sharey=False,
4405
4640
  verbose=False,
4641
+ fig=None,
4406
4642
  **kwargs,
4407
4643
  ):
4408
4644
  """
@@ -4431,8 +4667,8 @@ def subplot(
4431
4667
  )
4432
4668
 
4433
4669
  figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
4434
-
4435
- fig = plt.figure(figsize=figsize, constrained_layout=True)
4670
+ if fig is None:
4671
+ fig = plt.figure(figsize=figsize, constrained_layout=True)
4436
4672
  grid_spec = GridSpec(rows, cols, figure=fig)
4437
4673
  occupied = set()
4438
4674
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
@@ -4496,18 +4732,19 @@ def subplot(
4496
4732
  #! radar chart
4497
4733
  def radar(
4498
4734
  data: pd.DataFrame,
4499
- title="Radar Chart",
4735
+ columns=None,
4500
4736
  ylim=(0, 100),
4501
- color=get_color(5),
4737
+ facecolor=None,
4738
+ edgecolor="none",
4739
+ edge_linewidth=0.5,
4502
4740
  fontsize=10,
4503
4741
  fontcolor="k",
4504
4742
  size=6,
4505
4743
  linewidth=1,
4506
4744
  linestyle="-",
4507
- alpha=0.5,
4745
+ alpha=0.3,
4746
+ fmt=".1f",
4508
4747
  marker="o",
4509
- edgecolor="none",
4510
- edge_linewidth=0.5,
4511
4748
  bg_color="0.8",
4512
4749
  bg_alpha=None,
4513
4750
  grid_interval_ratio=0.2,
@@ -4527,19 +4764,34 @@ def radar(
4527
4764
  ax=None,
4528
4765
  sp=2,
4529
4766
  verbose=True,
4767
+ axis=0,
4530
4768
  **kwargs,
4531
4769
  ):
4532
4770
  """
4533
4771
  Example DATA:
4534
4772
  df = pd.DataFrame(
4535
- data=[
4536
- [80, 80, 80, 80, 80, 80, 80],
4537
- [90, 20, 95, 95, 30, 30, 80],
4538
- [60, 90, 20, 20, 100, 90, 50],
4539
- ],
4540
- index=["Hero", "Warrior", "Wizard"],
4541
- columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
4542
-
4773
+ data=[
4774
+ [80, 90, 60],
4775
+ [80, 20, 90],
4776
+ [80, 95, 20],
4777
+ [80, 95, 20],
4778
+ [80, 30, 100],
4779
+ [80, 30, 90],
4780
+ [80, 80, 50],
4781
+ ],
4782
+ index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
4783
+ columns=["Hero", "Warrior", "Wizard"],
4784
+ )
4785
+ usage 1:
4786
+ radar(data=df)
4787
+ usage 2:
4788
+ radar(data=df["Wizard"])
4789
+ usage 3:
4790
+ radar(data=df, columns="Wizard")
4791
+ usage 4:
4792
+ nexttile = subplot(1, 2)
4793
+ radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
4794
+ pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
4543
4795
  Parameters:
4544
4796
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
4545
4797
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4556,7 +4808,6 @@ def radar(
4556
4808
  - edge_linewidth (int): Line width for the marker edges.
4557
4809
  - bg_color (str): Background color for the radar chart.
4558
4810
  - grid_interval_ratio (float): Determines the intervals for the grid lines as a fraction of the y-limit.
4559
- - title (str): The title of the radar chart.
4560
4811
  - cmap (str): The colormap to use if `color` is a list.
4561
4812
  - legend_loc (str): The location of the legend.
4562
4813
  - legend_fontsize (int): Font size for the legend.
@@ -4573,22 +4824,22 @@ def radar(
4573
4824
  - sp (int): Padding for the ticks from the plot area.
4574
4825
  - **kwargs: Additional arguments for customization.
4575
4826
  """
4576
- if run_once_within() and verbose:
4827
+ if run_once_within(20,reverse=True) and verbose:
4577
4828
  usage_="""usage:
4578
4829
  radar(
4579
4830
  data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
4580
- title="Radar Chart",
4581
4831
  ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
4582
- color=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
4832
+ facecolor=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
4833
+ edgecolor="none",#for the marker edges.
4834
+ edge_linewidth=0.5,#for the marker edges.
4583
4835
  fontsize=10,# Font size for the angular labels (x-axis).
4584
4836
  fontcolor="k",# Color for the angular labels.
4585
4837
  size=6,#The size of the markers for each data point.
4586
4838
  linewidth=1,
4587
4839
  linestyle="-",
4588
4840
  alpha=0.5,#for the filled area.
4841
+ fmt=".1f",
4589
4842
  marker="o",# for the data points.
4590
- edgecolor="none",#for the marker edges.
4591
- edge_linewidth=0.5,#for the marker edges.
4592
4843
  bg_color="0.8",
4593
4844
  bg_alpha=None,
4594
4845
  grid_interval_ratio=0.2,#Determines the intervals for the grid lines as a fraction of the y-limit.
@@ -4618,9 +4869,25 @@ def radar(
4618
4869
  kws_figsets = v_arg
4619
4870
  kwargs.pop(k_arg, None)
4620
4871
  break
4621
- categories = list(data.columns)
4872
+ if axis==1:
4873
+ data=data.T
4874
+ if isinstance(data, dict):
4875
+ data = pd.DataFrame(pd.Series(data))
4876
+ if ~isinstance(data, pd.DataFrame):
4877
+ data=pd.DataFrame(data)
4878
+ if isinstance(data, pd.DataFrame):
4879
+ data=data.select_dtypes(include=np.number)
4880
+ if isinstance(columns,str):
4881
+ columns=[columns]
4882
+ if columns is None:
4883
+ columns = list(data.columns)
4884
+ data=data[columns]
4885
+ categories = list(data.index)
4622
4886
  num_vars = len(categories)
4623
4887
 
4888
+ # Set y-axis limits and grid intervals
4889
+ vmin, vmax = ylim
4890
+
4624
4891
  # Set up angle for each category on radar chart
4625
4892
  angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
4626
4893
  angles += angles[:1] # Complete the loop to ensure straight-line connections
@@ -4644,9 +4911,6 @@ def radar(
4644
4911
  # Draw one axis per variable and add labels
4645
4912
  ax.set_xticks(angles[:-1])
4646
4913
  ax.set_xticklabels(categories)
4647
-
4648
- # Set y-axis limits and grid intervals
4649
- vmin, vmax = ylim
4650
4914
  if circular:
4651
4915
  # * cicular style
4652
4916
  ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
@@ -4669,7 +4933,7 @@ def radar(
4669
4933
  else:
4670
4934
  # * spider style: spider-style grid (straight lines, not circles)
4671
4935
  # Create the spider-style grid (straight lines, not circles)
4672
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4936
+ for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
4673
4937
  ax.plot(
4674
4938
  angles + [angles[0]], # Closing the loop
4675
4939
  [i * vmax * grid_interval_ratio] * (num_vars + 1)
@@ -4680,7 +4944,7 @@ def radar(
4680
4944
  linewidth=grid_linewidth,
4681
4945
  )
4682
4946
  # set bg_color
4683
- ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4947
+ ax.fill(angles, [vmax] * (data.shape[0] + 1), color=bg_color, alpha=bg_alpha)
4684
4948
  ax.yaxis.grid(False)
4685
4949
  # Move radial labels away from plotted line
4686
4950
  if tick_loc is None:
@@ -4702,14 +4966,20 @@ def radar(
4702
4966
  ax.tick_params(axis="x", pad=sp) # move spines outward
4703
4967
  ax.tick_params(axis="y", pad=sp) # move spines outward
4704
4968
  # colors
4705
- colors = (
4706
- get_color(data.shape[0])
4707
- if cmap is None
4708
- else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4709
- )
4969
+ if facecolor is not None:
4970
+ if not isinstance(facecolor,list):
4971
+ facecolor=[facecolor]
4972
+ colors = facecolor
4973
+ else:
4974
+ colors = (
4975
+ get_color(data.shape[1])
4976
+ if cmap is None
4977
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
4978
+ )
4979
+
4710
4980
  # Plot each row with straight lines
4711
- for i, (index, row) in enumerate(data.iterrows()):
4712
- values = row.tolist()
4981
+ for i, (col, val) in enumerate(data.items()):
4982
+ values = val.tolist()
4713
4983
  values += values[:1] # Close the loop
4714
4984
  ax.plot(
4715
4985
  angles,
@@ -4717,7 +4987,7 @@ def radar(
4717
4987
  color=colors[i],
4718
4988
  linewidth=linewidth,
4719
4989
  linestyle=linestyle,
4720
- label=index,
4990
+ label=col,
4721
4991
  clip_on=False,
4722
4992
  )
4723
4993
  ax.fill(angles, values, color=colors[i], alpha=alpha)
@@ -4748,7 +5018,7 @@ def radar(
4748
5018
  ax.text(
4749
5019
  angle,
4750
5020
  offset_radius,
4751
- str(value),
5021
+ f"{value:{fmt}}",
4752
5022
  ha="center",
4753
5023
  va="center",
4754
5024
  fontsize=fontsize,
@@ -4759,10 +5029,10 @@ def radar(
4759
5029
 
4760
5030
  ax.set_ylim(ylim)
4761
5031
  # Add markers for each data point
4762
- for i, row in enumerate(data.values):
5032
+ for i, (col, val) in enumerate(data.items()):
4763
5033
  ax.plot(
4764
5034
  angles,
4765
- list(row) + [row[0]], # Close the loop for markers
5035
+ list(val) + [val[0]], # Close the loop for markers
4766
5036
  color=colors[i],
4767
5037
  marker=marker,
4768
5038
  markersize=size,
@@ -4787,3 +5057,819 @@ def radar(
4787
5057
  **kws_figsets,
4788
5058
  )
4789
5059
  return ax
5060
+
5061
+
5062
+ def pie(
5063
+ data:pd.Series,
5064
+ columns:list = None,
5065
+ facecolor=None,
5066
+ explode=[0.1],
5067
+ startangle=90,
5068
+ shadow=True,
5069
+ fontcolor="k",
5070
+ fmt=".2f",
5071
+ width=None,# the center blank
5072
+ pctdistance=0.85,
5073
+ labeldistance=1.1,
5074
+ kws_wedge={},
5075
+ kws_text={},
5076
+ kws_arrow={},
5077
+ center=(0, 0),
5078
+ radius=1,
5079
+ frame=False,
5080
+ fontsize=10,
5081
+ edgecolor="white",
5082
+ edgewidth=1,
5083
+ cmap=None,
5084
+ show_value=False,
5085
+ show_label=True,# False: only show the outer layer, if it is None, not show
5086
+ expand_label=(1.2,1.2),
5087
+ kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5088
+ show_legend=True,
5089
+ legend_loc="upper right",
5090
+ bbox_to_anchor=[1.4, 1.1],
5091
+ legend_fontsize=10,
5092
+ rotation_correction=0,
5093
+ verbose=True,
5094
+ ax=None,
5095
+ **kwargs
5096
+ ):
5097
+ from adjustText import adjust_text
5098
+ if run_once_within(20,reverse=True) and verbose:
5099
+ usage_="""usage:
5100
+ pie(
5101
+ data:pd.Series,
5102
+ columns:list = None,
5103
+ facecolor=None,
5104
+ explode=[0.1],
5105
+ startangle=90,
5106
+ shadow=True,
5107
+ fontcolor="k",
5108
+ fmt=".2f",
5109
+ width=None,# the center blank
5110
+ pctdistance=0.85,
5111
+ labeldistance=1.1,
5112
+ kws_wedge={},
5113
+ kws_text={},
5114
+ center=(0, 0),
5115
+ radius=1,
5116
+ frame=False,
5117
+ fontsize=10,
5118
+ edgecolor="white",
5119
+ edgewidth=1,
5120
+ cmap=None,
5121
+ show_value=False,
5122
+ show_label=True,# False: only show the outer layer, if it is None, not show
5123
+ show_legend=True,
5124
+ legend_loc="upper right",
5125
+ bbox_to_anchor=[1.4, 1.1],
5126
+ legend_fontsize=10,
5127
+ rotation_correction=0,
5128
+ verbose=True,
5129
+ ax=None,
5130
+ **kwargs
5131
+ )
5132
+
5133
+ usage 1:
5134
+ data = {"Segment A": 30, "Segment B": 50, "Segment C": 20}
5135
+
5136
+ ax = pie(
5137
+ data=data,
5138
+ # columns="Segment A",
5139
+ explode=[0, 0.2, 0],
5140
+ # width=0.4,
5141
+ show_label=False,
5142
+ fontsize=10,
5143
+ # show_value=1,
5144
+ fmt=".3f",
5145
+ )
5146
+
5147
+ # prepare dataset
5148
+ df = pd.DataFrame(
5149
+ data=[
5150
+ [80, 90, 60],
5151
+ [80, 20, 90],
5152
+ [80, 95, 20],
5153
+ [80, 95, 20],
5154
+ [80, 30, 100],
5155
+ [80, 30, 90],
5156
+ [80, 80, 50],
5157
+ ],
5158
+ index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
5159
+ columns=["Hero", "Warrior", "Wizard"],
5160
+ )
5161
+ usage 1: only plot one column
5162
+ pie(
5163
+ df,
5164
+ columns="Wizard",
5165
+ width=0.6,
5166
+ show_label=False,
5167
+ fmt=".0f",
5168
+ )
5169
+ usage 2:
5170
+ pie(df,columns=["Hero", "Warrior"],show_label=False)
5171
+ usage 3: set different width
5172
+ pie(df,
5173
+ columns=["Hero", "Warrior", "Wizard"],
5174
+ width=[0.3, 0.2, 0.2],
5175
+ show_label=False,
5176
+ fmt=".0f",
5177
+ )
5178
+ usage 4: set width the same for all columns
5179
+ pie(df,
5180
+ columns=["Hero", "Warrior", "Wizard"],
5181
+ width=0.2,
5182
+ show_label=False,
5183
+ fmt=".0f",
5184
+ )
5185
+ usage 5: adjust the labels' offset
5186
+ pie(df, columns="Wizard", width=0.6, show_label=False, fmt=".6f", labeldistance=1.2)
5187
+
5188
+ usage 6:
5189
+ nexttile = subplot(1, 2)
5190
+ radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
5191
+ pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
5192
+ """
5193
+ print(usage_)
5194
+ # Convert data to a Pandas Series if needed
5195
+ if isinstance(data, dict):
5196
+ data = pd.DataFrame(pd.Series(data))
5197
+ if ~isinstance(data, pd.DataFrame):
5198
+ data=pd.DataFrame(data)
5199
+
5200
+ if isinstance(data, pd.DataFrame):
5201
+ data=data.select_dtypes(include=np.number)
5202
+ if isinstance(columns,str):
5203
+ columns=[columns]
5204
+ if columns is None:
5205
+ columns = list(data.columns)
5206
+ # data=data[columns]
5207
+ # columns = list(data.columns)
5208
+ # print(columns)
5209
+ # 选择部分数据
5210
+ df=data[columns]
5211
+
5212
+ if not isinstance(explode, list):
5213
+ explode=[explode]
5214
+ if explode == [None]:
5215
+ explode=[0]
5216
+
5217
+ if width is None:
5218
+ if df.shape[1]>1:
5219
+ width=1/(df.shape[1]+2)
5220
+ else:
5221
+ width=1
5222
+ if isinstance(width,(float,int)):
5223
+ width=[width]
5224
+ if len(width)<df.shape[1]:
5225
+ width=width*df.shape[1]
5226
+ if isinstance(radius,(float,int)):
5227
+ radius=[radius]
5228
+ radius_tile=[1]*df.shape[1]
5229
+ radius=radius_tile.copy()
5230
+ for i in range(1,df.shape[1]):
5231
+ radius[i]=radius_tile[i]-np.sum(width[:i])
5232
+
5233
+ # colors
5234
+ if facecolor is not None:
5235
+ if not isinstance(facecolor,list):
5236
+ facecolor=[facecolor]
5237
+ colors = facecolor
5238
+ else:
5239
+ colors = (
5240
+ get_color(data.shape[0])
5241
+ if cmap is None
5242
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
5243
+ )
5244
+ # to check if facecolor is nested list or not
5245
+ is_nested = True if any(isinstance(i, list) for i in colors) else False
5246
+ inested = 0
5247
+ for column_,width_,radius_ in zip(columns, width,radius):
5248
+ if column_!=columns[0]:
5249
+ labels = data.index if show_label else None
5250
+ else:
5251
+ labels = data.index if show_label is not None else None
5252
+ data = df[column_]
5253
+ labels_legend=data.index
5254
+ sizes = data.values
5255
+
5256
+ # Set wedge and text properties if none are provided
5257
+ kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
5258
+ kws_wedge.update({"width":width_})
5259
+ fontcolor=kws_text.get("color",fontcolor)
5260
+ fontsize=kws_text.get("fontsize",fontsize)
5261
+ kws_text.update({"color": fontcolor, "fontsize": fontsize})
5262
+
5263
+ if ax is None:
5264
+ ax=plt.gca()
5265
+ if len(explode)<len(labels_legend):
5266
+ explode.extend([0]*(len(labels_legend)-len(explode)))
5267
+ print(explode)
5268
+ if fmt:
5269
+ if not fmt.startswith("%"):
5270
+ autopct =f"%{fmt}%%"
5271
+ else:
5272
+ autopct=None
5273
+
5274
+ if show_value is None:
5275
+ result = ax.pie(
5276
+ sizes,
5277
+ labels=labels,
5278
+ autopct= None,
5279
+ startangle=startangle + rotation_correction,
5280
+ explode=explode,
5281
+ colors=colors[inested] if is_nested else colors,
5282
+ shadow=shadow,
5283
+ pctdistance=pctdistance,
5284
+ labeldistance=labeldistance,
5285
+ wedgeprops=kws_wedge,
5286
+ textprops=kws_text,
5287
+ center=center,
5288
+ radius=radius_,
5289
+ frame=frame,
5290
+ **kwargs
5291
+ )
5292
+ else:
5293
+ result = ax.pie(
5294
+ sizes,
5295
+ labels=labels,
5296
+ autopct=autopct if autopct else None,
5297
+ startangle=startangle + rotation_correction,
5298
+ explode=explode,
5299
+ colors=colors[inested] if is_nested else colors,
5300
+ shadow=shadow,#shadow,
5301
+ pctdistance=pctdistance,
5302
+ labeldistance=labeldistance,
5303
+ wedgeprops=kws_wedge,
5304
+ textprops=kws_text,
5305
+ center=center,
5306
+ radius=radius_,
5307
+ frame=frame,
5308
+ **kwargs
5309
+ )
5310
+ if len(result) == 3:
5311
+ wedges, texts, autotexts = result
5312
+ elif len(result) == 2:
5313
+ wedges, texts = result
5314
+ autotexts = None
5315
+ #! adjust_text
5316
+ if autotexts or texts:
5317
+ all_texts = []
5318
+ if autotexts and show_value:
5319
+ all_texts.extend(autotexts)
5320
+ if texts and show_label:
5321
+ all_texts.extend(texts)
5322
+
5323
+ adjust_text(
5324
+ all_texts,
5325
+ ax=ax,
5326
+ arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
5327
+ bbox=kws_bbox if kws_bbox else None,
5328
+ expand=expand_label,
5329
+ fontdict={
5330
+ "fontsize": fontsize,
5331
+ "color": fontcolor,
5332
+ },
5333
+ )
5334
+ # Show exact values on wedges if show_value is True
5335
+ if show_value:
5336
+ for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5337
+ angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5338
+ x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5339
+ y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5340
+ if not fmt.startswith("{"):
5341
+ value_text = f"{sizes[i]:{fmt}}"
5342
+ else:
5343
+ value_text = fmt.format(sizes[i])
5344
+ ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5345
+ inested+=1
5346
+ # Customize the legend
5347
+ if show_legend:
5348
+ ax.legend(
5349
+ wedges,
5350
+ labels_legend,
5351
+ loc=legend_loc,
5352
+ bbox_to_anchor=bbox_to_anchor,
5353
+ fontsize=legend_fontsize,
5354
+ title_fontsize=legend_fontsize,
5355
+ )
5356
+ ax.set(aspect="equal")
5357
+ return ax
5358
+
5359
+ def ellipse(
5360
+ data,
5361
+ x=None,
5362
+ y=None,
5363
+ hue=None,
5364
+ n_std=1.5,
5365
+ ax=None,
5366
+ confidence=0.95,
5367
+ annotate_center=False,
5368
+ palette=None,
5369
+ facecolor=None,
5370
+ edgecolor=None,
5371
+ label:bool=True,
5372
+ **kwargs,
5373
+ ):
5374
+ """
5375
+ Plot advanced ellipses representing covariance for different groups
5376
+ # simulate data:
5377
+ control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
5378
+ patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
5379
+ df = pd.DataFrame(
5380
+ {
5381
+ "Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
5382
+ "Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
5383
+ "Group": ["Control"] * 50 + ["Patient"] * 50,
5384
+ }
5385
+ )
5386
+ plotxy(
5387
+ data=df,
5388
+ x="Dim1",
5389
+ y="Dim2",
5390
+ hue="Group",
5391
+ kind_="scatter",
5392
+ palette=get_color(8),
5393
+ )
5394
+ ellipse(
5395
+ data=df,
5396
+ x="Dim1",
5397
+ y="Dim2",
5398
+ hue="Group",
5399
+ palette=get_color(8),
5400
+ alpha=0.1,
5401
+ lw=2,
5402
+ )
5403
+ Parameters:
5404
+ data (DataFrame): Input DataFrame with columns for x, y, and hue.
5405
+ x (str): Column name for x-axis values.
5406
+ y (str): Column name for y-axis values.
5407
+ hue (str, optional): Column name for group labels.
5408
+ n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
5409
+ ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
5410
+ confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
5411
+ annotate_center (bool): Whether to annotate the ellipse center (mean).
5412
+ palette (dict or list, optional): A mapping of hues to colors or a list of colors.
5413
+ **kwargs: Additional keyword arguments for the Ellipse patch.
5414
+
5415
+ Returns:
5416
+ list: List of Ellipse objects added to the Axes.
5417
+ """
5418
+ from matplotlib.patches import Ellipse
5419
+ import numpy as np
5420
+ import matplotlib.pyplot as plt
5421
+ import seaborn as sns
5422
+ import pandas as pd
5423
+ from scipy.stats import chi2
5424
+
5425
+ if ax is None:
5426
+ ax = plt.gca()
5427
+
5428
+ # Validate inputs
5429
+ if x is None or y is None:
5430
+ raise ValueError(
5431
+ "Both `x` and `y` must be specified as column names in the DataFrame."
5432
+ )
5433
+ if not isinstance(data, pd.DataFrame):
5434
+ raise ValueError("`data` must be a pandas DataFrame.")
5435
+
5436
+ # Prepare data for hue-based grouping
5437
+ ellipses = []
5438
+ if hue is not None:
5439
+ groups = data[hue].unique()
5440
+ colors = sns.color_palette(palette or "husl", len(groups))
5441
+ color_map = dict(zip(groups, colors))
5442
+ else:
5443
+ groups = [None]
5444
+ color_map = {None: kwargs.get("edgecolor", "blue")}
5445
+ alpha = kwargs.pop("alpha", 0.2)
5446
+ edgecolor=kwargs.pop("edgecolor", None)
5447
+ facecolor=kwargs.pop("facecolor", None)
5448
+ for group in groups:
5449
+ group_data = data[data[hue] == group] if hue else data
5450
+
5451
+ # Extract x and y columns for the group
5452
+ group_points = group_data[[x, y]].values
5453
+
5454
+ # Compute mean and covariance matrix
5455
+ # # 标准化处理
5456
+ # group_points = group_data[[x, y]].values
5457
+ # group_points -= group_points.mean(axis=0)
5458
+ # group_points /= group_points.std(axis=0)
5459
+
5460
+ cov = np.cov(group_points.T)
5461
+ mean = np.mean(group_points, axis=0)
5462
+
5463
+ # Eigenvalues and eigenvectors
5464
+ eigvals, eigvecs = np.linalg.eigh(cov)
5465
+ order = eigvals.argsort()[::-1]
5466
+ eigvals, eigvecs = eigvals[order], eigvecs[:, order]
5467
+
5468
+ # Rotation angle and ellipse dimensions
5469
+ angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
5470
+ if confidence:
5471
+ n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
5472
+ width, height = 2 * n_std * np.sqrt(eigvals)
5473
+
5474
+ # Create and style the ellipse
5475
+ if facecolor is None:
5476
+ facecolor_ = color_map[group]
5477
+ if edgecolor is None:
5478
+ edgecolor_ = color_map[group]
5479
+ ellipse = Ellipse(
5480
+ xy=mean,
5481
+ width=width,
5482
+ height=height,
5483
+ angle=angle,
5484
+ edgecolor=edgecolor_,
5485
+ facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
5486
+ # alpha=alpha,
5487
+ label=group if (hue and label) else None,
5488
+ **kwargs,
5489
+ )
5490
+ ax.add_patch(ellipse)
5491
+ ellipses.append(ellipse)
5492
+
5493
+ # Annotate center
5494
+ if annotate_center:
5495
+ ax.annotate(
5496
+ f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
5497
+ xy=mean,
5498
+ xycoords="data",
5499
+ fontsize=10,
5500
+ ha="center",
5501
+ color=ellipse_color,
5502
+ bbox=dict(
5503
+ boxstyle="round,pad=0.3",
5504
+ edgecolor="gray",
5505
+ facecolor="white",
5506
+ alpha=0.8,
5507
+ ),
5508
+ )
5509
+
5510
+ return ax
5511
+
5512
+ def ppi(
5513
+ interactions,
5514
+ player1="preferredName_A",
5515
+ player2="preferredName_B",
5516
+ weight="score",
5517
+ n_layers=None, # Number of concentric layers
5518
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5519
+ dist_node = 10, # Distance between each rank of circles
5520
+ layout="degree",
5521
+ size=None,#700,
5522
+ sizes=(50,500),# min and max of size
5523
+ facecolor="skyblue",
5524
+ cmap='coolwarm',
5525
+ edgecolor="k",
5526
+ edgelinewidth=1.5,
5527
+ alpha=.5,
5528
+ alphas=(0.1, 1.0),# min and max of alpha
5529
+ marker="o",
5530
+ node_hideticks=True,
5531
+ linecolor="gray",
5532
+ line_cmap='coolwarm',
5533
+ linewidth=1.5,
5534
+ linewidths=(0.5,5),# min and max of linewidth
5535
+ linealpha=1.0,
5536
+ linealphas=(0.1,1.0),# min and max of linealpha
5537
+ linestyle="-",
5538
+ line_arrowstyle='-',
5539
+ fontsize=10,
5540
+ fontcolor="k",
5541
+ ha:str="center",
5542
+ va:str="center",
5543
+ figsize=(12, 10),
5544
+ k_value=0.3,
5545
+ bgcolor="w",
5546
+ dir_save="./ppi_network.html",
5547
+ physics=True,
5548
+ notebook=False,
5549
+ scale=1,
5550
+ ax=None,
5551
+ **kwargs
5552
+ ):
5553
+ """
5554
+ Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
5555
+
5556
+ ppi(
5557
+ interactions_sort.iloc[:1000, :],
5558
+ player1="player1",
5559
+ player2="player2",
5560
+ weight="count",
5561
+ layout="spring",
5562
+ n_layers=13,
5563
+ fontsize=1,
5564
+ n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
5565
+ )
5566
+ """
5567
+ from pyvis.network import Network
5568
+ import networkx as nx
5569
+ from IPython.display import IFrame
5570
+ from matplotlib.colors import Normalize
5571
+ from matplotlib import cm
5572
+ from . import ips
5573
+
5574
+ if run_once_within():
5575
+ usage_str="""
5576
+ ppi(
5577
+ interactions,
5578
+ player1="preferredName_A",
5579
+ player2="preferredName_B",
5580
+ weight="score",
5581
+ n_layers=None, # Number of concentric layers
5582
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5583
+ dist_node = 10, # Distance between each rank of circles
5584
+ layout="degree",
5585
+ size=None,#700,
5586
+ sizes=(50,500),# min and max of size
5587
+ facecolor="skyblue",
5588
+ cmap='coolwarm',
5589
+ edgecolor="k",
5590
+ edgelinewidth=1.5,
5591
+ alpha=.5,
5592
+ alphas=(0.1, 1.0),# min and max of alpha
5593
+ marker="o",
5594
+ node_hideticks=True,
5595
+ linecolor="gray",
5596
+ line_cmap='coolwarm',
5597
+ linewidth=1.5,
5598
+ linewidths=(0.5,5),# min and max of linewidth
5599
+ linealpha=1.0,
5600
+ linealphas=(0.1,1.0),# min and max of linealpha
5601
+ linestyle="-",
5602
+ line_arrowstyle='-',
5603
+ fontsize=10,
5604
+ fontcolor="k",
5605
+ ha:str="center",
5606
+ va:str="center",
5607
+ figsize=(12, 10),
5608
+ k_value=0.3,
5609
+ bgcolor="w",
5610
+ dir_save="./ppi_network.html",
5611
+ physics=True,
5612
+ notebook=False,
5613
+ scale=1,
5614
+ ax=None,
5615
+ **kwargs
5616
+ ):
5617
+ """
5618
+ print(usage_str)
5619
+
5620
+ # Check for required columns in the DataFrame
5621
+ for col in [player1, player2, weight]:
5622
+ if col not in interactions.columns:
5623
+ raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
5624
+ interactions.sort_values(by=[weight], inplace=True)
5625
+ # Initialize Pyvis network
5626
+ net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
5627
+ net.force_atlas_2based(
5628
+ gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
5629
+ )
5630
+ net.toggle_physics(physics)
5631
+
5632
+ kws_figsets = {}
5633
+ for k_arg, v_arg in kwargs.items():
5634
+ if "figset" in k_arg:
5635
+ kws_figsets = v_arg
5636
+ kwargs.pop(k_arg, None)
5637
+ break
5638
+
5639
+ # Create a NetworkX graph from the interaction data
5640
+ G = nx.Graph()
5641
+ for _, row in interactions.iterrows():
5642
+ G.add_edge(row[player1], row[player2], weight=row[weight])
5643
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
5644
+
5645
+
5646
+ # Calculate node degrees
5647
+ degrees = dict(G.degree())
5648
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5649
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
5650
+
5651
+ if not ips.isa(facecolor, 'color'):
5652
+ print("facecolor: based on degrees")
5653
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
5654
+ num_nodes = G.number_of_nodes()
5655
+ #* size
5656
+ # Set properties based on degrees
5657
+ if not isinstance(size, (int,float,list)):
5658
+ print("size: based on degrees")
5659
+ size = [deg * 50 for deg in degrees.values()] # Scale sizes
5660
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
5661
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
5662
+ # Normalize sizes
5663
+ min_size, max_size = sizes # Use sizes tuple for min and max values
5664
+ min_degree, max_degree = min(size), max(size)
5665
+ if max_degree > min_degree: # Avoid division by zero
5666
+ size = [
5667
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5668
+ for sz in size
5669
+ ]
5670
+ else:
5671
+ # If all values are the same, set them to a default of the midpoint
5672
+ size = [(min_size + max_size) / 2] * len(size)
5673
+
5674
+ #* facecolor
5675
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
5676
+ # * facealpha
5677
+ if isinstance(alpha, list):
5678
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
5679
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
5680
+ if len(alpha) > 0:
5681
+ # Normalize alpha based on the specified min and max
5682
+ min_alpha, max_alpha = min(alpha), max(alpha)
5683
+ if max_alpha > min_alpha: # Avoid division by zero
5684
+ alpha = [
5685
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
5686
+ for ea in alpha
5687
+ ]
5688
+ else:
5689
+ # If all alpha values are the same, set them to the average of min and max
5690
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
5691
+ else:
5692
+ # Default to a full opacity if no edges are provided
5693
+ alpha = [1.0] * num_nodes
5694
+ else:
5695
+ # If alpha is a single value, convert it to a list and normalize it
5696
+ alpha = [alpha] * num_nodes # Adjust based on alphas
5697
+
5698
+ for i, node in enumerate(G.nodes()):
5699
+ net.add_node(
5700
+ node,
5701
+ label=node,
5702
+ size=size[i],
5703
+ color=facecolor[i],
5704
+ alpha=alpha[i],
5705
+ font={"size": fontsize, "color": fontcolor},
5706
+ )
5707
+ print(f'nodes number: {i+1}')
5708
+
5709
+ for edge in G.edges(data=True):
5710
+ net.add_edge(
5711
+ edge[0],
5712
+ edge[1],
5713
+ weight=edge[2]["weight"],
5714
+ color=edgecolor,
5715
+ width=edgelinewidth * edge[2]["weight"],
5716
+ )
5717
+
5718
+ layouts = [
5719
+ "spring",
5720
+ "circular",
5721
+ "kamada_kawai",
5722
+ "random",
5723
+ "shell",
5724
+ "planar",
5725
+ "spiral",
5726
+ "degree"
5727
+ ]
5728
+ layout = ips.strcmp(layout, layouts)[0]
5729
+ print(f"layout:{layout}, or select one in {layouts}")
5730
+
5731
+ # Choose layout
5732
+ if layout == "spring":
5733
+ pos = nx.spring_layout(G, k=k_value)
5734
+ elif layout == "circular":
5735
+ pos = nx.circular_layout(G)
5736
+ elif layout == "kamada_kawai":
5737
+ pos = nx.kamada_kawai_layout(G)
5738
+ elif layout == "spectral":
5739
+ pos = nx.spectral_layout(G)
5740
+ elif layout == "random":
5741
+ pos = nx.random_layout(G)
5742
+ elif layout == "shell":
5743
+ pos = nx.shell_layout(G)
5744
+ elif layout == "planar":
5745
+ if nx.check_planarity(G)[0]:
5746
+ pos = nx.planar_layout(G)
5747
+ else:
5748
+ print("Graph is not planar; switching to spring layout.")
5749
+ pos = nx.spring_layout(G, k=k_value)
5750
+ elif layout == "spiral":
5751
+ pos = nx.spiral_layout(G)
5752
+ elif layout=='degree':
5753
+ # Calculate node degrees and sort nodes by degree
5754
+ degrees = dict(G.degree())
5755
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
5756
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5757
+ colormap = cm.get_cmap(cmap)
5758
+
5759
+ # Create positions for concentric circles based on n_layers and n_rank
5760
+ pos = {}
5761
+ n_layers=len(n_rank)+1 if n_layers is None else n_layers
5762
+ for rank_index in range(n_layers):
5763
+ if rank_index < len(n_rank):
5764
+ nodes_per_rank = n_rank[rank_index]
5765
+ rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
5766
+ else:
5767
+ # 随机打乱剩余节点的顺序
5768
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
5769
+ random_indices = np.random.permutation(len(remaining_nodes))
5770
+ rank_nodes = [remaining_nodes[i] for i in random_indices]
5771
+
5772
+ radius = (rank_index + 1) * dist_node # Radius for this rank
5773
+
5774
+ # Arrange nodes in a circle for the current rank
5775
+ for i, (node, degree) in enumerate(rank_nodes):
5776
+ angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
5777
+ pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
5778
+
5779
+ else:
5780
+ print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
5781
+ pos = nx.spring_layout(G, k=k_value)
5782
+
5783
+ for node, (x, y) in pos.items():
5784
+ net.get_node(node)["x"] = x * scale
5785
+ net.get_node(node)["y"] = y * scale
5786
+
5787
+ # If ax is None, use plt.gca()
5788
+ if ax is None:
5789
+ fig, ax = plt.subplots(1,1,figsize=figsize)
5790
+
5791
+ # Draw nodes, edges, and labels with customization options
5792
+ nx.draw_networkx_nodes(
5793
+ G,
5794
+ pos,
5795
+ ax=ax,
5796
+ node_size=size,
5797
+ node_color=facecolor,
5798
+ linewidths=edgelinewidth,
5799
+ edgecolors=edgecolor,
5800
+ alpha=alpha,
5801
+ hide_ticks=node_hideticks,
5802
+ node_shape=marker
5803
+ )
5804
+
5805
+ #* linewidth
5806
+ if not isinstance(linewidth, list):
5807
+ linewidth = [linewidth] * G.number_of_edges()
5808
+ else:
5809
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
5810
+ # Normalize linewidth if it is a list
5811
+ if isinstance(linewidth, list):
5812
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
5813
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
5814
+ if max_linewidth > min_linewidth: # Avoid division by zero
5815
+ # Scale between vmin and vmax
5816
+ linewidth = [
5817
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
5818
+ for lw in linewidth
5819
+ ]
5820
+ else:
5821
+ # If all values are the same, set them to a default of the midpoint
5822
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
5823
+ else:
5824
+ # If linewidth is a single value, convert it to a list of that value
5825
+ linewidth = [linewidth] * G.number_of_edges()
5826
+ #* linecolor
5827
+ if not isinstance(linecolor, str):
5828
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
5829
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
5830
+ colormap = cm.get_cmap(line_cmap)
5831
+ linecolor = [colormap(norm(weight)) for weight in weights]
5832
+ else:
5833
+ linecolor = [linecolor] * G.number_of_edges()
5834
+
5835
+ # * linealpha
5836
+ if isinstance(linealpha, list):
5837
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
5838
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
5839
+ if len(linealpha) > 0:
5840
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
5841
+ if max_linealpha > min_linealpha: # Avoid division by zero
5842
+ linealpha = [
5843
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
5844
+ for ea in linealpha
5845
+ ]
5846
+ else:
5847
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
5848
+ else:
5849
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
5850
+ else:
5851
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
5852
+ nx.draw_networkx_edges(
5853
+ G,
5854
+ pos,
5855
+ ax=ax,
5856
+ edge_color=linecolor,
5857
+ width=linewidth,
5858
+ style=linestyle,
5859
+ arrowstyle=line_arrowstyle,
5860
+ alpha=linealpha
5861
+ )
5862
+
5863
+ nx.draw_networkx_labels(
5864
+ G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
5865
+ )
5866
+ figsets(ax=ax,**kws_figsets)
5867
+ ax.axis("off")
5868
+ if dir_save:
5869
+ if not os.path.basename(dir_save):
5870
+ dir_save="_.html"
5871
+ net.write_html(dir_save)
5872
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
5873
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
5874
+ ips.figsave(dir_save.replace(".html",".pdf"))
5875
+ return G,ax