py2ls 0.2.4.24__py3-none-any.whl → 0.2.4.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -20,7 +20,9 @@ from .ips import (
20
20
  flatten,
21
21
  plt_font,
22
22
  run_once_within,
23
- df_format
23
+ df_format,
24
+ df_corr,
25
+ df_scaler
24
26
  )
25
27
  import scipy.stats as scipy_stats
26
28
  from .stats import *
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
142
144
  **kwargs,
143
145
  )
144
146
 
147
+ def pval2str(p):
148
+ if p > 0.05:
149
+ txt = ""
150
+ elif 0.01 <= p <= 0.05:
151
+ txt = "*"
152
+ elif 0.001 <= p < 0.01:
153
+ txt = "**"
154
+ elif p < 0.001:
155
+ txt = "***"
156
+ return txt
145
157
 
146
158
  def heatmap(
147
159
  data,
148
160
  ax=None,
149
161
  kind="corr", #'corr','direct','pivot'
162
+ method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
150
163
  columns="all", # pivot, default: coll numeric columns
164
+ style=0,# for correlation
151
165
  index=None, # pivot
152
166
  values=None, # pivot
167
+ fontsize=10,
153
168
  tri="u",
154
169
  mask=True,
155
170
  k=1,
171
+ vmin=None,
172
+ vmax=None,
173
+ size_scale=500,
156
174
  annot=True,
157
175
  cmap="coolwarm",
158
176
  fmt=".2f",
177
+ show_indicator = True,# only for style==1
159
178
  cluster=False,
160
179
  inplace=False,
161
180
  figsize=(10, 8),
@@ -201,7 +220,7 @@ def heatmap(
201
220
  ax = plt.gca()
202
221
  # Select numeric columns or specific subset of columns
203
222
  if columns == "all":
204
- df_numeric = data.select_dtypes(include=[float, int])
223
+ df_numeric = data.select_dtypes(include=[np.number])
205
224
  else:
206
225
  df_numeric = data[columns]
207
226
 
@@ -209,8 +228,10 @@ def heatmap(
209
228
  kind = strcmp(kind, kinds)[0]
210
229
  print(kind)
211
230
  if "corr" in kind: # correlation
231
+ methods = ["pearson", "spearman", "kendall"]
232
+ method = strcmp(method, methods)[0]
212
233
  # Compute the correlation matrix
213
- data4heatmap = df_numeric.corr()
234
+ data4heatmap = df_numeric.corr(method=method)
214
235
  # Generate mask for the upper triangle if mask is True
215
236
  if mask:
216
237
  if "u" in tri.lower(): # upper => np.tril
@@ -288,18 +309,76 @@ def heatmap(
288
309
  df_col_cluster,
289
310
  )
290
311
  else:
291
- # Create a standard heatmap
292
- ax = sns.heatmap(
293
- data4heatmap,
294
- ax=ax,
295
- mask=mask_array,
296
- annot=annot,
297
- cmap=cmap,
298
- fmt=fmt,
299
- **kwargs, # Pass any additional arguments to sns.heatmap
300
- )
301
- # Return the Axes object for further customization if needed
302
- return ax
312
+ if style==0:
313
+ # Create a standard heatmap
314
+ ax = sns.heatmap(
315
+ data4heatmap,
316
+ ax=ax,
317
+ mask=mask_array,
318
+ annot=annot,
319
+ cmap=cmap,
320
+ fmt=fmt,
321
+ **kwargs, # Pass any additional arguments to sns.heatmap
322
+ )
323
+ return ax
324
+ elif style==1:
325
+ if isinstance(cmap, str):
326
+ cmap = plt.get_cmap(cmap)
327
+ norm = plt.Normalize(vmin=-1, vmax=1)
328
+ r_, p_ = df_corr(data4heatmap, method=method)
329
+ # size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
330
+ # 初始化一个空的可绘制对象用于颜色条
331
+ scatter_handles = []
332
+ # 循环绘制气泡图和数值
333
+ for i in range(len(r_.columns)):
334
+ for j in range(len(r_.columns)):
335
+ if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
336
+ color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
337
+ scatter = ax.scatter(
338
+ i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
339
+ # alpha=1,edgecolor=edgecolor,linewidth=linewidth,
340
+ **kwargs
341
+ )
342
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
343
+ # add *** indicators
344
+ if show_indicator:
345
+ ax.text(
346
+ i,
347
+ j,
348
+ pval2str(p_.iloc[i, j]),
349
+ ha="center",
350
+ va="center",
351
+ color="k",
352
+ fontsize=fontsize * 1.3,
353
+ )
354
+ elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
355
+ color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
356
+ ax.text(
357
+ i,
358
+ j,
359
+ f"{r_.iloc[i, j]:{fmt}}",
360
+ ha="center",
361
+ va="center",
362
+ color=color,
363
+ fontsize=fontsize,
364
+ )
365
+ else: # 对角线部分,显示空白
366
+ ax.scatter(i, j, s=1, color="white")
367
+ # 设置坐标轴标签
368
+ figsets(xticks=range(len(r_.columns)),
369
+ xticklabels=r_.columns,
370
+ xangle=90,
371
+ fontsize=fontsize,
372
+ yticks=range(len(r_.columns)),
373
+ yticklabels=r_.columns,
374
+ xlim=[-0.5,len(r_.columns)-0.5],
375
+ ylim=[-0.5,len(r_.columns)-0.5]
376
+ )
377
+ # 添加颜色条
378
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
379
+ sm.set_array([]) # 仅用于显示颜色条
380
+ plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
381
+ return ax
303
382
  elif "dir" in kind: # direct
304
383
  data4heatmap = df_numeric
305
384
  elif "pi" in kind: # pivot
@@ -389,16 +468,53 @@ def heatmap(
389
468
  )
390
469
  else:
391
470
  # Create a standard heatmap
392
- ax = sns.heatmap(
393
- data4heatmap,
394
- ax=ax,
395
- annot=annot,
396
- cmap=cmap,
397
- fmt=fmt,
398
- **kwargs, # Pass any additional arguments to sns.heatmap
399
- )
400
- # Return the Axes object for further customization if needed
401
- return ax
471
+ if style==0:
472
+ ax = sns.heatmap(
473
+ data4heatmap,
474
+ ax=ax,
475
+ annot=annot,
476
+ cmap=cmap,
477
+ fmt=fmt,
478
+ **kwargs, # Pass any additional arguments to sns.heatmap
479
+ )
480
+ # Return the Axes object for further customization if needed
481
+ return ax
482
+ elif style==1:
483
+ if isinstance(cmap, str):
484
+ cmap = plt.get_cmap(cmap)
485
+ if vmin is None:
486
+ vmin=np.min(data4heatmap)
487
+ if vmax is None:
488
+ vmax=np.max(data4heatmap)
489
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
490
+
491
+ # 初始化一个空的可绘制对象用于颜色条
492
+ scatter_handles = []
493
+ # 循环绘制气泡图和数值
494
+ print(len(data4heatmap.index),len(data4heatmap.columns))
495
+ for i in range(len(data4heatmap.index)):
496
+ for j in range(len(data4heatmap.columns)):
497
+ color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
498
+ scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
499
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
500
+
501
+ # 设置坐标轴标签
502
+ figsets(xticks=range(len(data4heatmap.columns)),
503
+ xticklabels=data4heatmap.columns,
504
+ xangle=90,
505
+ fontsize=fontsize,
506
+ yticks=range(len(data4heatmap.index)),
507
+ yticklabels=data4heatmap.index,
508
+ xlim=[-0.5,len(data4heatmap.columns)-0.5],
509
+ ylim=[-0.5,len(data4heatmap.index)-0.5]
510
+ )
511
+ # 添加颜色条
512
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
513
+ sm.set_array([]) # 仅用于显示颜色条
514
+ plt.colorbar(sm, ax=ax,
515
+ # label="Correlation Coefficient"
516
+ )
517
+ return ax
402
518
 
403
519
 
404
520
  # !usage: py2ls.plot.heatmap()
@@ -2475,7 +2591,8 @@ def get_color(
2475
2591
  "#B25E9D",
2476
2592
  "#4B8C3B",
2477
2593
  "#EF8632",
2478
- "#24578E" "#FF2C00",
2594
+ "#24578E",
2595
+ "#FF2C00",
2479
2596
  ]
2480
2597
  elif n == 8:
2481
2598
  # colorlist = ['#1f77b4','#ff7f0e','#367B7F','#51B34F','#d62728','#aa40fc','#e377c2','#17becf']
@@ -3199,9 +3316,12 @@ def plotxy(
3199
3316
  zorder = 0
3200
3317
  for k in kind_:
3201
3318
  # preprocess data
3202
- data=df_preprocessing_(data, kind=k)
3203
- if 'variable' in data.columns and 'value' in data.columns:
3204
- x,y='variable','value'
3319
+ try:
3320
+ data=df_preprocessing_(data, kind=k)
3321
+ if 'variable' in data.columns and 'value' in data.columns:
3322
+ x,y='variable','value'
3323
+ except Exception as e:
3324
+ print(e)
3205
3325
  zorder += 1
3206
3326
  # indicate 'col' features
3207
3327
  col = kwargs.get("col", None)
@@ -3222,18 +3342,37 @@ def plotxy(
3222
3342
  # (1) return FcetGrid
3223
3343
  if k == "jointplot":
3224
3344
  kws_joint = kwargs.pop("kws_joint", kwargs)
3225
- stats=kwargs.pop("stats",True)
3345
+ kws_joint = {
3346
+ k: v for k, v in kws_joint.items() if not k.startswith("kws_")
3347
+ }
3348
+ hue = kwargs.get("hue", None)
3349
+ if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
3350
+ kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
3351
+
3352
+ palette = kwargs.get("palette", None)
3353
+ if palette is None:
3354
+ palette = kws_joint.pop(
3355
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3356
+ )
3357
+ else:
3358
+ kws_joint.pop("palette", palette)
3359
+ stats=kwargs.pop("stats",None)
3360
+ if stats:
3361
+ stats=kws_joint.pop("stats",True)
3226
3362
  if stats:
3227
3363
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3228
- g = sns.jointplot(data=data, x=x, y=y, **kws_joint)
3229
- g.ax_joint.annotate(
3230
- f"pearsonr = {r:.2f} p = {p_value:.3f}",
3231
- xy=(0.6, 0.98),
3232
- xycoords="axes fraction",
3233
- fontsize=12,
3234
- color="black",
3235
- ha="center",
3236
- )
3364
+ for key in ["palette", "alpha", "hue","stats"]:
3365
+ kws_joint.pop(key, None)
3366
+ g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
3367
+ if stats:
3368
+ g.ax_joint.annotate(
3369
+ f"pearsonr = {r:.2f} p = {p_value:.3f}",
3370
+ xy=(0.6, 0.98),
3371
+ xycoords="axes fraction",
3372
+ fontsize=12,
3373
+ color="black",
3374
+ ha="center",
3375
+ )
3237
3376
  elif k == "lmplot":
3238
3377
  kws_lm = kwargs.pop("kws_lm", kwargs)
3239
3378
  stats = kwargs.pop("stats", True) # Flag to calculate stats
@@ -3379,8 +3518,7 @@ def plotxy(
3379
3518
  xycoords="axes fraction",
3380
3519
  fontsize=12,
3381
3520
  color="black",
3382
- ha="center",
3383
- )
3521
+ ha="center")
3384
3522
 
3385
3523
  elif k == "catplot_sns":
3386
3524
  kws_cat = kwargs.pop("kws_cat", kwargs)
@@ -3388,7 +3526,7 @@ def plotxy(
3388
3526
  elif k == "displot":
3389
3527
  kws_dis = kwargs.pop("kws_dis", kwargs)
3390
3528
  # displot creates a new figure and returns a FacetGrid
3391
- g = sns.displot(data=data, x=x, **kws_dis)
3529
+ g = sns.displot(data=data, x=x,y=y, **kws_dis)
3392
3530
 
3393
3531
  # (2) return axis
3394
3532
  if ax is None:
@@ -3399,19 +3537,58 @@ def plotxy(
3399
3537
  elif k == "stdshade":
3400
3538
  kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
3401
3539
  ax = stdshade(ax=ax, **kwargs)
3540
+ elif k=="ellipse":
3541
+ kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
3542
+ kws_ellipse = {
3543
+ k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
3544
+ }
3545
+ hue = kwargs.get("hue", None)
3546
+ if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
3547
+ kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
3548
+
3549
+ palette = kwargs.get("palette", None)
3550
+ if palette is None:
3551
+ palette = kws_ellipse.pop(
3552
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3553
+ )
3554
+ alpha = kws_ellipse.pop("alpha", 0.1)
3555
+ hue_order = kwargs.get("hue_order",None)
3556
+ if hue_order is None:
3557
+ hue_order = kws_ellipse.get("hue_order",None)
3558
+ if hue_order:
3559
+ data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
3560
+ data = data.sort_values(by="hue")
3561
+ for key in ["palette", "alpha", "hue","hue_order"]:
3562
+ kws_ellipse.pop(key, None)
3563
+ ax=ellipse(
3564
+ ax=ax,
3565
+ data=data,
3566
+ x=x,
3567
+ y=y,
3568
+ hue=hue,
3569
+ palette=palette,
3570
+ alpha=alpha,
3571
+ zorder=zorder,
3572
+ **kws_ellipse,
3573
+ )
3402
3574
  elif k == "scatterplot":
3403
3575
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3404
3576
  kws_scatter = {
3405
3577
  k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3406
3578
  }
3407
- hue = kwargs.pop("hue", None)
3579
+ hue = kwargs.get("hue", None)
3408
3580
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3409
3581
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3410
- palette = kws_scatter.pop(
3411
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3412
- )
3582
+ palette=kws_scatter.get("palette",None)
3583
+ if palette is None:
3584
+ palette = kws_scatter.pop(
3585
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3586
+ )
3413
3587
  s = kws_scatter.pop("s", 10)
3414
- alpha = kws_scatter.pop("alpha", 0.7)
3588
+ alpha = kws_scatter.pop("alpha", 0.7)
3589
+ for key in ["s", "palette", "alpha", "hue"]:
3590
+ kws_scatter.pop(key, None)
3591
+
3415
3592
  ax = sns.scatterplot(
3416
3593
  ax=ax,
3417
3594
  data=data,
@@ -3427,19 +3604,33 @@ def plotxy(
3427
3604
  elif k == "histplot":
3428
3605
  kws_hist = kwargs.pop("kws_hist", kwargs)
3429
3606
  kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3430
- ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3431
- elif k == "kdeplot":
3607
+ ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
3608
+ elif k == "kdeplot":
3432
3609
  kws_kde = kwargs.pop("kws_kde", kwargs)
3433
- kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3434
- ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3610
+ kws_kde = {
3611
+ k: v for k, v in kws_kde.items() if not k.startswith("kws_")
3612
+ }
3613
+ hue = kwargs.get("hue", None)
3614
+ if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
3615
+ kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
3616
+
3617
+ palette = kwargs.get("palette", None)
3618
+ if palette is None:
3619
+ palette = kws_kde.pop(
3620
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3621
+ )
3622
+ alpha = kws_kde.pop("alpha", 0.05)
3623
+ for key in ["palette", "alpha", "hue"]:
3624
+ kws_kde.pop(key, None)
3625
+ ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
3435
3626
  elif k == "ecdfplot":
3436
3627
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3437
3628
  kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3438
- ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3629
+ ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
3439
3630
  elif k == "rugplot":
3440
3631
  kws_rug = kwargs.pop("kws_rug", kwargs)
3441
3632
  kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3442
- ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3633
+ ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
3443
3634
  elif k == "stripplot":
3444
3635
  kws_strip = kwargs.pop("kws_strip", kwargs)
3445
3636
  kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
@@ -3512,7 +3703,8 @@ def plotxy(
3512
3703
  figsets(ax=ax, **kws_figsets)
3513
3704
  if kws_add_text:
3514
3705
  add_text(ax=ax, **kws_add_text) if kws_add_text else None
3515
- if run_once_within(10):
3706
+ if run_once_within(10):
3707
+ for k in kind_:
3516
3708
  print(f"\n{k}⤵ ")
3517
3709
  print(default_settings[k])
3518
3710
  # print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
@@ -3570,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
3570
3762
  "lineplot", # Can work with both wide and long formats
3571
3763
  "area plot", # Can work with both formats, useful for stacked areas
3572
3764
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3765
+ "ellipse",# ellipse plot, default confidence=0.95
3573
3766
  ],
3574
3767
  )[0]
3575
3768
 
@@ -3605,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
3605
3798
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3606
3799
  "relplot",
3607
3800
  "pointplot", # Works well with wide format
3801
+ "ellipse",
3608
3802
  ]
3609
3803
 
3610
3804
  # Wide format (e.g., for heatmap and pairplot)
@@ -4096,25 +4290,45 @@ def venn(
4096
4290
  """
4097
4291
  if ax is None:
4098
4292
  ax = plt.gca()
4293
+ if isinstance(lists, dict):
4294
+ labels = list(lists.keys())
4295
+ lists = list(lists.values())
4296
+ if isinstance(lists[0], set):
4297
+ lists = [list(i) for i in lists]
4099
4298
  lists = [set(flatten(i, verbose=False)) for i in lists]
4100
4299
  # Function to apply text styles to labels
4101
4300
  if colors is None:
4102
4301
  colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
4103
4302
  if labels is None:
4104
- labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
4303
+ if len(lists) == 2:
4304
+ labels = ["set1", "set2"]
4305
+ elif len(lists) == 3:
4306
+ labels = ["set1", "set2", "set3"]
4307
+ elif len(lists) == 4:
4308
+ labels = ["set1", "set2", "set3","set4"]
4309
+ elif len(lists) == 5:
4310
+ labels = ["set1", "set2", "set3","set4","set55"]
4311
+ elif len(lists) == 6:
4312
+ labels = ["set1", "set2", "set3","set4","set5","set6"]
4313
+ elif len(lists) == 7:
4314
+ labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
4105
4315
  if edgecolor is None:
4106
4316
  edgecolor = colors
4107
4317
  colors = [desaturate_color(color, saturation) for color in colors]
4108
- # Check colors and auto-calculate overlaps
4109
- if len(lists) == 2:
4318
+ universe = len(set.union(*lists))
4110
4319
 
4111
- def get_count_and_percentage(set_count, subset_count):
4112
- percent = subset_count / set_count if set_count > 0 else 0
4113
- return (
4114
- f"{subset_count}\n({fmt.format(percent)})"
4115
- if show_percentages
4116
- else f"{subset_count}"
4117
- )
4320
+ # Check colors and auto-calculate overlaps
4321
+ def get_count_and_percentage(set_count, subset_count):
4322
+ percent = subset_count / set_count if set_count > 0 else 0
4323
+ return (
4324
+ f"{subset_count}\n({fmt.format(percent)})"
4325
+ if show_percentages
4326
+ else f"{subset_count}"
4327
+ )
4328
+ if fmt is not None:
4329
+ if not fmt.startswith("{"):
4330
+ fmt="{:" + fmt + "}"
4331
+ if len(lists) == 2:
4118
4332
 
4119
4333
  from matplotlib_venn import venn2, venn2_circles
4120
4334
 
@@ -4127,21 +4341,28 @@ def venn(
4127
4341
  set1, set2 = lists[0], lists[1]
4128
4342
  v.get_patch_by_id("10").set_color(colors[0])
4129
4343
  v.get_patch_by_id("01").set_color(colors[1])
4130
- v.get_patch_by_id("11").set_color(
4131
- get_color_overlap(colors[0], colors[1]) if colors else None
4132
- )
4344
+ try:
4345
+ v.get_patch_by_id("11").set_color(
4346
+ get_color_overlap(colors[0], colors[1]) if colors else None
4347
+ )
4348
+ except Exception as e:
4349
+ print(e)
4133
4350
  # v.get_label_by_id('10').set_text(len(set1 - set2))
4134
4351
  # v.get_label_by_id('01').set_text(len(set2 - set1))
4135
4352
  # v.get_label_by_id('11').set_text(len(set1 & set2))
4353
+
4136
4354
  v.get_label_by_id("10").set_text(
4137
- get_count_and_percentage(len(set1), len(set1 - set2))
4355
+ get_count_and_percentage(universe, len(set1 - set2))
4138
4356
  )
4139
4357
  v.get_label_by_id("01").set_text(
4140
- get_count_and_percentage(len(set2), len(set2 - set1))
4141
- )
4142
- v.get_label_by_id("11").set_text(
4143
- get_count_and_percentage(len(set1 | set2), len(set1 & set2))
4358
+ get_count_and_percentage(universe, len(set2 - set1))
4144
4359
  )
4360
+ try:
4361
+ v.get_label_by_id("11").set_text(
4362
+ get_count_and_percentage(universe, len(set1 & set2))
4363
+ )
4364
+ except Exception as e:
4365
+ print(e)
4145
4366
 
4146
4367
  if not isinstance(linewidth, list):
4147
4368
  linewidth = [linewidth]
@@ -4224,16 +4445,14 @@ def venn(
4224
4445
  va=va,
4225
4446
  shadow=shadow,
4226
4447
  )
4227
-
4228
- elif len(lists) == 3:
4229
-
4230
- def get_label(set_count, subset_count):
4231
- percent = subset_count / set_count if set_count > 0 else 0
4232
- return (
4233
- f"{subset_count}\n({fmt.format(percent)})"
4234
- if show_percentages
4235
- else f"{subset_count}"
4236
- )
4448
+ # Set transparency level
4449
+ for patch in v.patches:
4450
+ if patch:
4451
+ patch.set_alpha(alpha)
4452
+ if "none" in edgecolor or 0 in linewidth:
4453
+ patch.set_edgecolor("none")
4454
+ return ax
4455
+ elif len(lists) == 3:
4237
4456
 
4238
4457
  from matplotlib_venn import venn3, venn3_circles
4239
4458
 
@@ -4249,36 +4468,34 @@ def venn(
4249
4468
  # Draw the venn diagram
4250
4469
  v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
4251
4470
  v.get_patch_by_id("100").set_color(colors[0])
4471
+ v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
4252
4472
  v.get_patch_by_id("010").set_color(colors[1])
4253
- v.get_patch_by_id("001").set_color(colors[2])
4254
- v.get_patch_by_id("110").set_color(colorAB)
4255
- v.get_patch_by_id("101").set_color(colorAC)
4256
- v.get_patch_by_id("011").set_color(colorBC)
4257
- v.get_patch_by_id("111").set_color(colorABC)
4258
-
4259
- # Correctly labeling subset sizes
4260
- # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
4261
- # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
4262
- # v.get_label_by_id('001').set_text(len(set3 - set1 - set2))
4263
- # v.get_label_by_id('110').set_text(len(set1 & set2 - set3))
4264
- # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
4265
- # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
4266
- # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
4267
- v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
4268
- v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
4269
- v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
4270
- v.get_label_by_id("110").set_text(
4271
- get_label(len(set1 | set2), len(set1 & set2 - set3))
4272
- )
4273
- v.get_label_by_id("101").set_text(
4274
- get_label(len(set1 | set3), len(set1 & set3 - set2))
4275
- )
4276
- v.get_label_by_id("011").set_text(
4277
- get_label(len(set2 | set3), len(set2 & set3 - set1))
4278
- )
4279
- v.get_label_by_id("111").set_text(
4280
- get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
4281
- )
4473
+ v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
4474
+ try:
4475
+ v.get_patch_by_id("001").set_color(colors[2])
4476
+ v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
4477
+ except Exception as e:
4478
+ print(e)
4479
+ try:
4480
+ v.get_patch_by_id("110").set_color(colorAB)
4481
+ v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
4482
+ except Exception as e:
4483
+ print(e)
4484
+ try:
4485
+ v.get_patch_by_id("101").set_color(colorAC)
4486
+ v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
4487
+ except Exception as e:
4488
+ print(e)
4489
+ try:
4490
+ v.get_patch_by_id("011").set_color(colorBC)
4491
+ v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
4492
+ except Exception as e:
4493
+ print(e)
4494
+ try:
4495
+ v.get_patch_by_id("111").set_color(colorABC)
4496
+ v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
4497
+ except Exception as e:
4498
+ print(e)
4282
4499
 
4283
4500
  # Apply styles to set labels
4284
4501
  for i, text in enumerate(v.set_labels):
@@ -4383,16 +4600,34 @@ def venn(
4383
4600
  ax.add_patch(ellipse1)
4384
4601
  ax.add_patch(ellipse2)
4385
4602
  ax.add_patch(ellipse3)
4603
+ # Set transparency level
4604
+ for patch in v.patches:
4605
+ if patch:
4606
+ patch.set_alpha(alpha)
4607
+ if "none" in edgecolor or 0 in linewidth:
4608
+ patch.set_edgecolor("none")
4609
+ return ax
4610
+
4611
+
4612
+ dict_data = {}
4613
+ for i_list, list_ in enumerate(lists):
4614
+ dict_data[labels[i_list]]={*list_}
4615
+
4616
+ if 3<len(lists)<6:
4617
+ from venn import venn as vn
4618
+
4619
+ legend_loc=kwargs.pop("legend_loc", "upper right")
4620
+ ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
4621
+
4622
+ return ax
4386
4623
  else:
4387
- raise ValueError("只支持2或者3个list")
4388
-
4389
- # Set transparency level
4390
- for patch in v.patches:
4391
- if patch:
4392
- patch.set_alpha(alpha)
4393
- if "none" in edgecolor or 0 in linewidth:
4394
- patch.set_edgecolor("none")
4395
- return ax
4624
+ from venn import pseudovenn
4625
+ cmap=kwargs.pop("cmap","plasma")
4626
+ ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
4627
+
4628
+ return ax
4629
+
4630
+
4396
4631
 
4397
4632
 
4398
4633
  #! subplots, support automatic extend new axis
@@ -4403,6 +4638,7 @@ def subplot(
4403
4638
  sharex=False,
4404
4639
  sharey=False,
4405
4640
  verbose=False,
4641
+ fig=None,
4406
4642
  **kwargs,
4407
4643
  ):
4408
4644
  """
@@ -4431,8 +4667,8 @@ def subplot(
4431
4667
  )
4432
4668
 
4433
4669
  figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
4434
-
4435
- fig = plt.figure(figsize=figsize, constrained_layout=True)
4670
+ if fig is None:
4671
+ fig = plt.figure(figsize=figsize, constrained_layout=True)
4436
4672
  grid_spec = GridSpec(rows, cols, figure=fig)
4437
4673
  occupied = set()
4438
4674
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
@@ -4496,18 +4732,19 @@ def subplot(
4496
4732
  #! radar chart
4497
4733
  def radar(
4498
4734
  data: pd.DataFrame,
4499
- title="Radar Chart",
4735
+ columns=None,
4500
4736
  ylim=(0, 100),
4501
- color=get_color(5),
4737
+ facecolor=None,
4738
+ edgecolor="none",
4739
+ edge_linewidth=0.5,
4502
4740
  fontsize=10,
4503
4741
  fontcolor="k",
4504
4742
  size=6,
4505
4743
  linewidth=1,
4506
4744
  linestyle="-",
4507
- alpha=0.5,
4745
+ alpha=0.3,
4746
+ fmt=".1f",
4508
4747
  marker="o",
4509
- edgecolor="none",
4510
- edge_linewidth=0.5,
4511
4748
  bg_color="0.8",
4512
4749
  bg_alpha=None,
4513
4750
  grid_interval_ratio=0.2,
@@ -4527,19 +4764,34 @@ def radar(
4527
4764
  ax=None,
4528
4765
  sp=2,
4529
4766
  verbose=True,
4767
+ axis=0,
4530
4768
  **kwargs,
4531
4769
  ):
4532
4770
  """
4533
4771
  Example DATA:
4534
4772
  df = pd.DataFrame(
4535
- data=[
4536
- [80, 80, 80, 80, 80, 80, 80],
4537
- [90, 20, 95, 95, 30, 30, 80],
4538
- [60, 90, 20, 20, 100, 90, 50],
4539
- ],
4540
- index=["Hero", "Warrior", "Wizard"],
4541
- columns=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"])
4542
-
4773
+ data=[
4774
+ [80, 90, 60],
4775
+ [80, 20, 90],
4776
+ [80, 95, 20],
4777
+ [80, 95, 20],
4778
+ [80, 30, 100],
4779
+ [80, 30, 90],
4780
+ [80, 80, 50],
4781
+ ],
4782
+ index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
4783
+ columns=["Hero", "Warrior", "Wizard"],
4784
+ )
4785
+ usage 1:
4786
+ radar(data=df)
4787
+ usage 2:
4788
+ radar(data=df["Wizard"])
4789
+ usage 3:
4790
+ radar(data=df, columns="Wizard")
4791
+ usage 4:
4792
+ nexttile = subplot(1, 2)
4793
+ radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
4794
+ pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
4543
4795
  Parameters:
4544
4796
  - data (pd.DataFrame): The data to plot. Each column corresponds to a variable, and each row represents a data point.
4545
4797
  - ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4556,7 +4808,6 @@ def radar(
4556
4808
  - edge_linewidth (int): Line width for the marker edges.
4557
4809
  - bg_color (str): Background color for the radar chart.
4558
4810
  - grid_interval_ratio (float): Determines the intervals for the grid lines as a fraction of the y-limit.
4559
- - title (str): The title of the radar chart.
4560
4811
  - cmap (str): The colormap to use if `color` is a list.
4561
4812
  - legend_loc (str): The location of the legend.
4562
4813
  - legend_fontsize (int): Font size for the legend.
@@ -4573,22 +4824,22 @@ def radar(
4573
4824
  - sp (int): Padding for the ticks from the plot area.
4574
4825
  - **kwargs: Additional arguments for customization.
4575
4826
  """
4576
- if run_once_within() and verbose:
4827
+ if run_once_within(20,reverse=True) and verbose:
4577
4828
  usage_="""usage:
4578
4829
  radar(
4579
4830
  data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
4580
- title="Radar Chart",
4581
4831
  ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
4582
- color=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
4832
+ facecolor=get_color(5),#The color(s) for the plot. Can be a single color or a list of colors.
4833
+ edgecolor="none",#for the marker edges.
4834
+ edge_linewidth=0.5,#for the marker edges.
4583
4835
  fontsize=10,# Font size for the angular labels (x-axis).
4584
4836
  fontcolor="k",# Color for the angular labels.
4585
4837
  size=6,#The size of the markers for each data point.
4586
4838
  linewidth=1,
4587
4839
  linestyle="-",
4588
4840
  alpha=0.5,#for the filled area.
4841
+ fmt=".1f",
4589
4842
  marker="o",# for the data points.
4590
- edgecolor="none",#for the marker edges.
4591
- edge_linewidth=0.5,#for the marker edges.
4592
4843
  bg_color="0.8",
4593
4844
  bg_alpha=None,
4594
4845
  grid_interval_ratio=0.2,#Determines the intervals for the grid lines as a fraction of the y-limit.
@@ -4618,9 +4869,25 @@ def radar(
4618
4869
  kws_figsets = v_arg
4619
4870
  kwargs.pop(k_arg, None)
4620
4871
  break
4621
- categories = list(data.columns)
4872
+ if axis==1:
4873
+ data=data.T
4874
+ if isinstance(data, dict):
4875
+ data = pd.DataFrame(pd.Series(data))
4876
+ if ~isinstance(data, pd.DataFrame):
4877
+ data=pd.DataFrame(data)
4878
+ if isinstance(data, pd.DataFrame):
4879
+ data=data.select_dtypes(include=np.number)
4880
+ if isinstance(columns,str):
4881
+ columns=[columns]
4882
+ if columns is None:
4883
+ columns = list(data.columns)
4884
+ data=data[columns]
4885
+ categories = list(data.index)
4622
4886
  num_vars = len(categories)
4623
4887
 
4888
+ # Set y-axis limits and grid intervals
4889
+ vmin, vmax = ylim
4890
+
4624
4891
  # Set up angle for each category on radar chart
4625
4892
  angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
4626
4893
  angles += angles[:1] # Complete the loop to ensure straight-line connections
@@ -4644,9 +4911,6 @@ def radar(
4644
4911
  # Draw one axis per variable and add labels
4645
4912
  ax.set_xticks(angles[:-1])
4646
4913
  ax.set_xticklabels(categories)
4647
-
4648
- # Set y-axis limits and grid intervals
4649
- vmin, vmax = ylim
4650
4914
  if circular:
4651
4915
  # * cicular style
4652
4916
  ax.yaxis.set_ticks(np.arange(vmin, vmax + 1, vmax * grid_interval_ratio))
@@ -4669,7 +4933,7 @@ def radar(
4669
4933
  else:
4670
4934
  # * spider style: spider-style grid (straight lines, not circles)
4671
4935
  # Create the spider-style grid (straight lines, not circles)
4672
- for i in range(1, int(vmax * grid_interval_ratio) + 1):
4936
+ for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
4673
4937
  ax.plot(
4674
4938
  angles + [angles[0]], # Closing the loop
4675
4939
  [i * vmax * grid_interval_ratio] * (num_vars + 1)
@@ -4680,7 +4944,7 @@ def radar(
4680
4944
  linewidth=grid_linewidth,
4681
4945
  )
4682
4946
  # set bg_color
4683
- ax.fill(angles, [vmax] * (data.shape[1] + 1), color=bg_color, alpha=bg_alpha)
4947
+ ax.fill(angles, [vmax] * (data.shape[0] + 1), color=bg_color, alpha=bg_alpha)
4684
4948
  ax.yaxis.grid(False)
4685
4949
  # Move radial labels away from plotted line
4686
4950
  if tick_loc is None:
@@ -4702,14 +4966,20 @@ def radar(
4702
4966
  ax.tick_params(axis="x", pad=sp) # move spines outward
4703
4967
  ax.tick_params(axis="y", pad=sp) # move spines outward
4704
4968
  # colors
4705
- colors = (
4706
- get_color(data.shape[0])
4707
- if cmap is None
4708
- else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4709
- )
4969
+ if facecolor is not None:
4970
+ if not isinstance(facecolor,list):
4971
+ facecolor=[facecolor]
4972
+ colors = facecolor
4973
+ else:
4974
+ colors = (
4975
+ get_color(data.shape[1])
4976
+ if cmap is None
4977
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
4978
+ )
4979
+
4710
4980
  # Plot each row with straight lines
4711
- for i, (index, row) in enumerate(data.iterrows()):
4712
- values = row.tolist()
4981
+ for i, (col, val) in enumerate(data.items()):
4982
+ values = val.tolist()
4713
4983
  values += values[:1] # Close the loop
4714
4984
  ax.plot(
4715
4985
  angles,
@@ -4717,7 +4987,7 @@ def radar(
4717
4987
  color=colors[i],
4718
4988
  linewidth=linewidth,
4719
4989
  linestyle=linestyle,
4720
- label=index,
4990
+ label=col,
4721
4991
  clip_on=False,
4722
4992
  )
4723
4993
  ax.fill(angles, values, color=colors[i], alpha=alpha)
@@ -4748,7 +5018,7 @@ def radar(
4748
5018
  ax.text(
4749
5019
  angle,
4750
5020
  offset_radius,
4751
- str(value),
5021
+ f"{value:{fmt}}",
4752
5022
  ha="center",
4753
5023
  va="center",
4754
5024
  fontsize=fontsize,
@@ -4759,10 +5029,10 @@ def radar(
4759
5029
 
4760
5030
  ax.set_ylim(ylim)
4761
5031
  # Add markers for each data point
4762
- for i, row in enumerate(data.values):
5032
+ for i, (col, val) in enumerate(data.items()):
4763
5033
  ax.plot(
4764
5034
  angles,
4765
- list(row) + [row[0]], # Close the loop for markers
5035
+ list(val) + [val[0]], # Close the loop for markers
4766
5036
  color=colors[i],
4767
5037
  marker=marker,
4768
5038
  markersize=size,
@@ -4787,3 +5057,819 @@ def radar(
4787
5057
  **kws_figsets,
4788
5058
  )
4789
5059
  return ax
5060
+
5061
+
5062
+ def pie(
5063
+ data:pd.Series,
5064
+ columns:list = None,
5065
+ facecolor=None,
5066
+ explode=[0.1],
5067
+ startangle=90,
5068
+ shadow=True,
5069
+ fontcolor="k",
5070
+ fmt=".2f",
5071
+ width=None,# the center blank
5072
+ pctdistance=0.85,
5073
+ labeldistance=1.1,
5074
+ kws_wedge={},
5075
+ kws_text={},
5076
+ kws_arrow={},
5077
+ center=(0, 0),
5078
+ radius=1,
5079
+ frame=False,
5080
+ fontsize=10,
5081
+ edgecolor="white",
5082
+ edgewidth=1,
5083
+ cmap=None,
5084
+ show_value=False,
5085
+ show_label=True,# False: only show the outer layer, if it is None, not show
5086
+ expand_label=(1.2,1.2),
5087
+ kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5088
+ show_legend=True,
5089
+ legend_loc="upper right",
5090
+ bbox_to_anchor=[1.4, 1.1],
5091
+ legend_fontsize=10,
5092
+ rotation_correction=0,
5093
+ verbose=True,
5094
+ ax=None,
5095
+ **kwargs
5096
+ ):
5097
+ from adjustText import adjust_text
5098
+ if run_once_within(20,reverse=True) and verbose:
5099
+ usage_="""usage:
5100
+ pie(
5101
+ data:pd.Series,
5102
+ columns:list = None,
5103
+ facecolor=None,
5104
+ explode=[0.1],
5105
+ startangle=90,
5106
+ shadow=True,
5107
+ fontcolor="k",
5108
+ fmt=".2f",
5109
+ width=None,# the center blank
5110
+ pctdistance=0.85,
5111
+ labeldistance=1.1,
5112
+ kws_wedge={},
5113
+ kws_text={},
5114
+ center=(0, 0),
5115
+ radius=1,
5116
+ frame=False,
5117
+ fontsize=10,
5118
+ edgecolor="white",
5119
+ edgewidth=1,
5120
+ cmap=None,
5121
+ show_value=False,
5122
+ show_label=True,# False: only show the outer layer, if it is None, not show
5123
+ show_legend=True,
5124
+ legend_loc="upper right",
5125
+ bbox_to_anchor=[1.4, 1.1],
5126
+ legend_fontsize=10,
5127
+ rotation_correction=0,
5128
+ verbose=True,
5129
+ ax=None,
5130
+ **kwargs
5131
+ )
5132
+
5133
+ usage 1:
5134
+ data = {"Segment A": 30, "Segment B": 50, "Segment C": 20}
5135
+
5136
+ ax = pie(
5137
+ data=data,
5138
+ # columns="Segment A",
5139
+ explode=[0, 0.2, 0],
5140
+ # width=0.4,
5141
+ show_label=False,
5142
+ fontsize=10,
5143
+ # show_value=1,
5144
+ fmt=".3f",
5145
+ )
5146
+
5147
+ # prepare dataset
5148
+ df = pd.DataFrame(
5149
+ data=[
5150
+ [80, 90, 60],
5151
+ [80, 20, 90],
5152
+ [80, 95, 20],
5153
+ [80, 95, 20],
5154
+ [80, 30, 100],
5155
+ [80, 30, 90],
5156
+ [80, 80, 50],
5157
+ ],
5158
+ index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
5159
+ columns=["Hero", "Warrior", "Wizard"],
5160
+ )
5161
+ usage 1: only plot one column
5162
+ pie(
5163
+ df,
5164
+ columns="Wizard",
5165
+ width=0.6,
5166
+ show_label=False,
5167
+ fmt=".0f",
5168
+ )
5169
+ usage 2:
5170
+ pie(df,columns=["Hero", "Warrior"],show_label=False)
5171
+ usage 3: set different width
5172
+ pie(df,
5173
+ columns=["Hero", "Warrior", "Wizard"],
5174
+ width=[0.3, 0.2, 0.2],
5175
+ show_label=False,
5176
+ fmt=".0f",
5177
+ )
5178
+ usage 4: set width the same for all columns
5179
+ pie(df,
5180
+ columns=["Hero", "Warrior", "Wizard"],
5181
+ width=0.2,
5182
+ show_label=False,
5183
+ fmt=".0f",
5184
+ )
5185
+ usage 5: adjust the labels' offset
5186
+ pie(df, columns="Wizard", width=0.6, show_label=False, fmt=".6f", labeldistance=1.2)
5187
+
5188
+ usage 6:
5189
+ nexttile = subplot(1, 2)
5190
+ radar(data=df, columns="Wizard", ax=nexttile(projection="polar"))
5191
+ pie(data=df, columns="Wizard", ax=nexttile(), width=0.5, pctdistance=0.7)
5192
+ """
5193
+ print(usage_)
5194
+ # Convert data to a Pandas Series if needed
5195
+ if isinstance(data, dict):
5196
+ data = pd.DataFrame(pd.Series(data))
5197
+ if ~isinstance(data, pd.DataFrame):
5198
+ data=pd.DataFrame(data)
5199
+
5200
+ if isinstance(data, pd.DataFrame):
5201
+ data=data.select_dtypes(include=np.number)
5202
+ if isinstance(columns,str):
5203
+ columns=[columns]
5204
+ if columns is None:
5205
+ columns = list(data.columns)
5206
+ # data=data[columns]
5207
+ # columns = list(data.columns)
5208
+ # print(columns)
5209
+ # 选择部分数据
5210
+ df=data[columns]
5211
+
5212
+ if not isinstance(explode, list):
5213
+ explode=[explode]
5214
+ if explode == [None]:
5215
+ explode=[0]
5216
+
5217
+ if width is None:
5218
+ if df.shape[1]>1:
5219
+ width=1/(df.shape[1]+2)
5220
+ else:
5221
+ width=1
5222
+ if isinstance(width,(float,int)):
5223
+ width=[width]
5224
+ if len(width)<df.shape[1]:
5225
+ width=width*df.shape[1]
5226
+ if isinstance(radius,(float,int)):
5227
+ radius=[radius]
5228
+ radius_tile=[1]*df.shape[1]
5229
+ radius=radius_tile.copy()
5230
+ for i in range(1,df.shape[1]):
5231
+ radius[i]=radius_tile[i]-np.sum(width[:i])
5232
+
5233
+ # colors
5234
+ if facecolor is not None:
5235
+ if not isinstance(facecolor,list):
5236
+ facecolor=[facecolor]
5237
+ colors = facecolor
5238
+ else:
5239
+ colors = (
5240
+ get_color(data.shape[0])
5241
+ if cmap is None
5242
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
5243
+ )
5244
+ # to check if facecolor is nested list or not
5245
+ is_nested = True if any(isinstance(i, list) for i in colors) else False
5246
+ inested = 0
5247
+ for column_,width_,radius_ in zip(columns, width,radius):
5248
+ if column_!=columns[0]:
5249
+ labels = data.index if show_label else None
5250
+ else:
5251
+ labels = data.index if show_label is not None else None
5252
+ data = df[column_]
5253
+ labels_legend=data.index
5254
+ sizes = data.values
5255
+
5256
+ # Set wedge and text properties if none are provided
5257
+ kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
5258
+ kws_wedge.update({"width":width_})
5259
+ fontcolor=kws_text.get("color",fontcolor)
5260
+ fontsize=kws_text.get("fontsize",fontsize)
5261
+ kws_text.update({"color": fontcolor, "fontsize": fontsize})
5262
+
5263
+ if ax is None:
5264
+ ax=plt.gca()
5265
+ if len(explode)<len(labels_legend):
5266
+ explode.extend([0]*(len(labels_legend)-len(explode)))
5267
+ print(explode)
5268
+ if fmt:
5269
+ if not fmt.startswith("%"):
5270
+ autopct =f"%{fmt}%%"
5271
+ else:
5272
+ autopct=None
5273
+
5274
+ if show_value is None:
5275
+ result = ax.pie(
5276
+ sizes,
5277
+ labels=labels,
5278
+ autopct= None,
5279
+ startangle=startangle + rotation_correction,
5280
+ explode=explode,
5281
+ colors=colors[inested] if is_nested else colors,
5282
+ shadow=shadow,
5283
+ pctdistance=pctdistance,
5284
+ labeldistance=labeldistance,
5285
+ wedgeprops=kws_wedge,
5286
+ textprops=kws_text,
5287
+ center=center,
5288
+ radius=radius_,
5289
+ frame=frame,
5290
+ **kwargs
5291
+ )
5292
+ else:
5293
+ result = ax.pie(
5294
+ sizes,
5295
+ labels=labels,
5296
+ autopct=autopct if autopct else None,
5297
+ startangle=startangle + rotation_correction,
5298
+ explode=explode,
5299
+ colors=colors[inested] if is_nested else colors,
5300
+ shadow=shadow,#shadow,
5301
+ pctdistance=pctdistance,
5302
+ labeldistance=labeldistance,
5303
+ wedgeprops=kws_wedge,
5304
+ textprops=kws_text,
5305
+ center=center,
5306
+ radius=radius_,
5307
+ frame=frame,
5308
+ **kwargs
5309
+ )
5310
+ if len(result) == 3:
5311
+ wedges, texts, autotexts = result
5312
+ elif len(result) == 2:
5313
+ wedges, texts = result
5314
+ autotexts = None
5315
+ #! adjust_text
5316
+ if autotexts or texts:
5317
+ all_texts = []
5318
+ if autotexts and show_value:
5319
+ all_texts.extend(autotexts)
5320
+ if texts and show_label:
5321
+ all_texts.extend(texts)
5322
+
5323
+ adjust_text(
5324
+ all_texts,
5325
+ ax=ax,
5326
+ arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
5327
+ bbox=kws_bbox if kws_bbox else None,
5328
+ expand=expand_label,
5329
+ fontdict={
5330
+ "fontsize": fontsize,
5331
+ "color": fontcolor,
5332
+ },
5333
+ )
5334
+ # Show exact values on wedges if show_value is True
5335
+ if show_value:
5336
+ for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5337
+ angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5338
+ x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5339
+ y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5340
+ if not fmt.startswith("{"):
5341
+ value_text = f"{sizes[i]:{fmt}}"
5342
+ else:
5343
+ value_text = fmt.format(sizes[i])
5344
+ ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5345
+ inested+=1
5346
+ # Customize the legend
5347
+ if show_legend:
5348
+ ax.legend(
5349
+ wedges,
5350
+ labels_legend,
5351
+ loc=legend_loc,
5352
+ bbox_to_anchor=bbox_to_anchor,
5353
+ fontsize=legend_fontsize,
5354
+ title_fontsize=legend_fontsize,
5355
+ )
5356
+ ax.set(aspect="equal")
5357
+ return ax
5358
+
5359
+ def ellipse(
5360
+ data,
5361
+ x=None,
5362
+ y=None,
5363
+ hue=None,
5364
+ n_std=1.5,
5365
+ ax=None,
5366
+ confidence=0.95,
5367
+ annotate_center=False,
5368
+ palette=None,
5369
+ facecolor=None,
5370
+ edgecolor=None,
5371
+ label:bool=True,
5372
+ **kwargs,
5373
+ ):
5374
+ """
5375
+ Plot advanced ellipses representing covariance for different groups
5376
+ # simulate data:
5377
+ control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
5378
+ patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
5379
+ df = pd.DataFrame(
5380
+ {
5381
+ "Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
5382
+ "Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
5383
+ "Group": ["Control"] * 50 + ["Patient"] * 50,
5384
+ }
5385
+ )
5386
+ plotxy(
5387
+ data=df,
5388
+ x="Dim1",
5389
+ y="Dim2",
5390
+ hue="Group",
5391
+ kind_="scatter",
5392
+ palette=get_color(8),
5393
+ )
5394
+ ellipse(
5395
+ data=df,
5396
+ x="Dim1",
5397
+ y="Dim2",
5398
+ hue="Group",
5399
+ palette=get_color(8),
5400
+ alpha=0.1,
5401
+ lw=2,
5402
+ )
5403
+ Parameters:
5404
+ data (DataFrame): Input DataFrame with columns for x, y, and hue.
5405
+ x (str): Column name for x-axis values.
5406
+ y (str): Column name for y-axis values.
5407
+ hue (str, optional): Column name for group labels.
5408
+ n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
5409
+ ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
5410
+ confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
5411
+ annotate_center (bool): Whether to annotate the ellipse center (mean).
5412
+ palette (dict or list, optional): A mapping of hues to colors or a list of colors.
5413
+ **kwargs: Additional keyword arguments for the Ellipse patch.
5414
+
5415
+ Returns:
5416
+ list: List of Ellipse objects added to the Axes.
5417
+ """
5418
+ from matplotlib.patches import Ellipse
5419
+ import numpy as np
5420
+ import matplotlib.pyplot as plt
5421
+ import seaborn as sns
5422
+ import pandas as pd
5423
+ from scipy.stats import chi2
5424
+
5425
+ if ax is None:
5426
+ ax = plt.gca()
5427
+
5428
+ # Validate inputs
5429
+ if x is None or y is None:
5430
+ raise ValueError(
5431
+ "Both `x` and `y` must be specified as column names in the DataFrame."
5432
+ )
5433
+ if not isinstance(data, pd.DataFrame):
5434
+ raise ValueError("`data` must be a pandas DataFrame.")
5435
+
5436
+ # Prepare data for hue-based grouping
5437
+ ellipses = []
5438
+ if hue is not None:
5439
+ groups = data[hue].unique()
5440
+ colors = sns.color_palette(palette or "husl", len(groups))
5441
+ color_map = dict(zip(groups, colors))
5442
+ else:
5443
+ groups = [None]
5444
+ color_map = {None: kwargs.get("edgecolor", "blue")}
5445
+ alpha = kwargs.pop("alpha", 0.2)
5446
+ edgecolor=kwargs.pop("edgecolor", None)
5447
+ facecolor=kwargs.pop("facecolor", None)
5448
+ for group in groups:
5449
+ group_data = data[data[hue] == group] if hue else data
5450
+
5451
+ # Extract x and y columns for the group
5452
+ group_points = group_data[[x, y]].values
5453
+
5454
+ # Compute mean and covariance matrix
5455
+ # # 标准化处理
5456
+ # group_points = group_data[[x, y]].values
5457
+ # group_points -= group_points.mean(axis=0)
5458
+ # group_points /= group_points.std(axis=0)
5459
+
5460
+ cov = np.cov(group_points.T)
5461
+ mean = np.mean(group_points, axis=0)
5462
+
5463
+ # Eigenvalues and eigenvectors
5464
+ eigvals, eigvecs = np.linalg.eigh(cov)
5465
+ order = eigvals.argsort()[::-1]
5466
+ eigvals, eigvecs = eigvals[order], eigvecs[:, order]
5467
+
5468
+ # Rotation angle and ellipse dimensions
5469
+ angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
5470
+ if confidence:
5471
+ n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
5472
+ width, height = 2 * n_std * np.sqrt(eigvals)
5473
+
5474
+ # Create and style the ellipse
5475
+ if facecolor is None:
5476
+ facecolor_ = color_map[group]
5477
+ if edgecolor is None:
5478
+ edgecolor_ = color_map[group]
5479
+ ellipse = Ellipse(
5480
+ xy=mean,
5481
+ width=width,
5482
+ height=height,
5483
+ angle=angle,
5484
+ edgecolor=edgecolor_,
5485
+ facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
5486
+ # alpha=alpha,
5487
+ label=group if (hue and label) else None,
5488
+ **kwargs,
5489
+ )
5490
+ ax.add_patch(ellipse)
5491
+ ellipses.append(ellipse)
5492
+
5493
+ # Annotate center
5494
+ if annotate_center:
5495
+ ax.annotate(
5496
+ f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
5497
+ xy=mean,
5498
+ xycoords="data",
5499
+ fontsize=10,
5500
+ ha="center",
5501
+ color=ellipse_color,
5502
+ bbox=dict(
5503
+ boxstyle="round,pad=0.3",
5504
+ edgecolor="gray",
5505
+ facecolor="white",
5506
+ alpha=0.8,
5507
+ ),
5508
+ )
5509
+
5510
+ return ax
5511
+
5512
+ def ppi(
5513
+ interactions,
5514
+ player1="preferredName_A",
5515
+ player2="preferredName_B",
5516
+ weight="score",
5517
+ n_layers=None, # Number of concentric layers
5518
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5519
+ dist_node = 10, # Distance between each rank of circles
5520
+ layout="degree",
5521
+ size=None,#700,
5522
+ sizes=(50,500),# min and max of size
5523
+ facecolor="skyblue",
5524
+ cmap='coolwarm',
5525
+ edgecolor="k",
5526
+ edgelinewidth=1.5,
5527
+ alpha=.5,
5528
+ alphas=(0.1, 1.0),# min and max of alpha
5529
+ marker="o",
5530
+ node_hideticks=True,
5531
+ linecolor="gray",
5532
+ line_cmap='coolwarm',
5533
+ linewidth=1.5,
5534
+ linewidths=(0.5,5),# min and max of linewidth
5535
+ linealpha=1.0,
5536
+ linealphas=(0.1,1.0),# min and max of linealpha
5537
+ linestyle="-",
5538
+ line_arrowstyle='-',
5539
+ fontsize=10,
5540
+ fontcolor="k",
5541
+ ha:str="center",
5542
+ va:str="center",
5543
+ figsize=(12, 10),
5544
+ k_value=0.3,
5545
+ bgcolor="w",
5546
+ dir_save="./ppi_network.html",
5547
+ physics=True,
5548
+ notebook=False,
5549
+ scale=1,
5550
+ ax=None,
5551
+ **kwargs
5552
+ ):
5553
+ """
5554
+ Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
5555
+
5556
+ ppi(
5557
+ interactions_sort.iloc[:1000, :],
5558
+ player1="player1",
5559
+ player2="player2",
5560
+ weight="count",
5561
+ layout="spring",
5562
+ n_layers=13,
5563
+ fontsize=1,
5564
+ n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
5565
+ )
5566
+ """
5567
+ from pyvis.network import Network
5568
+ import networkx as nx
5569
+ from IPython.display import IFrame
5570
+ from matplotlib.colors import Normalize
5571
+ from matplotlib import cm
5572
+ from . import ips
5573
+
5574
+ if run_once_within():
5575
+ usage_str="""
5576
+ ppi(
5577
+ interactions,
5578
+ player1="preferredName_A",
5579
+ player2="preferredName_B",
5580
+ weight="score",
5581
+ n_layers=None, # Number of concentric layers
5582
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5583
+ dist_node = 10, # Distance between each rank of circles
5584
+ layout="degree",
5585
+ size=None,#700,
5586
+ sizes=(50,500),# min and max of size
5587
+ facecolor="skyblue",
5588
+ cmap='coolwarm',
5589
+ edgecolor="k",
5590
+ edgelinewidth=1.5,
5591
+ alpha=.5,
5592
+ alphas=(0.1, 1.0),# min and max of alpha
5593
+ marker="o",
5594
+ node_hideticks=True,
5595
+ linecolor="gray",
5596
+ line_cmap='coolwarm',
5597
+ linewidth=1.5,
5598
+ linewidths=(0.5,5),# min and max of linewidth
5599
+ linealpha=1.0,
5600
+ linealphas=(0.1,1.0),# min and max of linealpha
5601
+ linestyle="-",
5602
+ line_arrowstyle='-',
5603
+ fontsize=10,
5604
+ fontcolor="k",
5605
+ ha:str="center",
5606
+ va:str="center",
5607
+ figsize=(12, 10),
5608
+ k_value=0.3,
5609
+ bgcolor="w",
5610
+ dir_save="./ppi_network.html",
5611
+ physics=True,
5612
+ notebook=False,
5613
+ scale=1,
5614
+ ax=None,
5615
+ **kwargs
5616
+ ):
5617
+ """
5618
+ print(usage_str)
5619
+
5620
+ # Check for required columns in the DataFrame
5621
+ for col in [player1, player2, weight]:
5622
+ if col not in interactions.columns:
5623
+ raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
5624
+ interactions.sort_values(by=[weight], inplace=True)
5625
+ # Initialize Pyvis network
5626
+ net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
5627
+ net.force_atlas_2based(
5628
+ gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
5629
+ )
5630
+ net.toggle_physics(physics)
5631
+
5632
+ kws_figsets = {}
5633
+ for k_arg, v_arg in kwargs.items():
5634
+ if "figset" in k_arg:
5635
+ kws_figsets = v_arg
5636
+ kwargs.pop(k_arg, None)
5637
+ break
5638
+
5639
+ # Create a NetworkX graph from the interaction data
5640
+ G = nx.Graph()
5641
+ for _, row in interactions.iterrows():
5642
+ G.add_edge(row[player1], row[player2], weight=row[weight])
5643
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
5644
+
5645
+
5646
+ # Calculate node degrees
5647
+ degrees = dict(G.degree())
5648
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5649
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
5650
+
5651
+ if not ips.isa(facecolor, 'color'):
5652
+ print("facecolor: based on degrees")
5653
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
5654
+ num_nodes = G.number_of_nodes()
5655
+ #* size
5656
+ # Set properties based on degrees
5657
+ if not isinstance(size, (int,float,list)):
5658
+ print("size: based on degrees")
5659
+ size = [deg * 50 for deg in degrees.values()] # Scale sizes
5660
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
5661
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
5662
+ # Normalize sizes
5663
+ min_size, max_size = sizes # Use sizes tuple for min and max values
5664
+ min_degree, max_degree = min(size), max(size)
5665
+ if max_degree > min_degree: # Avoid division by zero
5666
+ size = [
5667
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5668
+ for sz in size
5669
+ ]
5670
+ else:
5671
+ # If all values are the same, set them to a default of the midpoint
5672
+ size = [(min_size + max_size) / 2] * len(size)
5673
+
5674
+ #* facecolor
5675
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
5676
+ # * facealpha
5677
+ if isinstance(alpha, list):
5678
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
5679
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
5680
+ if len(alpha) > 0:
5681
+ # Normalize alpha based on the specified min and max
5682
+ min_alpha, max_alpha = min(alpha), max(alpha)
5683
+ if max_alpha > min_alpha: # Avoid division by zero
5684
+ alpha = [
5685
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
5686
+ for ea in alpha
5687
+ ]
5688
+ else:
5689
+ # If all alpha values are the same, set them to the average of min and max
5690
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
5691
+ else:
5692
+ # Default to a full opacity if no edges are provided
5693
+ alpha = [1.0] * num_nodes
5694
+ else:
5695
+ # If alpha is a single value, convert it to a list and normalize it
5696
+ alpha = [alpha] * num_nodes # Adjust based on alphas
5697
+
5698
+ for i, node in enumerate(G.nodes()):
5699
+ net.add_node(
5700
+ node,
5701
+ label=node,
5702
+ size=size[i],
5703
+ color=facecolor[i],
5704
+ alpha=alpha[i],
5705
+ font={"size": fontsize, "color": fontcolor},
5706
+ )
5707
+ print(f'nodes number: {i+1}')
5708
+
5709
+ for edge in G.edges(data=True):
5710
+ net.add_edge(
5711
+ edge[0],
5712
+ edge[1],
5713
+ weight=edge[2]["weight"],
5714
+ color=edgecolor,
5715
+ width=edgelinewidth * edge[2]["weight"],
5716
+ )
5717
+
5718
+ layouts = [
5719
+ "spring",
5720
+ "circular",
5721
+ "kamada_kawai",
5722
+ "random",
5723
+ "shell",
5724
+ "planar",
5725
+ "spiral",
5726
+ "degree"
5727
+ ]
5728
+ layout = ips.strcmp(layout, layouts)[0]
5729
+ print(f"layout:{layout}, or select one in {layouts}")
5730
+
5731
+ # Choose layout
5732
+ if layout == "spring":
5733
+ pos = nx.spring_layout(G, k=k_value)
5734
+ elif layout == "circular":
5735
+ pos = nx.circular_layout(G)
5736
+ elif layout == "kamada_kawai":
5737
+ pos = nx.kamada_kawai_layout(G)
5738
+ elif layout == "spectral":
5739
+ pos = nx.spectral_layout(G)
5740
+ elif layout == "random":
5741
+ pos = nx.random_layout(G)
5742
+ elif layout == "shell":
5743
+ pos = nx.shell_layout(G)
5744
+ elif layout == "planar":
5745
+ if nx.check_planarity(G)[0]:
5746
+ pos = nx.planar_layout(G)
5747
+ else:
5748
+ print("Graph is not planar; switching to spring layout.")
5749
+ pos = nx.spring_layout(G, k=k_value)
5750
+ elif layout == "spiral":
5751
+ pos = nx.spiral_layout(G)
5752
+ elif layout=='degree':
5753
+ # Calculate node degrees and sort nodes by degree
5754
+ degrees = dict(G.degree())
5755
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
5756
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5757
+ colormap = cm.get_cmap(cmap)
5758
+
5759
+ # Create positions for concentric circles based on n_layers and n_rank
5760
+ pos = {}
5761
+ n_layers=len(n_rank)+1 if n_layers is None else n_layers
5762
+ for rank_index in range(n_layers):
5763
+ if rank_index < len(n_rank):
5764
+ nodes_per_rank = n_rank[rank_index]
5765
+ rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
5766
+ else:
5767
+ # 随机打乱剩余节点的顺序
5768
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
5769
+ random_indices = np.random.permutation(len(remaining_nodes))
5770
+ rank_nodes = [remaining_nodes[i] for i in random_indices]
5771
+
5772
+ radius = (rank_index + 1) * dist_node # Radius for this rank
5773
+
5774
+ # Arrange nodes in a circle for the current rank
5775
+ for i, (node, degree) in enumerate(rank_nodes):
5776
+ angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
5777
+ pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
5778
+
5779
+ else:
5780
+ print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
5781
+ pos = nx.spring_layout(G, k=k_value)
5782
+
5783
+ for node, (x, y) in pos.items():
5784
+ net.get_node(node)["x"] = x * scale
5785
+ net.get_node(node)["y"] = y * scale
5786
+
5787
+ # If ax is None, use plt.gca()
5788
+ if ax is None:
5789
+ fig, ax = plt.subplots(1,1,figsize=figsize)
5790
+
5791
+ # Draw nodes, edges, and labels with customization options
5792
+ nx.draw_networkx_nodes(
5793
+ G,
5794
+ pos,
5795
+ ax=ax,
5796
+ node_size=size,
5797
+ node_color=facecolor,
5798
+ linewidths=edgelinewidth,
5799
+ edgecolors=edgecolor,
5800
+ alpha=alpha,
5801
+ hide_ticks=node_hideticks,
5802
+ node_shape=marker
5803
+ )
5804
+
5805
+ #* linewidth
5806
+ if not isinstance(linewidth, list):
5807
+ linewidth = [linewidth] * G.number_of_edges()
5808
+ else:
5809
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
5810
+ # Normalize linewidth if it is a list
5811
+ if isinstance(linewidth, list):
5812
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
5813
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
5814
+ if max_linewidth > min_linewidth: # Avoid division by zero
5815
+ # Scale between vmin and vmax
5816
+ linewidth = [
5817
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
5818
+ for lw in linewidth
5819
+ ]
5820
+ else:
5821
+ # If all values are the same, set them to a default of the midpoint
5822
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
5823
+ else:
5824
+ # If linewidth is a single value, convert it to a list of that value
5825
+ linewidth = [linewidth] * G.number_of_edges()
5826
+ #* linecolor
5827
+ if not isinstance(linecolor, str):
5828
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
5829
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
5830
+ colormap = cm.get_cmap(line_cmap)
5831
+ linecolor = [colormap(norm(weight)) for weight in weights]
5832
+ else:
5833
+ linecolor = [linecolor] * G.number_of_edges()
5834
+
5835
+ # * linealpha
5836
+ if isinstance(linealpha, list):
5837
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
5838
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
5839
+ if len(linealpha) > 0:
5840
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
5841
+ if max_linealpha > min_linealpha: # Avoid division by zero
5842
+ linealpha = [
5843
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
5844
+ for ea in linealpha
5845
+ ]
5846
+ else:
5847
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
5848
+ else:
5849
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
5850
+ else:
5851
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
5852
+ nx.draw_networkx_edges(
5853
+ G,
5854
+ pos,
5855
+ ax=ax,
5856
+ edge_color=linecolor,
5857
+ width=linewidth,
5858
+ style=linestyle,
5859
+ arrowstyle=line_arrowstyle,
5860
+ alpha=linealpha
5861
+ )
5862
+
5863
+ nx.draw_networkx_labels(
5864
+ G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
5865
+ )
5866
+ figsets(ax=ax,**kws_figsets)
5867
+ ax.axis("off")
5868
+ if dir_save:
5869
+ if not os.path.basename(dir_save):
5870
+ dir_save="_.html"
5871
+ net.write_html(dir_save)
5872
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
5873
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
5874
+ ips.figsave(dir_save.replace(".html",".pdf"))
5875
+ return G,ax