py2ls 0.2.4.5__py3-none-any.whl → 0.2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.git/index +0 -0
- py2ls/bio.py +562 -52
- py2ls/ips.py +161 -63
- py2ls/mol.py +289 -0
- py2ls/plot.py +274 -132
- {py2ls-0.2.4.5.dist-info → py2ls-0.2.4.7.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.5.dist-info → py2ls-0.2.4.7.dist-info}/RECORD +8 -7
- {py2ls-0.2.4.5.dist-info → py2ls-0.2.4.7.dist-info}/WHEEL +1 -1
py2ls/bio.py
CHANGED
@@ -166,9 +166,23 @@ def get_probe(
|
|
166
166
|
if platform_id is None:
|
167
167
|
df_meta = get_meta(geo=geo, dataset=dataset, verbose=False)
|
168
168
|
platform_id = df_meta["platform_id"].unique().tolist()
|
169
|
-
platform_id = platform_id[0] if len(platform_id) == 1 else platform_id
|
170
169
|
print(f"Platform: {platform_id}")
|
171
|
-
|
170
|
+
if len(platform_id) > 1:
|
171
|
+
df_probe= geo[dataset].gpls[platform_id[0]].table
|
172
|
+
# df_probe=pd.DataFrame()
|
173
|
+
# # Iterate over each platform ID and collect the probe tables
|
174
|
+
# for platform_id_ in platform_id:
|
175
|
+
# if platform_id_ in geo[dataset].gpls:
|
176
|
+
# df_probe_ = geo[dataset].gpls[platform_id_].table
|
177
|
+
# if not df_probe_.empty:
|
178
|
+
# df_probe=pd.concat([df_probe, df_probe_])
|
179
|
+
# else:
|
180
|
+
# print(f"Warning: Probe table for platform {platform_id_} is empty.")
|
181
|
+
# else:
|
182
|
+
# print(f"Warning: Platform ID {platform_id_} not found in dataset {dataset}.")
|
183
|
+
else:
|
184
|
+
df_probe= geo[dataset].gpls[platform_id[0]].table
|
185
|
+
|
172
186
|
if df_probe.empty:
|
173
187
|
print(
|
174
188
|
f"Warning: cannot find the probe info. 看一下是不是在单独的文件中包含了probe信息"
|
@@ -215,9 +229,12 @@ def get_data(geo: dict, dataset: str = "GSE25097", verbose=False):
|
|
215
229
|
df_expression = get_expression_data(geo, dataset=dataset)
|
216
230
|
if not df_expression.select_dtypes(include=["number"]).empty:
|
217
231
|
# 如果数据全部是counts类型的话, 则使用TMM进行normalize
|
218
|
-
if 'counts' in get_data_type(df_expression):
|
219
|
-
|
220
|
-
|
232
|
+
if 'counts' in get_data_type(df_expression):
|
233
|
+
try:
|
234
|
+
df_expression=counts2expression(df_expression.T).T
|
235
|
+
print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
|
236
|
+
except Exception as e:
|
237
|
+
print("raw counts data")
|
221
238
|
if any([df_probe.empty, df_expression.empty]):
|
222
239
|
print(
|
223
240
|
f"got empty values, check the probe info. 看一下是不是在单独的文件中包含了probe信息"
|
@@ -814,7 +831,11 @@ def counts_deseq(counts_sam_gene: pd.DataFrame,
|
|
814
831
|
# .reset_index()
|
815
832
|
# .rename(columns={"index": "gene"})
|
816
833
|
# )
|
817
|
-
|
834
|
+
df_norm=pd.DataFrame(dds.layers['normed_counts'])
|
835
|
+
df_norm.index=counts_sam_gene.index
|
836
|
+
df_norm.columns=counts_sam_gene.columns
|
837
|
+
print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
|
838
|
+
return dds, diff, stat_res,df_norm
|
818
839
|
|
819
840
|
def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
|
820
841
|
"""
|
@@ -1005,25 +1026,25 @@ def plot_enrichr(results_df,
|
|
1005
1026
|
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
1006
1027
|
elif isinstance(cmap,list):
|
1007
1028
|
palette=cmap
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
height_=9
|
1029
|
+
if n_top < 5:
|
1030
|
+
height_ = 3
|
1031
|
+
elif 5 <= n_top < 10:
|
1032
|
+
height_ = 3
|
1033
|
+
elif 10 <= n_top < 15:
|
1034
|
+
height_ = 3
|
1035
|
+
elif 15 <= n_top < 20:
|
1036
|
+
height_ =4
|
1037
|
+
elif 20 <= n_top < 30:
|
1038
|
+
height_ = 5
|
1039
|
+
elif 30 <= n_top < 40:
|
1040
|
+
height_ = int(n_top / 6)
|
1021
1041
|
else:
|
1022
|
-
height_=int(n_top/
|
1023
|
-
|
1024
|
-
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1042
|
+
height_ = int(n_top / 8)
|
1043
|
+
|
1025
1044
|
#! barplot
|
1026
1045
|
if 'bar' in kind.lower():
|
1046
|
+
if ax is None:
|
1047
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1027
1048
|
ax=plot.plotxy(
|
1028
1049
|
data=results_df.head(n_top),
|
1029
1050
|
kind="barplot",
|
@@ -1035,6 +1056,7 @@ def plot_enrichr(results_df,
|
|
1035
1056
|
)
|
1036
1057
|
plot.figsets(ax=ax, **kws_figsets)
|
1037
1058
|
return ax,results_df
|
1059
|
+
|
1038
1060
|
#! dotplot
|
1039
1061
|
elif 'dot' in kind.lower():
|
1040
1062
|
#! dotplot
|
@@ -1066,16 +1088,20 @@ def plot_enrichr(results_df,
|
|
1066
1088
|
print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
|
1067
1089
|
plot.figsets(ax=ax, **kws_figsets)
|
1068
1090
|
return ax,results_df
|
1091
|
+
|
1069
1092
|
#! barplot with counts
|
1070
1093
|
elif 'count' in kind.lower():
|
1094
|
+
if ax is None:
|
1095
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1071
1096
|
# 从overlap中提取出个数
|
1072
|
-
results_df["
|
1097
|
+
results_df["Count"] = results_df["Overlap"].apply(
|
1073
1098
|
lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
|
1074
|
-
df_=results_df.sort_values(by="
|
1099
|
+
df_=results_df.sort_values(by="Count", ascending=False)
|
1100
|
+
|
1075
1101
|
ax=plot.plotxy(
|
1076
1102
|
data=df_.head(n_top),
|
1077
1103
|
kind="barplot",
|
1078
|
-
x="
|
1104
|
+
x="Count",
|
1079
1105
|
y="Term",
|
1080
1106
|
hue="Term",
|
1081
1107
|
palette=palette,
|
@@ -1094,11 +1120,12 @@ def plot_bp_cc_mf(
|
|
1094
1120
|
"GO_Molecular_Function_2023",
|
1095
1121
|
],
|
1096
1122
|
species="human",
|
1123
|
+
download=False,
|
1097
1124
|
n_top=10,
|
1098
1125
|
plot_=True,
|
1099
1126
|
ax=None,
|
1100
1127
|
palette=plot.get_color(3),
|
1101
|
-
**
|
1128
|
+
**kwargs,
|
1102
1129
|
):
|
1103
1130
|
|
1104
1131
|
def res_enrichr_2_count(res_enrichr, n_top=10):
|
@@ -1111,13 +1138,13 @@ def plot_bp_cc_mf(
|
|
1111
1138
|
return res_enrichr.head(n_top)#[["Term", "Count"]]
|
1112
1139
|
|
1113
1140
|
res_enrichr_BP = get_enrichr(
|
1114
|
-
deg_gene_list, gene_sets[0], species=species, plot_=False
|
1141
|
+
deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
|
1115
1142
|
)
|
1116
1143
|
res_enrichr_CC = get_enrichr(
|
1117
|
-
deg_gene_list, gene_sets[1], species=species, plot_=False
|
1144
|
+
deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
|
1118
1145
|
)
|
1119
1146
|
res_enrichr_MF = get_enrichr(
|
1120
|
-
deg_gene_list, gene_sets[2], species=species, plot_=False
|
1147
|
+
deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
|
1121
1148
|
)
|
1122
1149
|
|
1123
1150
|
df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
|
@@ -1149,6 +1176,7 @@ def plot_bp_cc_mf(
|
|
1149
1176
|
if ax is None:
|
1150
1177
|
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1151
1178
|
# 作图
|
1179
|
+
display(df2plot)
|
1152
1180
|
if df2plot["Term"].tolist()[0].endswith(")"):
|
1153
1181
|
df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
|
1154
1182
|
if plot_:
|
@@ -1164,8 +1192,15 @@ def plot_bp_cc_mf(
|
|
1164
1192
|
)
|
1165
1193
|
return ax, df2plot
|
1166
1194
|
|
1167
|
-
def get_library_name():
|
1168
|
-
|
1195
|
+
def get_library_name(by=None, verbose=False):
|
1196
|
+
lib_names=gp.get_library_name()
|
1197
|
+
if by is None:
|
1198
|
+
if verbose:
|
1199
|
+
[print(i) for i in lib_names]
|
1200
|
+
return lib_names
|
1201
|
+
else:
|
1202
|
+
return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
|
1203
|
+
|
1169
1204
|
|
1170
1205
|
def get_gsva(
|
1171
1206
|
data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
|
@@ -1398,7 +1433,385 @@ def plot_gsva(gsva_res, # output from bio.get_gsva()
|
|
1398
1433
|
plot.figsets(ax=ax, **kws_figsets)
|
1399
1434
|
return ax
|
1400
1435
|
|
1436
|
+
def get_prerank(
|
1437
|
+
rnk: pd.DataFrame,
|
1438
|
+
gene_sets: str,
|
1439
|
+
download: bool = False,
|
1440
|
+
species="Human",
|
1441
|
+
threads=8, # Number of CPU cores to use
|
1442
|
+
permutation_num=1000, # Number of permutations for significance
|
1443
|
+
min_size=1, # Minimum gene set size
|
1444
|
+
max_size=2000, # Maximum gene set size
|
1445
|
+
seed=1, # Seed for reproducibility
|
1446
|
+
verbose=True, # Verbosity
|
1447
|
+
dir_save="./",
|
1448
|
+
plot_=False,
|
1449
|
+
size=5,
|
1450
|
+
cutoff=0.25,
|
1451
|
+
show_ring=False,
|
1452
|
+
cmap="coolwarm",
|
1453
|
+
check_shared=True,
|
1454
|
+
**kwargs,
|
1455
|
+
):
|
1456
|
+
"""
|
1457
|
+
Note: Enrichr uses a list of Entrez gene symbols as input.
|
1458
|
+
|
1459
|
+
"""
|
1460
|
+
kws_figsets = {}
|
1461
|
+
for k_arg, v_arg in kwargs.items():
|
1462
|
+
if "figset" in k_arg:
|
1463
|
+
kws_figsets = v_arg
|
1464
|
+
kwargs.pop(k_arg, None)
|
1465
|
+
break
|
1466
|
+
species_org = species
|
1467
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
1468
|
+
organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
|
1469
|
+
species = ips.strcmp(species, organisms)[0]
|
1470
|
+
if species_org.lower() != species.lower():
|
1471
|
+
print(f"species was corrected to {species}, becasue only support {organisms}")
|
1472
|
+
if os.path.isfile(gene_sets):
|
1473
|
+
gene_sets_name = os.path.basename(gene_sets)
|
1474
|
+
gene_sets = ips.fload(gene_sets)
|
1475
|
+
else:
|
1476
|
+
lib_support_names = gp.get_library_name()
|
1477
|
+
# correct input gene_set name
|
1478
|
+
gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
|
1479
|
+
|
1480
|
+
# download it
|
1481
|
+
if download:
|
1482
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
1483
|
+
else:
|
1484
|
+
gene_sets = gene_sets_name # 避免重复下载
|
1485
|
+
print(f"\ngene_sets get ready: {gene_sets_name}")
|
1486
|
+
|
1487
|
+
#! prerank
|
1488
|
+
try:
|
1489
|
+
pre_res = gp.prerank(
|
1490
|
+
rnk=rnk,
|
1491
|
+
gene_sets=gene_sets,
|
1492
|
+
threads=threads, # Number of CPU cores to use
|
1493
|
+
permutation_num=permutation_num, # Number of permutations for significance
|
1494
|
+
min_size=min_size, # Minimum gene set size
|
1495
|
+
max_size=max_size, # Maximum gene set size
|
1496
|
+
seed=seed, # Seed for reproducibility
|
1497
|
+
verbose=verbose, # Verbosity
|
1498
|
+
)
|
1499
|
+
except ValueError as e:
|
1500
|
+
print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
|
1501
|
+
return None
|
1502
|
+
df_prerank = pre_res.res2d
|
1503
|
+
if plot_:
|
1504
|
+
#! gseaplot
|
1505
|
+
# # (1) easy way
|
1506
|
+
# terms = df_prerank.Term
|
1507
|
+
# axs = pre_res.plot(terms=terms[0])
|
1508
|
+
# (2) # to make more control on the plot, use
|
1509
|
+
terms = df_prerank.Term
|
1510
|
+
axs = pre_res.plot(
|
1511
|
+
terms=terms[:7],
|
1512
|
+
# legend_kws={"loc": (1.2, 0)}, # set the legend loc
|
1513
|
+
# show_ranking=True, # whether to show the second yaxis
|
1514
|
+
figsize=(3, 4),
|
1515
|
+
)
|
1516
|
+
ips.figsave(dir_save + f"prerank_gseaplot_{gene_sets}.pdf")
|
1517
|
+
#!dotplot
|
1518
|
+
from gseapy import dotplot
|
1519
|
+
|
1520
|
+
# to save your figure, make sure that ``ofname`` is not None
|
1521
|
+
ax = dotplot(
|
1522
|
+
df_prerank,
|
1523
|
+
column="NOM p-val", # FDR q-val",
|
1524
|
+
cmap=cmap,
|
1525
|
+
size=size,
|
1526
|
+
figsize=(10, 5),
|
1527
|
+
cutoff=cutoff,
|
1528
|
+
show_ring=show_ring,
|
1529
|
+
)
|
1530
|
+
ips.figsave(dir_save + f"prerank_dotplot_{gene_sets}.pdf")
|
1531
|
+
|
1532
|
+
#! network plot
|
1533
|
+
from gseapy import enrichment_map
|
1534
|
+
import networkx as nx
|
1535
|
+
|
1536
|
+
for top_term in range(5, 50):
|
1537
|
+
try:
|
1538
|
+
# return two dataframe
|
1539
|
+
nodes, edges = enrichment_map(
|
1540
|
+
df=df_prerank,
|
1541
|
+
columns="FDR q-val",
|
1542
|
+
cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1543
|
+
top_term=top_term,
|
1544
|
+
)
|
1545
|
+
# build graph
|
1546
|
+
G = nx.from_pandas_edgelist(
|
1547
|
+
edges,
|
1548
|
+
source="src_idx",
|
1549
|
+
target="targ_idx",
|
1550
|
+
edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
|
1551
|
+
)
|
1552
|
+
# to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
|
1553
|
+
if len(list(nodes.Hits_ratio)) == len(G.nodes):
|
1554
|
+
node_sizes = list(nodes.Hits_ratio * 1000)
|
1555
|
+
else:
|
1556
|
+
raise ValueError(
|
1557
|
+
"The size of node_size list does not match the number of nodes in the graph."
|
1558
|
+
)
|
1559
|
+
|
1560
|
+
layout = "circular"
|
1561
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
1562
|
+
if layout == "spring":
|
1563
|
+
pos = nx.layout.spring_layout(G)
|
1564
|
+
elif layout == "circular":
|
1565
|
+
pos = nx.layout.circular_layout(G)
|
1566
|
+
elif layout == "shell":
|
1567
|
+
pos = nx.layout.shell_layout(G)
|
1568
|
+
elif layout == "spectral":
|
1569
|
+
pos = nx.layout.spectral_layout(G)
|
1570
|
+
|
1571
|
+
# node_size = nx.get_node_attributes()
|
1572
|
+
# draw node
|
1573
|
+
nx.draw_networkx_nodes(
|
1574
|
+
G,
|
1575
|
+
pos=pos,
|
1576
|
+
cmap=plt.cm.RdYlBu,
|
1577
|
+
node_color=list(nodes.NES),
|
1578
|
+
node_size=list(nodes.Hits_ratio * 1000),
|
1579
|
+
)
|
1580
|
+
# draw node label
|
1581
|
+
nx.draw_networkx_labels(
|
1582
|
+
G,
|
1583
|
+
pos=pos,
|
1584
|
+
labels=nodes.Term.to_dict(),
|
1585
|
+
font_size=8,
|
1586
|
+
verticalalignment="bottom",
|
1587
|
+
)
|
1588
|
+
# draw edge
|
1589
|
+
edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
|
1590
|
+
nx.draw_networkx_edges(
|
1591
|
+
G,
|
1592
|
+
pos=pos,
|
1593
|
+
width=list(map(lambda x: x * 10, edge_weight)),
|
1594
|
+
edge_color="#CDDBD4",
|
1595
|
+
)
|
1596
|
+
ax.set_axis_off()
|
1597
|
+
print(f"{gene_sets}(top_term={top_term})")
|
1598
|
+
plot.figsets(title=f"{gene_sets}(top_term={top_term})")
|
1599
|
+
ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
|
1600
|
+
break
|
1601
|
+
except:
|
1602
|
+
print(f"not work {top_term}")
|
1603
|
+
return df_prerank
|
1604
|
+
def plot_prerank(
|
1605
|
+
results_df,
|
1606
|
+
kind="bar", # 'barplot', 'dotplot'
|
1607
|
+
cutoff=0.25,
|
1608
|
+
show_ring=False,
|
1609
|
+
xticklabels_rot=0,
|
1610
|
+
title=None, # 'KEGG'
|
1611
|
+
cmap="coolwarm",
|
1612
|
+
n_top=10,
|
1613
|
+
size=5, # when size is None in network, by "NES"
|
1614
|
+
facecolor=None,# default by "NES"
|
1615
|
+
linewidth=None,# default by "NES"
|
1616
|
+
linecolor=None,# default by "NES"
|
1617
|
+
linealpha=None, # default by "NES"
|
1618
|
+
alpha=None,# default by "NES"
|
1619
|
+
ax=None,
|
1620
|
+
**kwargs,
|
1621
|
+
):
|
1622
|
+
kws_figsets = {}
|
1623
|
+
for k_arg, v_arg in kwargs.items():
|
1624
|
+
if "figset" in k_arg:
|
1625
|
+
kws_figsets = v_arg
|
1626
|
+
kwargs.pop(k_arg, None)
|
1627
|
+
break
|
1628
|
+
if isinstance(cmap, str):
|
1629
|
+
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
1630
|
+
elif isinstance(cmap, list):
|
1631
|
+
palette = cmap
|
1632
|
+
if n_top < 5:
|
1633
|
+
height_ = 4
|
1634
|
+
elif 5 <= n_top < 10:
|
1635
|
+
height_ = 5
|
1636
|
+
elif 10 <= n_top < 15:
|
1637
|
+
height_ = 6
|
1638
|
+
elif 15 <= n_top < 20:
|
1639
|
+
height_ = 7
|
1640
|
+
elif 20 <= n_top < 30:
|
1641
|
+
height_ = 8
|
1642
|
+
elif 30 <= n_top < 40:
|
1643
|
+
height_ = int(n_top / 5)
|
1644
|
+
else:
|
1645
|
+
height_ = int(n_top / 6)
|
1646
|
+
results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
|
1647
|
+
results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
|
1648
|
+
#! barplot
|
1649
|
+
if "bar" in kind.lower():
|
1650
|
+
df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
|
1651
|
+
if ax is None:
|
1652
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1653
|
+
ax = plot.plotxy(
|
1654
|
+
data=df_.head(n_top),
|
1655
|
+
kind="barplot",
|
1656
|
+
x="-log10(Adjusted P-value)",
|
1657
|
+
y="Term",
|
1658
|
+
hue="Term",
|
1659
|
+
palette=palette,
|
1660
|
+
legend=None,
|
1661
|
+
)
|
1662
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1663
|
+
return ax, df_
|
1664
|
+
|
1665
|
+
#! dotplot
|
1666
|
+
elif "dot" in kind.lower():
|
1667
|
+
#! dotplot
|
1668
|
+
cutoff_curr = cutoff
|
1669
|
+
step = 0.05
|
1670
|
+
cutoff_stop = 0.5
|
1671
|
+
while cutoff_curr <= cutoff_stop:
|
1672
|
+
try:
|
1673
|
+
if cutoff_curr != cutoff:
|
1674
|
+
plt.clf()
|
1675
|
+
ax = gp.dotplot(
|
1676
|
+
results_df,
|
1677
|
+
column="NOM p-val",
|
1678
|
+
show_ring=show_ring,
|
1679
|
+
xticklabels_rot=xticklabels_rot,
|
1680
|
+
title=title,
|
1681
|
+
cmap=cmap,
|
1682
|
+
cutoff=cutoff_curr,
|
1683
|
+
top_term=n_top,
|
1684
|
+
size=size,
|
1685
|
+
figsize=[10, height_],
|
1686
|
+
)
|
1687
|
+
if len(ax.collections) >= n_top:
|
1688
|
+
print(f"cutoff={cutoff_curr} done! ")
|
1689
|
+
break
|
1690
|
+
if cutoff_curr == cutoff_stop:
|
1691
|
+
break
|
1692
|
+
cutoff_curr += step
|
1693
|
+
except Exception as e:
|
1694
|
+
cutoff_curr += step
|
1695
|
+
print(
|
1696
|
+
f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
|
1697
|
+
)
|
1698
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1699
|
+
return ax, results_df
|
1700
|
+
|
1701
|
+
#! barplot with counts
|
1702
|
+
elif "co" in kind.lower():
|
1703
|
+
if ax is None:
|
1704
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1705
|
+
# 从overlap中提取出个数
|
1706
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
1707
|
+
ax = plot.plotxy(
|
1708
|
+
data=df_.head(n_top),
|
1709
|
+
kind="barplot",
|
1710
|
+
x="Count",
|
1711
|
+
y="Term",
|
1712
|
+
hue="Term",
|
1713
|
+
palette=palette,
|
1714
|
+
legend=None,
|
1715
|
+
ax=ax,
|
1716
|
+
**kwargs,
|
1717
|
+
)
|
1718
|
+
|
1719
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1720
|
+
return ax, df_
|
1721
|
+
#! scatter with counts
|
1722
|
+
elif "sca" in kind.lower():
|
1723
|
+
if isinstance(cmap, str):
|
1724
|
+
palette = plot.get_color(n_top, cmap=cmap)
|
1725
|
+
elif isinstance(cmap, list):
|
1726
|
+
palette = cmap
|
1727
|
+
if ax is None:
|
1728
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1729
|
+
# 从overlap中提取出个数
|
1730
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
1731
|
+
ax = plot.plotxy(
|
1732
|
+
data=df_.head(n_top),
|
1733
|
+
kind="scatter",
|
1734
|
+
x="Count",
|
1735
|
+
y="Term",
|
1736
|
+
hue="Count",
|
1737
|
+
size="Count",
|
1738
|
+
sizes=[10,50],
|
1739
|
+
palette=palette,
|
1740
|
+
legend=None,
|
1741
|
+
ax=ax,
|
1742
|
+
**kwargs,
|
1743
|
+
)
|
1401
1744
|
|
1745
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1746
|
+
return ax, df_
|
1747
|
+
elif "net" in kind.lower():
|
1748
|
+
#! network plot
|
1749
|
+
from gseapy import enrichment_map
|
1750
|
+
import networkx as nx
|
1751
|
+
from matplotlib import cm
|
1752
|
+
# try:
|
1753
|
+
if cutoff>=1 or cutoff is None:
|
1754
|
+
print(f"cutoff is {cutoff} => Without applying filter")
|
1755
|
+
nodes, edges = enrichment_map(
|
1756
|
+
df=results_df,
|
1757
|
+
columns="NOM p-val",
|
1758
|
+
cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1759
|
+
top_term=n_top,
|
1760
|
+
)
|
1761
|
+
else:
|
1762
|
+
cutoff_curr = cutoff
|
1763
|
+
step = 0.05
|
1764
|
+
cutoff_stop = 1.0
|
1765
|
+
while cutoff_curr <= cutoff_stop:
|
1766
|
+
try:
|
1767
|
+
# return two dataframe
|
1768
|
+
nodes, edges = enrichment_map(
|
1769
|
+
df=results_df,
|
1770
|
+
columns="NOM p-val",
|
1771
|
+
cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1772
|
+
top_term=n_top,
|
1773
|
+
)
|
1774
|
+
|
1775
|
+
if nodes.shape[0] >= n_top:
|
1776
|
+
print(f"cutoff={cutoff_curr} done! ")
|
1777
|
+
break
|
1778
|
+
if cutoff_curr == cutoff_stop:
|
1779
|
+
break
|
1780
|
+
cutoff_curr += step
|
1781
|
+
except Exception as e:
|
1782
|
+
cutoff_curr += step
|
1783
|
+
print(
|
1784
|
+
f"{e}: trying cutoff={cutoff_curr}"
|
1785
|
+
)
|
1786
|
+
|
1787
|
+
print("size: by 'NES'") if size is None else print("")
|
1788
|
+
print("linewidth: by 'NES'") if linewidth is None else print("")
|
1789
|
+
print("linecolor: by 'NES'") if linecolor is None else print("")
|
1790
|
+
print("linealpha: by 'NES'") if linealpha is None else print("")
|
1791
|
+
print("facecolor: by 'NES'") if facecolor is None else print("")
|
1792
|
+
print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
|
1793
|
+
edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
|
1794
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
1795
|
+
G,ax=plot_ppi(
|
1796
|
+
interactions=edges,
|
1797
|
+
player1="src_name",
|
1798
|
+
player2="targ_name",
|
1799
|
+
weight="jaccard_coef",
|
1800
|
+
size=[
|
1801
|
+
node["NES"] * 300 for _, node in nodes.iterrows()
|
1802
|
+
] if size is None else size, # size nodes by NES
|
1803
|
+
facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
|
1804
|
+
linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
|
1805
|
+
linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
|
1806
|
+
linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
|
1807
|
+
alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
|
1808
|
+
**kwargs
|
1809
|
+
)
|
1810
|
+
# except Exception as e:
|
1811
|
+
# print(f"not work {n_top},{e}")
|
1812
|
+
return ax, G, nodes, edges
|
1813
|
+
|
1814
|
+
|
1402
1815
|
#! https://string-db.org/help/api/
|
1403
1816
|
|
1404
1817
|
import pandas as pd
|
@@ -1527,16 +1940,22 @@ def plot_ppi(
|
|
1527
1940
|
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
1528
1941
|
dist_node = 10, # Distance between each rank of circles
|
1529
1942
|
layout="degree",
|
1530
|
-
size=
|
1943
|
+
size=None,#700,
|
1944
|
+
sizes=(50,500),# min and max of size
|
1531
1945
|
facecolor="skyblue",
|
1532
1946
|
cmap='coolwarm',
|
1533
1947
|
edgecolor="k",
|
1534
1948
|
edgelinewidth=1.5,
|
1535
1949
|
alpha=.5,
|
1950
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
1536
1951
|
marker="o",
|
1537
1952
|
node_hideticks=True,
|
1538
1953
|
linecolor="gray",
|
1954
|
+
line_cmap='coolwarm',
|
1539
1955
|
linewidth=1.5,
|
1956
|
+
linewidths=(0.5,5),# min and max of linewidth
|
1957
|
+
linealpha=1.0,
|
1958
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
1540
1959
|
linestyle="-",
|
1541
1960
|
line_arrowstyle='-',
|
1542
1961
|
fontsize=10,
|
@@ -1565,7 +1984,7 @@ def plot_ppi(
|
|
1565
1984
|
for col in [player1, player2, weight]:
|
1566
1985
|
if col not in interactions.columns:
|
1567
1986
|
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
1568
|
-
|
1987
|
+
interactions.sort_values(by=[weight], inplace=True)
|
1569
1988
|
# Initialize Pyvis network
|
1570
1989
|
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
1571
1990
|
net.force_atlas_2based(
|
@@ -1584,34 +2003,71 @@ def plot_ppi(
|
|
1584
2003
|
G = nx.Graph()
|
1585
2004
|
for _, row in interactions.iterrows():
|
1586
2005
|
G.add_edge(row[player1], row[player2], weight=row[weight])
|
2006
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
2007
|
+
|
1587
2008
|
|
1588
2009
|
# Calculate node degrees
|
1589
2010
|
degrees = dict(G.degree())
|
1590
2011
|
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
1591
2012
|
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
1592
2013
|
|
2014
|
+
if not ips.isa(facecolor, 'color'):
|
2015
|
+
print("facecolor: based on degrees")
|
2016
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
2017
|
+
num_nodes = G.number_of_nodes()
|
2018
|
+
#* size
|
1593
2019
|
# Set properties based on degrees
|
1594
2020
|
if not isinstance(size, (int,float,list)):
|
2021
|
+
print("size: based on degrees")
|
1595
2022
|
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
1596
|
-
if
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
2023
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
2024
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
2025
|
+
# Normalize sizes
|
2026
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
2027
|
+
min_degree, max_degree = min(size), max(size)
|
2028
|
+
if max_degree > min_degree: # Avoid division by zero
|
2029
|
+
size = [
|
2030
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
2031
|
+
for sz in size
|
2032
|
+
]
|
2033
|
+
else:
|
2034
|
+
# If all values are the same, set them to a default of the midpoint
|
2035
|
+
size = [(min_size + max_size) / 2] * len(size)
|
2036
|
+
|
2037
|
+
#* facecolor
|
2038
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
2039
|
+
# * facealpha
|
2040
|
+
if isinstance(alpha, list):
|
2041
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
2042
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
2043
|
+
if len(alpha) > 0:
|
2044
|
+
# Normalize alpha based on the specified min and max
|
2045
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
2046
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
2047
|
+
alpha = [
|
2048
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
2049
|
+
for ea in alpha
|
2050
|
+
]
|
2051
|
+
else:
|
2052
|
+
# If all alpha values are the same, set them to the average of min and max
|
2053
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
2054
|
+
else:
|
2055
|
+
# Default to a full opacity if no edges are provided
|
2056
|
+
alpha = [1.0] * num_nodes
|
2057
|
+
else:
|
2058
|
+
# If alpha is a single value, convert it to a list and normalize it
|
2059
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
2060
|
+
|
2061
|
+
for i, node in enumerate(G.nodes()):
|
1608
2062
|
net.add_node(
|
1609
2063
|
node,
|
1610
2064
|
label=node,
|
1611
|
-
size=size[
|
1612
|
-
color=facecolor[
|
2065
|
+
size=size[i],
|
2066
|
+
color=facecolor[i],
|
2067
|
+
alpha=alpha[i],
|
1613
2068
|
font={"size": fontsize, "color": fontcolor},
|
1614
2069
|
)
|
2070
|
+
print(f'nodes number: {i+1}')
|
1615
2071
|
|
1616
2072
|
for edge in G.edges(data=True):
|
1617
2073
|
net.add_edge(
|
@@ -1621,6 +2077,7 @@ def plot_ppi(
|
|
1621
2077
|
color=edgecolor,
|
1622
2078
|
width=edgelinewidth * edge[2]["weight"],
|
1623
2079
|
)
|
2080
|
+
|
1624
2081
|
layouts = [
|
1625
2082
|
"spring",
|
1626
2083
|
"circular",
|
@@ -1632,7 +2089,8 @@ def plot_ppi(
|
|
1632
2089
|
"degree"
|
1633
2090
|
]
|
1634
2091
|
layout = ips.strcmp(layout, layouts)[0]
|
1635
|
-
print(layout)
|
2092
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
2093
|
+
|
1636
2094
|
# Choose layout
|
1637
2095
|
if layout == "spring":
|
1638
2096
|
pos = nx.spring_layout(G, k=k_value)
|
@@ -1658,7 +2116,9 @@ def plot_ppi(
|
|
1658
2116
|
# Calculate node degrees and sort nodes by degree
|
1659
2117
|
degrees = dict(G.degree())
|
1660
2118
|
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
1661
|
-
|
2119
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
2120
|
+
colormap = cm.get_cmap(cmap)
|
2121
|
+
|
1662
2122
|
# Create positions for concentric circles based on n_layers and n_rank
|
1663
2123
|
pos = {}
|
1664
2124
|
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
@@ -1689,8 +2149,8 @@ def plot_ppi(
|
|
1689
2149
|
|
1690
2150
|
# If ax is None, use plt.gca()
|
1691
2151
|
if ax is None:
|
1692
|
-
fig, ax = plt.subplots(1,1,figsize=figsize)
|
1693
|
-
|
2152
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
2153
|
+
|
1694
2154
|
# Draw nodes, edges, and labels with customization options
|
1695
2155
|
nx.draw_networkx_nodes(
|
1696
2156
|
G,
|
@@ -1704,6 +2164,54 @@ def plot_ppi(
|
|
1704
2164
|
hide_ticks=node_hideticks,
|
1705
2165
|
node_shape=marker
|
1706
2166
|
)
|
2167
|
+
|
2168
|
+
#* linewidth
|
2169
|
+
if not isinstance(linewidth, list):
|
2170
|
+
linewidth = [linewidth] * G.number_of_edges()
|
2171
|
+
else:
|
2172
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
2173
|
+
# Normalize linewidth if it is a list
|
2174
|
+
if isinstance(linewidth, list):
|
2175
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
2176
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
2177
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
2178
|
+
# Scale between vmin and vmax
|
2179
|
+
linewidth = [
|
2180
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
2181
|
+
for lw in linewidth
|
2182
|
+
]
|
2183
|
+
else:
|
2184
|
+
# If all values are the same, set them to a default of the midpoint
|
2185
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
2186
|
+
else:
|
2187
|
+
# If linewidth is a single value, convert it to a list of that value
|
2188
|
+
linewidth = [linewidth] * G.number_of_edges()
|
2189
|
+
#* linecolor
|
2190
|
+
if not isinstance(linecolor, str):
|
2191
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
2192
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
2193
|
+
colormap = cm.get_cmap(line_cmap)
|
2194
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
2195
|
+
else:
|
2196
|
+
linecolor = [linecolor] * G.number_of_edges()
|
2197
|
+
|
2198
|
+
# * linealpha
|
2199
|
+
if isinstance(linealpha, list):
|
2200
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
2201
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
2202
|
+
if len(linealpha) > 0:
|
2203
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
2204
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
2205
|
+
linealpha = [
|
2206
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
2207
|
+
for ea in linealpha
|
2208
|
+
]
|
2209
|
+
else:
|
2210
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
2211
|
+
else:
|
2212
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
2213
|
+
else:
|
2214
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
1707
2215
|
nx.draw_networkx_edges(
|
1708
2216
|
G,
|
1709
2217
|
pos,
|
@@ -1712,8 +2220,9 @@ def plot_ppi(
|
|
1712
2220
|
width=linewidth,
|
1713
2221
|
style=linestyle,
|
1714
2222
|
arrowstyle=line_arrowstyle,
|
1715
|
-
alpha=
|
2223
|
+
alpha=linealpha
|
1716
2224
|
)
|
2225
|
+
|
1717
2226
|
nx.draw_networkx_labels(
|
1718
2227
|
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
1719
2228
|
)
|
@@ -1725,6 +2234,7 @@ def plot_ppi(
|
|
1725
2234
|
net.write_html(dir_save)
|
1726
2235
|
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
1727
2236
|
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
2237
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
1728
2238
|
return G,ax
|
1729
2239
|
|
1730
2240
|
|