py2ls 0.2.4.25__py3-none-any.whl → 0.2.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -20,7 +20,9 @@ from .ips import (
20
20
  flatten,
21
21
  plt_font,
22
22
  run_once_within,
23
- df_format
23
+ df_format,
24
+ df_corr,
25
+ df_scaler
24
26
  )
25
27
  import scipy.stats as scipy_stats
26
28
  from .stats import *
@@ -142,20 +144,37 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
142
144
  **kwargs,
143
145
  )
144
146
 
147
+ def pval2str(p):
148
+ if p > 0.05:
149
+ txt = ""
150
+ elif 0.01 <= p <= 0.05:
151
+ txt = "*"
152
+ elif 0.001 <= p < 0.01:
153
+ txt = "**"
154
+ elif p < 0.001:
155
+ txt = "***"
156
+ return txt
145
157
 
146
158
  def heatmap(
147
159
  data,
148
160
  ax=None,
149
161
  kind="corr", #'corr','direct','pivot'
162
+ method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
150
163
  columns="all", # pivot, default: coll numeric columns
164
+ style=0,# for correlation
151
165
  index=None, # pivot
152
166
  values=None, # pivot
167
+ fontsize=10,
153
168
  tri="u",
154
169
  mask=True,
155
170
  k=1,
171
+ vmin=None,
172
+ vmax=None,
173
+ size_scale=500,
156
174
  annot=True,
157
175
  cmap="coolwarm",
158
176
  fmt=".2f",
177
+ show_indicator = True,# only for style==1
159
178
  cluster=False,
160
179
  inplace=False,
161
180
  figsize=(10, 8),
@@ -201,7 +220,7 @@ def heatmap(
201
220
  ax = plt.gca()
202
221
  # Select numeric columns or specific subset of columns
203
222
  if columns == "all":
204
- df_numeric = data.select_dtypes(include=[float, int])
223
+ df_numeric = data.select_dtypes(include=[np.number])
205
224
  else:
206
225
  df_numeric = data[columns]
207
226
 
@@ -209,8 +228,10 @@ def heatmap(
209
228
  kind = strcmp(kind, kinds)[0]
210
229
  print(kind)
211
230
  if "corr" in kind: # correlation
231
+ methods = ["pearson", "spearman", "kendall"]
232
+ method = strcmp(method, methods)[0]
212
233
  # Compute the correlation matrix
213
- data4heatmap = df_numeric.corr()
234
+ data4heatmap = df_numeric.corr(method=method)
214
235
  # Generate mask for the upper triangle if mask is True
215
236
  if mask:
216
237
  if "u" in tri.lower(): # upper => np.tril
@@ -288,18 +309,76 @@ def heatmap(
288
309
  df_col_cluster,
289
310
  )
290
311
  else:
291
- # Create a standard heatmap
292
- ax = sns.heatmap(
293
- data4heatmap,
294
- ax=ax,
295
- mask=mask_array,
296
- annot=annot,
297
- cmap=cmap,
298
- fmt=fmt,
299
- **kwargs, # Pass any additional arguments to sns.heatmap
300
- )
301
- # Return the Axes object for further customization if needed
302
- return ax
312
+ if style==0:
313
+ # Create a standard heatmap
314
+ ax = sns.heatmap(
315
+ data4heatmap,
316
+ ax=ax,
317
+ mask=mask_array,
318
+ annot=annot,
319
+ cmap=cmap,
320
+ fmt=fmt,
321
+ **kwargs, # Pass any additional arguments to sns.heatmap
322
+ )
323
+ return ax
324
+ elif style==1:
325
+ if isinstance(cmap, str):
326
+ cmap = plt.get_cmap(cmap)
327
+ norm = plt.Normalize(vmin=-1, vmax=1)
328
+ r_, p_ = df_corr(data4heatmap, method=method)
329
+ # size_r_norm=df_scaler(data=r_, method="minmax", vmin=-1,vmax=1)
330
+ # 初始化一个空的可绘制对象用于颜色条
331
+ scatter_handles = []
332
+ # 循环绘制气泡图和数值
333
+ for i in range(len(r_.columns)):
334
+ for j in range(len(r_.columns)):
335
+ if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
336
+ color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
337
+ scatter = ax.scatter(
338
+ i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
339
+ # alpha=1,edgecolor=edgecolor,linewidth=linewidth,
340
+ **kwargs
341
+ )
342
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
343
+ # add *** indicators
344
+ if show_indicator:
345
+ ax.text(
346
+ i,
347
+ j,
348
+ pval2str(p_.iloc[i, j]),
349
+ ha="center",
350
+ va="center",
351
+ color="k",
352
+ fontsize=fontsize * 1.3,
353
+ )
354
+ elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
355
+ color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
356
+ ax.text(
357
+ i,
358
+ j,
359
+ f"{r_.iloc[i, j]:{fmt}}",
360
+ ha="center",
361
+ va="center",
362
+ color=color,
363
+ fontsize=fontsize,
364
+ )
365
+ else: # 对角线部分,显示空白
366
+ ax.scatter(i, j, s=1, color="white")
367
+ # 设置坐标轴标签
368
+ figsets(xticks=range(len(r_.columns)),
369
+ xticklabels=r_.columns,
370
+ xangle=90,
371
+ fontsize=fontsize,
372
+ yticks=range(len(r_.columns)),
373
+ yticklabels=r_.columns,
374
+ xlim=[-0.5,len(r_.columns)-0.5],
375
+ ylim=[-0.5,len(r_.columns)-0.5]
376
+ )
377
+ # 添加颜色条
378
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
379
+ sm.set_array([]) # 仅用于显示颜色条
380
+ plt.colorbar(sm, ax=ax, label="Correlation Coefficient")
381
+ return ax
303
382
  elif "dir" in kind: # direct
304
383
  data4heatmap = df_numeric
305
384
  elif "pi" in kind: # pivot
@@ -389,16 +468,53 @@ def heatmap(
389
468
  )
390
469
  else:
391
470
  # Create a standard heatmap
392
- ax = sns.heatmap(
393
- data4heatmap,
394
- ax=ax,
395
- annot=annot,
396
- cmap=cmap,
397
- fmt=fmt,
398
- **kwargs, # Pass any additional arguments to sns.heatmap
399
- )
400
- # Return the Axes object for further customization if needed
401
- return ax
471
+ if style==0:
472
+ ax = sns.heatmap(
473
+ data4heatmap,
474
+ ax=ax,
475
+ annot=annot,
476
+ cmap=cmap,
477
+ fmt=fmt,
478
+ **kwargs, # Pass any additional arguments to sns.heatmap
479
+ )
480
+ # Return the Axes object for further customization if needed
481
+ return ax
482
+ elif style==1:
483
+ if isinstance(cmap, str):
484
+ cmap = plt.get_cmap(cmap)
485
+ if vmin is None:
486
+ vmin=np.min(data4heatmap)
487
+ if vmax is None:
488
+ vmax=np.max(data4heatmap)
489
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
490
+
491
+ # 初始化一个空的可绘制对象用于颜色条
492
+ scatter_handles = []
493
+ # 循环绘制气泡图和数值
494
+ print(len(data4heatmap.index),len(data4heatmap.columns))
495
+ for i in range(len(data4heatmap.index)):
496
+ for j in range(len(data4heatmap.columns)):
497
+ color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
498
+ scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
499
+ scatter_handles.append(scatter) # 保存scatter对象用于颜色条
500
+
501
+ # 设置坐标轴标签
502
+ figsets(xticks=range(len(data4heatmap.columns)),
503
+ xticklabels=data4heatmap.columns,
504
+ xangle=90,
505
+ fontsize=fontsize,
506
+ yticks=range(len(data4heatmap.index)),
507
+ yticklabels=data4heatmap.index,
508
+ xlim=[-0.5,len(data4heatmap.columns)-0.5],
509
+ ylim=[-0.5,len(data4heatmap.index)-0.5]
510
+ )
511
+ # 添加颜色条
512
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
513
+ sm.set_array([]) # 仅用于显示颜色条
514
+ plt.colorbar(sm, ax=ax,
515
+ # label="Correlation Coefficient"
516
+ )
517
+ return ax
402
518
 
403
519
 
404
520
  # !usage: py2ls.plot.heatmap()
@@ -3226,18 +3342,37 @@ def plotxy(
3226
3342
  # (1) return FcetGrid
3227
3343
  if k == "jointplot":
3228
3344
  kws_joint = kwargs.pop("kws_joint", kwargs)
3229
- stats=kwargs.pop("stats",True)
3345
+ kws_joint = {
3346
+ k: v for k, v in kws_joint.items() if not k.startswith("kws_")
3347
+ }
3348
+ hue = kwargs.get("hue", None)
3349
+ if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
3350
+ kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
3351
+
3352
+ palette = kwargs.get("palette", None)
3353
+ if palette is None:
3354
+ palette = kws_joint.pop(
3355
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3356
+ )
3357
+ else:
3358
+ kws_joint.pop("palette", palette)
3359
+ stats=kwargs.pop("stats",None)
3360
+ if stats:
3361
+ stats=kws_joint.pop("stats",True)
3230
3362
  if stats:
3231
3363
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3232
- g = sns.jointplot(data=data, x=x, y=y, **kws_joint)
3233
- g.ax_joint.annotate(
3234
- f"pearsonr = {r:.2f} p = {p_value:.3f}",
3235
- xy=(0.6, 0.98),
3236
- xycoords="axes fraction",
3237
- fontsize=12,
3238
- color="black",
3239
- ha="center",
3240
- )
3364
+ for key in ["palette", "alpha", "hue","stats"]:
3365
+ kws_joint.pop(key, None)
3366
+ g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
3367
+ if stats:
3368
+ g.ax_joint.annotate(
3369
+ f"pearsonr = {r:.2f} p = {p_value:.3f}",
3370
+ xy=(0.6, 0.98),
3371
+ xycoords="axes fraction",
3372
+ fontsize=12,
3373
+ color="black",
3374
+ ha="center",
3375
+ )
3241
3376
  elif k == "lmplot":
3242
3377
  kws_lm = kwargs.pop("kws_lm", kwargs)
3243
3378
  stats = kwargs.pop("stats", True) # Flag to calculate stats
@@ -3383,8 +3518,7 @@ def plotxy(
3383
3518
  xycoords="axes fraction",
3384
3519
  fontsize=12,
3385
3520
  color="black",
3386
- ha="center",
3387
- )
3521
+ ha="center")
3388
3522
 
3389
3523
  elif k == "catplot_sns":
3390
3524
  kws_cat = kwargs.pop("kws_cat", kwargs)
@@ -3392,7 +3526,7 @@ def plotxy(
3392
3526
  elif k == "displot":
3393
3527
  kws_dis = kwargs.pop("kws_dis", kwargs)
3394
3528
  # displot creates a new figure and returns a FacetGrid
3395
- g = sns.displot(data=data, x=x, **kws_dis)
3529
+ g = sns.displot(data=data, x=x,y=y, **kws_dis)
3396
3530
 
3397
3531
  # (2) return axis
3398
3532
  if ax is None:
@@ -3403,19 +3537,58 @@ def plotxy(
3403
3537
  elif k == "stdshade":
3404
3538
  kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
3405
3539
  ax = stdshade(ax=ax, **kwargs)
3540
+ elif k=="ellipse":
3541
+ kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
3542
+ kws_ellipse = {
3543
+ k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
3544
+ }
3545
+ hue = kwargs.get("hue", None)
3546
+ if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
3547
+ kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
3548
+
3549
+ palette = kwargs.get("palette", None)
3550
+ if palette is None:
3551
+ palette = kws_ellipse.pop(
3552
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3553
+ )
3554
+ alpha = kws_ellipse.pop("alpha", 0.1)
3555
+ hue_order = kwargs.get("hue_order",None)
3556
+ if hue_order is None:
3557
+ hue_order = kws_ellipse.get("hue_order",None)
3558
+ if hue_order:
3559
+ data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
3560
+ data = data.sort_values(by="hue")
3561
+ for key in ["palette", "alpha", "hue","hue_order"]:
3562
+ kws_ellipse.pop(key, None)
3563
+ ax=ellipse(
3564
+ ax=ax,
3565
+ data=data,
3566
+ x=x,
3567
+ y=y,
3568
+ hue=hue,
3569
+ palette=palette,
3570
+ alpha=alpha,
3571
+ zorder=zorder,
3572
+ **kws_ellipse,
3573
+ )
3406
3574
  elif k == "scatterplot":
3407
3575
  kws_scatter = kwargs.pop("kws_scatter", kwargs)
3408
3576
  kws_scatter = {
3409
3577
  k: v for k, v in kws_scatter.items() if not k.startswith("kws_")
3410
3578
  }
3411
- hue = kwargs.pop("hue", None)
3579
+ hue = kwargs.get("hue", None)
3412
3580
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3413
3581
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3414
- palette = kws_scatter.pop(
3415
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3416
- )
3582
+ palette=kws_scatter.get("palette",None)
3583
+ if palette is None:
3584
+ palette = kws_scatter.pop(
3585
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3586
+ )
3417
3587
  s = kws_scatter.pop("s", 10)
3418
- alpha = kws_scatter.pop("alpha", 0.7)
3588
+ alpha = kws_scatter.pop("alpha", 0.7)
3589
+ for key in ["s", "palette", "alpha", "hue"]:
3590
+ kws_scatter.pop(key, None)
3591
+
3419
3592
  ax = sns.scatterplot(
3420
3593
  ax=ax,
3421
3594
  data=data,
@@ -3431,19 +3604,33 @@ def plotxy(
3431
3604
  elif k == "histplot":
3432
3605
  kws_hist = kwargs.pop("kws_hist", kwargs)
3433
3606
  kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3434
- ax = sns.histplot(data=data, x=x, ax=ax, zorder=zorder, **kws_hist)
3435
- elif k == "kdeplot":
3607
+ ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
3608
+ elif k == "kdeplot":
3436
3609
  kws_kde = kwargs.pop("kws_kde", kwargs)
3437
- kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3438
- ax = sns.kdeplot(data=data, x=x, ax=ax, zorder=zorder, **kws_kde)
3610
+ kws_kde = {
3611
+ k: v for k, v in kws_kde.items() if not k.startswith("kws_")
3612
+ }
3613
+ hue = kwargs.get("hue", None)
3614
+ if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
3615
+ kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
3616
+
3617
+ palette = kwargs.get("palette", None)
3618
+ if palette is None:
3619
+ palette = kws_kde.pop(
3620
+ "palette", get_color(data[hue].nunique()) if hue is not None else None
3621
+ )
3622
+ alpha = kws_kde.pop("alpha", 0.05)
3623
+ for key in ["palette", "alpha", "hue"]:
3624
+ kws_kde.pop(key, None)
3625
+ ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
3439
3626
  elif k == "ecdfplot":
3440
3627
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3441
3628
  kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3442
- ax = sns.ecdfplot(data=data, x=x, ax=ax, zorder=zorder, **kws_ecdf)
3629
+ ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
3443
3630
  elif k == "rugplot":
3444
3631
  kws_rug = kwargs.pop("kws_rug", kwargs)
3445
3632
  kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
3446
- ax = sns.rugplot(data=data, x=x, ax=ax, zorder=zorder, **kws_rug)
3633
+ ax = sns.rugplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_rug)
3447
3634
  elif k == "stripplot":
3448
3635
  kws_strip = kwargs.pop("kws_strip", kwargs)
3449
3636
  kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
@@ -3516,7 +3703,8 @@ def plotxy(
3516
3703
  figsets(ax=ax, **kws_figsets)
3517
3704
  if kws_add_text:
3518
3705
  add_text(ax=ax, **kws_add_text) if kws_add_text else None
3519
- if run_once_within(10):
3706
+ if run_once_within(10):
3707
+ for k in kind_:
3520
3708
  print(f"\n{k}⤵ ")
3521
3709
  print(default_settings[k])
3522
3710
  # print("=>\t",sns_info[sns_info["Functions"].str.contains(k)].iloc[:, -1].tolist()[0],"\n")
@@ -3574,6 +3762,7 @@ def df_preprocessing_(data, kind, verbose=False):
3574
3762
  "lineplot", # Can work with both wide and long formats
3575
3763
  "area plot", # Can work with both formats, useful for stacked areas
3576
3764
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3765
+ "ellipse",# ellipse plot, default confidence=0.95
3577
3766
  ],
3578
3767
  )[0]
3579
3768
 
@@ -3609,6 +3798,7 @@ def df_preprocessing_(data, kind, verbose=False):
3609
3798
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3610
3799
  "relplot",
3611
3800
  "pointplot", # Works well with wide format
3801
+ "ellipse",
3612
3802
  ]
3613
3803
 
3614
3804
  # Wide format (e.g., for heatmap and pairplot)
@@ -4100,25 +4290,45 @@ def venn(
4100
4290
  """
4101
4291
  if ax is None:
4102
4292
  ax = plt.gca()
4293
+ if isinstance(lists, dict):
4294
+ labels = list(lists.keys())
4295
+ lists = list(lists.values())
4296
+ if isinstance(lists[0], set):
4297
+ lists = [list(i) for i in lists]
4103
4298
  lists = [set(flatten(i, verbose=False)) for i in lists]
4104
4299
  # Function to apply text styles to labels
4105
4300
  if colors is None:
4106
4301
  colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
4107
4302
  if labels is None:
4108
- labels = ["set1", "set2"] if len(lists) == 2 else ["set1", "set2", "set3"]
4303
+ if len(lists) == 2:
4304
+ labels = ["set1", "set2"]
4305
+ elif len(lists) == 3:
4306
+ labels = ["set1", "set2", "set3"]
4307
+ elif len(lists) == 4:
4308
+ labels = ["set1", "set2", "set3","set4"]
4309
+ elif len(lists) == 5:
4310
+ labels = ["set1", "set2", "set3","set4","set55"]
4311
+ elif len(lists) == 6:
4312
+ labels = ["set1", "set2", "set3","set4","set5","set6"]
4313
+ elif len(lists) == 7:
4314
+ labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
4109
4315
  if edgecolor is None:
4110
4316
  edgecolor = colors
4111
4317
  colors = [desaturate_color(color, saturation) for color in colors]
4112
- # Check colors and auto-calculate overlaps
4113
- if len(lists) == 2:
4318
+ universe = len(set.union(*lists))
4114
4319
 
4115
- def get_count_and_percentage(set_count, subset_count):
4116
- percent = subset_count / set_count if set_count > 0 else 0
4117
- return (
4118
- f"{subset_count}\n({fmt.format(percent)})"
4119
- if show_percentages
4120
- else f"{subset_count}"
4121
- )
4320
+ # Check colors and auto-calculate overlaps
4321
+ def get_count_and_percentage(set_count, subset_count):
4322
+ percent = subset_count / set_count if set_count > 0 else 0
4323
+ return (
4324
+ f"{subset_count}\n({fmt.format(percent)})"
4325
+ if show_percentages
4326
+ else f"{subset_count}"
4327
+ )
4328
+ if fmt is not None:
4329
+ if not fmt.startswith("{"):
4330
+ fmt="{:" + fmt + "}"
4331
+ if len(lists) == 2:
4122
4332
 
4123
4333
  from matplotlib_venn import venn2, venn2_circles
4124
4334
 
@@ -4131,21 +4341,28 @@ def venn(
4131
4341
  set1, set2 = lists[0], lists[1]
4132
4342
  v.get_patch_by_id("10").set_color(colors[0])
4133
4343
  v.get_patch_by_id("01").set_color(colors[1])
4134
- v.get_patch_by_id("11").set_color(
4135
- get_color_overlap(colors[0], colors[1]) if colors else None
4136
- )
4344
+ try:
4345
+ v.get_patch_by_id("11").set_color(
4346
+ get_color_overlap(colors[0], colors[1]) if colors else None
4347
+ )
4348
+ except Exception as e:
4349
+ print(e)
4137
4350
  # v.get_label_by_id('10').set_text(len(set1 - set2))
4138
4351
  # v.get_label_by_id('01').set_text(len(set2 - set1))
4139
4352
  # v.get_label_by_id('11').set_text(len(set1 & set2))
4353
+
4140
4354
  v.get_label_by_id("10").set_text(
4141
- get_count_and_percentage(len(set1), len(set1 - set2))
4355
+ get_count_and_percentage(universe, len(set1 - set2))
4142
4356
  )
4143
4357
  v.get_label_by_id("01").set_text(
4144
- get_count_and_percentage(len(set2), len(set2 - set1))
4145
- )
4146
- v.get_label_by_id("11").set_text(
4147
- get_count_and_percentage(len(set1 | set2), len(set1 & set2))
4358
+ get_count_and_percentage(universe, len(set2 - set1))
4148
4359
  )
4360
+ try:
4361
+ v.get_label_by_id("11").set_text(
4362
+ get_count_and_percentage(universe, len(set1 & set2))
4363
+ )
4364
+ except Exception as e:
4365
+ print(e)
4149
4366
 
4150
4367
  if not isinstance(linewidth, list):
4151
4368
  linewidth = [linewidth]
@@ -4228,16 +4445,14 @@ def venn(
4228
4445
  va=va,
4229
4446
  shadow=shadow,
4230
4447
  )
4231
-
4232
- elif len(lists) == 3:
4233
-
4234
- def get_label(set_count, subset_count):
4235
- percent = subset_count / set_count if set_count > 0 else 0
4236
- return (
4237
- f"{subset_count}\n({fmt.format(percent)})"
4238
- if show_percentages
4239
- else f"{subset_count}"
4240
- )
4448
+ # Set transparency level
4449
+ for patch in v.patches:
4450
+ if patch:
4451
+ patch.set_alpha(alpha)
4452
+ if "none" in edgecolor or 0 in linewidth:
4453
+ patch.set_edgecolor("none")
4454
+ return ax
4455
+ elif len(lists) == 3:
4241
4456
 
4242
4457
  from matplotlib_venn import venn3, venn3_circles
4243
4458
 
@@ -4253,36 +4468,34 @@ def venn(
4253
4468
  # Draw the venn diagram
4254
4469
  v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
4255
4470
  v.get_patch_by_id("100").set_color(colors[0])
4471
+ v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
4256
4472
  v.get_patch_by_id("010").set_color(colors[1])
4257
- v.get_patch_by_id("001").set_color(colors[2])
4258
- v.get_patch_by_id("110").set_color(colorAB)
4259
- v.get_patch_by_id("101").set_color(colorAC)
4260
- v.get_patch_by_id("011").set_color(colorBC)
4261
- v.get_patch_by_id("111").set_color(colorABC)
4262
-
4263
- # Correctly labeling subset sizes
4264
- # v.get_label_by_id('100').set_text(len(set1 - set2 - set3))
4265
- # v.get_label_by_id('010').set_text(len(set2 - set1 - set3))
4266
- # v.get_label_by_id('001').set_text(len(set3 - set1 - set2))
4267
- # v.get_label_by_id('110').set_text(len(set1 & set2 - set3))
4268
- # v.get_label_by_id('101').set_text(len(set1 & set3 - set2))
4269
- # v.get_label_by_id('011').set_text(len(set2 & set3 - set1))
4270
- # v.get_label_by_id('111').set_text(len(set1 & set2 & set3))
4271
- v.get_label_by_id("100").set_text(get_label(len(set1), len(set1 - set2 - set3)))
4272
- v.get_label_by_id("010").set_text(get_label(len(set2), len(set2 - set1 - set3)))
4273
- v.get_label_by_id("001").set_text(get_label(len(set3), len(set3 - set1 - set2)))
4274
- v.get_label_by_id("110").set_text(
4275
- get_label(len(set1 | set2), len(set1 & set2 - set3))
4276
- )
4277
- v.get_label_by_id("101").set_text(
4278
- get_label(len(set1 | set3), len(set1 & set3 - set2))
4279
- )
4280
- v.get_label_by_id("011").set_text(
4281
- get_label(len(set2 | set3), len(set2 & set3 - set1))
4282
- )
4283
- v.get_label_by_id("111").set_text(
4284
- get_label(len(set1 | set2 | set3), len(set1 & set2 & set3))
4285
- )
4473
+ v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
4474
+ try:
4475
+ v.get_patch_by_id("001").set_color(colors[2])
4476
+ v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
4477
+ except Exception as e:
4478
+ print(e)
4479
+ try:
4480
+ v.get_patch_by_id("110").set_color(colorAB)
4481
+ v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
4482
+ except Exception as e:
4483
+ print(e)
4484
+ try:
4485
+ v.get_patch_by_id("101").set_color(colorAC)
4486
+ v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
4487
+ except Exception as e:
4488
+ print(e)
4489
+ try:
4490
+ v.get_patch_by_id("011").set_color(colorBC)
4491
+ v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
4492
+ except Exception as e:
4493
+ print(e)
4494
+ try:
4495
+ v.get_patch_by_id("111").set_color(colorABC)
4496
+ v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
4497
+ except Exception as e:
4498
+ print(e)
4286
4499
 
4287
4500
  # Apply styles to set labels
4288
4501
  for i, text in enumerate(v.set_labels):
@@ -4387,16 +4600,34 @@ def venn(
4387
4600
  ax.add_patch(ellipse1)
4388
4601
  ax.add_patch(ellipse2)
4389
4602
  ax.add_patch(ellipse3)
4603
+ # Set transparency level
4604
+ for patch in v.patches:
4605
+ if patch:
4606
+ patch.set_alpha(alpha)
4607
+ if "none" in edgecolor or 0 in linewidth:
4608
+ patch.set_edgecolor("none")
4609
+ return ax
4610
+
4611
+
4612
+ dict_data = {}
4613
+ for i_list, list_ in enumerate(lists):
4614
+ dict_data[labels[i_list]]={*list_}
4615
+
4616
+ if 3<len(lists)<6:
4617
+ from venn import venn as vn
4618
+
4619
+ legend_loc=kwargs.pop("legend_loc", "upper right")
4620
+ ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
4621
+
4622
+ return ax
4390
4623
  else:
4391
- raise ValueError("只支持2或者3个list")
4392
-
4393
- # Set transparency level
4394
- for patch in v.patches:
4395
- if patch:
4396
- patch.set_alpha(alpha)
4397
- if "none" in edgecolor or 0 in linewidth:
4398
- patch.set_edgecolor("none")
4399
- return ax
4624
+ from venn import pseudovenn
4625
+ cmap=kwargs.pop("cmap","plasma")
4626
+ ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
4627
+
4628
+ return ax
4629
+
4630
+
4400
4631
 
4401
4632
 
4402
4633
  #! subplots, support automatic extend new axis
@@ -4407,6 +4638,7 @@ def subplot(
4407
4638
  sharex=False,
4408
4639
  sharey=False,
4409
4640
  verbose=False,
4641
+ fig=None,
4410
4642
  **kwargs,
4411
4643
  ):
4412
4644
  """
@@ -4435,8 +4667,8 @@ def subplot(
4435
4667
  )
4436
4668
 
4437
4669
  figsize_recommend = f"subplot({rows}, {cols}, figsize={figsize})"
4438
-
4439
- fig = plt.figure(figsize=figsize, constrained_layout=True)
4670
+ if fig is None:
4671
+ fig = plt.figure(figsize=figsize, constrained_layout=True)
4440
4672
  grid_spec = GridSpec(rows, cols, figure=fig)
4441
4673
  occupied = set()
4442
4674
  row_first_axes = [None] * rows # Track the first axis in each row (for sharey)
@@ -4502,7 +4734,7 @@ def radar(
4502
4734
  data: pd.DataFrame,
4503
4735
  columns=None,
4504
4736
  ylim=(0, 100),
4505
- facecolor=get_color(5),
4737
+ facecolor=None,
4506
4738
  edgecolor="none",
4507
4739
  edge_linewidth=0.5,
4508
4740
  fontsize=10,
@@ -4701,7 +4933,7 @@ def radar(
4701
4933
  else:
4702
4934
  # * spider style: spider-style grid (straight lines, not circles)
4703
4935
  # Create the spider-style grid (straight lines, not circles)
4704
- for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))):#int(vmax * grid_interval_ratio) + 1):
4936
+ for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
4705
4937
  ax.plot(
4706
4938
  angles + [angles[0]], # Closing the loop
4707
4939
  [i * vmax * grid_interval_ratio] * (num_vars + 1)
@@ -4740,9 +4972,9 @@ def radar(
4740
4972
  colors = facecolor
4741
4973
  else:
4742
4974
  colors = (
4743
- get_color(data.shape[0])
4975
+ get_color(data.shape[1])
4744
4976
  if cmap is None
4745
- else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[0]))
4977
+ else plt.get_cmap(cmap)(np.linspace(0, 1, data.shape[1]))
4746
4978
  )
4747
4979
 
4748
4980
  # Plot each row with straight lines
@@ -4840,7 +5072,8 @@ def pie(
4840
5072
  pctdistance=0.85,
4841
5073
  labeldistance=1.1,
4842
5074
  kws_wedge={},
4843
- kws_text={},
5075
+ kws_text={},
5076
+ kws_arrow={},
4844
5077
  center=(0, 0),
4845
5078
  radius=1,
4846
5079
  frame=False,
@@ -4850,6 +5083,8 @@ def pie(
4850
5083
  cmap=None,
4851
5084
  show_value=False,
4852
5085
  show_label=True,# False: only show the outer layer, if it is None, not show
5086
+ expand_label=(1.2,1.2),
5087
+ kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
4853
5088
  show_legend=True,
4854
5089
  legend_loc="upper right",
4855
5090
  bbox_to_anchor=[1.4, 1.1],
@@ -4859,7 +5094,7 @@ def pie(
4859
5094
  ax=None,
4860
5095
  **kwargs
4861
5096
  ):
4862
-
5097
+ from adjustText import adjust_text
4863
5098
  if run_once_within(20,reverse=True) and verbose:
4864
5099
  usage_="""usage:
4865
5100
  pie(
@@ -4974,6 +5209,11 @@ def pie(
4974
5209
  # 选择部分数据
4975
5210
  df=data[columns]
4976
5211
 
5212
+ if not isinstance(explode, list):
5213
+ explode=[explode]
5214
+ if explode == [None]:
5215
+ explode=[0]
5216
+
4977
5217
  if width is None:
4978
5218
  if df.shape[1]>1:
4979
5219
  width=1/(df.shape[1]+2)
@@ -5022,10 +5262,9 @@ def pie(
5022
5262
 
5023
5263
  if ax is None:
5024
5264
  ax=plt.gca()
5025
- if explode is not None:
5026
- if len(explode)<len(labels_legend):
5027
- explode.extend([0]*(len(labels_legend)-len(explode)))
5028
-
5265
+ if len(explode)<len(labels_legend):
5266
+ explode.extend([0]*(len(labels_legend)-len(explode)))
5267
+ print(explode)
5029
5268
  if fmt:
5030
5269
  if not fmt.startswith("%"):
5031
5270
  autopct =f"%{fmt}%%"
@@ -5073,19 +5312,37 @@ def pie(
5073
5312
  elif len(result) == 2:
5074
5313
  wedges, texts = result
5075
5314
  autotexts = None
5076
-
5077
- # Show exact values on wedges if show_value is True
5078
- if show_value:
5079
- for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5080
- angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5081
- x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5082
- y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5083
- if not fmt.startswith("{"):
5084
- value_text = f"{sizes[i]:{fmt}}"
5085
- else:
5086
- value_text = fmt.format(sizes[i])
5087
- ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5088
- inested+=1
5315
+ #! adjust_text
5316
+ if autotexts or texts:
5317
+ all_texts = []
5318
+ if autotexts and show_value:
5319
+ all_texts.extend(autotexts)
5320
+ if texts and show_label:
5321
+ all_texts.extend(texts)
5322
+
5323
+ adjust_text(
5324
+ all_texts,
5325
+ ax=ax,
5326
+ arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
5327
+ bbox=kws_bbox if kws_bbox else None,
5328
+ expand=expand_label,
5329
+ fontdict={
5330
+ "fontsize": fontsize,
5331
+ "color": fontcolor,
5332
+ },
5333
+ )
5334
+ # Show exact values on wedges if show_value is True
5335
+ if show_value:
5336
+ for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5337
+ angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5338
+ x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5339
+ y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5340
+ if not fmt.startswith("{"):
5341
+ value_text = f"{sizes[i]:{fmt}}"
5342
+ else:
5343
+ value_text = fmt.format(sizes[i])
5344
+ ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5345
+ inested+=1
5089
5346
  # Customize the legend
5090
5347
  if show_legend:
5091
5348
  ax.legend(
@@ -5098,3 +5355,521 @@ def pie(
5098
5355
  )
5099
5356
  ax.set(aspect="equal")
5100
5357
  return ax
5358
+
5359
+ def ellipse(
5360
+ data,
5361
+ x=None,
5362
+ y=None,
5363
+ hue=None,
5364
+ n_std=1.5,
5365
+ ax=None,
5366
+ confidence=0.95,
5367
+ annotate_center=False,
5368
+ palette=None,
5369
+ facecolor=None,
5370
+ edgecolor=None,
5371
+ label:bool=True,
5372
+ **kwargs,
5373
+ ):
5374
+ """
5375
+ Plot advanced ellipses representing covariance for different groups
5376
+ # simulate data:
5377
+ control = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], size=50)
5378
+ patient = np.random.multivariate_normal([2, 1], [[1, -0.3], [-0.3, 1]], size=50)
5379
+ df = pd.DataFrame(
5380
+ {
5381
+ "Dim1": np.concatenate([control[:, 0], patient[:, 0]]),
5382
+ "Dim2": np.concatenate([control[:, 1], patient[:, 1]]),
5383
+ "Group": ["Control"] * 50 + ["Patient"] * 50,
5384
+ }
5385
+ )
5386
+ plotxy(
5387
+ data=df,
5388
+ x="Dim1",
5389
+ y="Dim2",
5390
+ hue="Group",
5391
+ kind_="scatter",
5392
+ palette=get_color(8),
5393
+ )
5394
+ ellipse(
5395
+ data=df,
5396
+ x="Dim1",
5397
+ y="Dim2",
5398
+ hue="Group",
5399
+ palette=get_color(8),
5400
+ alpha=0.1,
5401
+ lw=2,
5402
+ )
5403
+ Parameters:
5404
+ data (DataFrame): Input DataFrame with columns for x, y, and hue.
5405
+ x (str): Column name for x-axis values.
5406
+ y (str): Column name for y-axis values.
5407
+ hue (str, optional): Column name for group labels.
5408
+ n_std (float): Number of standard deviations for the ellipse (overridden if confidence is provided).
5409
+ ax (matplotlib.axes.Axes, optional): Matplotlib Axes object to plot on. Defaults to current Axes.
5410
+ confidence (float, optional): Confidence level (e.g., 0.95 for 95% confidence interval).
5411
+ annotate_center (bool): Whether to annotate the ellipse center (mean).
5412
+ palette (dict or list, optional): A mapping of hues to colors or a list of colors.
5413
+ **kwargs: Additional keyword arguments for the Ellipse patch.
5414
+
5415
+ Returns:
5416
+ list: List of Ellipse objects added to the Axes.
5417
+ """
5418
+ from matplotlib.patches import Ellipse
5419
+ import numpy as np
5420
+ import matplotlib.pyplot as plt
5421
+ import seaborn as sns
5422
+ import pandas as pd
5423
+ from scipy.stats import chi2
5424
+
5425
+ if ax is None:
5426
+ ax = plt.gca()
5427
+
5428
+ # Validate inputs
5429
+ if x is None or y is None:
5430
+ raise ValueError(
5431
+ "Both `x` and `y` must be specified as column names in the DataFrame."
5432
+ )
5433
+ if not isinstance(data, pd.DataFrame):
5434
+ raise ValueError("`data` must be a pandas DataFrame.")
5435
+
5436
+ # Prepare data for hue-based grouping
5437
+ ellipses = []
5438
+ if hue is not None:
5439
+ groups = data[hue].unique()
5440
+ colors = sns.color_palette(palette or "husl", len(groups))
5441
+ color_map = dict(zip(groups, colors))
5442
+ else:
5443
+ groups = [None]
5444
+ color_map = {None: kwargs.get("edgecolor", "blue")}
5445
+ alpha = kwargs.pop("alpha", 0.2)
5446
+ edgecolor=kwargs.pop("edgecolor", None)
5447
+ facecolor=kwargs.pop("facecolor", None)
5448
+ for group in groups:
5449
+ group_data = data[data[hue] == group] if hue else data
5450
+
5451
+ # Extract x and y columns for the group
5452
+ group_points = group_data[[x, y]].values
5453
+
5454
+ # Compute mean and covariance matrix
5455
+ # # 标准化处理
5456
+ # group_points = group_data[[x, y]].values
5457
+ # group_points -= group_points.mean(axis=0)
5458
+ # group_points /= group_points.std(axis=0)
5459
+
5460
+ cov = np.cov(group_points.T)
5461
+ mean = np.mean(group_points, axis=0)
5462
+
5463
+ # Eigenvalues and eigenvectors
5464
+ eigvals, eigvecs = np.linalg.eigh(cov)
5465
+ order = eigvals.argsort()[::-1]
5466
+ eigvals, eigvecs = eigvals[order], eigvecs[:, order]
5467
+
5468
+ # Rotation angle and ellipse dimensions
5469
+ angle = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
5470
+ if confidence:
5471
+ n_std = np.sqrt(chi2.ppf(confidence, df=2)) # Chi-square quantile
5472
+ width, height = 2 * n_std * np.sqrt(eigvals)
5473
+
5474
+ # Create and style the ellipse
5475
+ if facecolor is None:
5476
+ facecolor_ = color_map[group]
5477
+ if edgecolor is None:
5478
+ edgecolor_ = color_map[group]
5479
+ ellipse = Ellipse(
5480
+ xy=mean,
5481
+ width=width,
5482
+ height=height,
5483
+ angle=angle,
5484
+ edgecolor=edgecolor_,
5485
+ facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
5486
+ # alpha=alpha,
5487
+ label=group if (hue and label) else None,
5488
+ **kwargs,
5489
+ )
5490
+ ax.add_patch(ellipse)
5491
+ ellipses.append(ellipse)
5492
+
5493
+ # Annotate center
5494
+ if annotate_center:
5495
+ ax.annotate(
5496
+ f"Mean\n({mean[0]:.2f}, {mean[1]:.2f})",
5497
+ xy=mean,
5498
+ xycoords="data",
5499
+ fontsize=10,
5500
+ ha="center",
5501
+ color=ellipse_color,
5502
+ bbox=dict(
5503
+ boxstyle="round,pad=0.3",
5504
+ edgecolor="gray",
5505
+ facecolor="white",
5506
+ alpha=0.8,
5507
+ ),
5508
+ )
5509
+
5510
+ return ax
5511
+
5512
+ def ppi(
5513
+ interactions,
5514
+ player1="preferredName_A",
5515
+ player2="preferredName_B",
5516
+ weight="score",
5517
+ n_layers=None, # Number of concentric layers
5518
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5519
+ dist_node = 10, # Distance between each rank of circles
5520
+ layout="degree",
5521
+ size=None,#700,
5522
+ sizes=(50,500),# min and max of size
5523
+ facecolor="skyblue",
5524
+ cmap='coolwarm',
5525
+ edgecolor="k",
5526
+ edgelinewidth=1.5,
5527
+ alpha=.5,
5528
+ alphas=(0.1, 1.0),# min and max of alpha
5529
+ marker="o",
5530
+ node_hideticks=True,
5531
+ linecolor="gray",
5532
+ line_cmap='coolwarm',
5533
+ linewidth=1.5,
5534
+ linewidths=(0.5,5),# min and max of linewidth
5535
+ linealpha=1.0,
5536
+ linealphas=(0.1,1.0),# min and max of linealpha
5537
+ linestyle="-",
5538
+ line_arrowstyle='-',
5539
+ fontsize=10,
5540
+ fontcolor="k",
5541
+ ha:str="center",
5542
+ va:str="center",
5543
+ figsize=(12, 10),
5544
+ k_value=0.3,
5545
+ bgcolor="w",
5546
+ dir_save="./ppi_network.html",
5547
+ physics=True,
5548
+ notebook=False,
5549
+ scale=1,
5550
+ ax=None,
5551
+ **kwargs
5552
+ ):
5553
+ """
5554
+ Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
5555
+
5556
+ ppi(
5557
+ interactions_sort.iloc[:1000, :],
5558
+ player1="player1",
5559
+ player2="player2",
5560
+ weight="count",
5561
+ layout="spring",
5562
+ n_layers=13,
5563
+ fontsize=1,
5564
+ n_rank=[5, 10, 20, 40, 80, 80, 80, 80, 80, 80, 80, 80],
5565
+ )
5566
+ """
5567
+ from pyvis.network import Network
5568
+ import networkx as nx
5569
+ from IPython.display import IFrame
5570
+ from matplotlib.colors import Normalize
5571
+ from matplotlib import cm
5572
+ from . import ips
5573
+
5574
+ if run_once_within():
5575
+ usage_str="""
5576
+ ppi(
5577
+ interactions,
5578
+ player1="preferredName_A",
5579
+ player2="preferredName_B",
5580
+ weight="score",
5581
+ n_layers=None, # Number of concentric layers
5582
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
5583
+ dist_node = 10, # Distance between each rank of circles
5584
+ layout="degree",
5585
+ size=None,#700,
5586
+ sizes=(50,500),# min and max of size
5587
+ facecolor="skyblue",
5588
+ cmap='coolwarm',
5589
+ edgecolor="k",
5590
+ edgelinewidth=1.5,
5591
+ alpha=.5,
5592
+ alphas=(0.1, 1.0),# min and max of alpha
5593
+ marker="o",
5594
+ node_hideticks=True,
5595
+ linecolor="gray",
5596
+ line_cmap='coolwarm',
5597
+ linewidth=1.5,
5598
+ linewidths=(0.5,5),# min and max of linewidth
5599
+ linealpha=1.0,
5600
+ linealphas=(0.1,1.0),# min and max of linealpha
5601
+ linestyle="-",
5602
+ line_arrowstyle='-',
5603
+ fontsize=10,
5604
+ fontcolor="k",
5605
+ ha:str="center",
5606
+ va:str="center",
5607
+ figsize=(12, 10),
5608
+ k_value=0.3,
5609
+ bgcolor="w",
5610
+ dir_save="./ppi_network.html",
5611
+ physics=True,
5612
+ notebook=False,
5613
+ scale=1,
5614
+ ax=None,
5615
+ **kwargs
5616
+ ):
5617
+ """
5618
+ print(usage_str)
5619
+
5620
+ # Check for required columns in the DataFrame
5621
+ for col in [player1, player2, weight]:
5622
+ if col not in interactions.columns:
5623
+ raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
5624
+ interactions.sort_values(by=[weight], inplace=True)
5625
+ # Initialize Pyvis network
5626
+ net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
5627
+ net.force_atlas_2based(
5628
+ gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
5629
+ )
5630
+ net.toggle_physics(physics)
5631
+
5632
+ kws_figsets = {}
5633
+ for k_arg, v_arg in kwargs.items():
5634
+ if "figset" in k_arg:
5635
+ kws_figsets = v_arg
5636
+ kwargs.pop(k_arg, None)
5637
+ break
5638
+
5639
+ # Create a NetworkX graph from the interaction data
5640
+ G = nx.Graph()
5641
+ for _, row in interactions.iterrows():
5642
+ G.add_edge(row[player1], row[player2], weight=row[weight])
5643
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
5644
+
5645
+
5646
+ # Calculate node degrees
5647
+ degrees = dict(G.degree())
5648
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5649
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
5650
+
5651
+ if not ips.isa(facecolor, 'color'):
5652
+ print("facecolor: based on degrees")
5653
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
5654
+ num_nodes = G.number_of_nodes()
5655
+ #* size
5656
+ # Set properties based on degrees
5657
+ if not isinstance(size, (int,float,list)):
5658
+ print("size: based on degrees")
5659
+ size = [deg * 50 for deg in degrees.values()] # Scale sizes
5660
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
5661
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
5662
+ # Normalize sizes
5663
+ min_size, max_size = sizes # Use sizes tuple for min and max values
5664
+ min_degree, max_degree = min(size), max(size)
5665
+ if max_degree > min_degree: # Avoid division by zero
5666
+ size = [
5667
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5668
+ for sz in size
5669
+ ]
5670
+ else:
5671
+ # If all values are the same, set them to a default of the midpoint
5672
+ size = [(min_size + max_size) / 2] * len(size)
5673
+
5674
+ #* facecolor
5675
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
5676
+ # * facealpha
5677
+ if isinstance(alpha, list):
5678
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
5679
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
5680
+ if len(alpha) > 0:
5681
+ # Normalize alpha based on the specified min and max
5682
+ min_alpha, max_alpha = min(alpha), max(alpha)
5683
+ if max_alpha > min_alpha: # Avoid division by zero
5684
+ alpha = [
5685
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
5686
+ for ea in alpha
5687
+ ]
5688
+ else:
5689
+ # If all alpha values are the same, set them to the average of min and max
5690
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
5691
+ else:
5692
+ # Default to a full opacity if no edges are provided
5693
+ alpha = [1.0] * num_nodes
5694
+ else:
5695
+ # If alpha is a single value, convert it to a list and normalize it
5696
+ alpha = [alpha] * num_nodes # Adjust based on alphas
5697
+
5698
+ for i, node in enumerate(G.nodes()):
5699
+ net.add_node(
5700
+ node,
5701
+ label=node,
5702
+ size=size[i],
5703
+ color=facecolor[i],
5704
+ alpha=alpha[i],
5705
+ font={"size": fontsize, "color": fontcolor},
5706
+ )
5707
+ print(f'nodes number: {i+1}')
5708
+
5709
+ for edge in G.edges(data=True):
5710
+ net.add_edge(
5711
+ edge[0],
5712
+ edge[1],
5713
+ weight=edge[2]["weight"],
5714
+ color=edgecolor,
5715
+ width=edgelinewidth * edge[2]["weight"],
5716
+ )
5717
+
5718
+ layouts = [
5719
+ "spring",
5720
+ "circular",
5721
+ "kamada_kawai",
5722
+ "random",
5723
+ "shell",
5724
+ "planar",
5725
+ "spiral",
5726
+ "degree"
5727
+ ]
5728
+ layout = ips.strcmp(layout, layouts)[0]
5729
+ print(f"layout:{layout}, or select one in {layouts}")
5730
+
5731
+ # Choose layout
5732
+ if layout == "spring":
5733
+ pos = nx.spring_layout(G, k=k_value)
5734
+ elif layout == "circular":
5735
+ pos = nx.circular_layout(G)
5736
+ elif layout == "kamada_kawai":
5737
+ pos = nx.kamada_kawai_layout(G)
5738
+ elif layout == "spectral":
5739
+ pos = nx.spectral_layout(G)
5740
+ elif layout == "random":
5741
+ pos = nx.random_layout(G)
5742
+ elif layout == "shell":
5743
+ pos = nx.shell_layout(G)
5744
+ elif layout == "planar":
5745
+ if nx.check_planarity(G)[0]:
5746
+ pos = nx.planar_layout(G)
5747
+ else:
5748
+ print("Graph is not planar; switching to spring layout.")
5749
+ pos = nx.spring_layout(G, k=k_value)
5750
+ elif layout == "spiral":
5751
+ pos = nx.spiral_layout(G)
5752
+ elif layout=='degree':
5753
+ # Calculate node degrees and sort nodes by degree
5754
+ degrees = dict(G.degree())
5755
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
5756
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5757
+ colormap = cm.get_cmap(cmap)
5758
+
5759
+ # Create positions for concentric circles based on n_layers and n_rank
5760
+ pos = {}
5761
+ n_layers=len(n_rank)+1 if n_layers is None else n_layers
5762
+ for rank_index in range(n_layers):
5763
+ if rank_index < len(n_rank):
5764
+ nodes_per_rank = n_rank[rank_index]
5765
+ rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
5766
+ else:
5767
+ # 随机打乱剩余节点的顺序
5768
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
5769
+ random_indices = np.random.permutation(len(remaining_nodes))
5770
+ rank_nodes = [remaining_nodes[i] for i in random_indices]
5771
+
5772
+ radius = (rank_index + 1) * dist_node # Radius for this rank
5773
+
5774
+ # Arrange nodes in a circle for the current rank
5775
+ for i, (node, degree) in enumerate(rank_nodes):
5776
+ angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
5777
+ pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
5778
+
5779
+ else:
5780
+ print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
5781
+ pos = nx.spring_layout(G, k=k_value)
5782
+
5783
+ for node, (x, y) in pos.items():
5784
+ net.get_node(node)["x"] = x * scale
5785
+ net.get_node(node)["y"] = y * scale
5786
+
5787
+ # If ax is None, use plt.gca()
5788
+ if ax is None:
5789
+ fig, ax = plt.subplots(1,1,figsize=figsize)
5790
+
5791
+ # Draw nodes, edges, and labels with customization options
5792
+ nx.draw_networkx_nodes(
5793
+ G,
5794
+ pos,
5795
+ ax=ax,
5796
+ node_size=size,
5797
+ node_color=facecolor,
5798
+ linewidths=edgelinewidth,
5799
+ edgecolors=edgecolor,
5800
+ alpha=alpha,
5801
+ hide_ticks=node_hideticks,
5802
+ node_shape=marker
5803
+ )
5804
+
5805
+ #* linewidth
5806
+ if not isinstance(linewidth, list):
5807
+ linewidth = [linewidth] * G.number_of_edges()
5808
+ else:
5809
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
5810
+ # Normalize linewidth if it is a list
5811
+ if isinstance(linewidth, list):
5812
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
5813
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
5814
+ if max_linewidth > min_linewidth: # Avoid division by zero
5815
+ # Scale between vmin and vmax
5816
+ linewidth = [
5817
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
5818
+ for lw in linewidth
5819
+ ]
5820
+ else:
5821
+ # If all values are the same, set them to a default of the midpoint
5822
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
5823
+ else:
5824
+ # If linewidth is a single value, convert it to a list of that value
5825
+ linewidth = [linewidth] * G.number_of_edges()
5826
+ #* linecolor
5827
+ if not isinstance(linecolor, str):
5828
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
5829
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
5830
+ colormap = cm.get_cmap(line_cmap)
5831
+ linecolor = [colormap(norm(weight)) for weight in weights]
5832
+ else:
5833
+ linecolor = [linecolor] * G.number_of_edges()
5834
+
5835
+ # * linealpha
5836
+ if isinstance(linealpha, list):
5837
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
5838
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
5839
+ if len(linealpha) > 0:
5840
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
5841
+ if max_linealpha > min_linealpha: # Avoid division by zero
5842
+ linealpha = [
5843
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
5844
+ for ea in linealpha
5845
+ ]
5846
+ else:
5847
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
5848
+ else:
5849
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
5850
+ else:
5851
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
5852
+ nx.draw_networkx_edges(
5853
+ G,
5854
+ pos,
5855
+ ax=ax,
5856
+ edge_color=linecolor,
5857
+ width=linewidth,
5858
+ style=linestyle,
5859
+ arrowstyle=line_arrowstyle,
5860
+ alpha=linealpha
5861
+ )
5862
+
5863
+ nx.draw_networkx_labels(
5864
+ G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
5865
+ )
5866
+ figsets(ax=ax,**kws_figsets)
5867
+ ax.axis("off")
5868
+ if dir_save:
5869
+ if not os.path.basename(dir_save):
5870
+ dir_save="_.html"
5871
+ net.write_html(dir_save)
5872
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
5873
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
5874
+ ips.figsave(dir_save.replace(".html",".pdf"))
5875
+ return G,ax