py2ls 0.2.4.31__py3-none-any.whl → 0.2.4.33__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
py2ls/plot.py CHANGED
@@ -22,7 +22,7 @@ from .ips import (
22
22
  run_once_within,
23
23
  get_df_format,
24
24
  df_corr,
25
- df_scaler
25
+ df_scaler,
26
26
  )
27
27
  import scipy.stats as scipy_stats
28
28
  from .stats import *
@@ -144,6 +144,7 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
144
144
  **kwargs,
145
145
  )
146
146
 
147
+
147
148
  def pval2str(p):
148
149
  if p > 0.05:
149
150
  txt = ""
@@ -155,16 +156,17 @@ def pval2str(p):
155
156
  txt = "***"
156
157
  return txt
157
158
 
159
+
158
160
  def heatmap(
159
161
  data,
160
162
  ax=None,
161
163
  kind="corr", #'corr','direct','pivot'
162
- method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
164
+ method="pearson", # for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
163
165
  columns="all", # pivot, default: coll numeric columns
164
- style=0,# for correlation
166
+ style=0, # for correlation
165
167
  index=None, # pivot
166
168
  values=None, # pivot
167
- fontsize=10,
169
+ fontsize=10,
168
170
  tri="u",
169
171
  mask=True,
170
172
  k=1,
@@ -174,7 +176,7 @@ def heatmap(
174
176
  annot=True,
175
177
  cmap="coolwarm",
176
178
  fmt=".2f",
177
- show_indicator = True,# only for style==1
179
+ show_indicator=True, # only for style==1
178
180
  cluster=False,
179
181
  inplace=False,
180
182
  figsize=(10, 8),
@@ -309,7 +311,7 @@ def heatmap(
309
311
  df_col_cluster,
310
312
  )
311
313
  else:
312
- if style==0:
314
+ if style == 0:
313
315
  # Create a standard heatmap
314
316
  ax = sns.heatmap(
315
317
  data4heatmap,
@@ -321,7 +323,7 @@ def heatmap(
321
323
  **kwargs, # Pass any additional arguments to sns.heatmap
322
324
  )
323
325
  return ax
324
- elif style==1:
326
+ elif style == 1:
325
327
  if isinstance(cmap, str):
326
328
  cmap = plt.get_cmap(cmap)
327
329
  norm = plt.Normalize(vmin=-1, vmax=1)
@@ -332,12 +334,17 @@ def heatmap(
332
334
  # 循环绘制气泡图和数值
333
335
  for i in range(len(r_.columns)):
334
336
  for j in range(len(r_.columns)):
335
- if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
337
+ if (
338
+ (i < j) if "u" in tri.lower() else (j < i)
339
+ ): # 对角线左上部只显示气泡
336
340
  color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
337
341
  scatter = ax.scatter(
338
- i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
342
+ i,
343
+ j,
344
+ s=np.abs(r_.iloc[i, j]) * size_scale,
345
+ color=color,
339
346
  # alpha=1,edgecolor=edgecolor,linewidth=linewidth,
340
- **kwargs
347
+ **kwargs,
341
348
  )
342
349
  scatter_handles.append(scatter) # 保存scatter对象用于颜色条
343
350
  # add *** indicators
@@ -351,8 +358,12 @@ def heatmap(
351
358
  color="k",
352
359
  fontsize=fontsize * 1.3,
353
360
  )
354
- elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
355
- color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
361
+ elif (
362
+ (i > j) if "u" in tri.lower() else (j > i)
363
+ ): # 对角只显示数值
364
+ color = cmap(
365
+ norm(r_.iloc[i, j])
366
+ ) # 数值的颜色同样基于相关系数
356
367
  ax.text(
357
368
  i,
358
369
  j,
@@ -365,15 +376,16 @@ def heatmap(
365
376
  else: # 对角线部分,显示空白
366
377
  ax.scatter(i, j, s=1, color="white")
367
378
  # 设置坐标轴标签
368
- figsets(xticks=range(len(r_.columns)),
369
- xticklabels=r_.columns,
370
- xangle=90,
371
- fontsize=fontsize,
372
- yticks=range(len(r_.columns)),
373
- yticklabels=r_.columns,
374
- xlim=[-0.5,len(r_.columns)-0.5],
375
- ylim=[-0.5,len(r_.columns)-0.5]
376
- )
379
+ figsets(
380
+ xticks=range(len(r_.columns)),
381
+ xticklabels=r_.columns,
382
+ xangle=90,
383
+ fontsize=fontsize,
384
+ yticks=range(len(r_.columns)),
385
+ yticklabels=r_.columns,
386
+ xlim=[-0.5, len(r_.columns) - 0.5],
387
+ ylim=[-0.5, len(r_.columns) - 0.5],
388
+ )
377
389
  # 添加颜色条
378
390
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
379
391
  sm.set_array([]) # 仅用于显示颜色条
@@ -468,7 +480,7 @@ def heatmap(
468
480
  )
469
481
  else:
470
482
  # Create a standard heatmap
471
- if style==0:
483
+ if style == 0:
472
484
  ax = sns.heatmap(
473
485
  data4heatmap,
474
486
  ax=ax,
@@ -478,42 +490,51 @@ def heatmap(
478
490
  **kwargs, # Pass any additional arguments to sns.heatmap
479
491
  )
480
492
  # Return the Axes object for further customization if needed
481
- return ax
482
- elif style==1:
493
+ return ax
494
+ elif style == 1:
483
495
  if isinstance(cmap, str):
484
496
  cmap = plt.get_cmap(cmap)
485
497
  if vmin is None:
486
- vmin=np.min(data4heatmap)
498
+ vmin = np.min(data4heatmap)
487
499
  if vmax is None:
488
- vmax=np.max(data4heatmap)
500
+ vmax = np.max(data4heatmap)
489
501
  norm = plt.Normalize(vmin=vmin, vmax=vmax)
490
502
 
491
503
  # 初始化一个空的可绘制对象用于颜色条
492
504
  scatter_handles = []
493
- # 循环绘制气泡图和数值
494
- print(len(data4heatmap.index),len(data4heatmap.columns))
505
+ # 循环绘制气泡图和数值
506
+ print(len(data4heatmap.index), len(data4heatmap.columns))
495
507
  for i in range(len(data4heatmap.index)):
496
508
  for j in range(len(data4heatmap.columns)):
497
509
  color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
498
- scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
510
+ scatter = ax.scatter(
511
+ j,
512
+ i,
513
+ s=np.abs(data4heatmap.iloc[i, j]) * size_scale,
514
+ color=color,
515
+ **kwargs,
516
+ )
499
517
  scatter_handles.append(scatter) # 保存scatter对象用于颜色条
500
518
 
501
519
  # 设置坐标轴标签
502
- figsets(xticks=range(len(data4heatmap.columns)),
503
- xticklabels=data4heatmap.columns,
504
- xangle=90,
505
- fontsize=fontsize,
506
- yticks=range(len(data4heatmap.index)),
507
- yticklabels=data4heatmap.index,
508
- xlim=[-0.5,len(data4heatmap.columns)-0.5],
509
- ylim=[-0.5,len(data4heatmap.index)-0.5]
510
- )
520
+ figsets(
521
+ xticks=range(len(data4heatmap.columns)),
522
+ xticklabels=data4heatmap.columns,
523
+ xangle=90,
524
+ fontsize=fontsize,
525
+ yticks=range(len(data4heatmap.index)),
526
+ yticklabels=data4heatmap.index,
527
+ xlim=[-0.5, len(data4heatmap.columns) - 0.5],
528
+ ylim=[-0.5, len(data4heatmap.index) - 0.5],
529
+ )
511
530
  # 添加颜色条
512
531
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
513
532
  sm.set_array([]) # 仅用于显示颜色条
514
- plt.colorbar(sm, ax=ax,
515
- # label="Correlation Coefficient"
516
- )
533
+ plt.colorbar(
534
+ sm,
535
+ ax=ax,
536
+ # label="Correlation Coefficient"
537
+ )
517
538
  return ax
518
539
 
519
540
 
@@ -854,9 +875,11 @@ def catplot(data, *args, **kwargs):
854
875
  else:
855
876
  bx_opt["EdgeColor"] = bx_opt["EdgeColor"]
856
877
  if not isinstance(bx_opt["FaceColor"], list):
857
- bx_opt["FaceColor"]=[bx_opt["FaceColor"]]
858
- if len(bxp["boxes"])!= len(bx_opt["FaceColor"]) and (len(bx_opt["FaceColor"])==1):
859
- bx_opt["FaceColor"]=bx_opt["FaceColor"] *len(bxp["boxes"])
878
+ bx_opt["FaceColor"] = [bx_opt["FaceColor"]]
879
+ if len(bxp["boxes"]) != len(bx_opt["FaceColor"]) and (
880
+ len(bx_opt["FaceColor"]) == 1
881
+ ):
882
+ bx_opt["FaceColor"] = bx_opt["FaceColor"] * len(bxp["boxes"])
860
883
  for patch, color in zip(bxp["boxes"], bx_opt["FaceColor"]):
861
884
  patch.set_facecolor(to_rgba(color, bx_opt["FaceAlpha"]))
862
885
 
@@ -2439,7 +2462,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2439
2462
  return legends
2440
2463
 
2441
2464
 
2442
- def get_colors(n: int = 1,cmap: str = "auto",by: str = "start",alpha: float = 1.0,output: str = "hue",*args,**kwargs):
2465
+ def get_colors(
2466
+ n: int = 1,
2467
+ cmap: str = "auto",
2468
+ by: str = "start",
2469
+ alpha: float = 1.0,
2470
+ output: str = "hue",
2471
+ *args,
2472
+ **kwargs,
2473
+ ):
2443
2474
  return get_color(n=n, cmap=cmap, alpha=alpha, output=output, *args, **kwargs)
2444
2475
 
2445
2476
 
@@ -3317,9 +3348,9 @@ def plotxy(
3317
3348
  for k in kind_:
3318
3349
  # preprocess data
3319
3350
  try:
3320
- data=df_preprocessing_(data, kind=k)
3321
- if 'variable' in data.columns and 'value' in data.columns:
3322
- x,y='variable','value'
3351
+ data = df_preprocessing_(data, kind=k)
3352
+ if "variable" in data.columns and "value" in data.columns:
3353
+ x, y = "variable", "value"
3323
3354
  except Exception as e:
3324
3355
  print(e)
3325
3356
  zorder += 1
@@ -3342,28 +3373,31 @@ def plotxy(
3342
3373
  # (1) return FcetGrid
3343
3374
  if k == "jointplot":
3344
3375
  kws_joint = kwargs.pop("kws_joint", kwargs)
3345
- kws_joint = {
3346
- k: v for k, v in kws_joint.items() if not k.startswith("kws_")
3347
- }
3376
+ kws_joint = {k: v for k, v in kws_joint.items() if not k.startswith("kws_")}
3348
3377
  hue = kwargs.get("hue", None)
3349
- if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
3378
+ if (
3379
+ isinstance(kws_joint, dict) or hue is None
3380
+ ): # Check if kws_ellipse is a dictionary
3350
3381
  kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
3351
3382
 
3352
3383
  palette = kwargs.get("palette", None)
3353
3384
  if palette is None:
3354
3385
  palette = kws_joint.pop(
3355
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3386
+ "palette",
3387
+ get_color(data[hue].nunique()) if hue is not None else None,
3356
3388
  )
3357
3389
  else:
3358
3390
  kws_joint.pop("palette", palette)
3359
- stats=kwargs.pop("stats",None)
3391
+ stats = kwargs.pop("stats", None)
3360
3392
  if stats:
3361
- stats=kws_joint.pop("stats",True)
3393
+ stats = kws_joint.pop("stats", True)
3362
3394
  if stats:
3363
3395
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3364
- for key in ["palette", "alpha", "hue","stats"]:
3365
- kws_joint.pop(key, None)
3366
- g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
3396
+ for key in ["palette", "alpha", "hue", "stats"]:
3397
+ kws_joint.pop(key, None)
3398
+ g = sns.jointplot(
3399
+ data=data, x=x, y=y, hue=hue, palette=palette, **kws_joint
3400
+ )
3367
3401
  if stats:
3368
3402
  g.ax_joint.annotate(
3369
3403
  f"pearsonr = {r:.2f} p = {p_value:.3f}",
@@ -3391,41 +3425,64 @@ def plotxy(
3391
3425
  # If no hue, col, or row, calculate stats for the entire dataset
3392
3426
  if all([hue is None, col is None, row is None]):
3393
3427
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3394
- stats_per_facet[(None, None)] = (r, p_value) # Store stats for the entire dataset
3428
+ stats_per_facet[(None, None)] = (
3429
+ r,
3430
+ p_value,
3431
+ ) # Store stats for the entire dataset
3395
3432
 
3396
- else:
3433
+ else:
3397
3434
  if hue is None and (col is not None or row is not None):
3398
3435
  for ax in g.axes.flat:
3399
3436
  facet_name = ax.get_title()
3400
- if '=' in facet_name:
3437
+ if "=" in facet_name:
3401
3438
  # Assume facet_name is like 'Column = Value'
3402
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3403
- facet_value_str = facet_name.split('=')[1].strip() # Facet value after '='
3404
-
3439
+ facet_column_name = facet_name.split("=")[
3440
+ 0
3441
+ ].strip() # Column name before '='
3442
+ facet_value_str = facet_name.split("=")[
3443
+ 1
3444
+ ].strip() # Facet value after '='
3445
+
3405
3446
  # Try converting facet_value to match the data type of the DataFrame column
3406
3447
  facet_column_dtype = data[facet_column_name].dtype
3407
- if facet_column_dtype == 'int' or facet_column_dtype == 'float':
3408
- facet_value = pd.to_numeric(facet_value_str, errors='coerce') # Convert to numeric
3448
+ if (
3449
+ facet_column_dtype == "int"
3450
+ or facet_column_dtype == "float"
3451
+ ):
3452
+ facet_value = pd.to_numeric(
3453
+ facet_value_str, errors="coerce"
3454
+ ) # Convert to numeric
3409
3455
  else:
3410
3456
  facet_value = facet_value_str # Treat as a string if not numeric
3411
3457
  else:
3412
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3413
- facet_value=facet_name.split('=')[1].strip()
3458
+ facet_column_name = facet_name.split("=")[
3459
+ 0
3460
+ ].strip() # Column name before '='
3461
+ facet_value = facet_name.split("=")[1].strip()
3414
3462
  facet_data = data[data[facet_column_name] == facet_value]
3415
3463
  if not facet_data.empty:
3416
- r, p_value = scipy_stats.pearsonr(facet_data[x], facet_data[y])
3464
+ r, p_value = scipy_stats.pearsonr(
3465
+ facet_data[x], facet_data[y]
3466
+ )
3417
3467
  stats_per_facet[facet_name] = (r, p_value)
3418
3468
  else:
3419
- stats_per_facet[facet_name] = (None, None) # Handle empty facets
3469
+ stats_per_facet[facet_name] = (
3470
+ None,
3471
+ None,
3472
+ ) # Handle empty facets
3420
3473
 
3421
3474
  # Annotate the stats on the plot
3422
3475
  for ax in g.axes.flat:
3423
3476
  if stats:
3424
3477
  # Adjust the position for each facet to avoid overlap
3425
- idx=1
3426
- shift_factor = 0.02 * idx # Adjust this factor as needed to prevent overlap
3427
- y_position = 0.98 - shift_factor # Dynamic vertical shift for each facet
3428
-
3478
+ idx = 1
3479
+ shift_factor = (
3480
+ 0.02 * idx
3481
+ ) # Adjust this factor as needed to prevent overlap
3482
+ y_position = (
3483
+ 0.98 - shift_factor
3484
+ ) # Dynamic vertical shift for each facet
3485
+
3429
3486
  if all([hue is None, col is None, row is None]):
3430
3487
  # Use stats for the entire dataset if no hue, col, or row
3431
3488
  r, p_value = stats_per_facet.get((None, None), (None, None))
@@ -3448,15 +3505,21 @@ def plotxy(
3448
3505
  ha="center",
3449
3506
  )
3450
3507
  elif hue is not None:
3451
- if (col is None and row is None):
3452
- hue_categories = sorted(flatten(data[hue],verbose=0))
3453
- idx=1
3508
+ if col is None and row is None:
3509
+ hue_categories = sorted(flatten(data[hue], verbose=0))
3510
+ idx = 1
3454
3511
  for category in hue_categories:
3455
3512
  subset_data = data[data[hue] == category]
3456
- r, p_value = scipy_stats.pearsonr(subset_data[x], subset_data[y])
3513
+ r, p_value = scipy_stats.pearsonr(
3514
+ subset_data[x], subset_data[y]
3515
+ )
3457
3516
  stats_per_hue[category] = (r, p_value)
3458
- shift_factor = 0.05 * idx # Adjust this factor as needed to prevent overlap
3459
- y_position = 0.98 - shift_factor # Dynamic vertical shift for each facet
3517
+ shift_factor = (
3518
+ 0.05 * idx
3519
+ ) # Adjust this factor as needed to prevent overlap
3520
+ y_position = (
3521
+ 0.98 - shift_factor
3522
+ ) # Dynamic vertical shift for each facet
3460
3523
  ax.annotate(
3461
3524
  f"{category}: pearsonr = {r:.2f} p = {p_value:.3f}",
3462
3525
  xy=(0.6, y_position),
@@ -3465,31 +3528,49 @@ def plotxy(
3465
3528
  color="black",
3466
3529
  ha="center",
3467
3530
  )
3468
- idx+=1
3531
+ idx += 1
3469
3532
  else:
3470
3533
  for ax in g.axes.flat:
3471
3534
  facet_name = ax.get_title()
3472
- if '=' in facet_name:
3535
+ if "=" in facet_name:
3473
3536
  # Assume facet_name is like 'Column = Value'
3474
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3475
- facet_value_str = facet_name.split('=')[1].strip() # Facet value after '='
3476
-
3537
+ facet_column_name = facet_name.split("=")[
3538
+ 0
3539
+ ].strip() # Column name before '='
3540
+ facet_value_str = facet_name.split("=")[
3541
+ 1
3542
+ ].strip() # Facet value after '='
3543
+
3477
3544
  # Try converting facet_value to match the data type of the DataFrame column
3478
3545
  facet_column_dtype = data[facet_column_name].dtype
3479
- if facet_column_dtype == 'int' or facet_column_dtype == 'float':
3480
- facet_value = pd.to_numeric(facet_value_str, errors='coerce') # Convert to numeric
3546
+ if (
3547
+ facet_column_dtype == "int"
3548
+ or facet_column_dtype == "float"
3549
+ ):
3550
+ facet_value = pd.to_numeric(
3551
+ facet_value_str, errors="coerce"
3552
+ ) # Convert to numeric
3481
3553
  else:
3482
3554
  facet_value = facet_value_str # Treat as a string if not numeric
3483
3555
  else:
3484
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3485
- facet_value=facet_name.split('=')[1].strip()
3486
- facet_data = data[data[facet_column_name] == facet_value]
3556
+ facet_column_name = facet_name.split("=")[
3557
+ 0
3558
+ ].strip() # Column name before '='
3559
+ facet_value = facet_name.split("=")[1].strip()
3560
+ facet_data = data[
3561
+ data[facet_column_name] == facet_value
3562
+ ]
3487
3563
  if not facet_data.empty:
3488
- r, p_value = scipy_stats.pearsonr(facet_data[x], facet_data[y])
3564
+ r, p_value = scipy_stats.pearsonr(
3565
+ facet_data[x], facet_data[y]
3566
+ )
3489
3567
  stats_per_facet[facet_name] = (r, p_value)
3490
3568
  else:
3491
- stats_per_facet[facet_name] = (None, None) # Handle empty facets
3492
-
3569
+ stats_per_facet[facet_name] = (
3570
+ None,
3571
+ None,
3572
+ ) # Handle empty facets
3573
+
3493
3574
  ax.annotate(
3494
3575
  f"pearsonr = {r:.2f} p = {p_value:.3f}",
3495
3576
  xy=(0.6, y_position),
@@ -3497,7 +3578,7 @@ def plotxy(
3497
3578
  fontsize=12,
3498
3579
  color="black",
3499
3580
  ha="center",
3500
- )
3581
+ )
3501
3582
  elif hue is None and (col is not None or row is not None):
3502
3583
  # Annotate stats for each facet
3503
3584
  facet_name = ax.get_title()
@@ -3518,7 +3599,8 @@ def plotxy(
3518
3599
  xycoords="axes fraction",
3519
3600
  fontsize=12,
3520
3601
  color="black",
3521
- ha="center")
3602
+ ha="center",
3603
+ )
3522
3604
 
3523
3605
  elif k == "catplot_sns":
3524
3606
  kws_cat = kwargs.pop("kws_cat", kwargs)
@@ -3526,7 +3608,7 @@ def plotxy(
3526
3608
  elif k == "displot":
3527
3609
  kws_dis = kwargs.pop("kws_dis", kwargs)
3528
3610
  # displot creates a new figure and returns a FacetGrid
3529
- g = sns.displot(data=data, x=x,y=y, **kws_dis)
3611
+ g = sns.displot(data=data, x=x, y=y, **kws_dis)
3530
3612
 
3531
3613
  # (2) return axis
3532
3614
  if ax is None:
@@ -3537,30 +3619,35 @@ def plotxy(
3537
3619
  elif k == "stdshade":
3538
3620
  kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
3539
3621
  ax = stdshade(ax=ax, **kwargs)
3540
- elif k=="ellipse":
3622
+ elif k == "ellipse":
3541
3623
  kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
3542
3624
  kws_ellipse = {
3543
3625
  k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
3544
3626
  }
3545
3627
  hue = kwargs.get("hue", None)
3546
- if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
3628
+ if (
3629
+ isinstance(kws_ellipse, dict) or hue is None
3630
+ ): # Check if kws_ellipse is a dictionary
3547
3631
  kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
3548
3632
 
3549
3633
  palette = kwargs.get("palette", None)
3550
3634
  if palette is None:
3551
3635
  palette = kws_ellipse.pop(
3552
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3636
+ "palette",
3637
+ get_color(data[hue].nunique()) if hue is not None else None,
3553
3638
  )
3554
3639
  alpha = kws_ellipse.pop("alpha", 0.1)
3555
- hue_order = kwargs.get("hue_order",None)
3640
+ hue_order = kwargs.get("hue_order", None)
3556
3641
  if hue_order is None:
3557
- hue_order = kws_ellipse.get("hue_order",None)
3558
- if hue_order:
3559
- data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
3642
+ hue_order = kws_ellipse.get("hue_order", None)
3643
+ if hue_order:
3644
+ data["hue"] = pd.Categorical(
3645
+ data[hue], categories=hue_order, ordered=True
3646
+ )
3560
3647
  data = data.sort_values(by="hue")
3561
- for key in ["palette", "alpha", "hue","hue_order"]:
3562
- kws_ellipse.pop(key, None)
3563
- ax=ellipse(
3648
+ for key in ["palette", "alpha", "hue", "hue_order"]:
3649
+ kws_ellipse.pop(key, None)
3650
+ ax = ellipse(
3564
3651
  ax=ax,
3565
3652
  data=data,
3566
3653
  x=x,
@@ -3579,13 +3666,14 @@ def plotxy(
3579
3666
  hue = kwargs.get("hue", None)
3580
3667
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3581
3668
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3582
- palette=kws_scatter.get("palette",None)
3669
+ palette = kws_scatter.get("palette", None)
3583
3670
  if palette is None:
3584
3671
  palette = kws_scatter.pop(
3585
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3672
+ "palette",
3673
+ get_color(data[hue].nunique()) if hue is not None else None,
3586
3674
  )
3587
3675
  s = kws_scatter.pop("s", 10)
3588
- alpha = kws_scatter.pop("alpha", 0.7)
3676
+ alpha = kws_scatter.pop("alpha", 0.7)
3589
3677
  for key in ["s", "palette", "alpha", "hue"]:
3590
3678
  kws_scatter.pop(key, None)
3591
3679
 
@@ -3605,28 +3693,39 @@ def plotxy(
3605
3693
  kws_hist = kwargs.pop("kws_hist", kwargs)
3606
3694
  kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3607
3695
  ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
3608
- elif k == "kdeplot":
3696
+ elif k == "kdeplot":
3609
3697
  kws_kde = kwargs.pop("kws_kde", kwargs)
3610
- kws_kde = {
3611
- k: v for k, v in kws_kde.items() if not k.startswith("kws_")
3612
- }
3698
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3613
3699
  hue = kwargs.get("hue", None)
3614
- if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
3700
+ if (
3701
+ isinstance(kws_kde, dict) or hue is None
3702
+ ): # Check if kws_kde is a dictionary
3615
3703
  kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
3616
3704
 
3617
3705
  palette = kwargs.get("palette", None)
3618
3706
  if palette is None:
3619
3707
  palette = kws_kde.pop(
3620
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3708
+ "palette",
3709
+ get_color(data[hue].nunique()) if hue is not None else None,
3621
3710
  )
3622
3711
  alpha = kws_kde.pop("alpha", 0.05)
3623
3712
  for key in ["palette", "alpha", "hue"]:
3624
- kws_kde.pop(key, None)
3625
- ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
3713
+ kws_kde.pop(key, None)
3714
+ ax = sns.kdeplot(
3715
+ data=data,
3716
+ x=x,
3717
+ y=y,
3718
+ palette=palette,
3719
+ hue=hue,
3720
+ ax=ax,
3721
+ alpha=alpha,
3722
+ zorder=zorder,
3723
+ **kws_kde,
3724
+ )
3626
3725
  elif k == "ecdfplot":
3627
3726
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3628
3727
  kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3629
- ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
3728
+ ax = sns.ecdfplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_ecdf)
3630
3729
  elif k == "rugplot":
3631
3730
  kws_rug = kwargs.pop("kws_rug", kwargs)
3632
3731
  kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
@@ -3635,7 +3734,9 @@ def plotxy(
3635
3734
  kws_strip = kwargs.pop("kws_strip", kwargs)
3636
3735
  kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3637
3736
  dodge = kws_strip.pop("dodge", True)
3638
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip)
3737
+ ax = sns.stripplot(
3738
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3739
+ )
3639
3740
  elif k == "swarmplot":
3640
3741
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3641
3742
  kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
@@ -3672,21 +3773,21 @@ def plotxy(
3672
3773
  kws_reg = kwargs.pop("kws_reg", kwargs)
3673
3774
  kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3674
3775
  stats = kwargs.pop("stats", True) # Flag to calculate stats
3675
-
3776
+
3676
3777
  # Compute Pearson correlation if stats is True
3677
3778
  if stats:
3678
3779
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3679
- ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3680
-
3780
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3781
+
3681
3782
  # Annotate the Pearson correlation and p-value
3682
3783
  ax.annotate(
3683
- f"pearsonr = {r:.2f} p = {p_value:.3f}",
3684
- xy=(0.6, 0.98),
3685
- xycoords="axes fraction",
3686
- fontsize=12,
3687
- color="black",
3688
- ha="center",
3689
- )
3784
+ f"pearsonr = {r:.2f} p = {p_value:.3f}",
3785
+ xy=(0.6, 0.98),
3786
+ xycoords="axes fraction",
3787
+ fontsize=12,
3788
+ color="black",
3789
+ ha="center",
3790
+ )
3690
3791
  elif k == "residplot":
3691
3792
  kws_resid = kwargs.pop("kws_resid", kwargs)
3692
3793
  kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
@@ -3711,6 +3812,7 @@ def plotxy(
3711
3812
  return g, ax
3712
3813
  return ax
3713
3814
 
3815
+
3714
3816
  def df_preprocessing_(data, kind, verbose=False):
3715
3817
  """
3716
3818
  Automatically formats data for various seaborn plot types.
@@ -3733,10 +3835,10 @@ def df_preprocessing_(data, kind, verbose=False):
3733
3835
  "heatmap",
3734
3836
  "pairplot",
3735
3837
  "jointplot", # Typically requires wide format for axis variables
3736
- "facetgrid", # Used for creating small multiples (can work with wide format)
3737
- "barplot", # Can be used with wide format
3738
- "pointplot", # Works well with wide format
3739
- "pivot_table", # Works with wide format (aggregated data)
3838
+ "facetgrid", # Used for creating small multiples (can work with wide format)
3839
+ "barplot", # Can be used with wide format
3840
+ "pointplot", # Works well with wide format
3841
+ "pivot_table", # Works with wide format (aggregated data)
3740
3842
  "boxplot",
3741
3843
  "violinplot",
3742
3844
  "stripplot",
@@ -3745,23 +3847,23 @@ def df_preprocessing_(data, kind, verbose=False):
3745
3847
  "lineplot",
3746
3848
  "scatterplot",
3747
3849
  "relplot",
3748
- "barplot", # Can also work with long format (aggregated data in long form)
3749
- "boxenplot", # Similar to boxplot, works with long format
3750
- "countplot", # Works best with long format (categorical data)
3751
- "heatmap", # Can work with long format after reshaping
3752
- "lineplot", # Can work with long format (time series, continuous)
3753
- "histplot", # Can be used with both wide and long formats
3754
- "kdeplot", # Works with both wide and long formats
3755
- "ecdfplot", # Works with both formats
3756
- "scatterplot", # Can work with both formats depending on data structure
3757
- "lineplot", # Can work with both wide and long formats
3758
- "area plot", # Can work with both formats, useful for stacked areas
3850
+ "barplot", # Can also work with long format (aggregated data in long form)
3851
+ "boxenplot", # Similar to boxplot, works with long format
3852
+ "countplot", # Works best with long format (categorical data)
3853
+ "heatmap", # Can work with long format after reshaping
3854
+ "lineplot", # Can work with long format (time series, continuous)
3855
+ "histplot", # Can be used with both wide and long formats
3856
+ "kdeplot", # Works with both wide and long formats
3857
+ "ecdfplot", # Works with both formats
3858
+ "scatterplot", # Can work with both formats depending on data structure
3859
+ "lineplot", # Can work with both wide and long formats
3860
+ "area plot", # Can work with both formats, useful for stacked areas
3759
3861
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3760
- "ellipse",# ellipse plot, default confidence=0.95
3862
+ "ellipse", # ellipse plot, default confidence=0.95
3761
3863
  ],
3762
3864
  )[0]
3763
3865
 
3764
- wide_kinds = [
3866
+ wide_kinds = [
3765
3867
  "pairplot",
3766
3868
  ]
3767
3869
 
@@ -3773,26 +3875,26 @@ def df_preprocessing_(data, kind, verbose=False):
3773
3875
  # Flexible kinds: distribution plots can use either format
3774
3876
  flexible_kinds = [
3775
3877
  "jointplot", # Typically requires wide format for axis variables
3776
- "lineplot", # Can work with long format (time series, continuous)
3878
+ "lineplot", # Can work with long format (time series, continuous)
3777
3879
  "lineplot",
3778
3880
  "scatterplot",
3779
- "barplot", # Can also work with long format (aggregated data in long form)
3780
- "boxenplot", # Similar to boxplot, works with long format
3781
- "countplot", # Works best with long format (categorical data)
3881
+ "barplot", # Can also work with long format (aggregated data in long form)
3882
+ "boxenplot", # Similar to boxplot, works with long format
3883
+ "countplot", # Works best with long format (categorical data)
3782
3884
  "regplot",
3783
3885
  "violinplot",
3784
3886
  "stripplot",
3785
3887
  "swarmplot",
3786
3888
  "boxplot",
3787
- "histplot", # Can be used with both wide and long formats
3788
- "kdeplot", # Works with both wide and long formats
3789
- "ecdfplot", # Works with both formats
3790
- "scatterplot", # Can work with both formats depending on data structure
3791
- "lineplot", # Can work with both wide and long formats
3792
- "area plot", # Can work with both formats, useful for stacked areas
3889
+ "histplot", # Can be used with both wide and long formats
3890
+ "kdeplot", # Works with both wide and long formats
3891
+ "ecdfplot", # Works with both formats
3892
+ "scatterplot", # Can work with both formats depending on data structure
3893
+ "lineplot", # Can work with both wide and long formats
3894
+ "area plot", # Can work with both formats, useful for stacked areas
3793
3895
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3794
3896
  "relplot",
3795
- "pointplot", # Works well with wide format
3897
+ "pointplot", # Works well with wide format
3796
3898
  "ellipse",
3797
3899
  ]
3798
3900
 
@@ -4296,17 +4398,17 @@ def venn(
4296
4398
  colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
4297
4399
  if labels is None:
4298
4400
  if len(lists) == 2:
4299
- labels = ["set1", "set2"]
4300
- elif len(lists) == 3:
4401
+ labels = ["set1", "set2"]
4402
+ elif len(lists) == 3:
4301
4403
  labels = ["set1", "set2", "set3"]
4302
- elif len(lists) == 4:
4303
- labels = ["set1", "set2", "set3","set4"]
4304
- elif len(lists) == 5:
4305
- labels = ["set1", "set2", "set3","set4","set55"]
4306
- elif len(lists) == 6:
4307
- labels = ["set1", "set2", "set3","set4","set5","set6"]
4308
- elif len(lists) == 7:
4309
- labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
4404
+ elif len(lists) == 4:
4405
+ labels = ["set1", "set2", "set3", "set4"]
4406
+ elif len(lists) == 5:
4407
+ labels = ["set1", "set2", "set3", "set4", "set55"]
4408
+ elif len(lists) == 6:
4409
+ labels = ["set1", "set2", "set3", "set4", "set5", "set6"]
4410
+ elif len(lists) == 7:
4411
+ labels = ["set1", "set2", "set3", "set4", "set5", "set6", "set7"]
4310
4412
  if edgecolor is None:
4311
4413
  edgecolor = colors
4312
4414
  colors = [desaturate_color(color, saturation) for color in colors]
@@ -4320,10 +4422,11 @@ def venn(
4320
4422
  if show_percentages
4321
4423
  else f"{subset_count}"
4322
4424
  )
4425
+
4323
4426
  if fmt is not None:
4324
4427
  if not fmt.startswith("{"):
4325
- fmt="{:" + fmt + "}"
4326
- if len(lists) == 2:
4428
+ fmt = "{:" + fmt + "}"
4429
+ if len(lists) == 2:
4327
4430
 
4328
4431
  from matplotlib_venn import venn2, venn2_circles
4329
4432
 
@@ -4345,7 +4448,7 @@ def venn(
4345
4448
  # v.get_label_by_id('10').set_text(len(set1 - set2))
4346
4449
  # v.get_label_by_id('01').set_text(len(set2 - set1))
4347
4450
  # v.get_label_by_id('11').set_text(len(set1 & set2))
4348
-
4451
+
4349
4452
  v.get_label_by_id("10").set_text(
4350
4453
  get_count_and_percentage(universe, len(set1 - set2))
4351
4454
  )
@@ -4447,7 +4550,7 @@ def venn(
4447
4550
  if "none" in edgecolor or 0 in linewidth:
4448
4551
  patch.set_edgecolor("none")
4449
4552
  return ax
4450
- elif len(lists) == 3:
4553
+ elif len(lists) == 3:
4451
4554
 
4452
4555
  from matplotlib_venn import venn3, venn3_circles
4453
4556
 
@@ -4463,32 +4566,46 @@ def venn(
4463
4566
  # Draw the venn diagram
4464
4567
  v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
4465
4568
  v.get_patch_by_id("100").set_color(colors[0])
4466
- v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
4569
+ v.get_label_by_id("100").set_text(
4570
+ get_count_and_percentage(universe, len(set1 - set2 - set3))
4571
+ )
4467
4572
  v.get_patch_by_id("010").set_color(colors[1])
4468
- v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
4573
+ v.get_label_by_id("010").set_text(
4574
+ get_count_and_percentage(universe, len(set2 - set1 - set3))
4575
+ )
4469
4576
  try:
4470
4577
  v.get_patch_by_id("001").set_color(colors[2])
4471
- v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
4578
+ v.get_label_by_id("001").set_text(
4579
+ get_count_and_percentage(universe, len(set3 - set1 - set2))
4580
+ )
4472
4581
  except Exception as e:
4473
4582
  print(e)
4474
4583
  try:
4475
4584
  v.get_patch_by_id("110").set_color(colorAB)
4476
- v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
4585
+ v.get_label_by_id("110").set_text(
4586
+ get_count_and_percentage(universe, len(set1 & set2 - set3))
4587
+ )
4477
4588
  except Exception as e:
4478
4589
  print(e)
4479
4590
  try:
4480
4591
  v.get_patch_by_id("101").set_color(colorAC)
4481
- v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
4592
+ v.get_label_by_id("101").set_text(
4593
+ get_count_and_percentage(universe, len(set1 & set3 - set2))
4594
+ )
4482
4595
  except Exception as e:
4483
4596
  print(e)
4484
4597
  try:
4485
4598
  v.get_patch_by_id("011").set_color(colorBC)
4486
- v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
4599
+ v.get_label_by_id("011").set_text(
4600
+ get_count_and_percentage(universe, len(set2 & set3 - set1))
4601
+ )
4487
4602
  except Exception as e:
4488
4603
  print(e)
4489
4604
  try:
4490
4605
  v.get_patch_by_id("111").set_color(colorABC)
4491
- v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
4606
+ v.get_label_by_id("111").set_text(
4607
+ get_count_and_percentage(universe, len(set1 & set2 & set3))
4608
+ )
4492
4609
  except Exception as e:
4493
4610
  print(e)
4494
4611
 
@@ -4603,26 +4720,24 @@ def venn(
4603
4720
  patch.set_edgecolor("none")
4604
4721
  return ax
4605
4722
 
4606
-
4607
4723
  dict_data = {}
4608
4724
  for i_list, list_ in enumerate(lists):
4609
- dict_data[labels[i_list]]={*list_}
4725
+ dict_data[labels[i_list]] = {*list_}
4610
4726
 
4611
- if 3<len(lists)<6:
4727
+ if 3 < len(lists) < 6:
4612
4728
  from venn import venn as vn
4613
4729
 
4614
- legend_loc=kwargs.pop("legend_loc", "upper right")
4615
- ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
4730
+ legend_loc = kwargs.pop("legend_loc", "upper right")
4731
+ ax = vn(dict_data, ax=ax, legend_loc=legend_loc, **kwargs)
4616
4732
 
4617
4733
  return ax
4618
4734
  else:
4619
4735
  from venn import pseudovenn
4620
- cmap=kwargs.pop("cmap","plasma")
4621
- ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
4622
-
4623
- return ax
4624
4736
 
4737
+ cmap = kwargs.pop("cmap", "plasma")
4738
+ ax = pseudovenn(dict_data, cmap=cmap, ax=ax, **kwargs)
4625
4739
 
4740
+ return ax
4626
4741
 
4627
4742
 
4628
4743
  #! subplots, support automatic extend new axis
@@ -4670,7 +4785,7 @@ def subplot(
4670
4785
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
4671
4786
 
4672
4787
  def expand_ax():
4673
- nonlocal rows, grid_spec,cols,row_first_axes,fig,figsize,figsize_recommend
4788
+ nonlocal rows, grid_spec, cols, row_first_axes, fig, figsize, figsize_recommend
4674
4789
  # fig_height = fig.get_figheight()
4675
4790
  # subplot_height = fig_height / rows
4676
4791
  rows += 1 # Expands by adding a row
@@ -4678,10 +4793,11 @@ def subplot(
4678
4793
  fig.set_size_inches(figsize)
4679
4794
  grid_spec = GridSpec(rows, cols, figure=fig)
4680
4795
  row_first_axes.append(None)
4681
- figsize_recommend=f"Warning: 建议设置 subplot({rows}, {cols})"
4796
+ figsize_recommend = f"Warning: 建议设置 subplot({rows}, {cols})"
4682
4797
  print(figsize_recommend)
4798
+
4683
4799
  def nexttile(rowspan=1, colspan=1, **kwargs):
4684
- nonlocal rows, cols, occupied, grid_spec,fig,figsize_recommend
4800
+ nonlocal rows, cols, occupied, grid_spec, fig, figsize_recommend
4685
4801
  for row in range(rows):
4686
4802
  for col in range(cols):
4687
4803
  if all(
@@ -4693,7 +4809,7 @@ def subplot(
4693
4809
  else:
4694
4810
  continue
4695
4811
  break
4696
-
4812
+
4697
4813
  else:
4698
4814
  expand_ax()
4699
4815
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
@@ -4715,12 +4831,13 @@ def subplot(
4715
4831
  row_first_axes[row] = ax
4716
4832
  if col_first_axes[col] is None:
4717
4833
  col_first_axes[col] = ax
4718
-
4834
+
4719
4835
  for r in range(row, row + rowspan):
4720
4836
  for c in range(col, col + colspan):
4721
4837
  occupied.add((r, c))
4722
-
4838
+
4723
4839
  return ax
4840
+
4724
4841
  return nexttile
4725
4842
 
4726
4843
 
@@ -4743,7 +4860,7 @@ def radar(
4743
4860
  bg_color="0.8",
4744
4861
  bg_alpha=None,
4745
4862
  grid_interval_ratio=0.2,
4746
- show_value=False,# show text for each value
4863
+ show_value=False, # show text for each value
4747
4864
  cmap=None,
4748
4865
  legend_loc="upper right",
4749
4866
  legend_fontsize=10,
@@ -4777,9 +4894,9 @@ def radar(
4777
4894
  index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
4778
4895
  columns=["Hero", "Warrior", "Wizard"],
4779
4896
  )
4780
- usage 1:
4897
+ usage 1:
4781
4898
  radar(data=df)
4782
- usage 2:
4899
+ usage 2:
4783
4900
  radar(data=df["Wizard"])
4784
4901
  usage 3:
4785
4902
  radar(data=df, columns="Wizard")
@@ -4819,8 +4936,8 @@ def radar(
4819
4936
  - sp (int): Padding for the ticks from the plot area.
4820
4937
  - **kwargs: Additional arguments for customization.
4821
4938
  """
4822
- if run_once_within(20,reverse=True) and verbose:
4823
- usage_="""usage:
4939
+ if run_once_within(20, reverse=True) and verbose:
4940
+ usage_ = """usage:
4824
4941
  radar(
4825
4942
  data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
4826
4943
  ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4864,20 +4981,20 @@ def radar(
4864
4981
  kws_figsets = v_arg
4865
4982
  kwargs.pop(k_arg, None)
4866
4983
  break
4867
- if axis==1:
4868
- data=data.T
4984
+ if axis == 1:
4985
+ data = data.T
4869
4986
  if isinstance(data, dict):
4870
4987
  data = pd.DataFrame(pd.Series(data))
4871
4988
  if ~isinstance(data, pd.DataFrame):
4872
- data=pd.DataFrame(data)
4989
+ data = pd.DataFrame(data)
4873
4990
  if isinstance(data, pd.DataFrame):
4874
- data=data.select_dtypes(include=np.number)
4875
- if isinstance(columns,str):
4876
- columns=[columns]
4991
+ data = data.select_dtypes(include=np.number)
4992
+ if isinstance(columns, str):
4993
+ columns = [columns]
4877
4994
  if columns is None:
4878
4995
  columns = list(data.columns)
4879
- data=data[columns]
4880
- categories = list(data.index)
4996
+ data = data[columns]
4997
+ categories = list(data.index)
4881
4998
  num_vars = len(categories)
4882
4999
 
4883
5000
  # Set y-axis limits and grid intervals
@@ -4928,7 +5045,9 @@ def radar(
4928
5045
  else:
4929
5046
  # * spider style: spider-style grid (straight lines, not circles)
4930
5047
  # Create the spider-style grid (straight lines, not circles)
4931
- for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
5048
+ for i in range(
5049
+ 1, int((vmax - vmin) / ((vmax - vmin) * grid_interval_ratio)) + 1
5050
+ ): # int(vmax * grid_interval_ratio) + 1):
4932
5051
  ax.plot(
4933
5052
  angles + [angles[0]], # Closing the loop
4934
5053
  [i * vmax * grid_interval_ratio] * (num_vars + 1)
@@ -4962,8 +5081,8 @@ def radar(
4962
5081
  ax.tick_params(axis="y", pad=sp) # move spines outward
4963
5082
  # colors
4964
5083
  if facecolor is not None:
4965
- if not isinstance(facecolor,list):
4966
- facecolor=[facecolor]
5084
+ if not isinstance(facecolor, list):
5085
+ facecolor = [facecolor]
4967
5086
  colors = facecolor
4968
5087
  else:
4969
5088
  colors = (
@@ -4987,44 +5106,44 @@ def radar(
4987
5106
  )
4988
5107
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4989
5108
  # Add text labels for each value at each angle
4990
- labeled_points = set() #这样同一个点就不会标多次了
5109
+ labeled_points = set() # 这样同一个点就不会标多次了
4991
5110
  if show_value:
4992
5111
  for angle, value in zip(angles, values):
4993
5112
  if (angle, value) not in labeled_points:
4994
5113
  # offset_radius = value * value_offset
4995
5114
  lim_ = np.max(values)
4996
- sep_in = lim_/5
4997
- sep_low=sep_in*2
4998
- sep_med=sep_in*3
4999
- sep_hig=sep_in*4
5000
- sep_out=lim_*5
5001
- if value<sep_in:
5115
+ sep_in = lim_ / 5
5116
+ sep_low = sep_in * 2
5117
+ sep_med = sep_in * 3
5118
+ sep_hig = sep_in * 4
5119
+ sep_out = lim_ * 5
5120
+ if value < sep_in:
5002
5121
  offset_radius = value * 0.7
5003
- elif value<sep_low:
5122
+ elif value < sep_low:
5004
5123
  offset_radius = value * 0.8
5005
- elif sep_low<=value<sep_med:
5124
+ elif sep_low <= value < sep_med:
5006
5125
  offset_radius = value * 0.85
5007
- elif sep_med<=value<sep_hig:
5126
+ elif sep_med <= value < sep_hig:
5008
5127
  offset_radius = value * 0.9
5009
- elif sep_hig<=value<sep_out:
5128
+ elif sep_hig <= value < sep_out:
5010
5129
  offset_radius = value * 0.93
5011
5130
  else:
5012
5131
  offset_radius = value * 0.98
5013
5132
  ax.text(
5014
- angle,
5133
+ angle,
5015
5134
  offset_radius,
5016
- f"{value:{fmt}}",
5017
- ha="center",
5018
- va="center",
5019
- fontsize=fontsize,
5135
+ f"{value:{fmt}}",
5136
+ ha="center",
5137
+ va="center",
5138
+ fontsize=fontsize,
5020
5139
  color=fontcolor,
5021
- zorder=11
5140
+ zorder=11,
5022
5141
  )
5023
5142
  labeled_points.add((angle, value))
5024
5143
 
5025
5144
  ax.set_ylim(ylim)
5026
5145
  # Add markers for each data point
5027
- for i, (col, val) in enumerate(data.items()):
5146
+ for i, (col, val) in enumerate(data.items()):
5028
5147
  ax.plot(
5029
5148
  angles,
5030
5149
  list(val) + [val[0]], # Close the loop for markers
@@ -5055,15 +5174,15 @@ def radar(
5055
5174
 
5056
5175
 
5057
5176
  def pie(
5058
- data:pd.Series,
5059
- columns:list = None,
5177
+ data: pd.Series,
5178
+ columns: list = None,
5060
5179
  facecolor=None,
5061
5180
  explode=[0.1],
5062
5181
  startangle=90,
5063
5182
  shadow=True,
5064
5183
  fontcolor="k",
5065
- fmt=".2f",
5066
- width=None,# the center blank
5184
+ fmt=".2f",
5185
+ width=None, # the center blank
5067
5186
  pctdistance=0.85,
5068
5187
  labeldistance=1.1,
5069
5188
  kws_wedge={},
@@ -5077,9 +5196,9 @@ def pie(
5077
5196
  edgewidth=1,
5078
5197
  cmap=None,
5079
5198
  show_value=False,
5080
- show_label=True,# False: only show the outer layer, if it is None, not show
5081
- expand_label=(1.2,1.2),
5082
- kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5199
+ show_label=True, # False: only show the outer layer, if it is None, not show
5200
+ expand_label=(1.2, 1.2),
5201
+ kws_bbox={}, # dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5083
5202
  show_legend=True,
5084
5203
  legend_loc="upper right",
5085
5204
  bbox_to_anchor=[1.4, 1.1],
@@ -5087,11 +5206,12 @@ def pie(
5087
5206
  rotation_correction=0,
5088
5207
  verbose=True,
5089
5208
  ax=None,
5090
- **kwargs
5091
- ):
5209
+ **kwargs,
5210
+ ):
5092
5211
  from adjustText import adjust_text
5093
- if run_once_within(20,reverse=True) and verbose:
5094
- usage_="""usage:
5212
+
5213
+ if run_once_within(20, reverse=True) and verbose:
5214
+ usage_ = """usage:
5095
5215
  pie(
5096
5216
  data:pd.Series,
5097
5217
  columns:list = None,
@@ -5190,45 +5310,45 @@ def pie(
5190
5310
  if isinstance(data, dict):
5191
5311
  data = pd.DataFrame(pd.Series(data))
5192
5312
  if ~isinstance(data, pd.DataFrame):
5193
- data=pd.DataFrame(data)
5313
+ data = pd.DataFrame(data)
5194
5314
 
5195
5315
  if isinstance(data, pd.DataFrame):
5196
- data=data.select_dtypes(include=np.number)
5197
- if isinstance(columns,str):
5198
- columns=[columns]
5316
+ data = data.select_dtypes(include=np.number)
5317
+ if isinstance(columns, str):
5318
+ columns = [columns]
5199
5319
  if columns is None:
5200
5320
  columns = list(data.columns)
5201
5321
  # data=data[columns]
5202
5322
  # columns = list(data.columns)
5203
5323
  # print(columns)
5204
5324
  # 选择部分数据
5205
- df=data[columns]
5325
+ df = data[columns]
5206
5326
 
5207
5327
  if not isinstance(explode, list):
5208
- explode=[explode]
5328
+ explode = [explode]
5209
5329
  if explode == [None]:
5210
- explode=[0]
5330
+ explode = [0]
5211
5331
 
5212
5332
  if width is None:
5213
- if df.shape[1]>1:
5214
- width=1/(df.shape[1]+2)
5333
+ if df.shape[1] > 1:
5334
+ width = 1 / (df.shape[1] + 2)
5215
5335
  else:
5216
- width=1
5217
- if isinstance(width,(float,int)):
5218
- width=[width]
5219
- if len(width)<df.shape[1]:
5220
- width=width*df.shape[1]
5221
- if isinstance(radius,(float,int)):
5222
- radius=[radius]
5223
- radius_tile=[1]*df.shape[1]
5224
- radius=radius_tile.copy()
5225
- for i in range(1,df.shape[1]):
5226
- radius[i]=radius_tile[i]-np.sum(width[:i])
5336
+ width = 1
5337
+ if isinstance(width, (float, int)):
5338
+ width = [width]
5339
+ if len(width) < df.shape[1]:
5340
+ width = width * df.shape[1]
5341
+ if isinstance(radius, (float, int)):
5342
+ radius = [radius]
5343
+ radius_tile = [1] * df.shape[1]
5344
+ radius = radius_tile.copy()
5345
+ for i in range(1, df.shape[1]):
5346
+ radius[i] = radius_tile[i] - np.sum(width[:i])
5227
5347
 
5228
5348
  # colors
5229
5349
  if facecolor is not None:
5230
- if not isinstance(facecolor,list):
5231
- facecolor=[facecolor]
5350
+ if not isinstance(facecolor, list):
5351
+ facecolor = [facecolor]
5232
5352
  colors = facecolor
5233
5353
  else:
5234
5354
  colors = (
@@ -5239,38 +5359,38 @@ def pie(
5239
5359
  # to check if facecolor is nested list or not
5240
5360
  is_nested = True if any(isinstance(i, list) for i in colors) else False
5241
5361
  inested = 0
5242
- for column_,width_,radius_ in zip(columns, width,radius):
5243
- if column_!=columns[0]:
5362
+ for column_, width_, radius_ in zip(columns, width, radius):
5363
+ if column_ != columns[0]:
5244
5364
  labels = data.index if show_label else None
5245
5365
  else:
5246
5366
  labels = data.index if show_label is not None else None
5247
5367
  data = df[column_]
5248
- labels_legend=data.index
5368
+ labels_legend = data.index
5249
5369
  sizes = data.values
5250
-
5370
+
5251
5371
  # Set wedge and text properties if none are provided
5252
5372
  kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
5253
- kws_wedge.update({"width":width_})
5254
- fontcolor=kws_text.get("color",fontcolor)
5255
- fontsize=kws_text.get("fontsize",fontsize)
5373
+ kws_wedge.update({"width": width_})
5374
+ fontcolor = kws_text.get("color", fontcolor)
5375
+ fontsize = kws_text.get("fontsize", fontsize)
5256
5376
  kws_text.update({"color": fontcolor, "fontsize": fontsize})
5257
5377
 
5258
- if ax is None:
5259
- ax=plt.gca()
5260
- if len(explode)<len(labels_legend):
5261
- explode.extend([0]*(len(labels_legend)-len(explode)))
5378
+ if ax is None:
5379
+ ax = plt.gca()
5380
+ if len(explode) < len(labels_legend):
5381
+ explode.extend([0] * (len(labels_legend) - len(explode)))
5262
5382
  print(explode)
5263
5383
  if fmt:
5264
5384
  if not fmt.startswith("%"):
5265
- autopct =f"%{fmt}%%"
5385
+ autopct = f"%{fmt}%%"
5266
5386
  else:
5267
- autopct=None
5387
+ autopct = None
5268
5388
 
5269
5389
  if show_value is None:
5270
5390
  result = ax.pie(
5271
5391
  sizes,
5272
5392
  labels=labels,
5273
- autopct= None,
5393
+ autopct=None,
5274
5394
  startangle=startangle + rotation_correction,
5275
5395
  explode=explode,
5276
5396
  colors=colors[inested] if is_nested else colors,
@@ -5282,7 +5402,7 @@ def pie(
5282
5402
  center=center,
5283
5403
  radius=radius_,
5284
5404
  frame=frame,
5285
- **kwargs
5405
+ **kwargs,
5286
5406
  )
5287
5407
  else:
5288
5408
  result = ax.pie(
@@ -5292,7 +5412,7 @@ def pie(
5292
5412
  startangle=startangle + rotation_correction,
5293
5413
  explode=explode,
5294
5414
  colors=colors[inested] if is_nested else colors,
5295
- shadow=shadow,#shadow,
5415
+ shadow=shadow, # shadow,
5296
5416
  pctdistance=pctdistance,
5297
5417
  labeldistance=labeldistance,
5298
5418
  wedgeprops=kws_wedge,
@@ -5300,13 +5420,13 @@ def pie(
5300
5420
  center=center,
5301
5421
  radius=radius_,
5302
5422
  frame=frame,
5303
- **kwargs
5423
+ **kwargs,
5304
5424
  )
5305
5425
  if len(result) == 3:
5306
5426
  wedges, texts, autotexts = result
5307
5427
  elif len(result) == 2:
5308
5428
  wedges, texts = result
5309
- autotexts = None
5429
+ autotexts = None
5310
5430
  #! adjust_text
5311
5431
  if autotexts or texts:
5312
5432
  all_texts = []
@@ -5314,30 +5434,38 @@ def pie(
5314
5434
  all_texts.extend(autotexts)
5315
5435
  if texts and show_label:
5316
5436
  all_texts.extend(texts)
5317
-
5437
+
5318
5438
  adjust_text(
5319
5439
  all_texts,
5320
5440
  ax=ax,
5321
- arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
5441
+ arrowprops=kws_arrow, # dict(arrowstyle="-", color="gray", lw=0.5),
5322
5442
  bbox=kws_bbox if kws_bbox else None,
5323
5443
  expand=expand_label,
5324
5444
  fontdict={
5325
- "fontsize": fontsize,
5326
- "color": fontcolor,
5327
- },
5445
+ "fontsize": fontsize,
5446
+ "color": fontcolor,
5447
+ },
5328
5448
  )
5329
5449
  # Show exact values on wedges if show_value is True
5330
5450
  if show_value:
5331
5451
  for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5332
5452
  angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5333
- x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5334
- y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5453
+ x = np.cos(np.radians(angle)) * (pctdistance) * radius_
5454
+ y = np.sin(np.radians(angle)) * (pctdistance) * radius_
5335
5455
  if not fmt.startswith("{"):
5336
5456
  value_text = f"{sizes[i]:{fmt}}"
5337
5457
  else:
5338
- value_text = fmt.format(sizes[i])
5339
- ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5340
- inested+=1
5458
+ value_text = fmt.format(sizes[i])
5459
+ ax.text(
5460
+ x,
5461
+ y,
5462
+ value_text,
5463
+ ha="center",
5464
+ va="center",
5465
+ fontsize=fontsize,
5466
+ color=fontcolor,
5467
+ )
5468
+ inested += 1
5341
5469
  # Customize the legend
5342
5470
  if show_legend:
5343
5471
  ax.legend(
@@ -5347,10 +5475,11 @@ def pie(
5347
5475
  bbox_to_anchor=bbox_to_anchor,
5348
5476
  fontsize=legend_fontsize,
5349
5477
  title_fontsize=legend_fontsize,
5350
- )
5478
+ )
5351
5479
  ax.set(aspect="equal")
5352
5480
  return ax
5353
5481
 
5482
+
5354
5483
  def ellipse(
5355
5484
  data,
5356
5485
  x=None,
@@ -5363,7 +5492,7 @@ def ellipse(
5363
5492
  palette=None,
5364
5493
  facecolor=None,
5365
5494
  edgecolor=None,
5366
- label:bool=True,
5495
+ label: bool = True,
5367
5496
  **kwargs,
5368
5497
  ):
5369
5498
  """
@@ -5438,8 +5567,8 @@ def ellipse(
5438
5567
  groups = [None]
5439
5568
  color_map = {None: kwargs.get("edgecolor", "blue")}
5440
5569
  alpha = kwargs.pop("alpha", 0.2)
5441
- edgecolor=kwargs.pop("edgecolor", None)
5442
- facecolor=kwargs.pop("facecolor", None)
5570
+ edgecolor = kwargs.pop("edgecolor", None)
5571
+ facecolor = kwargs.pop("facecolor", None)
5443
5572
  for group in groups:
5444
5573
  group_data = data[data[hue] == group] if hue else data
5445
5574
 
@@ -5477,7 +5606,7 @@ def ellipse(
5477
5606
  height=height,
5478
5607
  angle=angle,
5479
5608
  edgecolor=edgecolor_,
5480
- facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
5609
+ facecolor=(facecolor_, alpha), # facecolor_, # only work on facecolor
5481
5610
  # alpha=alpha,
5482
5611
  label=group if (hue and label) else None,
5483
5612
  **kwargs,
@@ -5504,6 +5633,7 @@ def ellipse(
5504
5633
 
5505
5634
  return ax
5506
5635
 
5636
+
5507
5637
  def ppi(
5508
5638
  interactions,
5509
5639
  player1="preferredName_A",
@@ -5511,39 +5641,39 @@ def ppi(
5511
5641
  weight="score",
5512
5642
  n_layers=None, # Number of concentric layers
5513
5643
  n_rank=[5, 10], # Nodes in each rank for the concentric layout
5514
- dist_node = 10, # Distance between each rank of circles
5515
- layout="degree",
5516
- size=None,#700,
5517
- sizes=(50,500),# min and max of size
5644
+ dist_node=10, # Distance between each rank of circles
5645
+ layout="degree",
5646
+ size=None, # 700,
5647
+ sizes=(50, 500), # min and max of size
5518
5648
  facecolor="skyblue",
5519
- cmap='coolwarm',
5649
+ cmap="coolwarm",
5520
5650
  edgecolor="k",
5521
5651
  edgelinewidth=1.5,
5522
- alpha=.5,
5523
- alphas=(0.1, 1.0),# min and max of alpha
5652
+ alpha=0.5,
5653
+ alphas=(0.1, 1.0), # min and max of alpha
5524
5654
  marker="o",
5525
5655
  node_hideticks=True,
5526
5656
  linecolor="gray",
5527
- line_cmap='coolwarm',
5657
+ line_cmap="coolwarm",
5528
5658
  linewidth=1.5,
5529
- linewidths=(0.5,5),# min and max of linewidth
5659
+ linewidths=(0.5, 5), # min and max of linewidth
5530
5660
  linealpha=1.0,
5531
- linealphas=(0.1,1.0),# min and max of linealpha
5661
+ linealphas=(0.1, 1.0), # min and max of linealpha
5532
5662
  linestyle="-",
5533
- line_arrowstyle='-',
5663
+ line_arrowstyle="-",
5534
5664
  fontsize=10,
5535
5665
  fontcolor="k",
5536
- ha:str="center",
5537
- va:str="center",
5666
+ ha: str = "center",
5667
+ va: str = "center",
5538
5668
  figsize=(12, 10),
5539
- k_value=0.3,
5669
+ k_value=0.3,
5540
5670
  bgcolor="w",
5541
5671
  dir_save="./ppi_network.html",
5542
5672
  physics=True,
5543
5673
  notebook=False,
5544
5674
  scale=1,
5545
5675
  ax=None,
5546
- **kwargs
5676
+ **kwargs,
5547
5677
  ):
5548
5678
  """
5549
5679
  Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
@@ -5560,14 +5690,14 @@ def ppi(
5560
5690
  )
5561
5691
  """
5562
5692
  from pyvis.network import Network
5563
- import networkx as nx
5693
+ import networkx as nx
5564
5694
  from IPython.display import IFrame
5565
5695
  from matplotlib.colors import Normalize
5566
5696
  from matplotlib import cm
5567
5697
  from . import ips
5568
5698
 
5569
5699
  if run_once_within():
5570
- usage_str="""
5700
+ usage_str = """
5571
5701
  ppi(
5572
5702
  interactions,
5573
5703
  player1="preferredName_A",
@@ -5611,17 +5741,19 @@ def ppi(
5611
5741
  ):
5612
5742
  """
5613
5743
  print(usage_str)
5614
-
5744
+
5615
5745
  # Check for required columns in the DataFrame
5616
5746
  for col in [player1, player2, weight]:
5617
5747
  if col not in interactions.columns:
5618
- raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
5748
+ raise ValueError(
5749
+ f"Column '{col}' is missing from the interactions DataFrame."
5750
+ )
5619
5751
  interactions.sort_values(by=[weight], inplace=True)
5620
5752
  # Initialize Pyvis network
5621
5753
  net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
5622
5754
  net.force_atlas_2based(
5623
5755
  gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
5624
- )
5756
+ )
5625
5757
  net.toggle_physics(physics)
5626
5758
 
5627
5759
  kws_figsets = {}
@@ -5637,47 +5769,62 @@ def ppi(
5637
5769
  G.add_edge(row[player1], row[player2], weight=row[weight])
5638
5770
  # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
5639
5771
 
5640
-
5641
5772
  # Calculate node degrees
5642
5773
  degrees = dict(G.degree())
5643
5774
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5644
5775
  colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
5645
5776
 
5646
- if not ips.isa(facecolor, 'color'):
5777
+ if not ips.isa(facecolor, "color"):
5647
5778
  print("facecolor: based on degrees")
5648
5779
  facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
5649
5780
  num_nodes = G.number_of_nodes()
5650
- #* size
5781
+ # * size
5651
5782
  # Set properties based on degrees
5652
- if not isinstance(size, (int,float,list)):
5783
+ if not isinstance(size, (int, float, list)):
5653
5784
  print("size: based on degrees")
5654
5785
  size = [deg * 50 for deg in degrees.values()] # Scale sizes
5655
- size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
5656
- if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
5786
+ size = (
5787
+ (size[:num_nodes] if len(size) > num_nodes else size)
5788
+ if isinstance(size, list)
5789
+ else [size] * num_nodes
5790
+ )
5791
+ if isinstance(size, list) and len(ips.flatten(size, verbose=False)) != 1:
5657
5792
  # Normalize sizes
5658
5793
  min_size, max_size = sizes # Use sizes tuple for min and max values
5659
5794
  min_degree, max_degree = min(size), max(size)
5660
5795
  if max_degree > min_degree: # Avoid division by zero
5661
5796
  size = [
5662
- min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5797
+ min_size
5798
+ + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5663
5799
  for sz in size
5664
5800
  ]
5665
5801
  else:
5666
5802
  # If all values are the same, set them to a default of the midpoint
5667
5803
  size = [(min_size + max_size) / 2] * len(size)
5668
5804
 
5669
- #* facecolor
5670
- facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
5805
+ # * facecolor
5806
+ facecolor = (
5807
+ (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor)
5808
+ if isinstance(facecolor, list)
5809
+ else [facecolor] * num_nodes
5810
+ )
5671
5811
  # * facealpha
5672
5812
  if isinstance(alpha, list):
5673
- alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
5813
+ alpha = (
5814
+ alpha[:num_nodes]
5815
+ if len(alpha) > num_nodes
5816
+ else alpha + [alpha[-1]] * (num_nodes - len(alpha))
5817
+ )
5674
5818
  min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
5675
5819
  if len(alpha) > 0:
5676
5820
  # Normalize alpha based on the specified min and max
5677
5821
  min_alpha, max_alpha = min(alpha), max(alpha)
5678
5822
  if max_alpha > min_alpha: # Avoid division by zero
5679
5823
  alpha = [
5680
- min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
5824
+ min_alphas
5825
+ + (max_alphas - min_alphas)
5826
+ * (ea - min_alpha)
5827
+ / (max_alpha - min_alpha)
5681
5828
  for ea in alpha
5682
5829
  ]
5683
5830
  else:
@@ -5685,7 +5832,7 @@ def ppi(
5685
5832
  alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
5686
5833
  else:
5687
5834
  # Default to a full opacity if no edges are provided
5688
- alpha = [1.0] * num_nodes
5835
+ alpha = [1.0] * num_nodes
5689
5836
  else:
5690
5837
  # If alpha is a single value, convert it to a list and normalize it
5691
5838
  alpha = [alpha] * num_nodes # Adjust based on alphas
@@ -5699,7 +5846,7 @@ def ppi(
5699
5846
  alpha=alpha[i],
5700
5847
  font={"size": fontsize, "color": fontcolor},
5701
5848
  )
5702
- print(f'nodes number: {i+1}')
5849
+ print(f"nodes number: {i+1}")
5703
5850
 
5704
5851
  for edge in G.edges(data=True):
5705
5852
  net.add_edge(
@@ -5718,11 +5865,11 @@ def ppi(
5718
5865
  "shell",
5719
5866
  "planar",
5720
5867
  "spiral",
5721
- "degree"
5868
+ "degree",
5722
5869
  ]
5723
5870
  layout = ips.strcmp(layout, layouts)[0]
5724
5871
  print(f"layout:{layout}, or select one in {layouts}")
5725
-
5872
+
5726
5873
  # Choose layout
5727
5874
  if layout == "spring":
5728
5875
  pos = nx.spring_layout(G, k=k_value)
@@ -5744,24 +5891,26 @@ def ppi(
5744
5891
  pos = nx.spring_layout(G, k=k_value)
5745
5892
  elif layout == "spiral":
5746
5893
  pos = nx.spiral_layout(G)
5747
- elif layout=='degree':
5894
+ elif layout == "degree":
5748
5895
  # Calculate node degrees and sort nodes by degree
5749
5896
  degrees = dict(G.degree())
5750
5897
  sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
5751
5898
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5752
5899
  colormap = cm.get_cmap(cmap)
5753
-
5900
+
5754
5901
  # Create positions for concentric circles based on n_layers and n_rank
5755
5902
  pos = {}
5756
- n_layers=len(n_rank)+1 if n_layers is None else n_layers
5903
+ n_layers = len(n_rank) + 1 if n_layers is None else n_layers
5757
5904
  for rank_index in range(n_layers):
5758
5905
  if rank_index < len(n_rank):
5759
5906
  nodes_per_rank = n_rank[rank_index]
5760
- rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
5907
+ rank_nodes = sorted_nodes[
5908
+ sum(n_rank[:rank_index]) : sum(n_rank[: rank_index + 1])
5909
+ ]
5761
5910
  else:
5762
5911
  # 随机打乱剩余节点的顺序
5763
- remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
5764
- random_indices = np.random.permutation(len(remaining_nodes))
5912
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]) :]
5913
+ random_indices = np.random.permutation(len(remaining_nodes))
5765
5914
  rank_nodes = [remaining_nodes[i] for i in random_indices]
5766
5915
 
5767
5916
  radius = (rank_index + 1) * dist_node # Radius for this rank
@@ -5772,7 +5921,9 @@ def ppi(
5772
5921
  pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
5773
5922
 
5774
5923
  else:
5775
- print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
5924
+ print(
5925
+ f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}"
5926
+ )
5776
5927
  pos = nx.spring_layout(G, k=k_value)
5777
5928
 
5778
5929
  for node, (x, y) in pos.items():
@@ -5781,8 +5932,8 @@ def ppi(
5781
5932
 
5782
5933
  # If ax is None, use plt.gca()
5783
5934
  if ax is None:
5784
- fig, ax = plt.subplots(1,1,figsize=figsize)
5785
-
5935
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
5936
+
5786
5937
  # Draw nodes, edges, and labels with customization options
5787
5938
  nx.draw_networkx_nodes(
5788
5939
  G,
@@ -5794,14 +5945,18 @@ def ppi(
5794
5945
  edgecolors=edgecolor,
5795
5946
  alpha=alpha,
5796
5947
  hide_ticks=node_hideticks,
5797
- node_shape=marker
5948
+ node_shape=marker,
5798
5949
  )
5799
5950
 
5800
- #* linewidth
5951
+ # * linewidth
5801
5952
  if not isinstance(linewidth, list):
5802
5953
  linewidth = [linewidth] * G.number_of_edges()
5803
5954
  else:
5804
- linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
5955
+ linewidth = (
5956
+ linewidth[: G.number_of_edges()]
5957
+ if len(linewidth) > G.number_of_edges()
5958
+ else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth))
5959
+ )
5805
5960
  # Normalize linewidth if it is a list
5806
5961
  if isinstance(linewidth, list):
5807
5962
  min_linewidth, max_linewidth = min(linewidth), max(linewidth)
@@ -5809,7 +5964,10 @@ def ppi(
5809
5964
  if max_linewidth > min_linewidth: # Avoid division by zero
5810
5965
  # Scale between vmin and vmax
5811
5966
  linewidth = [
5812
- vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
5967
+ vmin
5968
+ + (vmax - vmin)
5969
+ * (lw - min_linewidth)
5970
+ / (max_linewidth - min_linewidth)
5813
5971
  for lw in linewidth
5814
5972
  ]
5815
5973
  else:
@@ -5818,8 +5976,8 @@ def ppi(
5818
5976
  else:
5819
5977
  # If linewidth is a single value, convert it to a list of that value
5820
5978
  linewidth = [linewidth] * G.number_of_edges()
5821
- #* linecolor
5822
- if not isinstance(linecolor, str):
5979
+ # * linecolor
5980
+ if not isinstance(linecolor, str):
5823
5981
  weights = [G[u][v]["weight"] for u, v in G.edges()]
5824
5982
  norm = Normalize(vmin=min(weights), vmax=max(weights))
5825
5983
  colormap = cm.get_cmap(line_cmap)
@@ -5829,45 +5987,58 @@ def ppi(
5829
5987
 
5830
5988
  # * linealpha
5831
5989
  if isinstance(linealpha, list):
5832
- linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
5990
+ linealpha = (
5991
+ linealpha[: G.number_of_edges()]
5992
+ if len(linealpha) > G.number_of_edges()
5993
+ else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha))
5994
+ )
5833
5995
  min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
5834
5996
  if len(linealpha) > 0:
5835
5997
  min_linealpha, max_linealpha = min(linealpha), max(linealpha)
5836
5998
  if max_linealpha > min_linealpha: # Avoid division by zero
5837
5999
  linealpha = [
5838
- min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
6000
+ min_alpha
6001
+ + (max_alpha - min_alpha)
6002
+ * (ea - min_linealpha)
6003
+ / (max_linealpha - min_linealpha)
5839
6004
  for ea in linealpha
5840
6005
  ]
5841
6006
  else:
5842
6007
  linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
5843
6008
  else:
5844
- linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
6009
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
5845
6010
  else:
5846
6011
  linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
5847
6012
  nx.draw_networkx_edges(
5848
- G,
5849
- pos,
5850
- ax=ax,
5851
- edge_color=linecolor,
6013
+ G,
6014
+ pos,
6015
+ ax=ax,
6016
+ edge_color=linecolor,
5852
6017
  width=linewidth,
5853
6018
  style=linestyle,
5854
- arrowstyle=line_arrowstyle,
5855
- alpha=linealpha
6019
+ arrowstyle=line_arrowstyle,
6020
+ alpha=linealpha,
5856
6021
  )
5857
-
6022
+
5858
6023
  nx.draw_networkx_labels(
5859
- G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
6024
+ G,
6025
+ pos,
6026
+ ax=ax,
6027
+ font_size=fontsize,
6028
+ font_color=fontcolor,
6029
+ horizontalalignment=ha,
6030
+ verticalalignment=va,
5860
6031
  )
5861
- figsets(ax=ax,**kws_figsets)
6032
+ figsets(ax=ax, **kws_figsets)
5862
6033
  ax.axis("off")
5863
6034
  if dir_save:
5864
6035
  if not os.path.basename(dir_save):
5865
- dir_save="_.html"
6036
+ dir_save = "_.html"
5866
6037
  net.write_html(dir_save)
5867
- nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
6038
+ nx.write_graphml(G, dir_save.replace(".html", ".graphml")) # Export to GraphML
5868
6039
  print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
5869
- ips.figsave(dir_save.replace(".html",".pdf"))
5870
- return G,ax
6040
+ ips.figsave(dir_save.replace(".html", ".pdf"))
6041
+ return G, ax
5871
6042
 
5872
6043
 
5873
6044
  def plot_map(
@@ -5886,6 +6057,7 @@ def plot_map(
5886
6057
  save_path=None, # Path to save the map in offline mode
5887
6058
  pydeck_map=False, # Whether to use pydeck for rendering (True for pydeck)
5888
6059
  pydeck_style="mapbox://styles/mapbox/streets-v11", # Map style for pydeck
6060
+ verbose=True, # show usage
5889
6061
  **kwargs, # Additional arguments for Folium Map
5890
6062
  ):
5891
6063
  """
@@ -5899,19 +6071,88 @@ def plot_map(
5899
6071
  df_tiles = pd.DataFrame({"tiles": tiles_support})
5900
6072
  fsave("....tiles.csv",df_tiles)
5901
6073
  """
6074
+ config_markers = """from folium import Icon
6075
+ # https://github.com/lennardv2/Leaflet.awesome-markers?tab=readme-ov-file
6076
+ markers = [
6077
+ {
6078
+ "location": [loc[0], loc[1]],
6079
+ "popup": "Center City",
6080
+ "tooltip": "Philadelphia",
6081
+ "icon": Icon(color="red", icon="flag"),
6082
+ },
6083
+ {
6084
+ "location": [loc[0], loc[1] + 0.05],
6085
+ "popup": "Rittenhouse Square",
6086
+ "tooltip": "A lovely park",
6087
+ "icon": Icon(
6088
+ color="purple", icon="flag", prefix="fa"
6089
+ ), # Purple marker with "star" icon (Font Awesome)
6090
+ },
6091
+ ]"""
6092
+ config_overlay = """
6093
+ from folium import Circle
6094
+
6095
+ circle = Circle(
6096
+ location=loc,
6097
+ radius=300, # In meters
6098
+ color="#EB686C",
6099
+ fill=True,
6100
+ fill_opacity=0.2,
6101
+ )
6102
+ markers = [
6103
+ {
6104
+ "location": [loc[0], loc[1]],
6105
+ "popup": "Center City",
6106
+ "tooltip": "Philadelphia",
6107
+ },
6108
+ {
6109
+ "location": [loc[0], loc[1] + 0.05],
6110
+ "popup": "Rittenhouse Square",
6111
+ "tooltip": "A lovely park",
6112
+ },
6113
+ ]
6114
+ plot_map(loc, overlays=[circle], zoom_start=14)
6115
+ """
6116
+ config_plugin = """
6117
+ from folium.plugins import HeatMap
6118
+ heat_data = [
6119
+ [48.54440975, 9.060237673391708, 1],
6120
+ [48.5421456, 9.057464182487431, 1],
6121
+ [48.54539175, 9.059915422200906, 1],
6122
+ ]
6123
+ heatmap = HeatMap(
6124
+ heat_data,
6125
+ radius=5, # Increase the radius of each point
6126
+ blur=5, # Adjust the blurring effect
6127
+ min_opacity=0.4, # Make the heatmap semi-transparent
6128
+ max_zoom=16, # Zoom level at which points appear
6129
+ gradient={ # Define a custom gradient
6130
+ 0.2: "blue",
6131
+ 0.4: "lime",
6132
+ 0.6: "yellow",
6133
+ 1.0: "#A34B00",
6134
+ },
6135
+ )
6136
+
6137
+ plot_map(loc, plugins=[heatmap])
6138
+ """
5902
6139
  from pathlib import Path
5903
6140
 
5904
6141
  # Get the current script's directory as a Path object
5905
6142
  current_directory = Path(__file__).resolve().parent
5906
6143
  if not "tiles_support" in locals():
5907
- tiles_support = fload(current_directory / "data" / "tiles.csv", verbose=0).iloc[:, 1].tolist()
5908
- tiles=strcmp(tiles, tiles_support)[0]
6144
+ tiles_support = (
6145
+ fload(current_directory / "data" / "tiles.csv", verbose=0)
6146
+ .iloc[:, 1]
6147
+ .tolist()
6148
+ )
6149
+ tiles = strcmp(tiles, tiles_support)[0]
5909
6150
  import folium
5910
6151
  import streamlit as st
5911
6152
  import pydeck as pdk
5912
6153
  from streamlit_folium import st_folium
5913
6154
  from folium.plugins import HeatMap
5914
-
6155
+
5915
6156
  if pydeck_map:
5916
6157
  view = pdk.ViewState(
5917
6158
  latitude=location[0],
@@ -5920,7 +6161,6 @@ def plot_map(
5920
6161
  pitch=0,
5921
6162
  )
5922
6163
 
5923
- # Example Layer (can be replaced by your custom layers)
5924
6164
  layer = pdk.Layer(
5925
6165
  "ScatterplotLayer",
5926
6166
  data=[{"lat": location[0], "lon": location[1]}],
@@ -5929,20 +6169,16 @@ def plot_map(
5929
6169
  get_radius=1000,
5930
6170
  )
5931
6171
 
5932
- # Create the deck
5933
6172
  deck = pdk.Deck(
5934
6173
  layers=[layer],
5935
6174
  initial_view_state=view,
5936
6175
  map_style=pydeck_style,
5937
6176
  )
5938
-
5939
- # Render map in Streamlit
5940
6177
  st.pydeck_chart(deck)
5941
6178
 
5942
6179
  return deck # Return the pydeck map
5943
6180
 
5944
6181
  else:
5945
- # Initialize the base map (Folium)
5946
6182
  m = folium.Map(
5947
6183
  location=location,
5948
6184
  zoom_start=zoom_start,
@@ -5950,39 +6186,39 @@ def plot_map(
5950
6186
  scrollWheelZoom=scroll_wheel_zoom,
5951
6187
  **kwargs,
5952
6188
  )
5953
-
5954
- # Add markers
5955
6189
  if markers:
6190
+ if verbose:
6191
+ print(config_markers)
5956
6192
  for marker in markers:
5957
6193
  folium.Marker(
5958
6194
  location=marker.get("location"),
5959
6195
  popup=marker.get("popup"),
5960
6196
  tooltip=marker.get("tooltip"),
5961
- icon=marker.get("icon", folium.Icon()), # Default icon if none specified
6197
+ icon=marker.get(
6198
+ "icon", folium.Icon()
6199
+ ), # Default icon if none specified
5962
6200
  ).add_to(m)
5963
6201
 
5964
- # Add overlays
5965
6202
  if overlays:
6203
+ if verbose:
6204
+ print(config_overlay)
5966
6205
  for overlay in overlays:
5967
6206
  overlay.add_to(m)
5968
6207
 
5969
- # Add custom layers
5970
6208
  if custom_layers:
5971
6209
  for layer in custom_layers:
5972
6210
  layer.add_to(m)
5973
6211
 
5974
- # Add plugins
5975
6212
  if plugins:
6213
+ if verbose:
6214
+ print(config_plugin)
5976
6215
  for plugin in plugins:
5977
6216
  plugin.add_to(m)
5978
6217
 
5979
- # Fit map bounds
5980
6218
  if fit_bounds:
5981
6219
  m.fit_bounds(fit_bounds)
5982
6220
 
5983
- # Handle rendering based on output
5984
6221
  if output == "streamlit":
5985
- # Render the map in Streamlit
5986
6222
  st_data = st_folium(m, width=map_width, height=map_height)
5987
6223
  return st_data
5988
6224
  elif output == "offline":