py2ls 0.2.4.5__py3-none-any.whl → 0.2.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/bio.py CHANGED
@@ -166,9 +166,23 @@ def get_probe(
166
166
  if platform_id is None:
167
167
  df_meta = get_meta(geo=geo, dataset=dataset, verbose=False)
168
168
  platform_id = df_meta["platform_id"].unique().tolist()
169
- platform_id = platform_id[0] if len(platform_id) == 1 else platform_id
170
169
  print(f"Platform: {platform_id}")
171
- df_probe = geo[dataset].gpls[platform_id].table
170
+ if len(platform_id) > 1:
171
+ df_probe= geo[dataset].gpls[platform_id[0]].table
172
+ # df_probe=pd.DataFrame()
173
+ # # Iterate over each platform ID and collect the probe tables
174
+ # for platform_id_ in platform_id:
175
+ # if platform_id_ in geo[dataset].gpls:
176
+ # df_probe_ = geo[dataset].gpls[platform_id_].table
177
+ # if not df_probe_.empty:
178
+ # df_probe=pd.concat([df_probe, df_probe_])
179
+ # else:
180
+ # print(f"Warning: Probe table for platform {platform_id_} is empty.")
181
+ # else:
182
+ # print(f"Warning: Platform ID {platform_id_} not found in dataset {dataset}.")
183
+ else:
184
+ df_probe= geo[dataset].gpls[platform_id[0]].table
185
+
172
186
  if df_probe.empty:
173
187
  print(
174
188
  f"Warning: cannot find the probe info. 看一下是不是在单独的文件中包含了probe信息"
@@ -215,9 +229,12 @@ def get_data(geo: dict, dataset: str = "GSE25097", verbose=False):
215
229
  df_expression = get_expression_data(geo, dataset=dataset)
216
230
  if not df_expression.select_dtypes(include=["number"]).empty:
217
231
  # 如果数据全部是counts类型的话, 则使用TMM进行normalize
218
- if 'counts' in get_data_type(df_expression):
219
- print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
220
- df_expression=counts2expression(df_expression.T).T
232
+ if 'counts' in get_data_type(df_expression):
233
+ try:
234
+ df_expression=counts2expression(df_expression.T).T
235
+ print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
236
+ except Exception as e:
237
+ print("raw counts data")
221
238
  if any([df_probe.empty, df_expression.empty]):
222
239
  print(
223
240
  f"got empty values, check the probe info. 看一下是不是在单独的文件中包含了probe信息"
@@ -814,7 +831,11 @@ def counts_deseq(counts_sam_gene: pd.DataFrame,
814
831
  # .reset_index()
815
832
  # .rename(columns={"index": "gene"})
816
833
  # )
817
- return dds, diff,stat_res
834
+ df_norm=pd.DataFrame(dds.layers['normed_counts'])
835
+ df_norm.index=counts_sam_gene.index
836
+ df_norm.columns=counts_sam_gene.columns
837
+ print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
838
+ return dds, diff, stat_res,df_norm
818
839
 
819
840
  def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
820
841
  """
@@ -1005,25 +1026,25 @@ def plot_enrichr(results_df,
1005
1026
  palette = plot.get_color(n_top, cmap=cmap)[::-1]
1006
1027
  elif isinstance(cmap,list):
1007
1028
  palette=cmap
1008
-
1009
- if n_top<5:
1010
- height_=4
1011
- elif 5<=n_top<10:
1012
- height_=5
1013
- elif 5<=n_top<10:
1014
- height_=6
1015
- elif 10<=n_top<15:
1016
- height_=7
1017
- elif 15<=n_top<20:
1018
- height_=8
1019
- elif 20<=n_top<30:
1020
- height_=9
1029
+ if n_top < 5:
1030
+ height_ = 3
1031
+ elif 5 <= n_top < 10:
1032
+ height_ = 3
1033
+ elif 10 <= n_top < 15:
1034
+ height_ = 3
1035
+ elif 15 <= n_top < 20:
1036
+ height_ =4
1037
+ elif 20 <= n_top < 30:
1038
+ height_ = 5
1039
+ elif 30 <= n_top < 40:
1040
+ height_ = int(n_top / 6)
1021
1041
  else:
1022
- height_=int(n_top/3)
1023
- if ax is None:
1024
- _,ax=plt.subplots(1,1,figsize=[10, height_])
1042
+ height_ = int(n_top / 8)
1043
+
1025
1044
  #! barplot
1026
1045
  if 'bar' in kind.lower():
1046
+ if ax is None:
1047
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1027
1048
  ax=plot.plotxy(
1028
1049
  data=results_df.head(n_top),
1029
1050
  kind="barplot",
@@ -1035,6 +1056,7 @@ def plot_enrichr(results_df,
1035
1056
  )
1036
1057
  plot.figsets(ax=ax, **kws_figsets)
1037
1058
  return ax,results_df
1059
+
1038
1060
  #! dotplot
1039
1061
  elif 'dot' in kind.lower():
1040
1062
  #! dotplot
@@ -1066,16 +1088,20 @@ def plot_enrichr(results_df,
1066
1088
  print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1067
1089
  plot.figsets(ax=ax, **kws_figsets)
1068
1090
  return ax,results_df
1091
+
1069
1092
  #! barplot with counts
1070
1093
  elif 'count' in kind.lower():
1094
+ if ax is None:
1095
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1071
1096
  # 从overlap中提取出个数
1072
- results_df["count"] = results_df["Overlap"].apply(
1097
+ results_df["Count"] = results_df["Overlap"].apply(
1073
1098
  lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
1074
- df_=results_df.sort_values(by="count", ascending=False)
1099
+ df_=results_df.sort_values(by="Count", ascending=False)
1100
+
1075
1101
  ax=plot.plotxy(
1076
1102
  data=df_.head(n_top),
1077
1103
  kind="barplot",
1078
- x="count",
1104
+ x="Count",
1079
1105
  y="Term",
1080
1106
  hue="Term",
1081
1107
  palette=palette,
@@ -1094,11 +1120,12 @@ def plot_bp_cc_mf(
1094
1120
  "GO_Molecular_Function_2023",
1095
1121
  ],
1096
1122
  species="human",
1123
+ download=False,
1097
1124
  n_top=10,
1098
1125
  plot_=True,
1099
1126
  ax=None,
1100
1127
  palette=plot.get_color(3),
1101
- ** kwargs,
1128
+ **kwargs,
1102
1129
  ):
1103
1130
 
1104
1131
  def res_enrichr_2_count(res_enrichr, n_top=10):
@@ -1111,13 +1138,13 @@ def plot_bp_cc_mf(
1111
1138
  return res_enrichr.head(n_top)#[["Term", "Count"]]
1112
1139
 
1113
1140
  res_enrichr_BP = get_enrichr(
1114
- deg_gene_list, gene_sets[0], species=species, plot_=False
1141
+ deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
1115
1142
  )
1116
1143
  res_enrichr_CC = get_enrichr(
1117
- deg_gene_list, gene_sets[1], species=species, plot_=False
1144
+ deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
1118
1145
  )
1119
1146
  res_enrichr_MF = get_enrichr(
1120
- deg_gene_list, gene_sets[2], species=species, plot_=False
1147
+ deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
1121
1148
  )
1122
1149
 
1123
1150
  df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
@@ -1149,6 +1176,7 @@ def plot_bp_cc_mf(
1149
1176
  if ax is None:
1150
1177
  _,ax=plt.subplots(1,1,figsize=[10, height_])
1151
1178
  # 作图
1179
+ display(df2plot)
1152
1180
  if df2plot["Term"].tolist()[0].endswith(")"):
1153
1181
  df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
1154
1182
  if plot_:
@@ -1164,8 +1192,15 @@ def plot_bp_cc_mf(
1164
1192
  )
1165
1193
  return ax, df2plot
1166
1194
 
1167
- def get_library_name():
1168
- return gp.get_library_name()
1195
+ def get_library_name(by=None, verbose=False):
1196
+ lib_names=gp.get_library_name()
1197
+ if by is None:
1198
+ if verbose:
1199
+ [print(i) for i in lib_names]
1200
+ return lib_names
1201
+ else:
1202
+ return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
1203
+
1169
1204
 
1170
1205
  def get_gsva(
1171
1206
  data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
@@ -1398,7 +1433,385 @@ def plot_gsva(gsva_res, # output from bio.get_gsva()
1398
1433
  plot.figsets(ax=ax, **kws_figsets)
1399
1434
  return ax
1400
1435
 
1436
+ def get_prerank(
1437
+ rnk: pd.DataFrame,
1438
+ gene_sets: str,
1439
+ download: bool = False,
1440
+ species="Human",
1441
+ threads=8, # Number of CPU cores to use
1442
+ permutation_num=1000, # Number of permutations for significance
1443
+ min_size=1, # Minimum gene set size
1444
+ max_size=2000, # Maximum gene set size
1445
+ seed=1, # Seed for reproducibility
1446
+ verbose=True, # Verbosity
1447
+ dir_save="./",
1448
+ plot_=False,
1449
+ size=5,
1450
+ cutoff=0.25,
1451
+ show_ring=False,
1452
+ cmap="coolwarm",
1453
+ check_shared=True,
1454
+ **kwargs,
1455
+ ):
1456
+ """
1457
+ Note: Enrichr uses a list of Entrez gene symbols as input.
1458
+
1459
+ """
1460
+ kws_figsets = {}
1461
+ for k_arg, v_arg in kwargs.items():
1462
+ if "figset" in k_arg:
1463
+ kws_figsets = v_arg
1464
+ kwargs.pop(k_arg, None)
1465
+ break
1466
+ species_org = species
1467
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1468
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1469
+ species = ips.strcmp(species, organisms)[0]
1470
+ if species_org.lower() != species.lower():
1471
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1472
+ if os.path.isfile(gene_sets):
1473
+ gene_sets_name = os.path.basename(gene_sets)
1474
+ gene_sets = ips.fload(gene_sets)
1475
+ else:
1476
+ lib_support_names = gp.get_library_name()
1477
+ # correct input gene_set name
1478
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1479
+
1480
+ # download it
1481
+ if download:
1482
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1483
+ else:
1484
+ gene_sets = gene_sets_name # 避免重复下载
1485
+ print(f"\ngene_sets get ready: {gene_sets_name}")
1486
+
1487
+ #! prerank
1488
+ try:
1489
+ pre_res = gp.prerank(
1490
+ rnk=rnk,
1491
+ gene_sets=gene_sets,
1492
+ threads=threads, # Number of CPU cores to use
1493
+ permutation_num=permutation_num, # Number of permutations for significance
1494
+ min_size=min_size, # Minimum gene set size
1495
+ max_size=max_size, # Maximum gene set size
1496
+ seed=seed, # Seed for reproducibility
1497
+ verbose=verbose, # Verbosity
1498
+ )
1499
+ except ValueError as e:
1500
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
1501
+ return None
1502
+ df_prerank = pre_res.res2d
1503
+ if plot_:
1504
+ #! gseaplot
1505
+ # # (1) easy way
1506
+ # terms = df_prerank.Term
1507
+ # axs = pre_res.plot(terms=terms[0])
1508
+ # (2) # to make more control on the plot, use
1509
+ terms = df_prerank.Term
1510
+ axs = pre_res.plot(
1511
+ terms=terms[:7],
1512
+ # legend_kws={"loc": (1.2, 0)}, # set the legend loc
1513
+ # show_ranking=True, # whether to show the second yaxis
1514
+ figsize=(3, 4),
1515
+ )
1516
+ ips.figsave(dir_save + f"prerank_gseaplot_{gene_sets}.pdf")
1517
+ #!dotplot
1518
+ from gseapy import dotplot
1519
+
1520
+ # to save your figure, make sure that ``ofname`` is not None
1521
+ ax = dotplot(
1522
+ df_prerank,
1523
+ column="NOM p-val", # FDR q-val",
1524
+ cmap=cmap,
1525
+ size=size,
1526
+ figsize=(10, 5),
1527
+ cutoff=cutoff,
1528
+ show_ring=show_ring,
1529
+ )
1530
+ ips.figsave(dir_save + f"prerank_dotplot_{gene_sets}.pdf")
1531
+
1532
+ #! network plot
1533
+ from gseapy import enrichment_map
1534
+ import networkx as nx
1535
+
1536
+ for top_term in range(5, 50):
1537
+ try:
1538
+ # return two dataframe
1539
+ nodes, edges = enrichment_map(
1540
+ df=df_prerank,
1541
+ columns="FDR q-val",
1542
+ cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1543
+ top_term=top_term,
1544
+ )
1545
+ # build graph
1546
+ G = nx.from_pandas_edgelist(
1547
+ edges,
1548
+ source="src_idx",
1549
+ target="targ_idx",
1550
+ edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
1551
+ )
1552
+ # to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
1553
+ if len(list(nodes.Hits_ratio)) == len(G.nodes):
1554
+ node_sizes = list(nodes.Hits_ratio * 1000)
1555
+ else:
1556
+ raise ValueError(
1557
+ "The size of node_size list does not match the number of nodes in the graph."
1558
+ )
1559
+
1560
+ layout = "circular"
1561
+ fig, ax = plt.subplots(figsize=(8, 8))
1562
+ if layout == "spring":
1563
+ pos = nx.layout.spring_layout(G)
1564
+ elif layout == "circular":
1565
+ pos = nx.layout.circular_layout(G)
1566
+ elif layout == "shell":
1567
+ pos = nx.layout.shell_layout(G)
1568
+ elif layout == "spectral":
1569
+ pos = nx.layout.spectral_layout(G)
1570
+
1571
+ # node_size = nx.get_node_attributes()
1572
+ # draw node
1573
+ nx.draw_networkx_nodes(
1574
+ G,
1575
+ pos=pos,
1576
+ cmap=plt.cm.RdYlBu,
1577
+ node_color=list(nodes.NES),
1578
+ node_size=list(nodes.Hits_ratio * 1000),
1579
+ )
1580
+ # draw node label
1581
+ nx.draw_networkx_labels(
1582
+ G,
1583
+ pos=pos,
1584
+ labels=nodes.Term.to_dict(),
1585
+ font_size=8,
1586
+ verticalalignment="bottom",
1587
+ )
1588
+ # draw edge
1589
+ edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
1590
+ nx.draw_networkx_edges(
1591
+ G,
1592
+ pos=pos,
1593
+ width=list(map(lambda x: x * 10, edge_weight)),
1594
+ edge_color="#CDDBD4",
1595
+ )
1596
+ ax.set_axis_off()
1597
+ print(f"{gene_sets}(top_term={top_term})")
1598
+ plot.figsets(title=f"{gene_sets}(top_term={top_term})")
1599
+ ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
1600
+ break
1601
+ except:
1602
+ print(f"not work {top_term}")
1603
+ return df_prerank
1604
+ def plot_prerank(
1605
+ results_df,
1606
+ kind="bar", # 'barplot', 'dotplot'
1607
+ cutoff=0.25,
1608
+ show_ring=False,
1609
+ xticklabels_rot=0,
1610
+ title=None, # 'KEGG'
1611
+ cmap="coolwarm",
1612
+ n_top=10,
1613
+ size=5, # when size is None in network, by "NES"
1614
+ facecolor=None,# default by "NES"
1615
+ linewidth=None,# default by "NES"
1616
+ linecolor=None,# default by "NES"
1617
+ linealpha=None, # default by "NES"
1618
+ alpha=None,# default by "NES"
1619
+ ax=None,
1620
+ **kwargs,
1621
+ ):
1622
+ kws_figsets = {}
1623
+ for k_arg, v_arg in kwargs.items():
1624
+ if "figset" in k_arg:
1625
+ kws_figsets = v_arg
1626
+ kwargs.pop(k_arg, None)
1627
+ break
1628
+ if isinstance(cmap, str):
1629
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1630
+ elif isinstance(cmap, list):
1631
+ palette = cmap
1632
+ if n_top < 5:
1633
+ height_ = 4
1634
+ elif 5 <= n_top < 10:
1635
+ height_ = 5
1636
+ elif 10 <= n_top < 15:
1637
+ height_ = 6
1638
+ elif 15 <= n_top < 20:
1639
+ height_ = 7
1640
+ elif 20 <= n_top < 30:
1641
+ height_ = 8
1642
+ elif 30 <= n_top < 40:
1643
+ height_ = int(n_top / 5)
1644
+ else:
1645
+ height_ = int(n_top / 6)
1646
+ results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
1647
+ results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
1648
+ #! barplot
1649
+ if "bar" in kind.lower():
1650
+ df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
1651
+ if ax is None:
1652
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1653
+ ax = plot.plotxy(
1654
+ data=df_.head(n_top),
1655
+ kind="barplot",
1656
+ x="-log10(Adjusted P-value)",
1657
+ y="Term",
1658
+ hue="Term",
1659
+ palette=palette,
1660
+ legend=None,
1661
+ )
1662
+ plot.figsets(ax=ax, **kws_figsets)
1663
+ return ax, df_
1664
+
1665
+ #! dotplot
1666
+ elif "dot" in kind.lower():
1667
+ #! dotplot
1668
+ cutoff_curr = cutoff
1669
+ step = 0.05
1670
+ cutoff_stop = 0.5
1671
+ while cutoff_curr <= cutoff_stop:
1672
+ try:
1673
+ if cutoff_curr != cutoff:
1674
+ plt.clf()
1675
+ ax = gp.dotplot(
1676
+ results_df,
1677
+ column="NOM p-val",
1678
+ show_ring=show_ring,
1679
+ xticklabels_rot=xticklabels_rot,
1680
+ title=title,
1681
+ cmap=cmap,
1682
+ cutoff=cutoff_curr,
1683
+ top_term=n_top,
1684
+ size=size,
1685
+ figsize=[10, height_],
1686
+ )
1687
+ if len(ax.collections) >= n_top:
1688
+ print(f"cutoff={cutoff_curr} done! ")
1689
+ break
1690
+ if cutoff_curr == cutoff_stop:
1691
+ break
1692
+ cutoff_curr += step
1693
+ except Exception as e:
1694
+ cutoff_curr += step
1695
+ print(
1696
+ f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
1697
+ )
1698
+ plot.figsets(ax=ax, **kws_figsets)
1699
+ return ax, results_df
1700
+
1701
+ #! barplot with counts
1702
+ elif "co" in kind.lower():
1703
+ if ax is None:
1704
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1705
+ # 从overlap中提取出个数
1706
+ df_ = results_df.sort_values(by="Count", ascending=False)
1707
+ ax = plot.plotxy(
1708
+ data=df_.head(n_top),
1709
+ kind="barplot",
1710
+ x="Count",
1711
+ y="Term",
1712
+ hue="Term",
1713
+ palette=palette,
1714
+ legend=None,
1715
+ ax=ax,
1716
+ **kwargs,
1717
+ )
1718
+
1719
+ plot.figsets(ax=ax, **kws_figsets)
1720
+ return ax, df_
1721
+ #! scatter with counts
1722
+ elif "sca" in kind.lower():
1723
+ if isinstance(cmap, str):
1724
+ palette = plot.get_color(n_top, cmap=cmap)
1725
+ elif isinstance(cmap, list):
1726
+ palette = cmap
1727
+ if ax is None:
1728
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1729
+ # 从overlap中提取出个数
1730
+ df_ = results_df.sort_values(by="Count", ascending=False)
1731
+ ax = plot.plotxy(
1732
+ data=df_.head(n_top),
1733
+ kind="scatter",
1734
+ x="Count",
1735
+ y="Term",
1736
+ hue="Count",
1737
+ size="Count",
1738
+ sizes=[10,50],
1739
+ palette=palette,
1740
+ legend=None,
1741
+ ax=ax,
1742
+ **kwargs,
1743
+ )
1401
1744
 
1745
+ plot.figsets(ax=ax, **kws_figsets)
1746
+ return ax, df_
1747
+ elif "net" in kind.lower():
1748
+ #! network plot
1749
+ from gseapy import enrichment_map
1750
+ import networkx as nx
1751
+ from matplotlib import cm
1752
+ # try:
1753
+ if cutoff>=1 or cutoff is None:
1754
+ print(f"cutoff is {cutoff} => Without applying filter")
1755
+ nodes, edges = enrichment_map(
1756
+ df=results_df,
1757
+ columns="NOM p-val",
1758
+ cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1759
+ top_term=n_top,
1760
+ )
1761
+ else:
1762
+ cutoff_curr = cutoff
1763
+ step = 0.05
1764
+ cutoff_stop = 1.0
1765
+ while cutoff_curr <= cutoff_stop:
1766
+ try:
1767
+ # return two dataframe
1768
+ nodes, edges = enrichment_map(
1769
+ df=results_df,
1770
+ columns="NOM p-val",
1771
+ cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1772
+ top_term=n_top,
1773
+ )
1774
+
1775
+ if nodes.shape[0] >= n_top:
1776
+ print(f"cutoff={cutoff_curr} done! ")
1777
+ break
1778
+ if cutoff_curr == cutoff_stop:
1779
+ break
1780
+ cutoff_curr += step
1781
+ except Exception as e:
1782
+ cutoff_curr += step
1783
+ print(
1784
+ f"{e}: trying cutoff={cutoff_curr}"
1785
+ )
1786
+
1787
+ print("size: by 'NES'") if size is None else print("")
1788
+ print("linewidth: by 'NES'") if linewidth is None else print("")
1789
+ print("linecolor: by 'NES'") if linecolor is None else print("")
1790
+ print("linealpha: by 'NES'") if linealpha is None else print("")
1791
+ print("facecolor: by 'NES'") if facecolor is None else print("")
1792
+ print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
1793
+ edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
1794
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1795
+ G,ax=plot_ppi(
1796
+ interactions=edges,
1797
+ player1="src_name",
1798
+ player2="targ_name",
1799
+ weight="jaccard_coef",
1800
+ size=[
1801
+ node["NES"] * 300 for _, node in nodes.iterrows()
1802
+ ] if size is None else size, # size nodes by NES
1803
+ facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
1804
+ linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
1805
+ linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
1806
+ linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
1807
+ alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
1808
+ **kwargs
1809
+ )
1810
+ # except Exception as e:
1811
+ # print(f"not work {n_top},{e}")
1812
+ return ax, G, nodes, edges
1813
+
1814
+
1402
1815
  #! https://string-db.org/help/api/
1403
1816
 
1404
1817
  import pandas as pd
@@ -1527,16 +1940,22 @@ def plot_ppi(
1527
1940
  n_rank=[5, 10], # Nodes in each rank for the concentric layout
1528
1941
  dist_node = 10, # Distance between each rank of circles
1529
1942
  layout="degree",
1530
- size='auto',#700,
1943
+ size=None,#700,
1944
+ sizes=(50,500),# min and max of size
1531
1945
  facecolor="skyblue",
1532
1946
  cmap='coolwarm',
1533
1947
  edgecolor="k",
1534
1948
  edgelinewidth=1.5,
1535
1949
  alpha=.5,
1950
+ alphas=(0.1, 1.0),# min and max of alpha
1536
1951
  marker="o",
1537
1952
  node_hideticks=True,
1538
1953
  linecolor="gray",
1954
+ line_cmap='coolwarm',
1539
1955
  linewidth=1.5,
1956
+ linewidths=(0.5,5),# min and max of linewidth
1957
+ linealpha=1.0,
1958
+ linealphas=(0.1,1.0),# min and max of linealpha
1540
1959
  linestyle="-",
1541
1960
  line_arrowstyle='-',
1542
1961
  fontsize=10,
@@ -1565,7 +1984,7 @@ def plot_ppi(
1565
1984
  for col in [player1, player2, weight]:
1566
1985
  if col not in interactions.columns:
1567
1986
  raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
1568
-
1987
+ interactions.sort_values(by=[weight], inplace=True)
1569
1988
  # Initialize Pyvis network
1570
1989
  net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
1571
1990
  net.force_atlas_2based(
@@ -1584,34 +2003,71 @@ def plot_ppi(
1584
2003
  G = nx.Graph()
1585
2004
  for _, row in interactions.iterrows():
1586
2005
  G.add_edge(row[player1], row[player2], weight=row[weight])
2006
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
2007
+
1587
2008
 
1588
2009
  # Calculate node degrees
1589
2010
  degrees = dict(G.degree())
1590
2011
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
1591
2012
  colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1592
2013
 
2014
+ if not ips.isa(facecolor, 'color'):
2015
+ print("facecolor: based on degrees")
2016
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
2017
+ num_nodes = G.number_of_nodes()
2018
+ #* size
1593
2019
  # Set properties based on degrees
1594
2020
  if not isinstance(size, (int,float,list)):
2021
+ print("size: based on degrees")
1595
2022
  size = [deg * 50 for deg in degrees.values()] # Scale sizes
1596
- if not ips.isa(facecolor, 'color'):
1597
- facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
1598
- if size is None:
1599
- size = [700] * G.number_of_nodes() # Default size for all nodes
1600
- elif isinstance(size, (int, float)):
1601
- size = [size] * G.number_of_nodes() # If a scalar, apply to all nodes
1602
- # else:
1603
- # size = size.tolist() # Ensure size is a list
1604
- if len(size)>G.number_of_nodes():
1605
- size=size[:G.number_of_nodes()]
1606
-
1607
- for node in G.nodes():
2023
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
2024
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
2025
+ # Normalize sizes
2026
+ min_size, max_size = sizes # Use sizes tuple for min and max values
2027
+ min_degree, max_degree = min(size), max(size)
2028
+ if max_degree > min_degree: # Avoid division by zero
2029
+ size = [
2030
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
2031
+ for sz in size
2032
+ ]
2033
+ else:
2034
+ # If all values are the same, set them to a default of the midpoint
2035
+ size = [(min_size + max_size) / 2] * len(size)
2036
+
2037
+ #* facecolor
2038
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
2039
+ # * facealpha
2040
+ if isinstance(alpha, list):
2041
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
2042
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
2043
+ if len(alpha) > 0:
2044
+ # Normalize alpha based on the specified min and max
2045
+ min_alpha, max_alpha = min(alpha), max(alpha)
2046
+ if max_alpha > min_alpha: # Avoid division by zero
2047
+ alpha = [
2048
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
2049
+ for ea in alpha
2050
+ ]
2051
+ else:
2052
+ # If all alpha values are the same, set them to the average of min and max
2053
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
2054
+ else:
2055
+ # Default to a full opacity if no edges are provided
2056
+ alpha = [1.0] * num_nodes
2057
+ else:
2058
+ # If alpha is a single value, convert it to a list and normalize it
2059
+ alpha = [alpha] * num_nodes # Adjust based on alphas
2060
+
2061
+ for i, node in enumerate(G.nodes()):
1608
2062
  net.add_node(
1609
2063
  node,
1610
2064
  label=node,
1611
- size=size[list(G.nodes()).index(node)] if isinstance(size,list) else size[0],
1612
- color=facecolor[list(G.nodes()).index(node)] if isinstance(facecolor,list) else facecolor,
2065
+ size=size[i],
2066
+ color=facecolor[i],
2067
+ alpha=alpha[i],
1613
2068
  font={"size": fontsize, "color": fontcolor},
1614
2069
  )
2070
+ print(f'nodes number: {i+1}')
1615
2071
 
1616
2072
  for edge in G.edges(data=True):
1617
2073
  net.add_edge(
@@ -1621,6 +2077,7 @@ def plot_ppi(
1621
2077
  color=edgecolor,
1622
2078
  width=edgelinewidth * edge[2]["weight"],
1623
2079
  )
2080
+
1624
2081
  layouts = [
1625
2082
  "spring",
1626
2083
  "circular",
@@ -1632,7 +2089,8 @@ def plot_ppi(
1632
2089
  "degree"
1633
2090
  ]
1634
2091
  layout = ips.strcmp(layout, layouts)[0]
1635
- print(layout)
2092
+ print(f"layout:{layout}, or select one in {layouts}")
2093
+
1636
2094
  # Choose layout
1637
2095
  if layout == "spring":
1638
2096
  pos = nx.spring_layout(G, k=k_value)
@@ -1658,7 +2116,9 @@ def plot_ppi(
1658
2116
  # Calculate node degrees and sort nodes by degree
1659
2117
  degrees = dict(G.degree())
1660
2118
  sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
1661
-
2119
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
2120
+ colormap = cm.get_cmap(cmap)
2121
+
1662
2122
  # Create positions for concentric circles based on n_layers and n_rank
1663
2123
  pos = {}
1664
2124
  n_layers=len(n_rank)+1 if n_layers is None else n_layers
@@ -1689,8 +2149,8 @@ def plot_ppi(
1689
2149
 
1690
2150
  # If ax is None, use plt.gca()
1691
2151
  if ax is None:
1692
- fig, ax = plt.subplots(1,1,figsize=figsize)
1693
-
2152
+ fig, ax = plt.subplots(1,1,figsize=figsize)
2153
+
1694
2154
  # Draw nodes, edges, and labels with customization options
1695
2155
  nx.draw_networkx_nodes(
1696
2156
  G,
@@ -1704,6 +2164,54 @@ def plot_ppi(
1704
2164
  hide_ticks=node_hideticks,
1705
2165
  node_shape=marker
1706
2166
  )
2167
+
2168
+ #* linewidth
2169
+ if not isinstance(linewidth, list):
2170
+ linewidth = [linewidth] * G.number_of_edges()
2171
+ else:
2172
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
2173
+ # Normalize linewidth if it is a list
2174
+ if isinstance(linewidth, list):
2175
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
2176
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
2177
+ if max_linewidth > min_linewidth: # Avoid division by zero
2178
+ # Scale between vmin and vmax
2179
+ linewidth = [
2180
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
2181
+ for lw in linewidth
2182
+ ]
2183
+ else:
2184
+ # If all values are the same, set them to a default of the midpoint
2185
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
2186
+ else:
2187
+ # If linewidth is a single value, convert it to a list of that value
2188
+ linewidth = [linewidth] * G.number_of_edges()
2189
+ #* linecolor
2190
+ if not isinstance(linecolor, str):
2191
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
2192
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
2193
+ colormap = cm.get_cmap(line_cmap)
2194
+ linecolor = [colormap(norm(weight)) for weight in weights]
2195
+ else:
2196
+ linecolor = [linecolor] * G.number_of_edges()
2197
+
2198
+ # * linealpha
2199
+ if isinstance(linealpha, list):
2200
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
2201
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
2202
+ if len(linealpha) > 0:
2203
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
2204
+ if max_linealpha > min_linealpha: # Avoid division by zero
2205
+ linealpha = [
2206
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
2207
+ for ea in linealpha
2208
+ ]
2209
+ else:
2210
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
2211
+ else:
2212
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
2213
+ else:
2214
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
1707
2215
  nx.draw_networkx_edges(
1708
2216
  G,
1709
2217
  pos,
@@ -1712,8 +2220,9 @@ def plot_ppi(
1712
2220
  width=linewidth,
1713
2221
  style=linestyle,
1714
2222
  arrowstyle=line_arrowstyle,
1715
- alpha=0.7
2223
+ alpha=linealpha
1716
2224
  )
2225
+
1717
2226
  nx.draw_networkx_labels(
1718
2227
  G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
1719
2228
  )
@@ -1725,6 +2234,7 @@ def plot_ppi(
1725
2234
  net.write_html(dir_save)
1726
2235
  nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
1727
2236
  print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
2237
+ ips.figsave(dir_save.replace(".html",".pdf"))
1728
2238
  return G,ax
1729
2239
 
1730
2240