py2ls 0.2.4.32__py3-none-any.whl → 0.2.4.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/plot.py CHANGED
@@ -22,7 +22,7 @@ from .ips import (
22
22
  run_once_within,
23
23
  get_df_format,
24
24
  df_corr,
25
- df_scaler
25
+ df_scaler,
26
26
  )
27
27
  import scipy.stats as scipy_stats
28
28
  from .stats import *
@@ -144,6 +144,7 @@ def add_text(ax=None, height_offset=0.5, fmt=".1f", **kwargs):
144
144
  **kwargs,
145
145
  )
146
146
 
147
+
147
148
  def pval2str(p):
148
149
  if p > 0.05:
149
150
  txt = ""
@@ -155,16 +156,17 @@ def pval2str(p):
155
156
  txt = "***"
156
157
  return txt
157
158
 
159
+
158
160
  def heatmap(
159
161
  data,
160
162
  ax=None,
161
163
  kind="corr", #'corr','direct','pivot'
162
- method="pearson",# for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
164
+ method="pearson", # for correlation: ‘pearson’(default), ‘kendall’, ‘spearman’
163
165
  columns="all", # pivot, default: coll numeric columns
164
- style=0,# for correlation
166
+ style=0, # for correlation
165
167
  index=None, # pivot
166
168
  values=None, # pivot
167
- fontsize=10,
169
+ fontsize=10,
168
170
  tri="u",
169
171
  mask=True,
170
172
  k=1,
@@ -174,7 +176,7 @@ def heatmap(
174
176
  annot=True,
175
177
  cmap="coolwarm",
176
178
  fmt=".2f",
177
- show_indicator = True,# only for style==1
179
+ show_indicator=True, # only for style==1
178
180
  cluster=False,
179
181
  inplace=False,
180
182
  figsize=(10, 8),
@@ -309,7 +311,7 @@ def heatmap(
309
311
  df_col_cluster,
310
312
  )
311
313
  else:
312
- if style==0:
314
+ if style == 0:
313
315
  # Create a standard heatmap
314
316
  ax = sns.heatmap(
315
317
  data4heatmap,
@@ -321,7 +323,7 @@ def heatmap(
321
323
  **kwargs, # Pass any additional arguments to sns.heatmap
322
324
  )
323
325
  return ax
324
- elif style==1:
326
+ elif style == 1:
325
327
  if isinstance(cmap, str):
326
328
  cmap = plt.get_cmap(cmap)
327
329
  norm = plt.Normalize(vmin=-1, vmax=1)
@@ -332,12 +334,17 @@ def heatmap(
332
334
  # 循环绘制气泡图和数值
333
335
  for i in range(len(r_.columns)):
334
336
  for j in range(len(r_.columns)):
335
- if (i < j) if "u" in tri.lower() else (j<i): # 对角线左上部只显示气泡
337
+ if (
338
+ (i < j) if "u" in tri.lower() else (j < i)
339
+ ): # 对角线左上部只显示气泡
336
340
  color = cmap(norm(r_.iloc[i, j])) # 根据相关系数获取颜色
337
341
  scatter = ax.scatter(
338
- i, j, s=np.abs(r_.iloc[i, j])*size_scale, color=color,
342
+ i,
343
+ j,
344
+ s=np.abs(r_.iloc[i, j]) * size_scale,
345
+ color=color,
339
346
  # alpha=1,edgecolor=edgecolor,linewidth=linewidth,
340
- **kwargs
347
+ **kwargs,
341
348
  )
342
349
  scatter_handles.append(scatter) # 保存scatter对象用于颜色条
343
350
  # add *** indicators
@@ -351,8 +358,12 @@ def heatmap(
351
358
  color="k",
352
359
  fontsize=fontsize * 1.3,
353
360
  )
354
- elif (i > j) if "u" in tri.lower() else (j>i): # 对角只显示数值
355
- color = cmap(norm(r_.iloc[i, j])) # 数值的颜色同样基于相关系数
361
+ elif (
362
+ (i > j) if "u" in tri.lower() else (j > i)
363
+ ): # 对角只显示数值
364
+ color = cmap(
365
+ norm(r_.iloc[i, j])
366
+ ) # 数值的颜色同样基于相关系数
356
367
  ax.text(
357
368
  i,
358
369
  j,
@@ -365,15 +376,16 @@ def heatmap(
365
376
  else: # 对角线部分,显示空白
366
377
  ax.scatter(i, j, s=1, color="white")
367
378
  # 设置坐标轴标签
368
- figsets(xticks=range(len(r_.columns)),
369
- xticklabels=r_.columns,
370
- xangle=90,
371
- fontsize=fontsize,
372
- yticks=range(len(r_.columns)),
373
- yticklabels=r_.columns,
374
- xlim=[-0.5,len(r_.columns)-0.5],
375
- ylim=[-0.5,len(r_.columns)-0.5]
376
- )
379
+ figsets(
380
+ xticks=range(len(r_.columns)),
381
+ xticklabels=r_.columns,
382
+ xangle=90,
383
+ fontsize=fontsize,
384
+ yticks=range(len(r_.columns)),
385
+ yticklabels=r_.columns,
386
+ xlim=[-0.5, len(r_.columns) - 0.5],
387
+ ylim=[-0.5, len(r_.columns) - 0.5],
388
+ )
377
389
  # 添加颜色条
378
390
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
379
391
  sm.set_array([]) # 仅用于显示颜色条
@@ -468,7 +480,7 @@ def heatmap(
468
480
  )
469
481
  else:
470
482
  # Create a standard heatmap
471
- if style==0:
483
+ if style == 0:
472
484
  ax = sns.heatmap(
473
485
  data4heatmap,
474
486
  ax=ax,
@@ -478,42 +490,51 @@ def heatmap(
478
490
  **kwargs, # Pass any additional arguments to sns.heatmap
479
491
  )
480
492
  # Return the Axes object for further customization if needed
481
- return ax
482
- elif style==1:
493
+ return ax
494
+ elif style == 1:
483
495
  if isinstance(cmap, str):
484
496
  cmap = plt.get_cmap(cmap)
485
497
  if vmin is None:
486
- vmin=np.min(data4heatmap)
498
+ vmin = np.min(data4heatmap)
487
499
  if vmax is None:
488
- vmax=np.max(data4heatmap)
500
+ vmax = np.max(data4heatmap)
489
501
  norm = plt.Normalize(vmin=vmin, vmax=vmax)
490
502
 
491
503
  # 初始化一个空的可绘制对象用于颜色条
492
504
  scatter_handles = []
493
- # 循环绘制气泡图和数值
494
- print(len(data4heatmap.index),len(data4heatmap.columns))
505
+ # 循环绘制气泡图和数值
506
+ print(len(data4heatmap.index), len(data4heatmap.columns))
495
507
  for i in range(len(data4heatmap.index)):
496
508
  for j in range(len(data4heatmap.columns)):
497
509
  color = cmap(norm(data4heatmap.iloc[i, j])) # 根据相关系数获取颜色
498
- scatter = ax.scatter(j,i, s=np.abs(data4heatmap.iloc[i, j])*size_scale, color=color, **kwargs)
510
+ scatter = ax.scatter(
511
+ j,
512
+ i,
513
+ s=np.abs(data4heatmap.iloc[i, j]) * size_scale,
514
+ color=color,
515
+ **kwargs,
516
+ )
499
517
  scatter_handles.append(scatter) # 保存scatter对象用于颜色条
500
518
 
501
519
  # 设置坐标轴标签
502
- figsets(xticks=range(len(data4heatmap.columns)),
503
- xticklabels=data4heatmap.columns,
504
- xangle=90,
505
- fontsize=fontsize,
506
- yticks=range(len(data4heatmap.index)),
507
- yticklabels=data4heatmap.index,
508
- xlim=[-0.5,len(data4heatmap.columns)-0.5],
509
- ylim=[-0.5,len(data4heatmap.index)-0.5]
510
- )
520
+ figsets(
521
+ xticks=range(len(data4heatmap.columns)),
522
+ xticklabels=data4heatmap.columns,
523
+ xangle=90,
524
+ fontsize=fontsize,
525
+ yticks=range(len(data4heatmap.index)),
526
+ yticklabels=data4heatmap.index,
527
+ xlim=[-0.5, len(data4heatmap.columns) - 0.5],
528
+ ylim=[-0.5, len(data4heatmap.index) - 0.5],
529
+ )
511
530
  # 添加颜色条
512
531
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
513
532
  sm.set_array([]) # 仅用于显示颜色条
514
- plt.colorbar(sm, ax=ax,
515
- # label="Correlation Coefficient"
516
- )
533
+ plt.colorbar(
534
+ sm,
535
+ ax=ax,
536
+ # label="Correlation Coefficient"
537
+ )
517
538
  return ax
518
539
 
519
540
 
@@ -854,9 +875,11 @@ def catplot(data, *args, **kwargs):
854
875
  else:
855
876
  bx_opt["EdgeColor"] = bx_opt["EdgeColor"]
856
877
  if not isinstance(bx_opt["FaceColor"], list):
857
- bx_opt["FaceColor"]=[bx_opt["FaceColor"]]
858
- if len(bxp["boxes"])!= len(bx_opt["FaceColor"]) and (len(bx_opt["FaceColor"])==1):
859
- bx_opt["FaceColor"]=bx_opt["FaceColor"] *len(bxp["boxes"])
878
+ bx_opt["FaceColor"] = [bx_opt["FaceColor"]]
879
+ if len(bxp["boxes"]) != len(bx_opt["FaceColor"]) and (
880
+ len(bx_opt["FaceColor"]) == 1
881
+ ):
882
+ bx_opt["FaceColor"] = bx_opt["FaceColor"] * len(bxp["boxes"])
860
883
  for patch, color in zip(bxp["boxes"], bx_opt["FaceColor"]):
861
884
  patch.set_facecolor(to_rgba(color, bx_opt["FaceAlpha"]))
862
885
 
@@ -2439,7 +2462,15 @@ def split_legend(ax, n=2, loc=None, title=None, bbox=None, ncol=1, **kwargs):
2439
2462
  return legends
2440
2463
 
2441
2464
 
2442
- def get_colors(n: int = 1,cmap: str = "auto",by: str = "start",alpha: float = 1.0,output: str = "hue",*args,**kwargs):
2465
+ def get_colors(
2466
+ n: int = 1,
2467
+ cmap: str = "auto",
2468
+ by: str = "start",
2469
+ alpha: float = 1.0,
2470
+ output: str = "hue",
2471
+ *args,
2472
+ **kwargs,
2473
+ ):
2443
2474
  return get_color(n=n, cmap=cmap, alpha=alpha, output=output, *args, **kwargs)
2444
2475
 
2445
2476
 
@@ -3317,9 +3348,9 @@ def plotxy(
3317
3348
  for k in kind_:
3318
3349
  # preprocess data
3319
3350
  try:
3320
- data=df_preprocessing_(data, kind=k)
3321
- if 'variable' in data.columns and 'value' in data.columns:
3322
- x,y='variable','value'
3351
+ data = df_preprocessing_(data, kind=k)
3352
+ if "variable" in data.columns and "value" in data.columns:
3353
+ x, y = "variable", "value"
3323
3354
  except Exception as e:
3324
3355
  print(e)
3325
3356
  zorder += 1
@@ -3342,28 +3373,31 @@ def plotxy(
3342
3373
  # (1) return FcetGrid
3343
3374
  if k == "jointplot":
3344
3375
  kws_joint = kwargs.pop("kws_joint", kwargs)
3345
- kws_joint = {
3346
- k: v for k, v in kws_joint.items() if not k.startswith("kws_")
3347
- }
3376
+ kws_joint = {k: v for k, v in kws_joint.items() if not k.startswith("kws_")}
3348
3377
  hue = kwargs.get("hue", None)
3349
- if isinstance(kws_joint, dict) or hue is None: # Check if kws_ellipse is a dictionary
3378
+ if (
3379
+ isinstance(kws_joint, dict) or hue is None
3380
+ ): # Check if kws_ellipse is a dictionary
3350
3381
  kws_joint.pop("hue", None) # Safely remove 'hue' if it exists
3351
3382
 
3352
3383
  palette = kwargs.get("palette", None)
3353
3384
  if palette is None:
3354
3385
  palette = kws_joint.pop(
3355
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3386
+ "palette",
3387
+ get_color(data[hue].nunique()) if hue is not None else None,
3356
3388
  )
3357
3389
  else:
3358
3390
  kws_joint.pop("palette", palette)
3359
- stats=kwargs.pop("stats",None)
3391
+ stats = kwargs.pop("stats", None)
3360
3392
  if stats:
3361
- stats=kws_joint.pop("stats",True)
3393
+ stats = kws_joint.pop("stats", True)
3362
3394
  if stats:
3363
3395
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3364
- for key in ["palette", "alpha", "hue","stats"]:
3365
- kws_joint.pop(key, None)
3366
- g = sns.jointplot(data=data, x=x, y=y,hue=hue,palette=palette, **kws_joint)
3396
+ for key in ["palette", "alpha", "hue", "stats"]:
3397
+ kws_joint.pop(key, None)
3398
+ g = sns.jointplot(
3399
+ data=data, x=x, y=y, hue=hue, palette=palette, **kws_joint
3400
+ )
3367
3401
  if stats:
3368
3402
  g.ax_joint.annotate(
3369
3403
  f"pearsonr = {r:.2f} p = {p_value:.3f}",
@@ -3391,41 +3425,64 @@ def plotxy(
3391
3425
  # If no hue, col, or row, calculate stats for the entire dataset
3392
3426
  if all([hue is None, col is None, row is None]):
3393
3427
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3394
- stats_per_facet[(None, None)] = (r, p_value) # Store stats for the entire dataset
3428
+ stats_per_facet[(None, None)] = (
3429
+ r,
3430
+ p_value,
3431
+ ) # Store stats for the entire dataset
3395
3432
 
3396
- else:
3433
+ else:
3397
3434
  if hue is None and (col is not None or row is not None):
3398
3435
  for ax in g.axes.flat:
3399
3436
  facet_name = ax.get_title()
3400
- if '=' in facet_name:
3437
+ if "=" in facet_name:
3401
3438
  # Assume facet_name is like 'Column = Value'
3402
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3403
- facet_value_str = facet_name.split('=')[1].strip() # Facet value after '='
3404
-
3439
+ facet_column_name = facet_name.split("=")[
3440
+ 0
3441
+ ].strip() # Column name before '='
3442
+ facet_value_str = facet_name.split("=")[
3443
+ 1
3444
+ ].strip() # Facet value after '='
3445
+
3405
3446
  # Try converting facet_value to match the data type of the DataFrame column
3406
3447
  facet_column_dtype = data[facet_column_name].dtype
3407
- if facet_column_dtype == 'int' or facet_column_dtype == 'float':
3408
- facet_value = pd.to_numeric(facet_value_str, errors='coerce') # Convert to numeric
3448
+ if (
3449
+ facet_column_dtype == "int"
3450
+ or facet_column_dtype == "float"
3451
+ ):
3452
+ facet_value = pd.to_numeric(
3453
+ facet_value_str, errors="coerce"
3454
+ ) # Convert to numeric
3409
3455
  else:
3410
3456
  facet_value = facet_value_str # Treat as a string if not numeric
3411
3457
  else:
3412
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3413
- facet_value=facet_name.split('=')[1].strip()
3458
+ facet_column_name = facet_name.split("=")[
3459
+ 0
3460
+ ].strip() # Column name before '='
3461
+ facet_value = facet_name.split("=")[1].strip()
3414
3462
  facet_data = data[data[facet_column_name] == facet_value]
3415
3463
  if not facet_data.empty:
3416
- r, p_value = scipy_stats.pearsonr(facet_data[x], facet_data[y])
3464
+ r, p_value = scipy_stats.pearsonr(
3465
+ facet_data[x], facet_data[y]
3466
+ )
3417
3467
  stats_per_facet[facet_name] = (r, p_value)
3418
3468
  else:
3419
- stats_per_facet[facet_name] = (None, None) # Handle empty facets
3469
+ stats_per_facet[facet_name] = (
3470
+ None,
3471
+ None,
3472
+ ) # Handle empty facets
3420
3473
 
3421
3474
  # Annotate the stats on the plot
3422
3475
  for ax in g.axes.flat:
3423
3476
  if stats:
3424
3477
  # Adjust the position for each facet to avoid overlap
3425
- idx=1
3426
- shift_factor = 0.02 * idx # Adjust this factor as needed to prevent overlap
3427
- y_position = 0.98 - shift_factor # Dynamic vertical shift for each facet
3428
-
3478
+ idx = 1
3479
+ shift_factor = (
3480
+ 0.02 * idx
3481
+ ) # Adjust this factor as needed to prevent overlap
3482
+ y_position = (
3483
+ 0.98 - shift_factor
3484
+ ) # Dynamic vertical shift for each facet
3485
+
3429
3486
  if all([hue is None, col is None, row is None]):
3430
3487
  # Use stats for the entire dataset if no hue, col, or row
3431
3488
  r, p_value = stats_per_facet.get((None, None), (None, None))
@@ -3448,15 +3505,21 @@ def plotxy(
3448
3505
  ha="center",
3449
3506
  )
3450
3507
  elif hue is not None:
3451
- if (col is None and row is None):
3452
- hue_categories = sorted(flatten(data[hue],verbose=0))
3453
- idx=1
3508
+ if col is None and row is None:
3509
+ hue_categories = sorted(flatten(data[hue], verbose=0))
3510
+ idx = 1
3454
3511
  for category in hue_categories:
3455
3512
  subset_data = data[data[hue] == category]
3456
- r, p_value = scipy_stats.pearsonr(subset_data[x], subset_data[y])
3513
+ r, p_value = scipy_stats.pearsonr(
3514
+ subset_data[x], subset_data[y]
3515
+ )
3457
3516
  stats_per_hue[category] = (r, p_value)
3458
- shift_factor = 0.05 * idx # Adjust this factor as needed to prevent overlap
3459
- y_position = 0.98 - shift_factor # Dynamic vertical shift for each facet
3517
+ shift_factor = (
3518
+ 0.05 * idx
3519
+ ) # Adjust this factor as needed to prevent overlap
3520
+ y_position = (
3521
+ 0.98 - shift_factor
3522
+ ) # Dynamic vertical shift for each facet
3460
3523
  ax.annotate(
3461
3524
  f"{category}: pearsonr = {r:.2f} p = {p_value:.3f}",
3462
3525
  xy=(0.6, y_position),
@@ -3465,31 +3528,49 @@ def plotxy(
3465
3528
  color="black",
3466
3529
  ha="center",
3467
3530
  )
3468
- idx+=1
3531
+ idx += 1
3469
3532
  else:
3470
3533
  for ax in g.axes.flat:
3471
3534
  facet_name = ax.get_title()
3472
- if '=' in facet_name:
3535
+ if "=" in facet_name:
3473
3536
  # Assume facet_name is like 'Column = Value'
3474
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3475
- facet_value_str = facet_name.split('=')[1].strip() # Facet value after '='
3476
-
3537
+ facet_column_name = facet_name.split("=")[
3538
+ 0
3539
+ ].strip() # Column name before '='
3540
+ facet_value_str = facet_name.split("=")[
3541
+ 1
3542
+ ].strip() # Facet value after '='
3543
+
3477
3544
  # Try converting facet_value to match the data type of the DataFrame column
3478
3545
  facet_column_dtype = data[facet_column_name].dtype
3479
- if facet_column_dtype == 'int' or facet_column_dtype == 'float':
3480
- facet_value = pd.to_numeric(facet_value_str, errors='coerce') # Convert to numeric
3546
+ if (
3547
+ facet_column_dtype == "int"
3548
+ or facet_column_dtype == "float"
3549
+ ):
3550
+ facet_value = pd.to_numeric(
3551
+ facet_value_str, errors="coerce"
3552
+ ) # Convert to numeric
3481
3553
  else:
3482
3554
  facet_value = facet_value_str # Treat as a string if not numeric
3483
3555
  else:
3484
- facet_column_name = facet_name.split('=')[0].strip() # Column name before '='
3485
- facet_value=facet_name.split('=')[1].strip()
3486
- facet_data = data[data[facet_column_name] == facet_value]
3556
+ facet_column_name = facet_name.split("=")[
3557
+ 0
3558
+ ].strip() # Column name before '='
3559
+ facet_value = facet_name.split("=")[1].strip()
3560
+ facet_data = data[
3561
+ data[facet_column_name] == facet_value
3562
+ ]
3487
3563
  if not facet_data.empty:
3488
- r, p_value = scipy_stats.pearsonr(facet_data[x], facet_data[y])
3564
+ r, p_value = scipy_stats.pearsonr(
3565
+ facet_data[x], facet_data[y]
3566
+ )
3489
3567
  stats_per_facet[facet_name] = (r, p_value)
3490
3568
  else:
3491
- stats_per_facet[facet_name] = (None, None) # Handle empty facets
3492
-
3569
+ stats_per_facet[facet_name] = (
3570
+ None,
3571
+ None,
3572
+ ) # Handle empty facets
3573
+
3493
3574
  ax.annotate(
3494
3575
  f"pearsonr = {r:.2f} p = {p_value:.3f}",
3495
3576
  xy=(0.6, y_position),
@@ -3497,7 +3578,7 @@ def plotxy(
3497
3578
  fontsize=12,
3498
3579
  color="black",
3499
3580
  ha="center",
3500
- )
3581
+ )
3501
3582
  elif hue is None and (col is not None or row is not None):
3502
3583
  # Annotate stats for each facet
3503
3584
  facet_name = ax.get_title()
@@ -3518,7 +3599,8 @@ def plotxy(
3518
3599
  xycoords="axes fraction",
3519
3600
  fontsize=12,
3520
3601
  color="black",
3521
- ha="center")
3602
+ ha="center",
3603
+ )
3522
3604
 
3523
3605
  elif k == "catplot_sns":
3524
3606
  kws_cat = kwargs.pop("kws_cat", kwargs)
@@ -3526,7 +3608,7 @@ def plotxy(
3526
3608
  elif k == "displot":
3527
3609
  kws_dis = kwargs.pop("kws_dis", kwargs)
3528
3610
  # displot creates a new figure and returns a FacetGrid
3529
- g = sns.displot(data=data, x=x,y=y, **kws_dis)
3611
+ g = sns.displot(data=data, x=x, y=y, **kws_dis)
3530
3612
 
3531
3613
  # (2) return axis
3532
3614
  if ax is None:
@@ -3537,30 +3619,35 @@ def plotxy(
3537
3619
  elif k == "stdshade":
3538
3620
  kws_stdshade = kwargs.pop("kws_stdshade", kwargs)
3539
3621
  ax = stdshade(ax=ax, **kwargs)
3540
- elif k=="ellipse":
3622
+ elif k == "ellipse":
3541
3623
  kws_ellipse = kwargs.pop("kws_ellipse", kwargs)
3542
3624
  kws_ellipse = {
3543
3625
  k: v for k, v in kws_ellipse.items() if not k.startswith("kws_")
3544
3626
  }
3545
3627
  hue = kwargs.get("hue", None)
3546
- if isinstance(kws_ellipse, dict) or hue is None: # Check if kws_ellipse is a dictionary
3628
+ if (
3629
+ isinstance(kws_ellipse, dict) or hue is None
3630
+ ): # Check if kws_ellipse is a dictionary
3547
3631
  kws_ellipse.pop("hue", None) # Safely remove 'hue' if it exists
3548
3632
 
3549
3633
  palette = kwargs.get("palette", None)
3550
3634
  if palette is None:
3551
3635
  palette = kws_ellipse.pop(
3552
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3636
+ "palette",
3637
+ get_color(data[hue].nunique()) if hue is not None else None,
3553
3638
  )
3554
3639
  alpha = kws_ellipse.pop("alpha", 0.1)
3555
- hue_order = kwargs.get("hue_order",None)
3640
+ hue_order = kwargs.get("hue_order", None)
3556
3641
  if hue_order is None:
3557
- hue_order = kws_ellipse.get("hue_order",None)
3558
- if hue_order:
3559
- data["hue"] = pd.Categorical(data[hue], categories=hue_order, ordered=True)
3642
+ hue_order = kws_ellipse.get("hue_order", None)
3643
+ if hue_order:
3644
+ data["hue"] = pd.Categorical(
3645
+ data[hue], categories=hue_order, ordered=True
3646
+ )
3560
3647
  data = data.sort_values(by="hue")
3561
- for key in ["palette", "alpha", "hue","hue_order"]:
3562
- kws_ellipse.pop(key, None)
3563
- ax=ellipse(
3648
+ for key in ["palette", "alpha", "hue", "hue_order"]:
3649
+ kws_ellipse.pop(key, None)
3650
+ ax = ellipse(
3564
3651
  ax=ax,
3565
3652
  data=data,
3566
3653
  x=x,
@@ -3579,13 +3666,14 @@ def plotxy(
3579
3666
  hue = kwargs.get("hue", None)
3580
3667
  if isinstance(kws_scatter, dict): # Check if kws_scatter is a dictionary
3581
3668
  kws_scatter.pop("hue", None) # Safely remove 'hue' if it exists
3582
- palette=kws_scatter.get("palette",None)
3669
+ palette = kws_scatter.get("palette", None)
3583
3670
  if palette is None:
3584
3671
  palette = kws_scatter.pop(
3585
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3672
+ "palette",
3673
+ get_color(data[hue].nunique()) if hue is not None else None,
3586
3674
  )
3587
3675
  s = kws_scatter.pop("s", 10)
3588
- alpha = kws_scatter.pop("alpha", 0.7)
3676
+ alpha = kws_scatter.pop("alpha", 0.7)
3589
3677
  for key in ["s", "palette", "alpha", "hue"]:
3590
3678
  kws_scatter.pop(key, None)
3591
3679
 
@@ -3605,28 +3693,39 @@ def plotxy(
3605
3693
  kws_hist = kwargs.pop("kws_hist", kwargs)
3606
3694
  kws_hist = {k: v for k, v in kws_hist.items() if not k.startswith("kws_")}
3607
3695
  ax = sns.histplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_hist)
3608
- elif k == "kdeplot":
3696
+ elif k == "kdeplot":
3609
3697
  kws_kde = kwargs.pop("kws_kde", kwargs)
3610
- kws_kde = {
3611
- k: v for k, v in kws_kde.items() if not k.startswith("kws_")
3612
- }
3698
+ kws_kde = {k: v for k, v in kws_kde.items() if not k.startswith("kws_")}
3613
3699
  hue = kwargs.get("hue", None)
3614
- if isinstance(kws_kde, dict) or hue is None: # Check if kws_kde is a dictionary
3700
+ if (
3701
+ isinstance(kws_kde, dict) or hue is None
3702
+ ): # Check if kws_kde is a dictionary
3615
3703
  kws_kde.pop("hue", None) # Safely remove 'hue' if it exists
3616
3704
 
3617
3705
  palette = kwargs.get("palette", None)
3618
3706
  if palette is None:
3619
3707
  palette = kws_kde.pop(
3620
- "palette", get_color(data[hue].nunique()) if hue is not None else None
3708
+ "palette",
3709
+ get_color(data[hue].nunique()) if hue is not None else None,
3621
3710
  )
3622
3711
  alpha = kws_kde.pop("alpha", 0.05)
3623
3712
  for key in ["palette", "alpha", "hue"]:
3624
- kws_kde.pop(key, None)
3625
- ax = sns.kdeplot(data=data, x=x, y=y, palette=palette,hue=hue, ax=ax,alpha=alpha, zorder=zorder, **kws_kde)
3713
+ kws_kde.pop(key, None)
3714
+ ax = sns.kdeplot(
3715
+ data=data,
3716
+ x=x,
3717
+ y=y,
3718
+ palette=palette,
3719
+ hue=hue,
3720
+ ax=ax,
3721
+ alpha=alpha,
3722
+ zorder=zorder,
3723
+ **kws_kde,
3724
+ )
3626
3725
  elif k == "ecdfplot":
3627
3726
  kws_ecdf = kwargs.pop("kws_ecdf", kwargs)
3628
3727
  kws_ecdf = {k: v for k, v in kws_ecdf.items() if not k.startswith("kws_")}
3629
- ax = sns.ecdfplot(data=data, x=x,y=y, ax=ax, zorder=zorder, **kws_ecdf)
3728
+ ax = sns.ecdfplot(data=data, x=x, y=y, ax=ax, zorder=zorder, **kws_ecdf)
3630
3729
  elif k == "rugplot":
3631
3730
  kws_rug = kwargs.pop("kws_rug", kwargs)
3632
3731
  kws_rug = {k: v for k, v in kws_rug.items() if not k.startswith("kws_")}
@@ -3635,7 +3734,9 @@ def plotxy(
3635
3734
  kws_strip = kwargs.pop("kws_strip", kwargs)
3636
3735
  kws_strip = {k: v for k, v in kws_strip.items() if not k.startswith("kws_")}
3637
3736
  dodge = kws_strip.pop("dodge", True)
3638
- ax = sns.stripplot(data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip)
3737
+ ax = sns.stripplot(
3738
+ data=data, x=x, y=y, ax=ax, zorder=zorder, dodge=dodge, **kws_strip
3739
+ )
3639
3740
  elif k == "swarmplot":
3640
3741
  kws_swarm = kwargs.pop("kws_swarm", kwargs)
3641
3742
  kws_swarm = {k: v for k, v in kws_swarm.items() if not k.startswith("kws_")}
@@ -3672,21 +3773,21 @@ def plotxy(
3672
3773
  kws_reg = kwargs.pop("kws_reg", kwargs)
3673
3774
  kws_reg = {k: v for k, v in kws_reg.items() if not k.startswith("kws_")}
3674
3775
  stats = kwargs.pop("stats", True) # Flag to calculate stats
3675
-
3776
+
3676
3777
  # Compute Pearson correlation if stats is True
3677
3778
  if stats:
3678
3779
  r, p_value = scipy_stats.pearsonr(data[x], data[y])
3679
- ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3680
-
3780
+ ax = sns.regplot(data=data, x=x, y=y, ax=ax, **kws_reg)
3781
+
3681
3782
  # Annotate the Pearson correlation and p-value
3682
3783
  ax.annotate(
3683
- f"pearsonr = {r:.2f} p = {p_value:.3f}",
3684
- xy=(0.6, 0.98),
3685
- xycoords="axes fraction",
3686
- fontsize=12,
3687
- color="black",
3688
- ha="center",
3689
- )
3784
+ f"pearsonr = {r:.2f} p = {p_value:.3f}",
3785
+ xy=(0.6, 0.98),
3786
+ xycoords="axes fraction",
3787
+ fontsize=12,
3788
+ color="black",
3789
+ ha="center",
3790
+ )
3690
3791
  elif k == "residplot":
3691
3792
  kws_resid = kwargs.pop("kws_resid", kwargs)
3692
3793
  kws_resid = {k: v for k, v in kws_resid.items() if not k.startswith("kws_")}
@@ -3711,6 +3812,7 @@ def plotxy(
3711
3812
  return g, ax
3712
3813
  return ax
3713
3814
 
3815
+
3714
3816
  def df_preprocessing_(data, kind, verbose=False):
3715
3817
  """
3716
3818
  Automatically formats data for various seaborn plot types.
@@ -3733,10 +3835,10 @@ def df_preprocessing_(data, kind, verbose=False):
3733
3835
  "heatmap",
3734
3836
  "pairplot",
3735
3837
  "jointplot", # Typically requires wide format for axis variables
3736
- "facetgrid", # Used for creating small multiples (can work with wide format)
3737
- "barplot", # Can be used with wide format
3738
- "pointplot", # Works well with wide format
3739
- "pivot_table", # Works with wide format (aggregated data)
3838
+ "facetgrid", # Used for creating small multiples (can work with wide format)
3839
+ "barplot", # Can be used with wide format
3840
+ "pointplot", # Works well with wide format
3841
+ "pivot_table", # Works with wide format (aggregated data)
3740
3842
  "boxplot",
3741
3843
  "violinplot",
3742
3844
  "stripplot",
@@ -3745,23 +3847,23 @@ def df_preprocessing_(data, kind, verbose=False):
3745
3847
  "lineplot",
3746
3848
  "scatterplot",
3747
3849
  "relplot",
3748
- "barplot", # Can also work with long format (aggregated data in long form)
3749
- "boxenplot", # Similar to boxplot, works with long format
3750
- "countplot", # Works best with long format (categorical data)
3751
- "heatmap", # Can work with long format after reshaping
3752
- "lineplot", # Can work with long format (time series, continuous)
3753
- "histplot", # Can be used with both wide and long formats
3754
- "kdeplot", # Works with both wide and long formats
3755
- "ecdfplot", # Works with both formats
3756
- "scatterplot", # Can work with both formats depending on data structure
3757
- "lineplot", # Can work with both wide and long formats
3758
- "area plot", # Can work with both formats, useful for stacked areas
3850
+ "barplot", # Can also work with long format (aggregated data in long form)
3851
+ "boxenplot", # Similar to boxplot, works with long format
3852
+ "countplot", # Works best with long format (categorical data)
3853
+ "heatmap", # Can work with long format after reshaping
3854
+ "lineplot", # Can work with long format (time series, continuous)
3855
+ "histplot", # Can be used with both wide and long formats
3856
+ "kdeplot", # Works with both wide and long formats
3857
+ "ecdfplot", # Works with both formats
3858
+ "scatterplot", # Can work with both formats depending on data structure
3859
+ "lineplot", # Can work with both wide and long formats
3860
+ "area plot", # Can work with both formats, useful for stacked areas
3759
3861
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3760
- "ellipse",# ellipse plot, default confidence=0.95
3862
+ "ellipse", # ellipse plot, default confidence=0.95
3761
3863
  ],
3762
3864
  )[0]
3763
3865
 
3764
- wide_kinds = [
3866
+ wide_kinds = [
3765
3867
  "pairplot",
3766
3868
  ]
3767
3869
 
@@ -3773,26 +3875,26 @@ def df_preprocessing_(data, kind, verbose=False):
3773
3875
  # Flexible kinds: distribution plots can use either format
3774
3876
  flexible_kinds = [
3775
3877
  "jointplot", # Typically requires wide format for axis variables
3776
- "lineplot", # Can work with long format (time series, continuous)
3878
+ "lineplot", # Can work with long format (time series, continuous)
3777
3879
  "lineplot",
3778
3880
  "scatterplot",
3779
- "barplot", # Can also work with long format (aggregated data in long form)
3780
- "boxenplot", # Similar to boxplot, works with long format
3781
- "countplot", # Works best with long format (categorical data)
3881
+ "barplot", # Can also work with long format (aggregated data in long form)
3882
+ "boxenplot", # Similar to boxplot, works with long format
3883
+ "countplot", # Works best with long format (categorical data)
3782
3884
  "regplot",
3783
3885
  "violinplot",
3784
3886
  "stripplot",
3785
3887
  "swarmplot",
3786
3888
  "boxplot",
3787
- "histplot", # Can be used with both wide and long formats
3788
- "kdeplot", # Works with both wide and long formats
3789
- "ecdfplot", # Works with both formats
3790
- "scatterplot", # Can work with both formats depending on data structure
3791
- "lineplot", # Can work with both wide and long formats
3792
- "area plot", # Can work with both formats, useful for stacked areas
3889
+ "histplot", # Can be used with both wide and long formats
3890
+ "kdeplot", # Works with both wide and long formats
3891
+ "ecdfplot", # Works with both formats
3892
+ "scatterplot", # Can work with both formats depending on data structure
3893
+ "lineplot", # Can work with both wide and long formats
3894
+ "area plot", # Can work with both formats, useful for stacked areas
3793
3895
  "violinplot", # Can work with both formats depending on categorical vs continuous data
3794
3896
  "relplot",
3795
- "pointplot", # Works well with wide format
3897
+ "pointplot", # Works well with wide format
3796
3898
  "ellipse",
3797
3899
  ]
3798
3900
 
@@ -4296,17 +4398,17 @@ def venn(
4296
4398
  colors = ["r", "b"] if len(lists) == 2 else ["r", "g", "b"]
4297
4399
  if labels is None:
4298
4400
  if len(lists) == 2:
4299
- labels = ["set1", "set2"]
4300
- elif len(lists) == 3:
4401
+ labels = ["set1", "set2"]
4402
+ elif len(lists) == 3:
4301
4403
  labels = ["set1", "set2", "set3"]
4302
- elif len(lists) == 4:
4303
- labels = ["set1", "set2", "set3","set4"]
4304
- elif len(lists) == 5:
4305
- labels = ["set1", "set2", "set3","set4","set55"]
4306
- elif len(lists) == 6:
4307
- labels = ["set1", "set2", "set3","set4","set5","set6"]
4308
- elif len(lists) == 7:
4309
- labels = ["set1", "set2", "set3","set4","set5","set6","set7"]
4404
+ elif len(lists) == 4:
4405
+ labels = ["set1", "set2", "set3", "set4"]
4406
+ elif len(lists) == 5:
4407
+ labels = ["set1", "set2", "set3", "set4", "set55"]
4408
+ elif len(lists) == 6:
4409
+ labels = ["set1", "set2", "set3", "set4", "set5", "set6"]
4410
+ elif len(lists) == 7:
4411
+ labels = ["set1", "set2", "set3", "set4", "set5", "set6", "set7"]
4310
4412
  if edgecolor is None:
4311
4413
  edgecolor = colors
4312
4414
  colors = [desaturate_color(color, saturation) for color in colors]
@@ -4320,10 +4422,11 @@ def venn(
4320
4422
  if show_percentages
4321
4423
  else f"{subset_count}"
4322
4424
  )
4425
+
4323
4426
  if fmt is not None:
4324
4427
  if not fmt.startswith("{"):
4325
- fmt="{:" + fmt + "}"
4326
- if len(lists) == 2:
4428
+ fmt = "{:" + fmt + "}"
4429
+ if len(lists) == 2:
4327
4430
 
4328
4431
  from matplotlib_venn import venn2, venn2_circles
4329
4432
 
@@ -4345,7 +4448,7 @@ def venn(
4345
4448
  # v.get_label_by_id('10').set_text(len(set1 - set2))
4346
4449
  # v.get_label_by_id('01').set_text(len(set2 - set1))
4347
4450
  # v.get_label_by_id('11').set_text(len(set1 & set2))
4348
-
4451
+
4349
4452
  v.get_label_by_id("10").set_text(
4350
4453
  get_count_and_percentage(universe, len(set1 - set2))
4351
4454
  )
@@ -4447,7 +4550,7 @@ def venn(
4447
4550
  if "none" in edgecolor or 0 in linewidth:
4448
4551
  patch.set_edgecolor("none")
4449
4552
  return ax
4450
- elif len(lists) == 3:
4553
+ elif len(lists) == 3:
4451
4554
 
4452
4555
  from matplotlib_venn import venn3, venn3_circles
4453
4556
 
@@ -4463,32 +4566,46 @@ def venn(
4463
4566
  # Draw the venn diagram
4464
4567
  v = venn3(subsets=lists, set_labels=labels, ax=ax, **kwargs)
4465
4568
  v.get_patch_by_id("100").set_color(colors[0])
4466
- v.get_label_by_id("100").set_text(get_count_and_percentage(universe, len(set1 - set2 - set3)))
4569
+ v.get_label_by_id("100").set_text(
4570
+ get_count_and_percentage(universe, len(set1 - set2 - set3))
4571
+ )
4467
4572
  v.get_patch_by_id("010").set_color(colors[1])
4468
- v.get_label_by_id("010").set_text(get_count_and_percentage(universe, len(set2 - set1 - set3)))
4573
+ v.get_label_by_id("010").set_text(
4574
+ get_count_and_percentage(universe, len(set2 - set1 - set3))
4575
+ )
4469
4576
  try:
4470
4577
  v.get_patch_by_id("001").set_color(colors[2])
4471
- v.get_label_by_id("001").set_text(get_count_and_percentage(universe, len(set3 - set1 - set2)))
4578
+ v.get_label_by_id("001").set_text(
4579
+ get_count_and_percentage(universe, len(set3 - set1 - set2))
4580
+ )
4472
4581
  except Exception as e:
4473
4582
  print(e)
4474
4583
  try:
4475
4584
  v.get_patch_by_id("110").set_color(colorAB)
4476
- v.get_label_by_id("110").set_text(get_count_and_percentage(universe, len(set1 & set2 - set3)))
4585
+ v.get_label_by_id("110").set_text(
4586
+ get_count_and_percentage(universe, len(set1 & set2 - set3))
4587
+ )
4477
4588
  except Exception as e:
4478
4589
  print(e)
4479
4590
  try:
4480
4591
  v.get_patch_by_id("101").set_color(colorAC)
4481
- v.get_label_by_id("101").set_text(get_count_and_percentage(universe, len(set1 & set3 - set2)))
4592
+ v.get_label_by_id("101").set_text(
4593
+ get_count_and_percentage(universe, len(set1 & set3 - set2))
4594
+ )
4482
4595
  except Exception as e:
4483
4596
  print(e)
4484
4597
  try:
4485
4598
  v.get_patch_by_id("011").set_color(colorBC)
4486
- v.get_label_by_id("011").set_text(get_count_and_percentage(universe, len(set2 & set3 - set1)))
4599
+ v.get_label_by_id("011").set_text(
4600
+ get_count_and_percentage(universe, len(set2 & set3 - set1))
4601
+ )
4487
4602
  except Exception as e:
4488
4603
  print(e)
4489
4604
  try:
4490
4605
  v.get_patch_by_id("111").set_color(colorABC)
4491
- v.get_label_by_id("111").set_text(get_count_and_percentage(universe, len(set1 & set2 & set3)))
4606
+ v.get_label_by_id("111").set_text(
4607
+ get_count_and_percentage(universe, len(set1 & set2 & set3))
4608
+ )
4492
4609
  except Exception as e:
4493
4610
  print(e)
4494
4611
 
@@ -4603,26 +4720,24 @@ def venn(
4603
4720
  patch.set_edgecolor("none")
4604
4721
  return ax
4605
4722
 
4606
-
4607
4723
  dict_data = {}
4608
4724
  for i_list, list_ in enumerate(lists):
4609
- dict_data[labels[i_list]]={*list_}
4725
+ dict_data[labels[i_list]] = {*list_}
4610
4726
 
4611
- if 3<len(lists)<6:
4727
+ if 3 < len(lists) < 6:
4612
4728
  from venn import venn as vn
4613
4729
 
4614
- legend_loc=kwargs.pop("legend_loc", "upper right")
4615
- ax=vn(dict_data,ax=ax,legend_loc=legend_loc,**kwargs)
4730
+ legend_loc = kwargs.pop("legend_loc", "upper right")
4731
+ ax = vn(dict_data, ax=ax, legend_loc=legend_loc, **kwargs)
4616
4732
 
4617
4733
  return ax
4618
4734
  else:
4619
4735
  from venn import pseudovenn
4620
- cmap=kwargs.pop("cmap","plasma")
4621
- ax=pseudovenn(dict_data, cmap=cmap,ax=ax,**kwargs)
4622
-
4623
- return ax
4624
4736
 
4737
+ cmap = kwargs.pop("cmap", "plasma")
4738
+ ax = pseudovenn(dict_data, cmap=cmap, ax=ax, **kwargs)
4625
4739
 
4740
+ return ax
4626
4741
 
4627
4742
 
4628
4743
  #! subplots, support automatic extend new axis
@@ -4670,7 +4785,7 @@ def subplot(
4670
4785
  col_first_axes = [None] * cols # Track the first axis in each column (for sharex)
4671
4786
 
4672
4787
  def expand_ax():
4673
- nonlocal rows, grid_spec,cols,row_first_axes,fig,figsize,figsize_recommend
4788
+ nonlocal rows, grid_spec, cols, row_first_axes, fig, figsize, figsize_recommend
4674
4789
  # fig_height = fig.get_figheight()
4675
4790
  # subplot_height = fig_height / rows
4676
4791
  rows += 1 # Expands by adding a row
@@ -4678,10 +4793,11 @@ def subplot(
4678
4793
  fig.set_size_inches(figsize)
4679
4794
  grid_spec = GridSpec(rows, cols, figure=fig)
4680
4795
  row_first_axes.append(None)
4681
- figsize_recommend=f"Warning: 建议设置 subplot({rows}, {cols})"
4796
+ figsize_recommend = f"Warning: 建议设置 subplot({rows}, {cols})"
4682
4797
  print(figsize_recommend)
4798
+
4683
4799
  def nexttile(rowspan=1, colspan=1, **kwargs):
4684
- nonlocal rows, cols, occupied, grid_spec,fig,figsize_recommend
4800
+ nonlocal rows, cols, occupied, grid_spec, fig, figsize_recommend
4685
4801
  for row in range(rows):
4686
4802
  for col in range(cols):
4687
4803
  if all(
@@ -4693,7 +4809,7 @@ def subplot(
4693
4809
  else:
4694
4810
  continue
4695
4811
  break
4696
-
4812
+
4697
4813
  else:
4698
4814
  expand_ax()
4699
4815
  return nexttile(rowspan=rowspan, colspan=colspan, **kwargs)
@@ -4715,12 +4831,13 @@ def subplot(
4715
4831
  row_first_axes[row] = ax
4716
4832
  if col_first_axes[col] is None:
4717
4833
  col_first_axes[col] = ax
4718
-
4834
+
4719
4835
  for r in range(row, row + rowspan):
4720
4836
  for c in range(col, col + colspan):
4721
4837
  occupied.add((r, c))
4722
-
4838
+
4723
4839
  return ax
4840
+
4724
4841
  return nexttile
4725
4842
 
4726
4843
 
@@ -4743,7 +4860,7 @@ def radar(
4743
4860
  bg_color="0.8",
4744
4861
  bg_alpha=None,
4745
4862
  grid_interval_ratio=0.2,
4746
- show_value=False,# show text for each value
4863
+ show_value=False, # show text for each value
4747
4864
  cmap=None,
4748
4865
  legend_loc="upper right",
4749
4866
  legend_fontsize=10,
@@ -4777,9 +4894,9 @@ def radar(
4777
4894
  index=["HP", "MP", "ATK", "DEF", "SP.ATK", "SP.DEF", "SPD"],
4778
4895
  columns=["Hero", "Warrior", "Wizard"],
4779
4896
  )
4780
- usage 1:
4897
+ usage 1:
4781
4898
  radar(data=df)
4782
- usage 2:
4899
+ usage 2:
4783
4900
  radar(data=df["Wizard"])
4784
4901
  usage 3:
4785
4902
  radar(data=df, columns="Wizard")
@@ -4819,8 +4936,8 @@ def radar(
4819
4936
  - sp (int): Padding for the ticks from the plot area.
4820
4937
  - **kwargs: Additional arguments for customization.
4821
4938
  """
4822
- if run_once_within(20,reverse=True) and verbose:
4823
- usage_="""usage:
4939
+ if run_once_within(20, reverse=True) and verbose:
4940
+ usage_ = """usage:
4824
4941
  radar(
4825
4942
  data: pd.DataFrame, #The data to plot. Each column corresponds to a variable, and each row represents a data point.
4826
4943
  ylim=(0, 100),# ylim (tuple): The limits of the radial axis (y-axis). Default is (0, 100).
@@ -4864,20 +4981,20 @@ def radar(
4864
4981
  kws_figsets = v_arg
4865
4982
  kwargs.pop(k_arg, None)
4866
4983
  break
4867
- if axis==1:
4868
- data=data.T
4984
+ if axis == 1:
4985
+ data = data.T
4869
4986
  if isinstance(data, dict):
4870
4987
  data = pd.DataFrame(pd.Series(data))
4871
4988
  if ~isinstance(data, pd.DataFrame):
4872
- data=pd.DataFrame(data)
4989
+ data = pd.DataFrame(data)
4873
4990
  if isinstance(data, pd.DataFrame):
4874
- data=data.select_dtypes(include=np.number)
4875
- if isinstance(columns,str):
4876
- columns=[columns]
4991
+ data = data.select_dtypes(include=np.number)
4992
+ if isinstance(columns, str):
4993
+ columns = [columns]
4877
4994
  if columns is None:
4878
4995
  columns = list(data.columns)
4879
- data=data[columns]
4880
- categories = list(data.index)
4996
+ data = data[columns]
4997
+ categories = list(data.index)
4881
4998
  num_vars = len(categories)
4882
4999
 
4883
5000
  # Set y-axis limits and grid intervals
@@ -4928,7 +5045,9 @@ def radar(
4928
5045
  else:
4929
5046
  # * spider style: spider-style grid (straight lines, not circles)
4930
5047
  # Create the spider-style grid (straight lines, not circles)
4931
- for i in range(1, int((vmax-vmin)/ ((vmax-vmin)*grid_interval_ratio))+1):#int(vmax * grid_interval_ratio) + 1):
5048
+ for i in range(
5049
+ 1, int((vmax - vmin) / ((vmax - vmin) * grid_interval_ratio)) + 1
5050
+ ): # int(vmax * grid_interval_ratio) + 1):
4932
5051
  ax.plot(
4933
5052
  angles + [angles[0]], # Closing the loop
4934
5053
  [i * vmax * grid_interval_ratio] * (num_vars + 1)
@@ -4962,8 +5081,8 @@ def radar(
4962
5081
  ax.tick_params(axis="y", pad=sp) # move spines outward
4963
5082
  # colors
4964
5083
  if facecolor is not None:
4965
- if not isinstance(facecolor,list):
4966
- facecolor=[facecolor]
5084
+ if not isinstance(facecolor, list):
5085
+ facecolor = [facecolor]
4967
5086
  colors = facecolor
4968
5087
  else:
4969
5088
  colors = (
@@ -4987,44 +5106,44 @@ def radar(
4987
5106
  )
4988
5107
  ax.fill(angles, values, color=colors[i], alpha=alpha)
4989
5108
  # Add text labels for each value at each angle
4990
- labeled_points = set() #这样同一个点就不会标多次了
5109
+ labeled_points = set() # 这样同一个点就不会标多次了
4991
5110
  if show_value:
4992
5111
  for angle, value in zip(angles, values):
4993
5112
  if (angle, value) not in labeled_points:
4994
5113
  # offset_radius = value * value_offset
4995
5114
  lim_ = np.max(values)
4996
- sep_in = lim_/5
4997
- sep_low=sep_in*2
4998
- sep_med=sep_in*3
4999
- sep_hig=sep_in*4
5000
- sep_out=lim_*5
5001
- if value<sep_in:
5115
+ sep_in = lim_ / 5
5116
+ sep_low = sep_in * 2
5117
+ sep_med = sep_in * 3
5118
+ sep_hig = sep_in * 4
5119
+ sep_out = lim_ * 5
5120
+ if value < sep_in:
5002
5121
  offset_radius = value * 0.7
5003
- elif value<sep_low:
5122
+ elif value < sep_low:
5004
5123
  offset_radius = value * 0.8
5005
- elif sep_low<=value<sep_med:
5124
+ elif sep_low <= value < sep_med:
5006
5125
  offset_radius = value * 0.85
5007
- elif sep_med<=value<sep_hig:
5126
+ elif sep_med <= value < sep_hig:
5008
5127
  offset_radius = value * 0.9
5009
- elif sep_hig<=value<sep_out:
5128
+ elif sep_hig <= value < sep_out:
5010
5129
  offset_radius = value * 0.93
5011
5130
  else:
5012
5131
  offset_radius = value * 0.98
5013
5132
  ax.text(
5014
- angle,
5133
+ angle,
5015
5134
  offset_radius,
5016
- f"{value:{fmt}}",
5017
- ha="center",
5018
- va="center",
5019
- fontsize=fontsize,
5135
+ f"{value:{fmt}}",
5136
+ ha="center",
5137
+ va="center",
5138
+ fontsize=fontsize,
5020
5139
  color=fontcolor,
5021
- zorder=11
5140
+ zorder=11,
5022
5141
  )
5023
5142
  labeled_points.add((angle, value))
5024
5143
 
5025
5144
  ax.set_ylim(ylim)
5026
5145
  # Add markers for each data point
5027
- for i, (col, val) in enumerate(data.items()):
5146
+ for i, (col, val) in enumerate(data.items()):
5028
5147
  ax.plot(
5029
5148
  angles,
5030
5149
  list(val) + [val[0]], # Close the loop for markers
@@ -5055,15 +5174,15 @@ def radar(
5055
5174
 
5056
5175
 
5057
5176
  def pie(
5058
- data:pd.Series,
5059
- columns:list = None,
5177
+ data: pd.Series,
5178
+ columns: list = None,
5060
5179
  facecolor=None,
5061
5180
  explode=[0.1],
5062
5181
  startangle=90,
5063
5182
  shadow=True,
5064
5183
  fontcolor="k",
5065
- fmt=".2f",
5066
- width=None,# the center blank
5184
+ fmt=".2f",
5185
+ width=None, # the center blank
5067
5186
  pctdistance=0.85,
5068
5187
  labeldistance=1.1,
5069
5188
  kws_wedge={},
@@ -5077,9 +5196,9 @@ def pie(
5077
5196
  edgewidth=1,
5078
5197
  cmap=None,
5079
5198
  show_value=False,
5080
- show_label=True,# False: only show the outer layer, if it is None, not show
5081
- expand_label=(1.2,1.2),
5082
- kws_bbox={},#dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5199
+ show_label=True, # False: only show the outer layer, if it is None, not show
5200
+ expand_label=(1.2, 1.2),
5201
+ kws_bbox={}, # dict(facecolor="none", alpha=0.5, edgecolor="black", boxstyle="round,pad=0.3"), # '{}' to hide
5083
5202
  show_legend=True,
5084
5203
  legend_loc="upper right",
5085
5204
  bbox_to_anchor=[1.4, 1.1],
@@ -5087,11 +5206,12 @@ def pie(
5087
5206
  rotation_correction=0,
5088
5207
  verbose=True,
5089
5208
  ax=None,
5090
- **kwargs
5091
- ):
5209
+ **kwargs,
5210
+ ):
5092
5211
  from adjustText import adjust_text
5093
- if run_once_within(20,reverse=True) and verbose:
5094
- usage_="""usage:
5212
+
5213
+ if run_once_within(20, reverse=True) and verbose:
5214
+ usage_ = """usage:
5095
5215
  pie(
5096
5216
  data:pd.Series,
5097
5217
  columns:list = None,
@@ -5190,45 +5310,45 @@ def pie(
5190
5310
  if isinstance(data, dict):
5191
5311
  data = pd.DataFrame(pd.Series(data))
5192
5312
  if ~isinstance(data, pd.DataFrame):
5193
- data=pd.DataFrame(data)
5313
+ data = pd.DataFrame(data)
5194
5314
 
5195
5315
  if isinstance(data, pd.DataFrame):
5196
- data=data.select_dtypes(include=np.number)
5197
- if isinstance(columns,str):
5198
- columns=[columns]
5316
+ data = data.select_dtypes(include=np.number)
5317
+ if isinstance(columns, str):
5318
+ columns = [columns]
5199
5319
  if columns is None:
5200
5320
  columns = list(data.columns)
5201
5321
  # data=data[columns]
5202
5322
  # columns = list(data.columns)
5203
5323
  # print(columns)
5204
5324
  # 选择部分数据
5205
- df=data[columns]
5325
+ df = data[columns]
5206
5326
 
5207
5327
  if not isinstance(explode, list):
5208
- explode=[explode]
5328
+ explode = [explode]
5209
5329
  if explode == [None]:
5210
- explode=[0]
5330
+ explode = [0]
5211
5331
 
5212
5332
  if width is None:
5213
- if df.shape[1]>1:
5214
- width=1/(df.shape[1]+2)
5333
+ if df.shape[1] > 1:
5334
+ width = 1 / (df.shape[1] + 2)
5215
5335
  else:
5216
- width=1
5217
- if isinstance(width,(float,int)):
5218
- width=[width]
5219
- if len(width)<df.shape[1]:
5220
- width=width*df.shape[1]
5221
- if isinstance(radius,(float,int)):
5222
- radius=[radius]
5223
- radius_tile=[1]*df.shape[1]
5224
- radius=radius_tile.copy()
5225
- for i in range(1,df.shape[1]):
5226
- radius[i]=radius_tile[i]-np.sum(width[:i])
5336
+ width = 1
5337
+ if isinstance(width, (float, int)):
5338
+ width = [width]
5339
+ if len(width) < df.shape[1]:
5340
+ width = width * df.shape[1]
5341
+ if isinstance(radius, (float, int)):
5342
+ radius = [radius]
5343
+ radius_tile = [1] * df.shape[1]
5344
+ radius = radius_tile.copy()
5345
+ for i in range(1, df.shape[1]):
5346
+ radius[i] = radius_tile[i] - np.sum(width[:i])
5227
5347
 
5228
5348
  # colors
5229
5349
  if facecolor is not None:
5230
- if not isinstance(facecolor,list):
5231
- facecolor=[facecolor]
5350
+ if not isinstance(facecolor, list):
5351
+ facecolor = [facecolor]
5232
5352
  colors = facecolor
5233
5353
  else:
5234
5354
  colors = (
@@ -5239,38 +5359,38 @@ def pie(
5239
5359
  # to check if facecolor is nested list or not
5240
5360
  is_nested = True if any(isinstance(i, list) for i in colors) else False
5241
5361
  inested = 0
5242
- for column_,width_,radius_ in zip(columns, width,radius):
5243
- if column_!=columns[0]:
5362
+ for column_, width_, radius_ in zip(columns, width, radius):
5363
+ if column_ != columns[0]:
5244
5364
  labels = data.index if show_label else None
5245
5365
  else:
5246
5366
  labels = data.index if show_label is not None else None
5247
5367
  data = df[column_]
5248
- labels_legend=data.index
5368
+ labels_legend = data.index
5249
5369
  sizes = data.values
5250
-
5370
+
5251
5371
  # Set wedge and text properties if none are provided
5252
5372
  kws_wedge = kws_wedge or {"edgecolor": edgecolor, "linewidth": edgewidth}
5253
- kws_wedge.update({"width":width_})
5254
- fontcolor=kws_text.get("color",fontcolor)
5255
- fontsize=kws_text.get("fontsize",fontsize)
5373
+ kws_wedge.update({"width": width_})
5374
+ fontcolor = kws_text.get("color", fontcolor)
5375
+ fontsize = kws_text.get("fontsize", fontsize)
5256
5376
  kws_text.update({"color": fontcolor, "fontsize": fontsize})
5257
5377
 
5258
- if ax is None:
5259
- ax=plt.gca()
5260
- if len(explode)<len(labels_legend):
5261
- explode.extend([0]*(len(labels_legend)-len(explode)))
5378
+ if ax is None:
5379
+ ax = plt.gca()
5380
+ if len(explode) < len(labels_legend):
5381
+ explode.extend([0] * (len(labels_legend) - len(explode)))
5262
5382
  print(explode)
5263
5383
  if fmt:
5264
5384
  if not fmt.startswith("%"):
5265
- autopct =f"%{fmt}%%"
5385
+ autopct = f"%{fmt}%%"
5266
5386
  else:
5267
- autopct=None
5387
+ autopct = None
5268
5388
 
5269
5389
  if show_value is None:
5270
5390
  result = ax.pie(
5271
5391
  sizes,
5272
5392
  labels=labels,
5273
- autopct= None,
5393
+ autopct=None,
5274
5394
  startangle=startangle + rotation_correction,
5275
5395
  explode=explode,
5276
5396
  colors=colors[inested] if is_nested else colors,
@@ -5282,7 +5402,7 @@ def pie(
5282
5402
  center=center,
5283
5403
  radius=radius_,
5284
5404
  frame=frame,
5285
- **kwargs
5405
+ **kwargs,
5286
5406
  )
5287
5407
  else:
5288
5408
  result = ax.pie(
@@ -5292,7 +5412,7 @@ def pie(
5292
5412
  startangle=startangle + rotation_correction,
5293
5413
  explode=explode,
5294
5414
  colors=colors[inested] if is_nested else colors,
5295
- shadow=shadow,#shadow,
5415
+ shadow=shadow, # shadow,
5296
5416
  pctdistance=pctdistance,
5297
5417
  labeldistance=labeldistance,
5298
5418
  wedgeprops=kws_wedge,
@@ -5300,13 +5420,13 @@ def pie(
5300
5420
  center=center,
5301
5421
  radius=radius_,
5302
5422
  frame=frame,
5303
- **kwargs
5423
+ **kwargs,
5304
5424
  )
5305
5425
  if len(result) == 3:
5306
5426
  wedges, texts, autotexts = result
5307
5427
  elif len(result) == 2:
5308
5428
  wedges, texts = result
5309
- autotexts = None
5429
+ autotexts = None
5310
5430
  #! adjust_text
5311
5431
  if autotexts or texts:
5312
5432
  all_texts = []
@@ -5314,30 +5434,38 @@ def pie(
5314
5434
  all_texts.extend(autotexts)
5315
5435
  if texts and show_label:
5316
5436
  all_texts.extend(texts)
5317
-
5437
+
5318
5438
  adjust_text(
5319
5439
  all_texts,
5320
5440
  ax=ax,
5321
- arrowprops=kws_arrow,#dict(arrowstyle="-", color="gray", lw=0.5),
5441
+ arrowprops=kws_arrow, # dict(arrowstyle="-", color="gray", lw=0.5),
5322
5442
  bbox=kws_bbox if kws_bbox else None,
5323
5443
  expand=expand_label,
5324
5444
  fontdict={
5325
- "fontsize": fontsize,
5326
- "color": fontcolor,
5327
- },
5445
+ "fontsize": fontsize,
5446
+ "color": fontcolor,
5447
+ },
5328
5448
  )
5329
5449
  # Show exact values on wedges if show_value is True
5330
5450
  if show_value:
5331
5451
  for i, (wedge, txt) in enumerate(zip(wedges, texts)):
5332
5452
  angle = (wedge.theta2 - wedge.theta1) / 2 + wedge.theta1
5333
- x = np.cos(np.radians(angle)) * (pctdistance ) * radius_
5334
- y = np.sin(np.radians(angle)) * (pctdistance ) * radius_
5453
+ x = np.cos(np.radians(angle)) * (pctdistance) * radius_
5454
+ y = np.sin(np.radians(angle)) * (pctdistance) * radius_
5335
5455
  if not fmt.startswith("{"):
5336
5456
  value_text = f"{sizes[i]:{fmt}}"
5337
5457
  else:
5338
- value_text = fmt.format(sizes[i])
5339
- ax.text(x, y, value_text, ha="center", va="center", fontsize=fontsize,color=fontcolor)
5340
- inested+=1
5458
+ value_text = fmt.format(sizes[i])
5459
+ ax.text(
5460
+ x,
5461
+ y,
5462
+ value_text,
5463
+ ha="center",
5464
+ va="center",
5465
+ fontsize=fontsize,
5466
+ color=fontcolor,
5467
+ )
5468
+ inested += 1
5341
5469
  # Customize the legend
5342
5470
  if show_legend:
5343
5471
  ax.legend(
@@ -5347,10 +5475,11 @@ def pie(
5347
5475
  bbox_to_anchor=bbox_to_anchor,
5348
5476
  fontsize=legend_fontsize,
5349
5477
  title_fontsize=legend_fontsize,
5350
- )
5478
+ )
5351
5479
  ax.set(aspect="equal")
5352
5480
  return ax
5353
5481
 
5482
+
5354
5483
  def ellipse(
5355
5484
  data,
5356
5485
  x=None,
@@ -5363,7 +5492,7 @@ def ellipse(
5363
5492
  palette=None,
5364
5493
  facecolor=None,
5365
5494
  edgecolor=None,
5366
- label:bool=True,
5495
+ label: bool = True,
5367
5496
  **kwargs,
5368
5497
  ):
5369
5498
  """
@@ -5438,8 +5567,8 @@ def ellipse(
5438
5567
  groups = [None]
5439
5568
  color_map = {None: kwargs.get("edgecolor", "blue")}
5440
5569
  alpha = kwargs.pop("alpha", 0.2)
5441
- edgecolor=kwargs.pop("edgecolor", None)
5442
- facecolor=kwargs.pop("facecolor", None)
5570
+ edgecolor = kwargs.pop("edgecolor", None)
5571
+ facecolor = kwargs.pop("facecolor", None)
5443
5572
  for group in groups:
5444
5573
  group_data = data[data[hue] == group] if hue else data
5445
5574
 
@@ -5477,7 +5606,7 @@ def ellipse(
5477
5606
  height=height,
5478
5607
  angle=angle,
5479
5608
  edgecolor=edgecolor_,
5480
- facecolor=(facecolor_, alpha), #facecolor_, # only work on facecolor
5609
+ facecolor=(facecolor_, alpha), # facecolor_, # only work on facecolor
5481
5610
  # alpha=alpha,
5482
5611
  label=group if (hue and label) else None,
5483
5612
  **kwargs,
@@ -5504,6 +5633,7 @@ def ellipse(
5504
5633
 
5505
5634
  return ax
5506
5635
 
5636
+
5507
5637
  def ppi(
5508
5638
  interactions,
5509
5639
  player1="preferredName_A",
@@ -5511,39 +5641,39 @@ def ppi(
5511
5641
  weight="score",
5512
5642
  n_layers=None, # Number of concentric layers
5513
5643
  n_rank=[5, 10], # Nodes in each rank for the concentric layout
5514
- dist_node = 10, # Distance between each rank of circles
5515
- layout="degree",
5516
- size=None,#700,
5517
- sizes=(50,500),# min and max of size
5644
+ dist_node=10, # Distance between each rank of circles
5645
+ layout="degree",
5646
+ size=None, # 700,
5647
+ sizes=(50, 500), # min and max of size
5518
5648
  facecolor="skyblue",
5519
- cmap='coolwarm',
5649
+ cmap="coolwarm",
5520
5650
  edgecolor="k",
5521
5651
  edgelinewidth=1.5,
5522
- alpha=.5,
5523
- alphas=(0.1, 1.0),# min and max of alpha
5652
+ alpha=0.5,
5653
+ alphas=(0.1, 1.0), # min and max of alpha
5524
5654
  marker="o",
5525
5655
  node_hideticks=True,
5526
5656
  linecolor="gray",
5527
- line_cmap='coolwarm',
5657
+ line_cmap="coolwarm",
5528
5658
  linewidth=1.5,
5529
- linewidths=(0.5,5),# min and max of linewidth
5659
+ linewidths=(0.5, 5), # min and max of linewidth
5530
5660
  linealpha=1.0,
5531
- linealphas=(0.1,1.0),# min and max of linealpha
5661
+ linealphas=(0.1, 1.0), # min and max of linealpha
5532
5662
  linestyle="-",
5533
- line_arrowstyle='-',
5663
+ line_arrowstyle="-",
5534
5664
  fontsize=10,
5535
5665
  fontcolor="k",
5536
- ha:str="center",
5537
- va:str="center",
5666
+ ha: str = "center",
5667
+ va: str = "center",
5538
5668
  figsize=(12, 10),
5539
- k_value=0.3,
5669
+ k_value=0.3,
5540
5670
  bgcolor="w",
5541
5671
  dir_save="./ppi_network.html",
5542
5672
  physics=True,
5543
5673
  notebook=False,
5544
5674
  scale=1,
5545
5675
  ax=None,
5546
- **kwargs
5676
+ **kwargs,
5547
5677
  ):
5548
5678
  """
5549
5679
  Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
@@ -5560,14 +5690,14 @@ def ppi(
5560
5690
  )
5561
5691
  """
5562
5692
  from pyvis.network import Network
5563
- import networkx as nx
5693
+ import networkx as nx
5564
5694
  from IPython.display import IFrame
5565
5695
  from matplotlib.colors import Normalize
5566
5696
  from matplotlib import cm
5567
5697
  from . import ips
5568
5698
 
5569
5699
  if run_once_within():
5570
- usage_str="""
5700
+ usage_str = """
5571
5701
  ppi(
5572
5702
  interactions,
5573
5703
  player1="preferredName_A",
@@ -5611,17 +5741,19 @@ def ppi(
5611
5741
  ):
5612
5742
  """
5613
5743
  print(usage_str)
5614
-
5744
+
5615
5745
  # Check for required columns in the DataFrame
5616
5746
  for col in [player1, player2, weight]:
5617
5747
  if col not in interactions.columns:
5618
- raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
5748
+ raise ValueError(
5749
+ f"Column '{col}' is missing from the interactions DataFrame."
5750
+ )
5619
5751
  interactions.sort_values(by=[weight], inplace=True)
5620
5752
  # Initialize Pyvis network
5621
5753
  net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
5622
5754
  net.force_atlas_2based(
5623
5755
  gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
5624
- )
5756
+ )
5625
5757
  net.toggle_physics(physics)
5626
5758
 
5627
5759
  kws_figsets = {}
@@ -5637,47 +5769,62 @@ def ppi(
5637
5769
  G.add_edge(row[player1], row[player2], weight=row[weight])
5638
5770
  # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
5639
5771
 
5640
-
5641
5772
  # Calculate node degrees
5642
5773
  degrees = dict(G.degree())
5643
5774
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5644
5775
  colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
5645
5776
 
5646
- if not ips.isa(facecolor, 'color'):
5777
+ if not ips.isa(facecolor, "color"):
5647
5778
  print("facecolor: based on degrees")
5648
5779
  facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
5649
5780
  num_nodes = G.number_of_nodes()
5650
- #* size
5781
+ # * size
5651
5782
  # Set properties based on degrees
5652
- if not isinstance(size, (int,float,list)):
5783
+ if not isinstance(size, (int, float, list)):
5653
5784
  print("size: based on degrees")
5654
5785
  size = [deg * 50 for deg in degrees.values()] # Scale sizes
5655
- size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
5656
- if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
5786
+ size = (
5787
+ (size[:num_nodes] if len(size) > num_nodes else size)
5788
+ if isinstance(size, list)
5789
+ else [size] * num_nodes
5790
+ )
5791
+ if isinstance(size, list) and len(ips.flatten(size, verbose=False)) != 1:
5657
5792
  # Normalize sizes
5658
5793
  min_size, max_size = sizes # Use sizes tuple for min and max values
5659
5794
  min_degree, max_degree = min(size), max(size)
5660
5795
  if max_degree > min_degree: # Avoid division by zero
5661
5796
  size = [
5662
- min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5797
+ min_size
5798
+ + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
5663
5799
  for sz in size
5664
5800
  ]
5665
5801
  else:
5666
5802
  # If all values are the same, set them to a default of the midpoint
5667
5803
  size = [(min_size + max_size) / 2] * len(size)
5668
5804
 
5669
- #* facecolor
5670
- facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
5805
+ # * facecolor
5806
+ facecolor = (
5807
+ (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor)
5808
+ if isinstance(facecolor, list)
5809
+ else [facecolor] * num_nodes
5810
+ )
5671
5811
  # * facealpha
5672
5812
  if isinstance(alpha, list):
5673
- alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
5813
+ alpha = (
5814
+ alpha[:num_nodes]
5815
+ if len(alpha) > num_nodes
5816
+ else alpha + [alpha[-1]] * (num_nodes - len(alpha))
5817
+ )
5674
5818
  min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
5675
5819
  if len(alpha) > 0:
5676
5820
  # Normalize alpha based on the specified min and max
5677
5821
  min_alpha, max_alpha = min(alpha), max(alpha)
5678
5822
  if max_alpha > min_alpha: # Avoid division by zero
5679
5823
  alpha = [
5680
- min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
5824
+ min_alphas
5825
+ + (max_alphas - min_alphas)
5826
+ * (ea - min_alpha)
5827
+ / (max_alpha - min_alpha)
5681
5828
  for ea in alpha
5682
5829
  ]
5683
5830
  else:
@@ -5685,7 +5832,7 @@ def ppi(
5685
5832
  alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
5686
5833
  else:
5687
5834
  # Default to a full opacity if no edges are provided
5688
- alpha = [1.0] * num_nodes
5835
+ alpha = [1.0] * num_nodes
5689
5836
  else:
5690
5837
  # If alpha is a single value, convert it to a list and normalize it
5691
5838
  alpha = [alpha] * num_nodes # Adjust based on alphas
@@ -5699,7 +5846,7 @@ def ppi(
5699
5846
  alpha=alpha[i],
5700
5847
  font={"size": fontsize, "color": fontcolor},
5701
5848
  )
5702
- print(f'nodes number: {i+1}')
5849
+ print(f"nodes number: {i+1}")
5703
5850
 
5704
5851
  for edge in G.edges(data=True):
5705
5852
  net.add_edge(
@@ -5718,11 +5865,11 @@ def ppi(
5718
5865
  "shell",
5719
5866
  "planar",
5720
5867
  "spiral",
5721
- "degree"
5868
+ "degree",
5722
5869
  ]
5723
5870
  layout = ips.strcmp(layout, layouts)[0]
5724
5871
  print(f"layout:{layout}, or select one in {layouts}")
5725
-
5872
+
5726
5873
  # Choose layout
5727
5874
  if layout == "spring":
5728
5875
  pos = nx.spring_layout(G, k=k_value)
@@ -5744,24 +5891,26 @@ def ppi(
5744
5891
  pos = nx.spring_layout(G, k=k_value)
5745
5892
  elif layout == "spiral":
5746
5893
  pos = nx.spiral_layout(G)
5747
- elif layout=='degree':
5894
+ elif layout == "degree":
5748
5895
  # Calculate node degrees and sort nodes by degree
5749
5896
  degrees = dict(G.degree())
5750
5897
  sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
5751
5898
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
5752
5899
  colormap = cm.get_cmap(cmap)
5753
-
5900
+
5754
5901
  # Create positions for concentric circles based on n_layers and n_rank
5755
5902
  pos = {}
5756
- n_layers=len(n_rank)+1 if n_layers is None else n_layers
5903
+ n_layers = len(n_rank) + 1 if n_layers is None else n_layers
5757
5904
  for rank_index in range(n_layers):
5758
5905
  if rank_index < len(n_rank):
5759
5906
  nodes_per_rank = n_rank[rank_index]
5760
- rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
5907
+ rank_nodes = sorted_nodes[
5908
+ sum(n_rank[:rank_index]) : sum(n_rank[: rank_index + 1])
5909
+ ]
5761
5910
  else:
5762
5911
  # 随机打乱剩余节点的顺序
5763
- remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
5764
- random_indices = np.random.permutation(len(remaining_nodes))
5912
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]) :]
5913
+ random_indices = np.random.permutation(len(remaining_nodes))
5765
5914
  rank_nodes = [remaining_nodes[i] for i in random_indices]
5766
5915
 
5767
5916
  radius = (rank_index + 1) * dist_node # Radius for this rank
@@ -5772,7 +5921,9 @@ def ppi(
5772
5921
  pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
5773
5922
 
5774
5923
  else:
5775
- print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
5924
+ print(
5925
+ f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}"
5926
+ )
5776
5927
  pos = nx.spring_layout(G, k=k_value)
5777
5928
 
5778
5929
  for node, (x, y) in pos.items():
@@ -5781,8 +5932,8 @@ def ppi(
5781
5932
 
5782
5933
  # If ax is None, use plt.gca()
5783
5934
  if ax is None:
5784
- fig, ax = plt.subplots(1,1,figsize=figsize)
5785
-
5935
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
5936
+
5786
5937
  # Draw nodes, edges, and labels with customization options
5787
5938
  nx.draw_networkx_nodes(
5788
5939
  G,
@@ -5794,14 +5945,18 @@ def ppi(
5794
5945
  edgecolors=edgecolor,
5795
5946
  alpha=alpha,
5796
5947
  hide_ticks=node_hideticks,
5797
- node_shape=marker
5948
+ node_shape=marker,
5798
5949
  )
5799
5950
 
5800
- #* linewidth
5951
+ # * linewidth
5801
5952
  if not isinstance(linewidth, list):
5802
5953
  linewidth = [linewidth] * G.number_of_edges()
5803
5954
  else:
5804
- linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
5955
+ linewidth = (
5956
+ linewidth[: G.number_of_edges()]
5957
+ if len(linewidth) > G.number_of_edges()
5958
+ else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth))
5959
+ )
5805
5960
  # Normalize linewidth if it is a list
5806
5961
  if isinstance(linewidth, list):
5807
5962
  min_linewidth, max_linewidth = min(linewidth), max(linewidth)
@@ -5809,7 +5964,10 @@ def ppi(
5809
5964
  if max_linewidth > min_linewidth: # Avoid division by zero
5810
5965
  # Scale between vmin and vmax
5811
5966
  linewidth = [
5812
- vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
5967
+ vmin
5968
+ + (vmax - vmin)
5969
+ * (lw - min_linewidth)
5970
+ / (max_linewidth - min_linewidth)
5813
5971
  for lw in linewidth
5814
5972
  ]
5815
5973
  else:
@@ -5818,8 +5976,8 @@ def ppi(
5818
5976
  else:
5819
5977
  # If linewidth is a single value, convert it to a list of that value
5820
5978
  linewidth = [linewidth] * G.number_of_edges()
5821
- #* linecolor
5822
- if not isinstance(linecolor, str):
5979
+ # * linecolor
5980
+ if not isinstance(linecolor, str):
5823
5981
  weights = [G[u][v]["weight"] for u, v in G.edges()]
5824
5982
  norm = Normalize(vmin=min(weights), vmax=max(weights))
5825
5983
  colormap = cm.get_cmap(line_cmap)
@@ -5829,45 +5987,58 @@ def ppi(
5829
5987
 
5830
5988
  # * linealpha
5831
5989
  if isinstance(linealpha, list):
5832
- linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
5990
+ linealpha = (
5991
+ linealpha[: G.number_of_edges()]
5992
+ if len(linealpha) > G.number_of_edges()
5993
+ else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha))
5994
+ )
5833
5995
  min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
5834
5996
  if len(linealpha) > 0:
5835
5997
  min_linealpha, max_linealpha = min(linealpha), max(linealpha)
5836
5998
  if max_linealpha > min_linealpha: # Avoid division by zero
5837
5999
  linealpha = [
5838
- min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
6000
+ min_alpha
6001
+ + (max_alpha - min_alpha)
6002
+ * (ea - min_linealpha)
6003
+ / (max_linealpha - min_linealpha)
5839
6004
  for ea in linealpha
5840
6005
  ]
5841
6006
  else:
5842
6007
  linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
5843
6008
  else:
5844
- linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
6009
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
5845
6010
  else:
5846
6011
  linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
5847
6012
  nx.draw_networkx_edges(
5848
- G,
5849
- pos,
5850
- ax=ax,
5851
- edge_color=linecolor,
6013
+ G,
6014
+ pos,
6015
+ ax=ax,
6016
+ edge_color=linecolor,
5852
6017
  width=linewidth,
5853
6018
  style=linestyle,
5854
- arrowstyle=line_arrowstyle,
5855
- alpha=linealpha
6019
+ arrowstyle=line_arrowstyle,
6020
+ alpha=linealpha,
5856
6021
  )
5857
-
6022
+
5858
6023
  nx.draw_networkx_labels(
5859
- G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
6024
+ G,
6025
+ pos,
6026
+ ax=ax,
6027
+ font_size=fontsize,
6028
+ font_color=fontcolor,
6029
+ horizontalalignment=ha,
6030
+ verticalalignment=va,
5860
6031
  )
5861
- figsets(ax=ax,**kws_figsets)
6032
+ figsets(ax=ax, **kws_figsets)
5862
6033
  ax.axis("off")
5863
6034
  if dir_save:
5864
6035
  if not os.path.basename(dir_save):
5865
- dir_save="_.html"
6036
+ dir_save = "_.html"
5866
6037
  net.write_html(dir_save)
5867
- nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
6038
+ nx.write_graphml(G, dir_save.replace(".html", ".graphml")) # Export to GraphML
5868
6039
  print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
5869
- ips.figsave(dir_save.replace(".html",".pdf"))
5870
- return G,ax
6040
+ ips.figsave(dir_save.replace(".html", ".pdf"))
6041
+ return G, ax
5871
6042
 
5872
6043
 
5873
6044
  def plot_map(
@@ -5886,6 +6057,7 @@ def plot_map(
5886
6057
  save_path=None, # Path to save the map in offline mode
5887
6058
  pydeck_map=False, # Whether to use pydeck for rendering (True for pydeck)
5888
6059
  pydeck_style="mapbox://styles/mapbox/streets-v11", # Map style for pydeck
6060
+ verbose=True, # show usage
5889
6061
  **kwargs, # Additional arguments for Folium Map
5890
6062
  ):
5891
6063
  """
@@ -5899,19 +6071,88 @@ def plot_map(
5899
6071
  df_tiles = pd.DataFrame({"tiles": tiles_support})
5900
6072
  fsave("....tiles.csv",df_tiles)
5901
6073
  """
6074
+ config_markers = """from folium import Icon
6075
+ # https://github.com/lennardv2/Leaflet.awesome-markers?tab=readme-ov-file
6076
+ markers = [
6077
+ {
6078
+ "location": [loc[0], loc[1]],
6079
+ "popup": "Center City",
6080
+ "tooltip": "Philadelphia",
6081
+ "icon": Icon(color="red", icon="flag"),
6082
+ },
6083
+ {
6084
+ "location": [loc[0], loc[1] + 0.05],
6085
+ "popup": "Rittenhouse Square",
6086
+ "tooltip": "A lovely park",
6087
+ "icon": Icon(
6088
+ color="purple", icon="flag", prefix="fa"
6089
+ ), # Purple marker with "star" icon (Font Awesome)
6090
+ },
6091
+ ]"""
6092
+ config_overlay = """
6093
+ from folium import Circle
6094
+
6095
+ circle = Circle(
6096
+ location=loc,
6097
+ radius=300, # In meters
6098
+ color="#EB686C",
6099
+ fill=True,
6100
+ fill_opacity=0.2,
6101
+ )
6102
+ markers = [
6103
+ {
6104
+ "location": [loc[0], loc[1]],
6105
+ "popup": "Center City",
6106
+ "tooltip": "Philadelphia",
6107
+ },
6108
+ {
6109
+ "location": [loc[0], loc[1] + 0.05],
6110
+ "popup": "Rittenhouse Square",
6111
+ "tooltip": "A lovely park",
6112
+ },
6113
+ ]
6114
+ plot_map(loc, overlays=[circle], zoom_start=14)
6115
+ """
6116
+ config_plugin = """
6117
+ from folium.plugins import HeatMap
6118
+ heat_data = [
6119
+ [48.54440975, 9.060237673391708, 1],
6120
+ [48.5421456, 9.057464182487431, 1],
6121
+ [48.54539175, 9.059915422200906, 1],
6122
+ ]
6123
+ heatmap = HeatMap(
6124
+ heat_data,
6125
+ radius=5, # Increase the radius of each point
6126
+ blur=5, # Adjust the blurring effect
6127
+ min_opacity=0.4, # Make the heatmap semi-transparent
6128
+ max_zoom=16, # Zoom level at which points appear
6129
+ gradient={ # Define a custom gradient
6130
+ 0.2: "blue",
6131
+ 0.4: "lime",
6132
+ 0.6: "yellow",
6133
+ 1.0: "#A34B00",
6134
+ },
6135
+ )
6136
+
6137
+ plot_map(loc, plugins=[heatmap])
6138
+ """
5902
6139
  from pathlib import Path
5903
6140
 
5904
6141
  # Get the current script's directory as a Path object
5905
6142
  current_directory = Path(__file__).resolve().parent
5906
6143
  if not "tiles_support" in locals():
5907
- tiles_support = fload(current_directory / "data" / "tiles.csv", verbose=0).iloc[:, 1].tolist()
5908
- tiles=strcmp(tiles, tiles_support)[0]
6144
+ tiles_support = (
6145
+ fload(current_directory / "data" / "tiles.csv", verbose=0)
6146
+ .iloc[:, 1]
6147
+ .tolist()
6148
+ )
6149
+ tiles = strcmp(tiles, tiles_support)[0]
5909
6150
  import folium
5910
6151
  import streamlit as st
5911
6152
  import pydeck as pdk
5912
6153
  from streamlit_folium import st_folium
5913
6154
  from folium.plugins import HeatMap
5914
-
6155
+
5915
6156
  if pydeck_map:
5916
6157
  view = pdk.ViewState(
5917
6158
  latitude=location[0],
@@ -5920,7 +6161,6 @@ def plot_map(
5920
6161
  pitch=0,
5921
6162
  )
5922
6163
 
5923
- # Example Layer (can be replaced by your custom layers)
5924
6164
  layer = pdk.Layer(
5925
6165
  "ScatterplotLayer",
5926
6166
  data=[{"lat": location[0], "lon": location[1]}],
@@ -5929,20 +6169,16 @@ def plot_map(
5929
6169
  get_radius=1000,
5930
6170
  )
5931
6171
 
5932
- # Create the deck
5933
6172
  deck = pdk.Deck(
5934
6173
  layers=[layer],
5935
6174
  initial_view_state=view,
5936
6175
  map_style=pydeck_style,
5937
6176
  )
5938
-
5939
- # Render map in Streamlit
5940
6177
  st.pydeck_chart(deck)
5941
6178
 
5942
6179
  return deck # Return the pydeck map
5943
6180
 
5944
6181
  else:
5945
- # Initialize the base map (Folium)
5946
6182
  m = folium.Map(
5947
6183
  location=location,
5948
6184
  zoom_start=zoom_start,
@@ -5950,39 +6186,39 @@ def plot_map(
5950
6186
  scrollWheelZoom=scroll_wheel_zoom,
5951
6187
  **kwargs,
5952
6188
  )
5953
-
5954
- # Add markers
5955
6189
  if markers:
6190
+ if verbose:
6191
+ print(config_markers)
5956
6192
  for marker in markers:
5957
6193
  folium.Marker(
5958
6194
  location=marker.get("location"),
5959
6195
  popup=marker.get("popup"),
5960
6196
  tooltip=marker.get("tooltip"),
5961
- icon=marker.get("icon", folium.Icon()), # Default icon if none specified
6197
+ icon=marker.get(
6198
+ "icon", folium.Icon()
6199
+ ), # Default icon if none specified
5962
6200
  ).add_to(m)
5963
6201
 
5964
- # Add overlays
5965
6202
  if overlays:
6203
+ if verbose:
6204
+ print(config_overlay)
5966
6205
  for overlay in overlays:
5967
6206
  overlay.add_to(m)
5968
6207
 
5969
- # Add custom layers
5970
6208
  if custom_layers:
5971
6209
  for layer in custom_layers:
5972
6210
  layer.add_to(m)
5973
6211
 
5974
- # Add plugins
5975
6212
  if plugins:
6213
+ if verbose:
6214
+ print(config_plugin)
5976
6215
  for plugin in plugins:
5977
6216
  plugin.add_to(m)
5978
6217
 
5979
- # Fit map bounds
5980
6218
  if fit_bounds:
5981
6219
  m.fit_bounds(fit_bounds)
5982
6220
 
5983
- # Handle rendering based on output
5984
6221
  if output == "streamlit":
5985
- # Render the map in Streamlit
5986
6222
  st_data = st_folium(m, width=map_width, height=map_height)
5987
6223
  return st_data
5988
6224
  elif output == "offline":