py2ls 0.2.4.5__py3-none-any.whl → 0.2.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/bio.py CHANGED
@@ -814,7 +814,11 @@ def counts_deseq(counts_sam_gene: pd.DataFrame,
814
814
  # .reset_index()
815
815
  # .rename(columns={"index": "gene"})
816
816
  # )
817
- return dds, diff,stat_res
817
+ df_norm=pd.DataFrame(dds.layers['normed_counts'])
818
+ df_norm.index=counts_sam_gene.index
819
+ df_norm.columns=counts_sam_gene.columns
820
+ print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
821
+ return dds, diff, stat_res,df_norm
818
822
 
819
823
  def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
820
824
  """
@@ -1005,25 +1009,25 @@ def plot_enrichr(results_df,
1005
1009
  palette = plot.get_color(n_top, cmap=cmap)[::-1]
1006
1010
  elif isinstance(cmap,list):
1007
1011
  palette=cmap
1008
-
1009
- if n_top<5:
1010
- height_=4
1011
- elif 5<=n_top<10:
1012
- height_=5
1013
- elif 5<=n_top<10:
1014
- height_=6
1015
- elif 10<=n_top<15:
1016
- height_=7
1017
- elif 15<=n_top<20:
1018
- height_=8
1019
- elif 20<=n_top<30:
1020
- height_=9
1012
+ if n_top < 5:
1013
+ height_ = 3
1014
+ elif 5 <= n_top < 10:
1015
+ height_ = 3
1016
+ elif 10 <= n_top < 15:
1017
+ height_ = 3
1018
+ elif 15 <= n_top < 20:
1019
+ height_ =4
1020
+ elif 20 <= n_top < 30:
1021
+ height_ = 5
1022
+ elif 30 <= n_top < 40:
1023
+ height_ = int(n_top / 6)
1021
1024
  else:
1022
- height_=int(n_top/3)
1023
- if ax is None:
1024
- _,ax=plt.subplots(1,1,figsize=[10, height_])
1025
+ height_ = int(n_top / 8)
1026
+
1025
1027
  #! barplot
1026
1028
  if 'bar' in kind.lower():
1029
+ if ax is None:
1030
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1027
1031
  ax=plot.plotxy(
1028
1032
  data=results_df.head(n_top),
1029
1033
  kind="barplot",
@@ -1035,6 +1039,7 @@ def plot_enrichr(results_df,
1035
1039
  )
1036
1040
  plot.figsets(ax=ax, **kws_figsets)
1037
1041
  return ax,results_df
1042
+
1038
1043
  #! dotplot
1039
1044
  elif 'dot' in kind.lower():
1040
1045
  #! dotplot
@@ -1066,16 +1071,20 @@ def plot_enrichr(results_df,
1066
1071
  print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1067
1072
  plot.figsets(ax=ax, **kws_figsets)
1068
1073
  return ax,results_df
1074
+
1069
1075
  #! barplot with counts
1070
1076
  elif 'count' in kind.lower():
1077
+ if ax is None:
1078
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1071
1079
  # 从overlap中提取出个数
1072
- results_df["count"] = results_df["Overlap"].apply(
1080
+ results_df["Count"] = results_df["Overlap"].apply(
1073
1081
  lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
1074
- df_=results_df.sort_values(by="count", ascending=False)
1082
+ df_=results_df.sort_values(by="Count", ascending=False)
1083
+
1075
1084
  ax=plot.plotxy(
1076
1085
  data=df_.head(n_top),
1077
1086
  kind="barplot",
1078
- x="count",
1087
+ x="Count",
1079
1088
  y="Term",
1080
1089
  hue="Term",
1081
1090
  palette=palette,
@@ -1094,11 +1103,12 @@ def plot_bp_cc_mf(
1094
1103
  "GO_Molecular_Function_2023",
1095
1104
  ],
1096
1105
  species="human",
1106
+ download=False,
1097
1107
  n_top=10,
1098
1108
  plot_=True,
1099
1109
  ax=None,
1100
1110
  palette=plot.get_color(3),
1101
- ** kwargs,
1111
+ **kwargs,
1102
1112
  ):
1103
1113
 
1104
1114
  def res_enrichr_2_count(res_enrichr, n_top=10):
@@ -1111,13 +1121,13 @@ def plot_bp_cc_mf(
1111
1121
  return res_enrichr.head(n_top)#[["Term", "Count"]]
1112
1122
 
1113
1123
  res_enrichr_BP = get_enrichr(
1114
- deg_gene_list, gene_sets[0], species=species, plot_=False
1124
+ deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
1115
1125
  )
1116
1126
  res_enrichr_CC = get_enrichr(
1117
- deg_gene_list, gene_sets[1], species=species, plot_=False
1127
+ deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
1118
1128
  )
1119
1129
  res_enrichr_MF = get_enrichr(
1120
- deg_gene_list, gene_sets[2], species=species, plot_=False
1130
+ deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
1121
1131
  )
1122
1132
 
1123
1133
  df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
@@ -1149,6 +1159,7 @@ def plot_bp_cc_mf(
1149
1159
  if ax is None:
1150
1160
  _,ax=plt.subplots(1,1,figsize=[10, height_])
1151
1161
  # 作图
1162
+ display(df2plot)
1152
1163
  if df2plot["Term"].tolist()[0].endswith(")"):
1153
1164
  df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
1154
1165
  if plot_:
@@ -1164,8 +1175,15 @@ def plot_bp_cc_mf(
1164
1175
  )
1165
1176
  return ax, df2plot
1166
1177
 
1167
- def get_library_name():
1168
- return gp.get_library_name()
1178
+ def get_library_name(by=None, verbose=False):
1179
+ lib_names=gp.get_library_name()
1180
+ if by is None:
1181
+ if verbose:
1182
+ [print(i) for i in lib_names]
1183
+ return lib_names
1184
+ else:
1185
+ return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
1186
+
1169
1187
 
1170
1188
  def get_gsva(
1171
1189
  data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
@@ -1398,7 +1416,385 @@ def plot_gsva(gsva_res, # output from bio.get_gsva()
1398
1416
  plot.figsets(ax=ax, **kws_figsets)
1399
1417
  return ax
1400
1418
 
1419
+ def get_prerank(
1420
+ rnk: pd.DataFrame,
1421
+ gene_sets: str,
1422
+ download: bool = False,
1423
+ species="Human",
1424
+ threads=8, # Number of CPU cores to use
1425
+ permutation_num=1000, # Number of permutations for significance
1426
+ min_size=1, # Minimum gene set size
1427
+ max_size=2000, # Maximum gene set size
1428
+ seed=1, # Seed for reproducibility
1429
+ verbose=True, # Verbosity
1430
+ dir_save="./",
1431
+ plot_=False,
1432
+ size=5,
1433
+ cutoff=0.25,
1434
+ show_ring=False,
1435
+ cmap="coolwarm",
1436
+ check_shared=True,
1437
+ **kwargs,
1438
+ ):
1439
+ """
1440
+ Note: Enrichr uses a list of Entrez gene symbols as input.
1441
+
1442
+ """
1443
+ kws_figsets = {}
1444
+ for k_arg, v_arg in kwargs.items():
1445
+ if "figset" in k_arg:
1446
+ kws_figsets = v_arg
1447
+ kwargs.pop(k_arg, None)
1448
+ break
1449
+ species_org = species
1450
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1451
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1452
+ species = ips.strcmp(species, organisms)[0]
1453
+ if species_org.lower() != species.lower():
1454
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1455
+ if os.path.isfile(gene_sets):
1456
+ gene_sets_name = os.path.basename(gene_sets)
1457
+ gene_sets = ips.fload(gene_sets)
1458
+ else:
1459
+ lib_support_names = gp.get_library_name()
1460
+ # correct input gene_set name
1461
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1401
1462
 
1463
+ # download it
1464
+ if download:
1465
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1466
+ else:
1467
+ gene_sets = gene_sets_name # 避免重复下载
1468
+ print(f"\ngene_sets get ready: {gene_sets_name}")
1469
+
1470
+ #! prerank
1471
+ try:
1472
+ pre_res = gp.prerank(
1473
+ rnk=rnk,
1474
+ gene_sets=gene_sets,
1475
+ threads=threads, # Number of CPU cores to use
1476
+ permutation_num=permutation_num, # Number of permutations for significance
1477
+ min_size=min_size, # Minimum gene set size
1478
+ max_size=max_size, # Maximum gene set size
1479
+ seed=seed, # Seed for reproducibility
1480
+ verbose=verbose, # Verbosity
1481
+ )
1482
+ except ValueError as e:
1483
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
1484
+ return None
1485
+ df_prerank = pre_res.res2d
1486
+ if plot_:
1487
+ #! gseaplot
1488
+ # # (1) easy way
1489
+ # terms = df_prerank.Term
1490
+ # axs = pre_res.plot(terms=terms[0])
1491
+ # (2) # to make more control on the plot, use
1492
+ terms = df_prerank.Term
1493
+ axs = pre_res.plot(
1494
+ terms=terms[:7],
1495
+ # legend_kws={"loc": (1.2, 0)}, # set the legend loc
1496
+ # show_ranking=True, # whether to show the second yaxis
1497
+ figsize=(3, 4),
1498
+ )
1499
+ ips.figsave(dir_save + f"prerank_gseaplot_{gene_sets}.pdf")
1500
+ #!dotplot
1501
+ from gseapy import dotplot
1502
+
1503
+ # to save your figure, make sure that ``ofname`` is not None
1504
+ ax = dotplot(
1505
+ df_prerank,
1506
+ column="NOM p-val", # FDR q-val",
1507
+ cmap=cmap,
1508
+ size=size,
1509
+ figsize=(10, 5),
1510
+ cutoff=cutoff,
1511
+ show_ring=show_ring,
1512
+ )
1513
+ ips.figsave(dir_save + f"prerank_dotplot_{gene_sets}.pdf")
1514
+
1515
+ #! network plot
1516
+ from gseapy import enrichment_map
1517
+ import networkx as nx
1518
+
1519
+ for top_term in range(5, 50):
1520
+ try:
1521
+ # return two dataframe
1522
+ nodes, edges = enrichment_map(
1523
+ df=df_prerank,
1524
+ columns="FDR q-val",
1525
+ cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1526
+ top_term=top_term,
1527
+ )
1528
+ # build graph
1529
+ G = nx.from_pandas_edgelist(
1530
+ edges,
1531
+ source="src_idx",
1532
+ target="targ_idx",
1533
+ edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
1534
+ )
1535
+ # to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
1536
+ if len(list(nodes.Hits_ratio)) == len(G.nodes):
1537
+ node_sizes = list(nodes.Hits_ratio * 1000)
1538
+ else:
1539
+ raise ValueError(
1540
+ "The size of node_size list does not match the number of nodes in the graph."
1541
+ )
1542
+
1543
+ layout = "circular"
1544
+ fig, ax = plt.subplots(figsize=(8, 8))
1545
+ if layout == "spring":
1546
+ pos = nx.layout.spring_layout(G)
1547
+ elif layout == "circular":
1548
+ pos = nx.layout.circular_layout(G)
1549
+ elif layout == "shell":
1550
+ pos = nx.layout.shell_layout(G)
1551
+ elif layout == "spectral":
1552
+ pos = nx.layout.spectral_layout(G)
1553
+
1554
+ # node_size = nx.get_node_attributes()
1555
+ # draw node
1556
+ nx.draw_networkx_nodes(
1557
+ G,
1558
+ pos=pos,
1559
+ cmap=plt.cm.RdYlBu,
1560
+ node_color=list(nodes.NES),
1561
+ node_size=list(nodes.Hits_ratio * 1000),
1562
+ )
1563
+ # draw node label
1564
+ nx.draw_networkx_labels(
1565
+ G,
1566
+ pos=pos,
1567
+ labels=nodes.Term.to_dict(),
1568
+ font_size=8,
1569
+ verticalalignment="bottom",
1570
+ )
1571
+ # draw edge
1572
+ edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
1573
+ nx.draw_networkx_edges(
1574
+ G,
1575
+ pos=pos,
1576
+ width=list(map(lambda x: x * 10, edge_weight)),
1577
+ edge_color="#CDDBD4",
1578
+ )
1579
+ ax.set_axis_off()
1580
+ print(f"{gene_sets}(top_term={top_term})")
1581
+ plot.figsets(title=f"{gene_sets}(top_term={top_term})")
1582
+ ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
1583
+ break
1584
+ except:
1585
+ print(f"not work {top_term}")
1586
+ return df_prerank
1587
+ def plot_prerank(
1588
+ results_df,
1589
+ kind="bar", # 'barplot', 'dotplot'
1590
+ cutoff=0.25,
1591
+ show_ring=False,
1592
+ xticklabels_rot=0,
1593
+ title=None, # 'KEGG'
1594
+ cmap="coolwarm",
1595
+ n_top=10,
1596
+ size=5, # when size is None in network, by "NES"
1597
+ facecolor=None,# default by "NES"
1598
+ linewidth=None,# default by "NES"
1599
+ linecolor=None,# default by "NES"
1600
+ linealpha=None, # default by "NES"
1601
+ alpha=None,# default by "NES"
1602
+ ax=None,
1603
+ **kwargs,
1604
+ ):
1605
+ kws_figsets = {}
1606
+ for k_arg, v_arg in kwargs.items():
1607
+ if "figset" in k_arg:
1608
+ kws_figsets = v_arg
1609
+ kwargs.pop(k_arg, None)
1610
+ break
1611
+ if isinstance(cmap, str):
1612
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1613
+ elif isinstance(cmap, list):
1614
+ palette = cmap
1615
+ if n_top < 5:
1616
+ height_ = 4
1617
+ elif 5 <= n_top < 10:
1618
+ height_ = 5
1619
+ elif 10 <= n_top < 15:
1620
+ height_ = 6
1621
+ elif 15 <= n_top < 20:
1622
+ height_ = 7
1623
+ elif 20 <= n_top < 30:
1624
+ height_ = 8
1625
+ elif 30 <= n_top < 40:
1626
+ height_ = int(n_top / 5)
1627
+ else:
1628
+ height_ = int(n_top / 6)
1629
+ results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
1630
+ results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
1631
+ #! barplot
1632
+ if "bar" in kind.lower():
1633
+ df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
1634
+ if ax is None:
1635
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1636
+ ax = plot.plotxy(
1637
+ data=df_.head(n_top),
1638
+ kind="barplot",
1639
+ x="-log10(Adjusted P-value)",
1640
+ y="Term",
1641
+ hue="Term",
1642
+ palette=palette,
1643
+ legend=None,
1644
+ )
1645
+ plot.figsets(ax=ax, **kws_figsets)
1646
+ return ax, df_
1647
+
1648
+ #! dotplot
1649
+ elif "dot" in kind.lower():
1650
+ #! dotplot
1651
+ cutoff_curr = cutoff
1652
+ step = 0.05
1653
+ cutoff_stop = 0.5
1654
+ while cutoff_curr <= cutoff_stop:
1655
+ try:
1656
+ if cutoff_curr != cutoff:
1657
+ plt.clf()
1658
+ ax = gp.dotplot(
1659
+ results_df,
1660
+ column="NOM p-val",
1661
+ show_ring=show_ring,
1662
+ xticklabels_rot=xticklabels_rot,
1663
+ title=title,
1664
+ cmap=cmap,
1665
+ cutoff=cutoff_curr,
1666
+ top_term=n_top,
1667
+ size=size,
1668
+ figsize=[10, height_],
1669
+ )
1670
+ if len(ax.collections) >= n_top:
1671
+ print(f"cutoff={cutoff_curr} done! ")
1672
+ break
1673
+ if cutoff_curr == cutoff_stop:
1674
+ break
1675
+ cutoff_curr += step
1676
+ except Exception as e:
1677
+ cutoff_curr += step
1678
+ print(
1679
+ f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
1680
+ )
1681
+ plot.figsets(ax=ax, **kws_figsets)
1682
+ return ax, results_df
1683
+
1684
+ #! barplot with counts
1685
+ elif "co" in kind.lower():
1686
+ if ax is None:
1687
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1688
+ # 从overlap中提取出个数
1689
+ df_ = results_df.sort_values(by="Count", ascending=False)
1690
+ ax = plot.plotxy(
1691
+ data=df_.head(n_top),
1692
+ kind="barplot",
1693
+ x="Count",
1694
+ y="Term",
1695
+ hue="Term",
1696
+ palette=palette,
1697
+ legend=None,
1698
+ ax=ax,
1699
+ **kwargs,
1700
+ )
1701
+
1702
+ plot.figsets(ax=ax, **kws_figsets)
1703
+ return ax, df_
1704
+ #! scatter with counts
1705
+ elif "sca" in kind.lower():
1706
+ if isinstance(cmap, str):
1707
+ palette = plot.get_color(n_top, cmap=cmap)
1708
+ elif isinstance(cmap, list):
1709
+ palette = cmap
1710
+ if ax is None:
1711
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1712
+ # 从overlap中提取出个数
1713
+ df_ = results_df.sort_values(by="Count", ascending=False)
1714
+ ax = plot.plotxy(
1715
+ data=df_.head(n_top),
1716
+ kind="scatter",
1717
+ x="Count",
1718
+ y="Term",
1719
+ hue="Count",
1720
+ size="Count",
1721
+ sizes=[10,50],
1722
+ palette=palette,
1723
+ legend=None,
1724
+ ax=ax,
1725
+ **kwargs,
1726
+ )
1727
+
1728
+ plot.figsets(ax=ax, **kws_figsets)
1729
+ return ax, df_
1730
+ elif "net" in kind.lower():
1731
+ #! network plot
1732
+ from gseapy import enrichment_map
1733
+ import networkx as nx
1734
+ from matplotlib import cm
1735
+ # try:
1736
+ if cutoff>=1 or cutoff is None:
1737
+ print(f"cutoff is {cutoff} => Without applying filter")
1738
+ nodes, edges = enrichment_map(
1739
+ df=results_df,
1740
+ columns="NOM p-val",
1741
+ cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1742
+ top_term=n_top,
1743
+ )
1744
+ else:
1745
+ cutoff_curr = cutoff
1746
+ step = 0.05
1747
+ cutoff_stop = 1.0
1748
+ while cutoff_curr <= cutoff_stop:
1749
+ try:
1750
+ # return two dataframe
1751
+ nodes, edges = enrichment_map(
1752
+ df=results_df,
1753
+ columns="NOM p-val",
1754
+ cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1755
+ top_term=n_top,
1756
+ )
1757
+
1758
+ if nodes.shape[0] >= n_top:
1759
+ print(f"cutoff={cutoff_curr} done! ")
1760
+ break
1761
+ if cutoff_curr == cutoff_stop:
1762
+ break
1763
+ cutoff_curr += step
1764
+ except Exception as e:
1765
+ cutoff_curr += step
1766
+ print(
1767
+ f"{e}: trying cutoff={cutoff_curr}"
1768
+ )
1769
+
1770
+ print("size: by 'NES'") if size is None else print("")
1771
+ print("linewidth: by 'NES'") if linewidth is None else print("")
1772
+ print("linecolor: by 'NES'") if linecolor is None else print("")
1773
+ print("linealpha: by 'NES'") if linealpha is None else print("")
1774
+ print("facecolor: by 'NES'") if facecolor is None else print("")
1775
+ print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
1776
+ edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
1777
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1778
+ G,ax=plot_ppi(
1779
+ interactions=edges,
1780
+ player1="src_name",
1781
+ player2="targ_name",
1782
+ weight="jaccard_coef",
1783
+ size=[
1784
+ node["NES"] * 300 for _, node in nodes.iterrows()
1785
+ ] if size is None else size, # size nodes by NES
1786
+ facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
1787
+ linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
1788
+ linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
1789
+ linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
1790
+ alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
1791
+ **kwargs
1792
+ )
1793
+ # except Exception as e:
1794
+ # print(f"not work {n_top},{e}")
1795
+ return ax, G, nodes, edges
1796
+
1797
+
1402
1798
  #! https://string-db.org/help/api/
1403
1799
 
1404
1800
  import pandas as pd
@@ -1527,16 +1923,22 @@ def plot_ppi(
1527
1923
  n_rank=[5, 10], # Nodes in each rank for the concentric layout
1528
1924
  dist_node = 10, # Distance between each rank of circles
1529
1925
  layout="degree",
1530
- size='auto',#700,
1926
+ size=None,#700,
1927
+ sizes=(50,500),# min and max of size
1531
1928
  facecolor="skyblue",
1532
1929
  cmap='coolwarm',
1533
1930
  edgecolor="k",
1534
1931
  edgelinewidth=1.5,
1535
1932
  alpha=.5,
1933
+ alphas=(0.1, 1.0),# min and max of alpha
1536
1934
  marker="o",
1537
1935
  node_hideticks=True,
1538
1936
  linecolor="gray",
1937
+ line_cmap='coolwarm',
1539
1938
  linewidth=1.5,
1939
+ linewidths=(0.5,5),# min and max of linewidth
1940
+ linealpha=1.0,
1941
+ linealphas=(0.1,1.0),# min and max of linealpha
1540
1942
  linestyle="-",
1541
1943
  line_arrowstyle='-',
1542
1944
  fontsize=10,
@@ -1565,7 +1967,7 @@ def plot_ppi(
1565
1967
  for col in [player1, player2, weight]:
1566
1968
  if col not in interactions.columns:
1567
1969
  raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
1568
-
1970
+ interactions.sort_values(by=[weight], inplace=True)
1569
1971
  # Initialize Pyvis network
1570
1972
  net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
1571
1973
  net.force_atlas_2based(
@@ -1584,34 +1986,71 @@ def plot_ppi(
1584
1986
  G = nx.Graph()
1585
1987
  for _, row in interactions.iterrows():
1586
1988
  G.add_edge(row[player1], row[player2], weight=row[weight])
1989
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
1990
+
1587
1991
 
1588
1992
  # Calculate node degrees
1589
1993
  degrees = dict(G.degree())
1590
1994
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
1591
1995
  colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1592
1996
 
1997
+ if not ips.isa(facecolor, 'color'):
1998
+ print("facecolor: based on degrees")
1999
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
2000
+ num_nodes = G.number_of_nodes()
2001
+ #* size
1593
2002
  # Set properties based on degrees
1594
2003
  if not isinstance(size, (int,float,list)):
2004
+ print("size: based on degrees")
1595
2005
  size = [deg * 50 for deg in degrees.values()] # Scale sizes
1596
- if not ips.isa(facecolor, 'color'):
1597
- facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
1598
- if size is None:
1599
- size = [700] * G.number_of_nodes() # Default size for all nodes
1600
- elif isinstance(size, (int, float)):
1601
- size = [size] * G.number_of_nodes() # If a scalar, apply to all nodes
1602
- # else:
1603
- # size = size.tolist() # Ensure size is a list
1604
- if len(size)>G.number_of_nodes():
1605
- size=size[:G.number_of_nodes()]
1606
-
1607
- for node in G.nodes():
2006
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
2007
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
2008
+ # Normalize sizes
2009
+ min_size, max_size = sizes # Use sizes tuple for min and max values
2010
+ min_degree, max_degree = min(size), max(size)
2011
+ if max_degree > min_degree: # Avoid division by zero
2012
+ size = [
2013
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
2014
+ for sz in size
2015
+ ]
2016
+ else:
2017
+ # If all values are the same, set them to a default of the midpoint
2018
+ size = [(min_size + max_size) / 2] * len(size)
2019
+
2020
+ #* facecolor
2021
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
2022
+ # * facealpha
2023
+ if isinstance(alpha, list):
2024
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
2025
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
2026
+ if len(alpha) > 0:
2027
+ # Normalize alpha based on the specified min and max
2028
+ min_alpha, max_alpha = min(alpha), max(alpha)
2029
+ if max_alpha > min_alpha: # Avoid division by zero
2030
+ alpha = [
2031
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
2032
+ for ea in alpha
2033
+ ]
2034
+ else:
2035
+ # If all alpha values are the same, set them to the average of min and max
2036
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
2037
+ else:
2038
+ # Default to a full opacity if no edges are provided
2039
+ alpha = [1.0] * num_nodes
2040
+ else:
2041
+ # If alpha is a single value, convert it to a list and normalize it
2042
+ alpha = [alpha] * num_nodes # Adjust based on alphas
2043
+
2044
+ for i, node in enumerate(G.nodes()):
1608
2045
  net.add_node(
1609
2046
  node,
1610
2047
  label=node,
1611
- size=size[list(G.nodes()).index(node)] if isinstance(size,list) else size[0],
1612
- color=facecolor[list(G.nodes()).index(node)] if isinstance(facecolor,list) else facecolor,
2048
+ size=size[i],
2049
+ color=facecolor[i],
2050
+ alpha=alpha[i],
1613
2051
  font={"size": fontsize, "color": fontcolor},
1614
2052
  )
2053
+ print(f'nodes number: {i+1}')
1615
2054
 
1616
2055
  for edge in G.edges(data=True):
1617
2056
  net.add_edge(
@@ -1621,6 +2060,7 @@ def plot_ppi(
1621
2060
  color=edgecolor,
1622
2061
  width=edgelinewidth * edge[2]["weight"],
1623
2062
  )
2063
+
1624
2064
  layouts = [
1625
2065
  "spring",
1626
2066
  "circular",
@@ -1632,7 +2072,8 @@ def plot_ppi(
1632
2072
  "degree"
1633
2073
  ]
1634
2074
  layout = ips.strcmp(layout, layouts)[0]
1635
- print(layout)
2075
+ print(f"layout:{layout}, or select one in {layouts}")
2076
+
1636
2077
  # Choose layout
1637
2078
  if layout == "spring":
1638
2079
  pos = nx.spring_layout(G, k=k_value)
@@ -1658,7 +2099,9 @@ def plot_ppi(
1658
2099
  # Calculate node degrees and sort nodes by degree
1659
2100
  degrees = dict(G.degree())
1660
2101
  sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
1661
-
2102
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
2103
+ colormap = cm.get_cmap(cmap)
2104
+
1662
2105
  # Create positions for concentric circles based on n_layers and n_rank
1663
2106
  pos = {}
1664
2107
  n_layers=len(n_rank)+1 if n_layers is None else n_layers
@@ -1689,8 +2132,8 @@ def plot_ppi(
1689
2132
 
1690
2133
  # If ax is None, use plt.gca()
1691
2134
  if ax is None:
1692
- fig, ax = plt.subplots(1,1,figsize=figsize)
1693
-
2135
+ fig, ax = plt.subplots(1,1,figsize=figsize)
2136
+
1694
2137
  # Draw nodes, edges, and labels with customization options
1695
2138
  nx.draw_networkx_nodes(
1696
2139
  G,
@@ -1704,6 +2147,54 @@ def plot_ppi(
1704
2147
  hide_ticks=node_hideticks,
1705
2148
  node_shape=marker
1706
2149
  )
2150
+
2151
+ #* linewidth
2152
+ if not isinstance(linewidth, list):
2153
+ linewidth = [linewidth] * G.number_of_edges()
2154
+ else:
2155
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
2156
+ # Normalize linewidth if it is a list
2157
+ if isinstance(linewidth, list):
2158
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
2159
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
2160
+ if max_linewidth > min_linewidth: # Avoid division by zero
2161
+ # Scale between vmin and vmax
2162
+ linewidth = [
2163
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
2164
+ for lw in linewidth
2165
+ ]
2166
+ else:
2167
+ # If all values are the same, set them to a default of the midpoint
2168
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
2169
+ else:
2170
+ # If linewidth is a single value, convert it to a list of that value
2171
+ linewidth = [linewidth] * G.number_of_edges()
2172
+ #* linecolor
2173
+ if not isinstance(linecolor, str):
2174
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
2175
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
2176
+ colormap = cm.get_cmap(line_cmap)
2177
+ linecolor = [colormap(norm(weight)) for weight in weights]
2178
+ else:
2179
+ linecolor = [linecolor] * G.number_of_edges()
2180
+
2181
+ # * linealpha
2182
+ if isinstance(linealpha, list):
2183
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
2184
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
2185
+ if len(linealpha) > 0:
2186
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
2187
+ if max_linealpha > min_linealpha: # Avoid division by zero
2188
+ linealpha = [
2189
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
2190
+ for ea in linealpha
2191
+ ]
2192
+ else:
2193
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
2194
+ else:
2195
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
2196
+ else:
2197
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
1707
2198
  nx.draw_networkx_edges(
1708
2199
  G,
1709
2200
  pos,
@@ -1712,8 +2203,9 @@ def plot_ppi(
1712
2203
  width=linewidth,
1713
2204
  style=linestyle,
1714
2205
  arrowstyle=line_arrowstyle,
1715
- alpha=0.7
2206
+ alpha=linealpha
1716
2207
  )
2208
+
1717
2209
  nx.draw_networkx_labels(
1718
2210
  G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
1719
2211
  )
@@ -1725,6 +2217,7 @@ def plot_ppi(
1725
2217
  net.write_html(dir_save)
1726
2218
  nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
1727
2219
  print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
2220
+ ips.figsave(dir_save.replace(".html",".pdf"))
1728
2221
  return G,ax
1729
2222
 
1730
2223