py2ls 0.2.4.4__py3-none-any.whl → 0.2.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.git/index +0 -0
- py2ls/bio.py +959 -38
- py2ls/ips.py +15 -6
- py2ls/mol.py +289 -0
- py2ls/plot.py +304 -109
- {py2ls-0.2.4.4.dist-info → py2ls-0.2.4.6.dist-info}/METADATA +1 -1
- {py2ls-0.2.4.4.dist-info → py2ls-0.2.4.6.dist-info}/RECORD +8 -7
- {py2ls-0.2.4.4.dist-info → py2ls-0.2.4.6.dist-info}/WHEEL +0 -0
py2ls/bio.py
CHANGED
@@ -324,7 +324,7 @@ def find_condition(data:pd.DataFrame, columns=["characteristics_ch1","title"]):
|
|
324
324
|
# 详细看看每个信息的有哪些类, 其中有数字的, 要去除
|
325
325
|
for col in columns:
|
326
326
|
print(f"{"="*10} {col} {"="*10}")
|
327
|
-
display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]]))
|
327
|
+
display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]],verbose=False))
|
328
328
|
|
329
329
|
def add_condition(
|
330
330
|
data: pd.DataFrame,
|
@@ -581,7 +581,7 @@ def batch_effect(
|
|
581
581
|
return df_corrected
|
582
582
|
|
583
583
|
def get_common_genes(elment1, elment2):
|
584
|
-
common_genes=ips.shared(elment1, elment2)
|
584
|
+
common_genes=ips.shared(elment1, elment2,verbose=False)
|
585
585
|
return common_genes
|
586
586
|
|
587
587
|
def counts2expression(
|
@@ -667,7 +667,7 @@ def counts2expression(
|
|
667
667
|
|
668
668
|
length.index=length.index.astype(str).str.strip()
|
669
669
|
counts.columns = counts.columns.astype(str).str.strip()
|
670
|
-
shared_genes=ips.shared(length.index, counts.columns)
|
670
|
+
shared_genes=ips.shared(length.index, counts.columns,verbose=False)
|
671
671
|
length=length.loc[shared_genes]
|
672
672
|
counts=counts.loc[:,shared_genes]
|
673
673
|
columns_org = counts.columns.tolist()
|
@@ -814,7 +814,11 @@ def counts_deseq(counts_sam_gene: pd.DataFrame,
|
|
814
814
|
# .reset_index()
|
815
815
|
# .rename(columns={"index": "gene"})
|
816
816
|
# )
|
817
|
-
|
817
|
+
df_norm=pd.DataFrame(dds.layers['normed_counts'])
|
818
|
+
df_norm.index=counts_sam_gene.index
|
819
|
+
df_norm.columns=counts_sam_gene.columns
|
820
|
+
print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
|
821
|
+
return dds, diff, stat_res,df_norm
|
818
822
|
|
819
823
|
def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
|
820
824
|
"""
|
@@ -842,6 +846,7 @@ def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", specie
|
|
842
846
|
|
843
847
|
def get_enrichr(gene_symbol_list,
|
844
848
|
gene_sets:str,
|
849
|
+
download:bool = False,
|
845
850
|
species='Human',
|
846
851
|
dir_save="./",
|
847
852
|
plot_=False,
|
@@ -854,6 +859,7 @@ def get_enrichr(gene_symbol_list,
|
|
854
859
|
title=None,# 'KEGG'
|
855
860
|
cutoff=0.05,
|
856
861
|
cmap="coolwarm",
|
862
|
+
size=5,
|
857
863
|
**kwargs):
|
858
864
|
"""
|
859
865
|
Note: Enrichr uses a list of Entrez gene symbols as input.
|
@@ -878,16 +884,22 @@ def get_enrichr(gene_symbol_list,
|
|
878
884
|
lib_support_names = gp.get_library_name()
|
879
885
|
# correct input gene_set name
|
880
886
|
gene_sets_name=ips.strcmp(gene_sets,lib_support_names)[0]
|
887
|
+
|
881
888
|
# download it
|
882
|
-
|
883
|
-
|
889
|
+
if download:
|
890
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
891
|
+
else:
|
892
|
+
gene_sets = gene_sets_name # 避免重复下载
|
893
|
+
print(f"\ngene_sets get ready: {gene_sets_name}")
|
884
894
|
|
885
895
|
# gene symbols are uppercase
|
886
896
|
gene_symbol_list=[str(i).upper() for i in gene_symbol_list]
|
887
897
|
|
888
898
|
# # check how shared genes
|
889
|
-
if check_shared:
|
890
|
-
shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
|
899
|
+
if check_shared and isinstance(gene_sets, dict):
|
900
|
+
shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
|
901
|
+
ips.flatten(gene_sets,verbose=False),
|
902
|
+
verbose=False)
|
891
903
|
|
892
904
|
#! enrichr
|
893
905
|
try:
|
@@ -903,13 +915,13 @@ def get_enrichr(gene_symbol_list,
|
|
903
915
|
return None
|
904
916
|
|
905
917
|
results_df = enr.results
|
906
|
-
print(f"got enrichr reslutls; shape: {results_df.shape}")
|
918
|
+
print(f"got enrichr reslutls; shape: {results_df.shape}\n")
|
907
919
|
results_df["-log10(Adjusted P-value)"] = -np.log10(results_df["Adjusted P-value"])
|
908
920
|
results_df.sort_values("-log10(Adjusted P-value)", inplace=True, ascending=False)
|
909
921
|
|
910
922
|
if plot_:
|
911
923
|
if palette is None:
|
912
|
-
palette=plot.get_color(n_top, cmap=
|
924
|
+
palette=plot.get_color(n_top, cmap=cmap)[::-1]
|
913
925
|
#! barplot
|
914
926
|
if n_top<5:
|
915
927
|
height_=4
|
@@ -921,11 +933,12 @@ def get_enrichr(gene_symbol_list,
|
|
921
933
|
height_=7
|
922
934
|
elif 15<=n_top<20:
|
923
935
|
height_=8
|
924
|
-
elif
|
936
|
+
elif 20<=n_top<30:
|
925
937
|
height_=9
|
926
938
|
else:
|
927
939
|
height_=int(n_top/3)
|
928
|
-
plt.figure(figsize=[
|
940
|
+
plt.figure(figsize=[10, height_])
|
941
|
+
|
929
942
|
ax1=plot.plotxy(
|
930
943
|
data=results_df.head(n_top),
|
931
944
|
kind="barplot",
|
@@ -935,18 +948,17 @@ def get_enrichr(gene_symbol_list,
|
|
935
948
|
palette=palette,
|
936
949
|
legend=None,
|
937
950
|
)
|
951
|
+
plot.figsets(ax=ax1, **kws_figsets)
|
938
952
|
if dir_save:
|
939
953
|
ips.figsave(f"{dir_save} enr_barplot.pdf")
|
940
|
-
plot.figsets(ax=ax1, **kws_figsets)
|
941
954
|
plt.show()
|
942
955
|
|
943
956
|
#! dotplot
|
944
957
|
cutoff_curr = cutoff
|
945
958
|
step=0.05
|
946
959
|
cutoff_stop = 0.5
|
947
|
-
while cutoff_curr <=cutoff_stop:
|
960
|
+
while cutoff_curr <= cutoff_stop:
|
948
961
|
try:
|
949
|
-
print(kws_figsets)
|
950
962
|
if cutoff_curr!=cutoff:
|
951
963
|
plt.clf()
|
952
964
|
ax2 = gp.dotplot(enr.res2d,
|
@@ -957,7 +969,8 @@ def get_enrichr(gene_symbol_list,
|
|
957
969
|
cmap=cmap,
|
958
970
|
cutoff=cutoff_curr,
|
959
971
|
top_term=n_top,
|
960
|
-
|
972
|
+
size=size,
|
973
|
+
figsize=[10, height_])
|
961
974
|
if len(ax2.collections)>=n_top:
|
962
975
|
print(f"cutoff={cutoff_curr} done! ")
|
963
976
|
break
|
@@ -975,7 +988,813 @@ def get_enrichr(gene_symbol_list,
|
|
975
988
|
|
976
989
|
return results_df
|
977
990
|
|
991
|
+
def plot_enrichr(results_df,
|
992
|
+
kind="bar",# 'barplot', 'dotplot'
|
993
|
+
cutoff=0.05,
|
994
|
+
show_ring=False,
|
995
|
+
xticklabels_rot=0,
|
996
|
+
title=None,# 'KEGG'
|
997
|
+
cmap="coolwarm",
|
998
|
+
n_top=10,
|
999
|
+
size=5,
|
1000
|
+
ax=None,
|
1001
|
+
**kwargs):
|
1002
|
+
kws_figsets = {}
|
1003
|
+
for k_arg, v_arg in kwargs.items():
|
1004
|
+
if "figset" in k_arg:
|
1005
|
+
kws_figsets = v_arg
|
1006
|
+
kwargs.pop(k_arg, None)
|
1007
|
+
break
|
1008
|
+
if isinstance(cmap,str):
|
1009
|
+
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
1010
|
+
elif isinstance(cmap,list):
|
1011
|
+
palette=cmap
|
1012
|
+
if n_top < 5:
|
1013
|
+
height_ = 3
|
1014
|
+
elif 5 <= n_top < 10:
|
1015
|
+
height_ = 3
|
1016
|
+
elif 10 <= n_top < 15:
|
1017
|
+
height_ = 3
|
1018
|
+
elif 15 <= n_top < 20:
|
1019
|
+
height_ =4
|
1020
|
+
elif 20 <= n_top < 30:
|
1021
|
+
height_ = 5
|
1022
|
+
elif 30 <= n_top < 40:
|
1023
|
+
height_ = int(n_top / 6)
|
1024
|
+
else:
|
1025
|
+
height_ = int(n_top / 8)
|
1026
|
+
|
1027
|
+
#! barplot
|
1028
|
+
if 'bar' in kind.lower():
|
1029
|
+
if ax is None:
|
1030
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1031
|
+
ax=plot.plotxy(
|
1032
|
+
data=results_df.head(n_top),
|
1033
|
+
kind="barplot",
|
1034
|
+
x="-log10(Adjusted P-value)",
|
1035
|
+
y="Term",
|
1036
|
+
hue="Term",
|
1037
|
+
palette=palette,
|
1038
|
+
legend=None,
|
1039
|
+
)
|
1040
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1041
|
+
return ax,results_df
|
1042
|
+
|
1043
|
+
#! dotplot
|
1044
|
+
elif 'dot' in kind.lower():
|
1045
|
+
#! dotplot
|
1046
|
+
cutoff_curr = cutoff
|
1047
|
+
step=0.05
|
1048
|
+
cutoff_stop = 0.5
|
1049
|
+
while cutoff_curr <= cutoff_stop:
|
1050
|
+
try:
|
1051
|
+
if cutoff_curr!=cutoff:
|
1052
|
+
plt.clf()
|
1053
|
+
ax = gp.dotplot(results_df,
|
1054
|
+
column="Adjusted P-value",
|
1055
|
+
show_ring=show_ring,
|
1056
|
+
xticklabels_rot=xticklabels_rot,
|
1057
|
+
title=title,
|
1058
|
+
cmap=cmap,
|
1059
|
+
cutoff=cutoff_curr,
|
1060
|
+
top_term=n_top,
|
1061
|
+
size=size,
|
1062
|
+
figsize=[10, height_])
|
1063
|
+
if len(ax.collections)>=n_top:
|
1064
|
+
print(f"cutoff={cutoff_curr} done! ")
|
1065
|
+
break
|
1066
|
+
if cutoff_curr==cutoff_stop:
|
1067
|
+
break
|
1068
|
+
cutoff_curr+=step
|
1069
|
+
except Exception as e:
|
1070
|
+
cutoff_curr+=step
|
1071
|
+
print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
|
1072
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1073
|
+
return ax,results_df
|
1074
|
+
|
1075
|
+
#! barplot with counts
|
1076
|
+
elif 'count' in kind.lower():
|
1077
|
+
if ax is None:
|
1078
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1079
|
+
# 从overlap中提取出个数
|
1080
|
+
results_df["Count"] = results_df["Overlap"].apply(
|
1081
|
+
lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
|
1082
|
+
df_=results_df.sort_values(by="Count", ascending=False)
|
1083
|
+
|
1084
|
+
ax=plot.plotxy(
|
1085
|
+
data=df_.head(n_top),
|
1086
|
+
kind="barplot",
|
1087
|
+
x="Count",
|
1088
|
+
y="Term",
|
1089
|
+
hue="Term",
|
1090
|
+
palette=palette,
|
1091
|
+
legend=None,
|
1092
|
+
ax=ax
|
1093
|
+
)
|
1094
|
+
|
1095
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1096
|
+
return ax,df_
|
1097
|
+
|
1098
|
+
def plot_bp_cc_mf(
|
1099
|
+
deg_gene_list,
|
1100
|
+
gene_sets=[
|
1101
|
+
"GO_Biological_Process_2023",
|
1102
|
+
"GO_Cellular_Component_2023",
|
1103
|
+
"GO_Molecular_Function_2023",
|
1104
|
+
],
|
1105
|
+
species="human",
|
1106
|
+
download=False,
|
1107
|
+
n_top=10,
|
1108
|
+
plot_=True,
|
1109
|
+
ax=None,
|
1110
|
+
palette=plot.get_color(3),
|
1111
|
+
**kwargs,
|
1112
|
+
):
|
1113
|
+
|
1114
|
+
def res_enrichr_2_count(res_enrichr, n_top=10):
|
1115
|
+
"""把enrich resulst 提取出count,并排序"""
|
1116
|
+
res_enrichr["Count"] = res_enrichr["Overlap"].apply(
|
1117
|
+
lambda x: int(x.split("/")[0]) if isinstance(x, str) else x
|
1118
|
+
)
|
1119
|
+
res_enrichr.sort_values(by="Count", ascending=False, inplace=True)
|
1120
|
+
|
1121
|
+
return res_enrichr.head(n_top)#[["Term", "Count"]]
|
1122
|
+
|
1123
|
+
res_enrichr_BP = get_enrichr(
|
1124
|
+
deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
|
1125
|
+
)
|
1126
|
+
res_enrichr_CC = get_enrichr(
|
1127
|
+
deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
|
1128
|
+
)
|
1129
|
+
res_enrichr_MF = get_enrichr(
|
1130
|
+
deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
|
1131
|
+
)
|
1132
|
+
|
1133
|
+
df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
|
1134
|
+
df_BP["Ontology"] = ["BP"] * n_top
|
1135
|
+
|
1136
|
+
df_CC = res_enrichr_2_count(res_enrichr_CC, n_top=n_top)
|
1137
|
+
df_CC["Ontology"] = ["CC"] * n_top
|
1138
|
+
|
1139
|
+
df_MF = res_enrichr_2_count(res_enrichr_MF, n_top=n_top)
|
1140
|
+
df_MF["Ontology"] = ["MF"] * n_top
|
1141
|
+
|
1142
|
+
# 合并
|
1143
|
+
df2plot = pd.concat([df_BP, df_CC, df_MF])
|
1144
|
+
n_top=n_top*3
|
1145
|
+
if n_top < 5:
|
1146
|
+
height_ = 4
|
1147
|
+
elif 5 <= n_top < 10:
|
1148
|
+
height_ = 5
|
1149
|
+
elif 10 <= n_top < 15:
|
1150
|
+
height_ = 6
|
1151
|
+
elif 15 <= n_top < 20:
|
1152
|
+
height_ = 7
|
1153
|
+
elif 20 <= n_top < 30:
|
1154
|
+
height_ = 8
|
1155
|
+
elif 30 <= n_top < 40:
|
1156
|
+
height_ = int(n_top / 4)
|
1157
|
+
else:
|
1158
|
+
height_ = int(n_top / 5)
|
1159
|
+
if ax is None:
|
1160
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1161
|
+
# 作图
|
1162
|
+
display(df2plot)
|
1163
|
+
if df2plot["Term"].tolist()[0].endswith(")"):
|
1164
|
+
df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
|
1165
|
+
if plot_:
|
1166
|
+
ax = plot.plotxy(
|
1167
|
+
data=df2plot,
|
1168
|
+
x="Count",
|
1169
|
+
y="Term",
|
1170
|
+
hue="Ontology",
|
1171
|
+
kind="bar",
|
1172
|
+
palette=palette,
|
1173
|
+
ax=ax,
|
1174
|
+
**kwargs
|
1175
|
+
)
|
1176
|
+
return ax, df2plot
|
1177
|
+
|
1178
|
+
def get_library_name(by=None, verbose=False):
|
1179
|
+
lib_names=gp.get_library_name()
|
1180
|
+
if by is None:
|
1181
|
+
if verbose:
|
1182
|
+
[print(i) for i in lib_names]
|
1183
|
+
return lib_names
|
1184
|
+
else:
|
1185
|
+
return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
|
1186
|
+
|
1187
|
+
|
1188
|
+
def get_gsva(
|
1189
|
+
data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
|
1190
|
+
gene_sets: str,
|
1191
|
+
species:str="Human",
|
1192
|
+
dir_save:str="./",
|
1193
|
+
plot_:bool=False,
|
1194
|
+
n_top:int=30,
|
1195
|
+
check_shared:bool=True,
|
1196
|
+
cmap="coolwarm",
|
1197
|
+
min_size=1,
|
1198
|
+
max_size=1000,
|
1199
|
+
kcdf="Gaussian",# 'Gaussian' for continuous data
|
1200
|
+
method='gsva',
|
1201
|
+
seed=1,
|
1202
|
+
**kwargs,
|
1203
|
+
):
|
1204
|
+
kws_figsets = {}
|
1205
|
+
for k_arg, v_arg in kwargs.items():
|
1206
|
+
if "figset" in k_arg:
|
1207
|
+
kws_figsets = v_arg
|
1208
|
+
kwargs.pop(k_arg, None)
|
1209
|
+
break
|
1210
|
+
species_org = species
|
1211
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
1212
|
+
organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
|
1213
|
+
species = ips.strcmp(species, organisms)[0]
|
1214
|
+
if species_org.lower() != species.lower():
|
1215
|
+
print(f"species was corrected to {species}, becasue only support {organisms}")
|
1216
|
+
if os.path.isfile(gene_sets):
|
1217
|
+
gene_sets_name = os.path.basename(gene_sets)
|
1218
|
+
gene_sets = ips.fload(gene_sets)
|
1219
|
+
else:
|
1220
|
+
lib_support_names = gp.get_library_name()
|
1221
|
+
# correct input gene_set name
|
1222
|
+
gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
|
1223
|
+
# download it
|
1224
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
1225
|
+
print(f"gene_sets get ready: {gene_sets_name}")
|
1226
|
+
|
1227
|
+
# gene symbols are uppercase
|
1228
|
+
gene_symbol_list = [str(i).upper() for i in data_gene_samples.index]
|
1229
|
+
data_gene_samples.index=gene_symbol_list
|
1230
|
+
# display(data_gene_samples.head(3))
|
1231
|
+
# # check how shared genes
|
1232
|
+
if check_shared:
|
1233
|
+
ips.shared(
|
1234
|
+
ips.flatten(gene_symbol_list, verbose=False),
|
1235
|
+
ips.flatten(gene_sets, verbose=False),
|
1236
|
+
verbose=False
|
1237
|
+
)
|
1238
|
+
gsva_results = gp.gsva(
|
1239
|
+
data=data_gene_samples, # matrix should have genes as rows and samples as columns
|
1240
|
+
gene_sets=gene_sets,
|
1241
|
+
outdir=None,
|
1242
|
+
kcdf=kcdf, # 'Gaussian' for continuous data
|
1243
|
+
min_size=min_size,
|
1244
|
+
method=method,
|
1245
|
+
max_size=max_size,
|
1246
|
+
verbose=True,
|
1247
|
+
seed=seed,
|
1248
|
+
# no_plot=False,
|
1249
|
+
)
|
1250
|
+
gsva_res = gsva_results.res2d.copy()
|
1251
|
+
gsva_res["ES_abs"] = gsva_res["ES"].apply(np.abs)
|
1252
|
+
gsva_res = gsva_res.sort_values(by="ES_abs", ascending=False)
|
1253
|
+
gsva_res = (
|
1254
|
+
gsva_res.drop_duplicates(subset="Term").drop(columns="ES_abs")
|
1255
|
+
# .iloc[:80, :]
|
1256
|
+
.reset_index(drop=True)
|
1257
|
+
)
|
1258
|
+
gsva_res = gsva_res.sort_values(by="ES", ascending=False)
|
1259
|
+
if plot_:
|
1260
|
+
if gsva_res.shape[0]>=2*n_top:
|
1261
|
+
gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
|
1262
|
+
else:
|
1263
|
+
gsva_res_plot = gsva_res
|
1264
|
+
if isinstance(cmap,str):
|
1265
|
+
palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
|
1266
|
+
elif isinstance(cmap,list):
|
1267
|
+
if len(cmap)==2:
|
1268
|
+
palette = [cmap[0]]*n_top+[cmap[1]]*n_top
|
1269
|
+
else:
|
1270
|
+
palette=cmap
|
1271
|
+
# ! barplot
|
1272
|
+
if n_top < 5:
|
1273
|
+
height_ = 3
|
1274
|
+
elif 5 <= n_top < 10:
|
1275
|
+
height_ = 4
|
1276
|
+
elif 10 <= n_top < 15:
|
1277
|
+
height_ = 5
|
1278
|
+
elif 15 <= n_top < 20:
|
1279
|
+
height_ = 6
|
1280
|
+
elif 20 <= n_top < 30:
|
1281
|
+
height_ = 7
|
1282
|
+
elif 30 <= n_top < 40:
|
1283
|
+
height_ = int(n_top / 3.5)
|
1284
|
+
else:
|
1285
|
+
height_ = int(n_top / 3)
|
1286
|
+
plt.figure(figsize=[10, height_])
|
1287
|
+
ax2 = plot.plotxy(
|
1288
|
+
data=gsva_res_plot,
|
1289
|
+
x="ES",
|
1290
|
+
y="Term",
|
1291
|
+
hue="Term",
|
1292
|
+
palette=palette,
|
1293
|
+
kind=["bar"],
|
1294
|
+
figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
|
1295
|
+
)
|
1296
|
+
# 改变labels的位置
|
1297
|
+
for i, bar in enumerate(ax2.patches):
|
1298
|
+
term = gsva_res_plot.iloc[i]["Term"]
|
1299
|
+
es_value = gsva_res_plot.iloc[i]["ES"]
|
1300
|
+
|
1301
|
+
# Positive ES values: Align y-labels to the left
|
1302
|
+
if es_value > 0:
|
1303
|
+
ax2.annotate(
|
1304
|
+
term,
|
1305
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
1306
|
+
xytext=(-5, 0), # Move to the left
|
1307
|
+
textcoords="offset points",
|
1308
|
+
ha="right",
|
1309
|
+
va="center", # Align labels to the right
|
1310
|
+
fontsize=10,
|
1311
|
+
color="black",
|
1312
|
+
)
|
1313
|
+
# Negative ES values: Align y-labels to the right
|
1314
|
+
else:
|
1315
|
+
ax2.annotate(
|
1316
|
+
term,
|
1317
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
1318
|
+
xytext=(5, 0), # Move to the right
|
1319
|
+
textcoords="offset points",
|
1320
|
+
ha="left",
|
1321
|
+
va="center", # Align labels to the left
|
1322
|
+
fontsize=10,
|
1323
|
+
color="black",
|
1324
|
+
)
|
1325
|
+
plot.figsets(ax=ax2, **kws_figsets)
|
1326
|
+
if dir_save:
|
1327
|
+
ips.figsave(dir_save + f"GSVA_{gene_sets_name}.pdf")
|
1328
|
+
plt.show()
|
1329
|
+
return gsva_res.reset_index(drop=True)
|
1330
|
+
|
1331
|
+
def plot_gsva(gsva_res, # output from bio.get_gsva()
|
1332
|
+
n_top=10,
|
1333
|
+
ax=None,
|
1334
|
+
x="ES",
|
1335
|
+
y="Term",
|
1336
|
+
hue="Term",
|
1337
|
+
cmap="coolwarm",
|
1338
|
+
**kwargs
|
1339
|
+
):
|
1340
|
+
kws_figsets = {}
|
1341
|
+
for k_arg, v_arg in kwargs.items():
|
1342
|
+
if "figset" in k_arg:
|
1343
|
+
kws_figsets = v_arg
|
1344
|
+
kwargs.pop(k_arg, None)
|
1345
|
+
break
|
1346
|
+
# ! barplot
|
1347
|
+
if n_top < 5:
|
1348
|
+
height_ = 4
|
1349
|
+
elif 5 <= n_top < 10:
|
1350
|
+
height_ = 5
|
1351
|
+
elif 10 <= n_top < 15:
|
1352
|
+
height_ = 6
|
1353
|
+
elif 15 <= n_top < 20:
|
1354
|
+
height_ = 7
|
1355
|
+
elif 20 <= n_top < 30:
|
1356
|
+
height_ = 8
|
1357
|
+
elif 30 <= n_top < 40:
|
1358
|
+
height_ = int(n_top / 3.5)
|
1359
|
+
else:
|
1360
|
+
height_ = int(n_top / 3)
|
1361
|
+
if ax is None:
|
1362
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
1363
|
+
gsva_res = gsva_res.sort_values(by=x, ascending=False)
|
1364
|
+
|
1365
|
+
if gsva_res.shape[0]>=2*n_top:
|
1366
|
+
gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
|
1367
|
+
else:
|
1368
|
+
gsva_res_plot = gsva_res
|
1369
|
+
if isinstance(cmap,str):
|
1370
|
+
palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
|
1371
|
+
elif isinstance(cmap,list):
|
1372
|
+
if len(cmap)==2:
|
1373
|
+
palette = [cmap[0]]*n_top+[cmap[1]]*n_top
|
1374
|
+
else:
|
1375
|
+
palette=cmap
|
1376
|
+
|
1377
|
+
ax = plot.plotxy(
|
1378
|
+
ax=ax,
|
1379
|
+
data=gsva_res_plot,
|
1380
|
+
x=x,
|
1381
|
+
y=y,
|
1382
|
+
hue=hue,
|
1383
|
+
palette=palette,
|
1384
|
+
kind=["bar"],
|
1385
|
+
figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
|
1386
|
+
)
|
1387
|
+
# 改变labels的位置
|
1388
|
+
for i, bar in enumerate(ax.patches):
|
1389
|
+
term = gsva_res_plot.iloc[i]["Term"]
|
1390
|
+
es_value = gsva_res_plot.iloc[i]["ES"]
|
1391
|
+
|
1392
|
+
# Positive ES values: Align y-labels to the left
|
1393
|
+
if es_value > 0:
|
1394
|
+
ax.annotate(
|
1395
|
+
term,
|
1396
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
1397
|
+
xytext=(-5, 0), # Move to the left
|
1398
|
+
textcoords="offset points",
|
1399
|
+
ha="right",
|
1400
|
+
va="center", # Align labels to the right
|
1401
|
+
fontsize=10,
|
1402
|
+
color="black",
|
1403
|
+
)
|
1404
|
+
# Negative ES values: Align y-labels to the right
|
1405
|
+
else:
|
1406
|
+
ax.annotate(
|
1407
|
+
term,
|
1408
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
1409
|
+
xytext=(5, 0), # Move to the right
|
1410
|
+
textcoords="offset points",
|
1411
|
+
ha="left",
|
1412
|
+
va="center", # Align labels to the left
|
1413
|
+
fontsize=10,
|
1414
|
+
color="black",
|
1415
|
+
)
|
1416
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1417
|
+
return ax
|
1418
|
+
|
1419
|
+
def get_prerank(
|
1420
|
+
rnk: pd.DataFrame,
|
1421
|
+
gene_sets: str,
|
1422
|
+
download: bool = False,
|
1423
|
+
species="Human",
|
1424
|
+
threads=8, # Number of CPU cores to use
|
1425
|
+
permutation_num=1000, # Number of permutations for significance
|
1426
|
+
min_size=1, # Minimum gene set size
|
1427
|
+
max_size=2000, # Maximum gene set size
|
1428
|
+
seed=1, # Seed for reproducibility
|
1429
|
+
verbose=True, # Verbosity
|
1430
|
+
dir_save="./",
|
1431
|
+
plot_=False,
|
1432
|
+
size=5,
|
1433
|
+
cutoff=0.25,
|
1434
|
+
show_ring=False,
|
1435
|
+
cmap="coolwarm",
|
1436
|
+
check_shared=True,
|
1437
|
+
**kwargs,
|
1438
|
+
):
|
1439
|
+
"""
|
1440
|
+
Note: Enrichr uses a list of Entrez gene symbols as input.
|
1441
|
+
|
1442
|
+
"""
|
1443
|
+
kws_figsets = {}
|
1444
|
+
for k_arg, v_arg in kwargs.items():
|
1445
|
+
if "figset" in k_arg:
|
1446
|
+
kws_figsets = v_arg
|
1447
|
+
kwargs.pop(k_arg, None)
|
1448
|
+
break
|
1449
|
+
species_org = species
|
1450
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
1451
|
+
organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
|
1452
|
+
species = ips.strcmp(species, organisms)[0]
|
1453
|
+
if species_org.lower() != species.lower():
|
1454
|
+
print(f"species was corrected to {species}, becasue only support {organisms}")
|
1455
|
+
if os.path.isfile(gene_sets):
|
1456
|
+
gene_sets_name = os.path.basename(gene_sets)
|
1457
|
+
gene_sets = ips.fload(gene_sets)
|
1458
|
+
else:
|
1459
|
+
lib_support_names = gp.get_library_name()
|
1460
|
+
# correct input gene_set name
|
1461
|
+
gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
|
1462
|
+
|
1463
|
+
# download it
|
1464
|
+
if download:
|
1465
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
1466
|
+
else:
|
1467
|
+
gene_sets = gene_sets_name # 避免重复下载
|
1468
|
+
print(f"\ngene_sets get ready: {gene_sets_name}")
|
978
1469
|
|
1470
|
+
#! prerank
|
1471
|
+
try:
|
1472
|
+
pre_res = gp.prerank(
|
1473
|
+
rnk=rnk,
|
1474
|
+
gene_sets=gene_sets,
|
1475
|
+
threads=threads, # Number of CPU cores to use
|
1476
|
+
permutation_num=permutation_num, # Number of permutations for significance
|
1477
|
+
min_size=min_size, # Minimum gene set size
|
1478
|
+
max_size=max_size, # Maximum gene set size
|
1479
|
+
seed=seed, # Seed for reproducibility
|
1480
|
+
verbose=verbose, # Verbosity
|
1481
|
+
)
|
1482
|
+
except ValueError as e:
|
1483
|
+
print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
|
1484
|
+
return None
|
1485
|
+
df_prerank = pre_res.res2d
|
1486
|
+
if plot_:
|
1487
|
+
#! gseaplot
|
1488
|
+
# # (1) easy way
|
1489
|
+
# terms = df_prerank.Term
|
1490
|
+
# axs = pre_res.plot(terms=terms[0])
|
1491
|
+
# (2) # to make more control on the plot, use
|
1492
|
+
terms = df_prerank.Term
|
1493
|
+
axs = pre_res.plot(
|
1494
|
+
terms=terms[:7],
|
1495
|
+
# legend_kws={"loc": (1.2, 0)}, # set the legend loc
|
1496
|
+
# show_ranking=True, # whether to show the second yaxis
|
1497
|
+
figsize=(3, 4),
|
1498
|
+
)
|
1499
|
+
ips.figsave(dir_save + f"prerank_gseaplot_{gene_sets}.pdf")
|
1500
|
+
#!dotplot
|
1501
|
+
from gseapy import dotplot
|
1502
|
+
|
1503
|
+
# to save your figure, make sure that ``ofname`` is not None
|
1504
|
+
ax = dotplot(
|
1505
|
+
df_prerank,
|
1506
|
+
column="NOM p-val", # FDR q-val",
|
1507
|
+
cmap=cmap,
|
1508
|
+
size=size,
|
1509
|
+
figsize=(10, 5),
|
1510
|
+
cutoff=cutoff,
|
1511
|
+
show_ring=show_ring,
|
1512
|
+
)
|
1513
|
+
ips.figsave(dir_save + f"prerank_dotplot_{gene_sets}.pdf")
|
1514
|
+
|
1515
|
+
#! network plot
|
1516
|
+
from gseapy import enrichment_map
|
1517
|
+
import networkx as nx
|
1518
|
+
|
1519
|
+
for top_term in range(5, 50):
|
1520
|
+
try:
|
1521
|
+
# return two dataframe
|
1522
|
+
nodes, edges = enrichment_map(
|
1523
|
+
df=df_prerank,
|
1524
|
+
columns="FDR q-val",
|
1525
|
+
cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1526
|
+
top_term=top_term,
|
1527
|
+
)
|
1528
|
+
# build graph
|
1529
|
+
G = nx.from_pandas_edgelist(
|
1530
|
+
edges,
|
1531
|
+
source="src_idx",
|
1532
|
+
target="targ_idx",
|
1533
|
+
edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
|
1534
|
+
)
|
1535
|
+
# to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
|
1536
|
+
if len(list(nodes.Hits_ratio)) == len(G.nodes):
|
1537
|
+
node_sizes = list(nodes.Hits_ratio * 1000)
|
1538
|
+
else:
|
1539
|
+
raise ValueError(
|
1540
|
+
"The size of node_size list does not match the number of nodes in the graph."
|
1541
|
+
)
|
1542
|
+
|
1543
|
+
layout = "circular"
|
1544
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
1545
|
+
if layout == "spring":
|
1546
|
+
pos = nx.layout.spring_layout(G)
|
1547
|
+
elif layout == "circular":
|
1548
|
+
pos = nx.layout.circular_layout(G)
|
1549
|
+
elif layout == "shell":
|
1550
|
+
pos = nx.layout.shell_layout(G)
|
1551
|
+
elif layout == "spectral":
|
1552
|
+
pos = nx.layout.spectral_layout(G)
|
1553
|
+
|
1554
|
+
# node_size = nx.get_node_attributes()
|
1555
|
+
# draw node
|
1556
|
+
nx.draw_networkx_nodes(
|
1557
|
+
G,
|
1558
|
+
pos=pos,
|
1559
|
+
cmap=plt.cm.RdYlBu,
|
1560
|
+
node_color=list(nodes.NES),
|
1561
|
+
node_size=list(nodes.Hits_ratio * 1000),
|
1562
|
+
)
|
1563
|
+
# draw node label
|
1564
|
+
nx.draw_networkx_labels(
|
1565
|
+
G,
|
1566
|
+
pos=pos,
|
1567
|
+
labels=nodes.Term.to_dict(),
|
1568
|
+
font_size=8,
|
1569
|
+
verticalalignment="bottom",
|
1570
|
+
)
|
1571
|
+
# draw edge
|
1572
|
+
edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
|
1573
|
+
nx.draw_networkx_edges(
|
1574
|
+
G,
|
1575
|
+
pos=pos,
|
1576
|
+
width=list(map(lambda x: x * 10, edge_weight)),
|
1577
|
+
edge_color="#CDDBD4",
|
1578
|
+
)
|
1579
|
+
ax.set_axis_off()
|
1580
|
+
print(f"{gene_sets}(top_term={top_term})")
|
1581
|
+
plot.figsets(title=f"{gene_sets}(top_term={top_term})")
|
1582
|
+
ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
|
1583
|
+
break
|
1584
|
+
except:
|
1585
|
+
print(f"not work {top_term}")
|
1586
|
+
return df_prerank
|
1587
|
+
def plot_prerank(
|
1588
|
+
results_df,
|
1589
|
+
kind="bar", # 'barplot', 'dotplot'
|
1590
|
+
cutoff=0.25,
|
1591
|
+
show_ring=False,
|
1592
|
+
xticklabels_rot=0,
|
1593
|
+
title=None, # 'KEGG'
|
1594
|
+
cmap="coolwarm",
|
1595
|
+
n_top=10,
|
1596
|
+
size=5, # when size is None in network, by "NES"
|
1597
|
+
facecolor=None,# default by "NES"
|
1598
|
+
linewidth=None,# default by "NES"
|
1599
|
+
linecolor=None,# default by "NES"
|
1600
|
+
linealpha=None, # default by "NES"
|
1601
|
+
alpha=None,# default by "NES"
|
1602
|
+
ax=None,
|
1603
|
+
**kwargs,
|
1604
|
+
):
|
1605
|
+
kws_figsets = {}
|
1606
|
+
for k_arg, v_arg in kwargs.items():
|
1607
|
+
if "figset" in k_arg:
|
1608
|
+
kws_figsets = v_arg
|
1609
|
+
kwargs.pop(k_arg, None)
|
1610
|
+
break
|
1611
|
+
if isinstance(cmap, str):
|
1612
|
+
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
1613
|
+
elif isinstance(cmap, list):
|
1614
|
+
palette = cmap
|
1615
|
+
if n_top < 5:
|
1616
|
+
height_ = 4
|
1617
|
+
elif 5 <= n_top < 10:
|
1618
|
+
height_ = 5
|
1619
|
+
elif 10 <= n_top < 15:
|
1620
|
+
height_ = 6
|
1621
|
+
elif 15 <= n_top < 20:
|
1622
|
+
height_ = 7
|
1623
|
+
elif 20 <= n_top < 30:
|
1624
|
+
height_ = 8
|
1625
|
+
elif 30 <= n_top < 40:
|
1626
|
+
height_ = int(n_top / 5)
|
1627
|
+
else:
|
1628
|
+
height_ = int(n_top / 6)
|
1629
|
+
results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
|
1630
|
+
results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
|
1631
|
+
#! barplot
|
1632
|
+
if "bar" in kind.lower():
|
1633
|
+
df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
|
1634
|
+
if ax is None:
|
1635
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1636
|
+
ax = plot.plotxy(
|
1637
|
+
data=df_.head(n_top),
|
1638
|
+
kind="barplot",
|
1639
|
+
x="-log10(Adjusted P-value)",
|
1640
|
+
y="Term",
|
1641
|
+
hue="Term",
|
1642
|
+
palette=palette,
|
1643
|
+
legend=None,
|
1644
|
+
)
|
1645
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1646
|
+
return ax, df_
|
1647
|
+
|
1648
|
+
#! dotplot
|
1649
|
+
elif "dot" in kind.lower():
|
1650
|
+
#! dotplot
|
1651
|
+
cutoff_curr = cutoff
|
1652
|
+
step = 0.05
|
1653
|
+
cutoff_stop = 0.5
|
1654
|
+
while cutoff_curr <= cutoff_stop:
|
1655
|
+
try:
|
1656
|
+
if cutoff_curr != cutoff:
|
1657
|
+
plt.clf()
|
1658
|
+
ax = gp.dotplot(
|
1659
|
+
results_df,
|
1660
|
+
column="NOM p-val",
|
1661
|
+
show_ring=show_ring,
|
1662
|
+
xticklabels_rot=xticklabels_rot,
|
1663
|
+
title=title,
|
1664
|
+
cmap=cmap,
|
1665
|
+
cutoff=cutoff_curr,
|
1666
|
+
top_term=n_top,
|
1667
|
+
size=size,
|
1668
|
+
figsize=[10, height_],
|
1669
|
+
)
|
1670
|
+
if len(ax.collections) >= n_top:
|
1671
|
+
print(f"cutoff={cutoff_curr} done! ")
|
1672
|
+
break
|
1673
|
+
if cutoff_curr == cutoff_stop:
|
1674
|
+
break
|
1675
|
+
cutoff_curr += step
|
1676
|
+
except Exception as e:
|
1677
|
+
cutoff_curr += step
|
1678
|
+
print(
|
1679
|
+
f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
|
1680
|
+
)
|
1681
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1682
|
+
return ax, results_df
|
1683
|
+
|
1684
|
+
#! barplot with counts
|
1685
|
+
elif "co" in kind.lower():
|
1686
|
+
if ax is None:
|
1687
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1688
|
+
# 从overlap中提取出个数
|
1689
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
1690
|
+
ax = plot.plotxy(
|
1691
|
+
data=df_.head(n_top),
|
1692
|
+
kind="barplot",
|
1693
|
+
x="Count",
|
1694
|
+
y="Term",
|
1695
|
+
hue="Term",
|
1696
|
+
palette=palette,
|
1697
|
+
legend=None,
|
1698
|
+
ax=ax,
|
1699
|
+
**kwargs,
|
1700
|
+
)
|
1701
|
+
|
1702
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1703
|
+
return ax, df_
|
1704
|
+
#! scatter with counts
|
1705
|
+
elif "sca" in kind.lower():
|
1706
|
+
if isinstance(cmap, str):
|
1707
|
+
palette = plot.get_color(n_top, cmap=cmap)
|
1708
|
+
elif isinstance(cmap, list):
|
1709
|
+
palette = cmap
|
1710
|
+
if ax is None:
|
1711
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
1712
|
+
# 从overlap中提取出个数
|
1713
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
1714
|
+
ax = plot.plotxy(
|
1715
|
+
data=df_.head(n_top),
|
1716
|
+
kind="scatter",
|
1717
|
+
x="Count",
|
1718
|
+
y="Term",
|
1719
|
+
hue="Count",
|
1720
|
+
size="Count",
|
1721
|
+
sizes=[10,50],
|
1722
|
+
palette=palette,
|
1723
|
+
legend=None,
|
1724
|
+
ax=ax,
|
1725
|
+
**kwargs,
|
1726
|
+
)
|
1727
|
+
|
1728
|
+
plot.figsets(ax=ax, **kws_figsets)
|
1729
|
+
return ax, df_
|
1730
|
+
elif "net" in kind.lower():
|
1731
|
+
#! network plot
|
1732
|
+
from gseapy import enrichment_map
|
1733
|
+
import networkx as nx
|
1734
|
+
from matplotlib import cm
|
1735
|
+
# try:
|
1736
|
+
if cutoff>=1 or cutoff is None:
|
1737
|
+
print(f"cutoff is {cutoff} => Without applying filter")
|
1738
|
+
nodes, edges = enrichment_map(
|
1739
|
+
df=results_df,
|
1740
|
+
columns="NOM p-val",
|
1741
|
+
cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1742
|
+
top_term=n_top,
|
1743
|
+
)
|
1744
|
+
else:
|
1745
|
+
cutoff_curr = cutoff
|
1746
|
+
step = 0.05
|
1747
|
+
cutoff_stop = 1.0
|
1748
|
+
while cutoff_curr <= cutoff_stop:
|
1749
|
+
try:
|
1750
|
+
# return two dataframe
|
1751
|
+
nodes, edges = enrichment_map(
|
1752
|
+
df=results_df,
|
1753
|
+
columns="NOM p-val",
|
1754
|
+
cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
1755
|
+
top_term=n_top,
|
1756
|
+
)
|
1757
|
+
|
1758
|
+
if nodes.shape[0] >= n_top:
|
1759
|
+
print(f"cutoff={cutoff_curr} done! ")
|
1760
|
+
break
|
1761
|
+
if cutoff_curr == cutoff_stop:
|
1762
|
+
break
|
1763
|
+
cutoff_curr += step
|
1764
|
+
except Exception as e:
|
1765
|
+
cutoff_curr += step
|
1766
|
+
print(
|
1767
|
+
f"{e}: trying cutoff={cutoff_curr}"
|
1768
|
+
)
|
1769
|
+
|
1770
|
+
print("size: by 'NES'") if size is None else print("")
|
1771
|
+
print("linewidth: by 'NES'") if linewidth is None else print("")
|
1772
|
+
print("linecolor: by 'NES'") if linecolor is None else print("")
|
1773
|
+
print("linealpha: by 'NES'") if linealpha is None else print("")
|
1774
|
+
print("facecolor: by 'NES'") if facecolor is None else print("")
|
1775
|
+
print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
|
1776
|
+
edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
|
1777
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
1778
|
+
G,ax=plot_ppi(
|
1779
|
+
interactions=edges,
|
1780
|
+
player1="src_name",
|
1781
|
+
player2="targ_name",
|
1782
|
+
weight="jaccard_coef",
|
1783
|
+
size=[
|
1784
|
+
node["NES"] * 300 for _, node in nodes.iterrows()
|
1785
|
+
] if size is None else size, # size nodes by NES
|
1786
|
+
facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
|
1787
|
+
linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
|
1788
|
+
linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
|
1789
|
+
linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
|
1790
|
+
alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
|
1791
|
+
**kwargs
|
1792
|
+
)
|
1793
|
+
# except Exception as e:
|
1794
|
+
# print(f"not work {n_top},{e}")
|
1795
|
+
return ax, G, nodes, edges
|
1796
|
+
|
1797
|
+
|
979
1798
|
#! https://string-db.org/help/api/
|
980
1799
|
|
981
1800
|
import pandas as pd
|
@@ -1104,16 +1923,22 @@ def plot_ppi(
|
|
1104
1923
|
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
1105
1924
|
dist_node = 10, # Distance between each rank of circles
|
1106
1925
|
layout="degree",
|
1107
|
-
size=
|
1926
|
+
size=None,#700,
|
1927
|
+
sizes=(50,500),# min and max of size
|
1108
1928
|
facecolor="skyblue",
|
1109
1929
|
cmap='coolwarm',
|
1110
1930
|
edgecolor="k",
|
1111
1931
|
edgelinewidth=1.5,
|
1112
1932
|
alpha=.5,
|
1933
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
1113
1934
|
marker="o",
|
1114
1935
|
node_hideticks=True,
|
1115
1936
|
linecolor="gray",
|
1937
|
+
line_cmap='coolwarm',
|
1116
1938
|
linewidth=1.5,
|
1939
|
+
linewidths=(0.5,5),# min and max of linewidth
|
1940
|
+
linealpha=1.0,
|
1941
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
1117
1942
|
linestyle="-",
|
1118
1943
|
line_arrowstyle='-',
|
1119
1944
|
fontsize=10,
|
@@ -1142,7 +1967,7 @@ def plot_ppi(
|
|
1142
1967
|
for col in [player1, player2, weight]:
|
1143
1968
|
if col not in interactions.columns:
|
1144
1969
|
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
1145
|
-
|
1970
|
+
interactions.sort_values(by=[weight], inplace=True)
|
1146
1971
|
# Initialize Pyvis network
|
1147
1972
|
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
1148
1973
|
net.force_atlas_2based(
|
@@ -1161,34 +1986,71 @@ def plot_ppi(
|
|
1161
1986
|
G = nx.Graph()
|
1162
1987
|
for _, row in interactions.iterrows():
|
1163
1988
|
G.add_edge(row[player1], row[player2], weight=row[weight])
|
1989
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
1990
|
+
|
1164
1991
|
|
1165
1992
|
# Calculate node degrees
|
1166
1993
|
degrees = dict(G.degree())
|
1167
1994
|
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
1168
1995
|
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
1169
1996
|
|
1997
|
+
if not ips.isa(facecolor, 'color'):
|
1998
|
+
print("facecolor: based on degrees")
|
1999
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
2000
|
+
num_nodes = G.number_of_nodes()
|
2001
|
+
#* size
|
1170
2002
|
# Set properties based on degrees
|
1171
2003
|
if not isinstance(size, (int,float,list)):
|
2004
|
+
print("size: based on degrees")
|
1172
2005
|
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
1173
|
-
if
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
2006
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
2007
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
2008
|
+
# Normalize sizes
|
2009
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
2010
|
+
min_degree, max_degree = min(size), max(size)
|
2011
|
+
if max_degree > min_degree: # Avoid division by zero
|
2012
|
+
size = [
|
2013
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
2014
|
+
for sz in size
|
2015
|
+
]
|
2016
|
+
else:
|
2017
|
+
# If all values are the same, set them to a default of the midpoint
|
2018
|
+
size = [(min_size + max_size) / 2] * len(size)
|
2019
|
+
|
2020
|
+
#* facecolor
|
2021
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
2022
|
+
# * facealpha
|
2023
|
+
if isinstance(alpha, list):
|
2024
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
2025
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
2026
|
+
if len(alpha) > 0:
|
2027
|
+
# Normalize alpha based on the specified min and max
|
2028
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
2029
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
2030
|
+
alpha = [
|
2031
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
2032
|
+
for ea in alpha
|
2033
|
+
]
|
2034
|
+
else:
|
2035
|
+
# If all alpha values are the same, set them to the average of min and max
|
2036
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
2037
|
+
else:
|
2038
|
+
# Default to a full opacity if no edges are provided
|
2039
|
+
alpha = [1.0] * num_nodes
|
2040
|
+
else:
|
2041
|
+
# If alpha is a single value, convert it to a list and normalize it
|
2042
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
2043
|
+
|
2044
|
+
for i, node in enumerate(G.nodes()):
|
1185
2045
|
net.add_node(
|
1186
2046
|
node,
|
1187
2047
|
label=node,
|
1188
|
-
size=size[
|
1189
|
-
color=facecolor[
|
2048
|
+
size=size[i],
|
2049
|
+
color=facecolor[i],
|
2050
|
+
alpha=alpha[i],
|
1190
2051
|
font={"size": fontsize, "color": fontcolor},
|
1191
2052
|
)
|
2053
|
+
print(f'nodes number: {i+1}')
|
1192
2054
|
|
1193
2055
|
for edge in G.edges(data=True):
|
1194
2056
|
net.add_edge(
|
@@ -1198,6 +2060,7 @@ def plot_ppi(
|
|
1198
2060
|
color=edgecolor,
|
1199
2061
|
width=edgelinewidth * edge[2]["weight"],
|
1200
2062
|
)
|
2063
|
+
|
1201
2064
|
layouts = [
|
1202
2065
|
"spring",
|
1203
2066
|
"circular",
|
@@ -1209,7 +2072,8 @@ def plot_ppi(
|
|
1209
2072
|
"degree"
|
1210
2073
|
]
|
1211
2074
|
layout = ips.strcmp(layout, layouts)[0]
|
1212
|
-
print(layout)
|
2075
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
2076
|
+
|
1213
2077
|
# Choose layout
|
1214
2078
|
if layout == "spring":
|
1215
2079
|
pos = nx.spring_layout(G, k=k_value)
|
@@ -1235,7 +2099,9 @@ def plot_ppi(
|
|
1235
2099
|
# Calculate node degrees and sort nodes by degree
|
1236
2100
|
degrees = dict(G.degree())
|
1237
2101
|
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
1238
|
-
|
2102
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
2103
|
+
colormap = cm.get_cmap(cmap)
|
2104
|
+
|
1239
2105
|
# Create positions for concentric circles based on n_layers and n_rank
|
1240
2106
|
pos = {}
|
1241
2107
|
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
@@ -1266,8 +2132,8 @@ def plot_ppi(
|
|
1266
2132
|
|
1267
2133
|
# If ax is None, use plt.gca()
|
1268
2134
|
if ax is None:
|
1269
|
-
fig, ax = plt.subplots(1,1,figsize=figsize)
|
1270
|
-
|
2135
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
2136
|
+
|
1271
2137
|
# Draw nodes, edges, and labels with customization options
|
1272
2138
|
nx.draw_networkx_nodes(
|
1273
2139
|
G,
|
@@ -1281,6 +2147,54 @@ def plot_ppi(
|
|
1281
2147
|
hide_ticks=node_hideticks,
|
1282
2148
|
node_shape=marker
|
1283
2149
|
)
|
2150
|
+
|
2151
|
+
#* linewidth
|
2152
|
+
if not isinstance(linewidth, list):
|
2153
|
+
linewidth = [linewidth] * G.number_of_edges()
|
2154
|
+
else:
|
2155
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
2156
|
+
# Normalize linewidth if it is a list
|
2157
|
+
if isinstance(linewidth, list):
|
2158
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
2159
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
2160
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
2161
|
+
# Scale between vmin and vmax
|
2162
|
+
linewidth = [
|
2163
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
2164
|
+
for lw in linewidth
|
2165
|
+
]
|
2166
|
+
else:
|
2167
|
+
# If all values are the same, set them to a default of the midpoint
|
2168
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
2169
|
+
else:
|
2170
|
+
# If linewidth is a single value, convert it to a list of that value
|
2171
|
+
linewidth = [linewidth] * G.number_of_edges()
|
2172
|
+
#* linecolor
|
2173
|
+
if not isinstance(linecolor, str):
|
2174
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
2175
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
2176
|
+
colormap = cm.get_cmap(line_cmap)
|
2177
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
2178
|
+
else:
|
2179
|
+
linecolor = [linecolor] * G.number_of_edges()
|
2180
|
+
|
2181
|
+
# * linealpha
|
2182
|
+
if isinstance(linealpha, list):
|
2183
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
2184
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
2185
|
+
if len(linealpha) > 0:
|
2186
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
2187
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
2188
|
+
linealpha = [
|
2189
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
2190
|
+
for ea in linealpha
|
2191
|
+
]
|
2192
|
+
else:
|
2193
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
2194
|
+
else:
|
2195
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
2196
|
+
else:
|
2197
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
1284
2198
|
nx.draw_networkx_edges(
|
1285
2199
|
G,
|
1286
2200
|
pos,
|
@@ -1289,14 +2203,21 @@ def plot_ppi(
|
|
1289
2203
|
width=linewidth,
|
1290
2204
|
style=linestyle,
|
1291
2205
|
arrowstyle=line_arrowstyle,
|
1292
|
-
alpha=
|
2206
|
+
alpha=linealpha
|
1293
2207
|
)
|
2208
|
+
|
1294
2209
|
nx.draw_networkx_labels(
|
1295
2210
|
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
1296
2211
|
)
|
1297
2212
|
plot.figsets(ax=ax,**kws_figsets)
|
1298
2213
|
ax.axis("off")
|
1299
|
-
|
2214
|
+
if dir_save:
|
2215
|
+
if not os.path.basename(dir_save):
|
2216
|
+
dir_save="_.html"
|
2217
|
+
net.write_html(dir_save)
|
2218
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
2219
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
2220
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
1300
2221
|
return G,ax
|
1301
2222
|
|
1302
2223
|
|