py2ls 0.2.4.4__py3-none-any.whl → 0.2.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/bio.py CHANGED
@@ -324,7 +324,7 @@ def find_condition(data:pd.DataFrame, columns=["characteristics_ch1","title"]):
324
324
  # 详细看看每个信息的有哪些类, 其中有数字的, 要去除
325
325
  for col in columns:
326
326
  print(f"{"="*10} {col} {"="*10}")
327
- display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]]))
327
+ display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]],verbose=False))
328
328
 
329
329
  def add_condition(
330
330
  data: pd.DataFrame,
@@ -581,7 +581,7 @@ def batch_effect(
581
581
  return df_corrected
582
582
 
583
583
  def get_common_genes(elment1, elment2):
584
- common_genes=ips.shared(elment1, elment2)
584
+ common_genes=ips.shared(elment1, elment2,verbose=False)
585
585
  return common_genes
586
586
 
587
587
  def counts2expression(
@@ -667,7 +667,7 @@ def counts2expression(
667
667
 
668
668
  length.index=length.index.astype(str).str.strip()
669
669
  counts.columns = counts.columns.astype(str).str.strip()
670
- shared_genes=ips.shared(length.index, counts.columns)
670
+ shared_genes=ips.shared(length.index, counts.columns,verbose=False)
671
671
  length=length.loc[shared_genes]
672
672
  counts=counts.loc[:,shared_genes]
673
673
  columns_org = counts.columns.tolist()
@@ -814,7 +814,11 @@ def counts_deseq(counts_sam_gene: pd.DataFrame,
814
814
  # .reset_index()
815
815
  # .rename(columns={"index": "gene"})
816
816
  # )
817
- return dds, diff,stat_res
817
+ df_norm=pd.DataFrame(dds.layers['normed_counts'])
818
+ df_norm.index=counts_sam_gene.index
819
+ df_norm.columns=counts_sam_gene.columns
820
+ print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
821
+ return dds, diff, stat_res,df_norm
818
822
 
819
823
  def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
820
824
  """
@@ -842,6 +846,7 @@ def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", specie
842
846
 
843
847
  def get_enrichr(gene_symbol_list,
844
848
  gene_sets:str,
849
+ download:bool = False,
845
850
  species='Human',
846
851
  dir_save="./",
847
852
  plot_=False,
@@ -854,6 +859,7 @@ def get_enrichr(gene_symbol_list,
854
859
  title=None,# 'KEGG'
855
860
  cutoff=0.05,
856
861
  cmap="coolwarm",
862
+ size=5,
857
863
  **kwargs):
858
864
  """
859
865
  Note: Enrichr uses a list of Entrez gene symbols as input.
@@ -878,16 +884,22 @@ def get_enrichr(gene_symbol_list,
878
884
  lib_support_names = gp.get_library_name()
879
885
  # correct input gene_set name
880
886
  gene_sets_name=ips.strcmp(gene_sets,lib_support_names)[0]
887
+
881
888
  # download it
882
- gene_sets = gp.get_library(name=gene_sets_name, organism=species)
883
- print(f"gene_sets get ready: {gene_sets_name}")
889
+ if download:
890
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
891
+ else:
892
+ gene_sets = gene_sets_name # 避免重复下载
893
+ print(f"\ngene_sets get ready: {gene_sets_name}")
884
894
 
885
895
  # gene symbols are uppercase
886
896
  gene_symbol_list=[str(i).upper() for i in gene_symbol_list]
887
897
 
888
898
  # # check how shared genes
889
- if check_shared:
890
- shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False), ips.flatten(gene_sets,verbose=False))
899
+ if check_shared and isinstance(gene_sets, dict):
900
+ shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
901
+ ips.flatten(gene_sets,verbose=False),
902
+ verbose=False)
891
903
 
892
904
  #! enrichr
893
905
  try:
@@ -903,13 +915,13 @@ def get_enrichr(gene_symbol_list,
903
915
  return None
904
916
 
905
917
  results_df = enr.results
906
- print(f"got enrichr reslutls; shape: {results_df.shape}")
918
+ print(f"got enrichr reslutls; shape: {results_df.shape}\n")
907
919
  results_df["-log10(Adjusted P-value)"] = -np.log10(results_df["Adjusted P-value"])
908
920
  results_df.sort_values("-log10(Adjusted P-value)", inplace=True, ascending=False)
909
921
 
910
922
  if plot_:
911
923
  if palette is None:
912
- palette=plot.get_color(n_top, cmap="coolwarm")[::-1]
924
+ palette=plot.get_color(n_top, cmap=cmap)[::-1]
913
925
  #! barplot
914
926
  if n_top<5:
915
927
  height_=4
@@ -921,11 +933,12 @@ def get_enrichr(gene_symbol_list,
921
933
  height_=7
922
934
  elif 15<=n_top<20:
923
935
  height_=8
924
- elif 25<=n_top<30:
936
+ elif 20<=n_top<30:
925
937
  height_=9
926
938
  else:
927
939
  height_=int(n_top/3)
928
- plt.figure(figsize=[5, height_])
940
+ plt.figure(figsize=[10, height_])
941
+
929
942
  ax1=plot.plotxy(
930
943
  data=results_df.head(n_top),
931
944
  kind="barplot",
@@ -935,18 +948,17 @@ def get_enrichr(gene_symbol_list,
935
948
  palette=palette,
936
949
  legend=None,
937
950
  )
951
+ plot.figsets(ax=ax1, **kws_figsets)
938
952
  if dir_save:
939
953
  ips.figsave(f"{dir_save} enr_barplot.pdf")
940
- plot.figsets(ax=ax1, **kws_figsets)
941
954
  plt.show()
942
955
 
943
956
  #! dotplot
944
957
  cutoff_curr = cutoff
945
958
  step=0.05
946
959
  cutoff_stop = 0.5
947
- while cutoff_curr <=cutoff_stop:
960
+ while cutoff_curr <= cutoff_stop:
948
961
  try:
949
- print(kws_figsets)
950
962
  if cutoff_curr!=cutoff:
951
963
  plt.clf()
952
964
  ax2 = gp.dotplot(enr.res2d,
@@ -957,7 +969,8 @@ def get_enrichr(gene_symbol_list,
957
969
  cmap=cmap,
958
970
  cutoff=cutoff_curr,
959
971
  top_term=n_top,
960
- figsize=[6, height_])
972
+ size=size,
973
+ figsize=[10, height_])
961
974
  if len(ax2.collections)>=n_top:
962
975
  print(f"cutoff={cutoff_curr} done! ")
963
976
  break
@@ -975,7 +988,813 @@ def get_enrichr(gene_symbol_list,
975
988
 
976
989
  return results_df
977
990
 
991
+ def plot_enrichr(results_df,
992
+ kind="bar",# 'barplot', 'dotplot'
993
+ cutoff=0.05,
994
+ show_ring=False,
995
+ xticklabels_rot=0,
996
+ title=None,# 'KEGG'
997
+ cmap="coolwarm",
998
+ n_top=10,
999
+ size=5,
1000
+ ax=None,
1001
+ **kwargs):
1002
+ kws_figsets = {}
1003
+ for k_arg, v_arg in kwargs.items():
1004
+ if "figset" in k_arg:
1005
+ kws_figsets = v_arg
1006
+ kwargs.pop(k_arg, None)
1007
+ break
1008
+ if isinstance(cmap,str):
1009
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1010
+ elif isinstance(cmap,list):
1011
+ palette=cmap
1012
+ if n_top < 5:
1013
+ height_ = 3
1014
+ elif 5 <= n_top < 10:
1015
+ height_ = 3
1016
+ elif 10 <= n_top < 15:
1017
+ height_ = 3
1018
+ elif 15 <= n_top < 20:
1019
+ height_ =4
1020
+ elif 20 <= n_top < 30:
1021
+ height_ = 5
1022
+ elif 30 <= n_top < 40:
1023
+ height_ = int(n_top / 6)
1024
+ else:
1025
+ height_ = int(n_top / 8)
1026
+
1027
+ #! barplot
1028
+ if 'bar' in kind.lower():
1029
+ if ax is None:
1030
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1031
+ ax=plot.plotxy(
1032
+ data=results_df.head(n_top),
1033
+ kind="barplot",
1034
+ x="-log10(Adjusted P-value)",
1035
+ y="Term",
1036
+ hue="Term",
1037
+ palette=palette,
1038
+ legend=None,
1039
+ )
1040
+ plot.figsets(ax=ax, **kws_figsets)
1041
+ return ax,results_df
1042
+
1043
+ #! dotplot
1044
+ elif 'dot' in kind.lower():
1045
+ #! dotplot
1046
+ cutoff_curr = cutoff
1047
+ step=0.05
1048
+ cutoff_stop = 0.5
1049
+ while cutoff_curr <= cutoff_stop:
1050
+ try:
1051
+ if cutoff_curr!=cutoff:
1052
+ plt.clf()
1053
+ ax = gp.dotplot(results_df,
1054
+ column="Adjusted P-value",
1055
+ show_ring=show_ring,
1056
+ xticklabels_rot=xticklabels_rot,
1057
+ title=title,
1058
+ cmap=cmap,
1059
+ cutoff=cutoff_curr,
1060
+ top_term=n_top,
1061
+ size=size,
1062
+ figsize=[10, height_])
1063
+ if len(ax.collections)>=n_top:
1064
+ print(f"cutoff={cutoff_curr} done! ")
1065
+ break
1066
+ if cutoff_curr==cutoff_stop:
1067
+ break
1068
+ cutoff_curr+=step
1069
+ except Exception as e:
1070
+ cutoff_curr+=step
1071
+ print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1072
+ plot.figsets(ax=ax, **kws_figsets)
1073
+ return ax,results_df
1074
+
1075
+ #! barplot with counts
1076
+ elif 'count' in kind.lower():
1077
+ if ax is None:
1078
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1079
+ # 从overlap中提取出个数
1080
+ results_df["Count"] = results_df["Overlap"].apply(
1081
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
1082
+ df_=results_df.sort_values(by="Count", ascending=False)
1083
+
1084
+ ax=plot.plotxy(
1085
+ data=df_.head(n_top),
1086
+ kind="barplot",
1087
+ x="Count",
1088
+ y="Term",
1089
+ hue="Term",
1090
+ palette=palette,
1091
+ legend=None,
1092
+ ax=ax
1093
+ )
1094
+
1095
+ plot.figsets(ax=ax, **kws_figsets)
1096
+ return ax,df_
1097
+
1098
+ def plot_bp_cc_mf(
1099
+ deg_gene_list,
1100
+ gene_sets=[
1101
+ "GO_Biological_Process_2023",
1102
+ "GO_Cellular_Component_2023",
1103
+ "GO_Molecular_Function_2023",
1104
+ ],
1105
+ species="human",
1106
+ download=False,
1107
+ n_top=10,
1108
+ plot_=True,
1109
+ ax=None,
1110
+ palette=plot.get_color(3),
1111
+ **kwargs,
1112
+ ):
1113
+
1114
+ def res_enrichr_2_count(res_enrichr, n_top=10):
1115
+ """把enrich resulst 提取出count,并排序"""
1116
+ res_enrichr["Count"] = res_enrichr["Overlap"].apply(
1117
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x
1118
+ )
1119
+ res_enrichr.sort_values(by="Count", ascending=False, inplace=True)
1120
+
1121
+ return res_enrichr.head(n_top)#[["Term", "Count"]]
1122
+
1123
+ res_enrichr_BP = get_enrichr(
1124
+ deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
1125
+ )
1126
+ res_enrichr_CC = get_enrichr(
1127
+ deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
1128
+ )
1129
+ res_enrichr_MF = get_enrichr(
1130
+ deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
1131
+ )
1132
+
1133
+ df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
1134
+ df_BP["Ontology"] = ["BP"] * n_top
1135
+
1136
+ df_CC = res_enrichr_2_count(res_enrichr_CC, n_top=n_top)
1137
+ df_CC["Ontology"] = ["CC"] * n_top
1138
+
1139
+ df_MF = res_enrichr_2_count(res_enrichr_MF, n_top=n_top)
1140
+ df_MF["Ontology"] = ["MF"] * n_top
1141
+
1142
+ # 合并
1143
+ df2plot = pd.concat([df_BP, df_CC, df_MF])
1144
+ n_top=n_top*3
1145
+ if n_top < 5:
1146
+ height_ = 4
1147
+ elif 5 <= n_top < 10:
1148
+ height_ = 5
1149
+ elif 10 <= n_top < 15:
1150
+ height_ = 6
1151
+ elif 15 <= n_top < 20:
1152
+ height_ = 7
1153
+ elif 20 <= n_top < 30:
1154
+ height_ = 8
1155
+ elif 30 <= n_top < 40:
1156
+ height_ = int(n_top / 4)
1157
+ else:
1158
+ height_ = int(n_top / 5)
1159
+ if ax is None:
1160
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1161
+ # 作图
1162
+ display(df2plot)
1163
+ if df2plot["Term"].tolist()[0].endswith(")"):
1164
+ df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
1165
+ if plot_:
1166
+ ax = plot.plotxy(
1167
+ data=df2plot,
1168
+ x="Count",
1169
+ y="Term",
1170
+ hue="Ontology",
1171
+ kind="bar",
1172
+ palette=palette,
1173
+ ax=ax,
1174
+ **kwargs
1175
+ )
1176
+ return ax, df2plot
1177
+
1178
+ def get_library_name(by=None, verbose=False):
1179
+ lib_names=gp.get_library_name()
1180
+ if by is None:
1181
+ if verbose:
1182
+ [print(i) for i in lib_names]
1183
+ return lib_names
1184
+ else:
1185
+ return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
1186
+
1187
+
1188
+ def get_gsva(
1189
+ data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
1190
+ gene_sets: str,
1191
+ species:str="Human",
1192
+ dir_save:str="./",
1193
+ plot_:bool=False,
1194
+ n_top:int=30,
1195
+ check_shared:bool=True,
1196
+ cmap="coolwarm",
1197
+ min_size=1,
1198
+ max_size=1000,
1199
+ kcdf="Gaussian",# 'Gaussian' for continuous data
1200
+ method='gsva',
1201
+ seed=1,
1202
+ **kwargs,
1203
+ ):
1204
+ kws_figsets = {}
1205
+ for k_arg, v_arg in kwargs.items():
1206
+ if "figset" in k_arg:
1207
+ kws_figsets = v_arg
1208
+ kwargs.pop(k_arg, None)
1209
+ break
1210
+ species_org = species
1211
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1212
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1213
+ species = ips.strcmp(species, organisms)[0]
1214
+ if species_org.lower() != species.lower():
1215
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1216
+ if os.path.isfile(gene_sets):
1217
+ gene_sets_name = os.path.basename(gene_sets)
1218
+ gene_sets = ips.fload(gene_sets)
1219
+ else:
1220
+ lib_support_names = gp.get_library_name()
1221
+ # correct input gene_set name
1222
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1223
+ # download it
1224
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1225
+ print(f"gene_sets get ready: {gene_sets_name}")
1226
+
1227
+ # gene symbols are uppercase
1228
+ gene_symbol_list = [str(i).upper() for i in data_gene_samples.index]
1229
+ data_gene_samples.index=gene_symbol_list
1230
+ # display(data_gene_samples.head(3))
1231
+ # # check how shared genes
1232
+ if check_shared:
1233
+ ips.shared(
1234
+ ips.flatten(gene_symbol_list, verbose=False),
1235
+ ips.flatten(gene_sets, verbose=False),
1236
+ verbose=False
1237
+ )
1238
+ gsva_results = gp.gsva(
1239
+ data=data_gene_samples, # matrix should have genes as rows and samples as columns
1240
+ gene_sets=gene_sets,
1241
+ outdir=None,
1242
+ kcdf=kcdf, # 'Gaussian' for continuous data
1243
+ min_size=min_size,
1244
+ method=method,
1245
+ max_size=max_size,
1246
+ verbose=True,
1247
+ seed=seed,
1248
+ # no_plot=False,
1249
+ )
1250
+ gsva_res = gsva_results.res2d.copy()
1251
+ gsva_res["ES_abs"] = gsva_res["ES"].apply(np.abs)
1252
+ gsva_res = gsva_res.sort_values(by="ES_abs", ascending=False)
1253
+ gsva_res = (
1254
+ gsva_res.drop_duplicates(subset="Term").drop(columns="ES_abs")
1255
+ # .iloc[:80, :]
1256
+ .reset_index(drop=True)
1257
+ )
1258
+ gsva_res = gsva_res.sort_values(by="ES", ascending=False)
1259
+ if plot_:
1260
+ if gsva_res.shape[0]>=2*n_top:
1261
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1262
+ else:
1263
+ gsva_res_plot = gsva_res
1264
+ if isinstance(cmap,str):
1265
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1266
+ elif isinstance(cmap,list):
1267
+ if len(cmap)==2:
1268
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1269
+ else:
1270
+ palette=cmap
1271
+ # ! barplot
1272
+ if n_top < 5:
1273
+ height_ = 3
1274
+ elif 5 <= n_top < 10:
1275
+ height_ = 4
1276
+ elif 10 <= n_top < 15:
1277
+ height_ = 5
1278
+ elif 15 <= n_top < 20:
1279
+ height_ = 6
1280
+ elif 20 <= n_top < 30:
1281
+ height_ = 7
1282
+ elif 30 <= n_top < 40:
1283
+ height_ = int(n_top / 3.5)
1284
+ else:
1285
+ height_ = int(n_top / 3)
1286
+ plt.figure(figsize=[10, height_])
1287
+ ax2 = plot.plotxy(
1288
+ data=gsva_res_plot,
1289
+ x="ES",
1290
+ y="Term",
1291
+ hue="Term",
1292
+ palette=palette,
1293
+ kind=["bar"],
1294
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1295
+ )
1296
+ # 改变labels的位置
1297
+ for i, bar in enumerate(ax2.patches):
1298
+ term = gsva_res_plot.iloc[i]["Term"]
1299
+ es_value = gsva_res_plot.iloc[i]["ES"]
1300
+
1301
+ # Positive ES values: Align y-labels to the left
1302
+ if es_value > 0:
1303
+ ax2.annotate(
1304
+ term,
1305
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1306
+ xytext=(-5, 0), # Move to the left
1307
+ textcoords="offset points",
1308
+ ha="right",
1309
+ va="center", # Align labels to the right
1310
+ fontsize=10,
1311
+ color="black",
1312
+ )
1313
+ # Negative ES values: Align y-labels to the right
1314
+ else:
1315
+ ax2.annotate(
1316
+ term,
1317
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1318
+ xytext=(5, 0), # Move to the right
1319
+ textcoords="offset points",
1320
+ ha="left",
1321
+ va="center", # Align labels to the left
1322
+ fontsize=10,
1323
+ color="black",
1324
+ )
1325
+ plot.figsets(ax=ax2, **kws_figsets)
1326
+ if dir_save:
1327
+ ips.figsave(dir_save + f"GSVA_{gene_sets_name}.pdf")
1328
+ plt.show()
1329
+ return gsva_res.reset_index(drop=True)
1330
+
1331
+ def plot_gsva(gsva_res, # output from bio.get_gsva()
1332
+ n_top=10,
1333
+ ax=None,
1334
+ x="ES",
1335
+ y="Term",
1336
+ hue="Term",
1337
+ cmap="coolwarm",
1338
+ **kwargs
1339
+ ):
1340
+ kws_figsets = {}
1341
+ for k_arg, v_arg in kwargs.items():
1342
+ if "figset" in k_arg:
1343
+ kws_figsets = v_arg
1344
+ kwargs.pop(k_arg, None)
1345
+ break
1346
+ # ! barplot
1347
+ if n_top < 5:
1348
+ height_ = 4
1349
+ elif 5 <= n_top < 10:
1350
+ height_ = 5
1351
+ elif 10 <= n_top < 15:
1352
+ height_ = 6
1353
+ elif 15 <= n_top < 20:
1354
+ height_ = 7
1355
+ elif 20 <= n_top < 30:
1356
+ height_ = 8
1357
+ elif 30 <= n_top < 40:
1358
+ height_ = int(n_top / 3.5)
1359
+ else:
1360
+ height_ = int(n_top / 3)
1361
+ if ax is None:
1362
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1363
+ gsva_res = gsva_res.sort_values(by=x, ascending=False)
1364
+
1365
+ if gsva_res.shape[0]>=2*n_top:
1366
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1367
+ else:
1368
+ gsva_res_plot = gsva_res
1369
+ if isinstance(cmap,str):
1370
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1371
+ elif isinstance(cmap,list):
1372
+ if len(cmap)==2:
1373
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1374
+ else:
1375
+ palette=cmap
1376
+
1377
+ ax = plot.plotxy(
1378
+ ax=ax,
1379
+ data=gsva_res_plot,
1380
+ x=x,
1381
+ y=y,
1382
+ hue=hue,
1383
+ palette=palette,
1384
+ kind=["bar"],
1385
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1386
+ )
1387
+ # 改变labels的位置
1388
+ for i, bar in enumerate(ax.patches):
1389
+ term = gsva_res_plot.iloc[i]["Term"]
1390
+ es_value = gsva_res_plot.iloc[i]["ES"]
1391
+
1392
+ # Positive ES values: Align y-labels to the left
1393
+ if es_value > 0:
1394
+ ax.annotate(
1395
+ term,
1396
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1397
+ xytext=(-5, 0), # Move to the left
1398
+ textcoords="offset points",
1399
+ ha="right",
1400
+ va="center", # Align labels to the right
1401
+ fontsize=10,
1402
+ color="black",
1403
+ )
1404
+ # Negative ES values: Align y-labels to the right
1405
+ else:
1406
+ ax.annotate(
1407
+ term,
1408
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1409
+ xytext=(5, 0), # Move to the right
1410
+ textcoords="offset points",
1411
+ ha="left",
1412
+ va="center", # Align labels to the left
1413
+ fontsize=10,
1414
+ color="black",
1415
+ )
1416
+ plot.figsets(ax=ax, **kws_figsets)
1417
+ return ax
1418
+
1419
+ def get_prerank(
1420
+ rnk: pd.DataFrame,
1421
+ gene_sets: str,
1422
+ download: bool = False,
1423
+ species="Human",
1424
+ threads=8, # Number of CPU cores to use
1425
+ permutation_num=1000, # Number of permutations for significance
1426
+ min_size=1, # Minimum gene set size
1427
+ max_size=2000, # Maximum gene set size
1428
+ seed=1, # Seed for reproducibility
1429
+ verbose=True, # Verbosity
1430
+ dir_save="./",
1431
+ plot_=False,
1432
+ size=5,
1433
+ cutoff=0.25,
1434
+ show_ring=False,
1435
+ cmap="coolwarm",
1436
+ check_shared=True,
1437
+ **kwargs,
1438
+ ):
1439
+ """
1440
+ Note: Enrichr uses a list of Entrez gene symbols as input.
1441
+
1442
+ """
1443
+ kws_figsets = {}
1444
+ for k_arg, v_arg in kwargs.items():
1445
+ if "figset" in k_arg:
1446
+ kws_figsets = v_arg
1447
+ kwargs.pop(k_arg, None)
1448
+ break
1449
+ species_org = species
1450
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1451
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1452
+ species = ips.strcmp(species, organisms)[0]
1453
+ if species_org.lower() != species.lower():
1454
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1455
+ if os.path.isfile(gene_sets):
1456
+ gene_sets_name = os.path.basename(gene_sets)
1457
+ gene_sets = ips.fload(gene_sets)
1458
+ else:
1459
+ lib_support_names = gp.get_library_name()
1460
+ # correct input gene_set name
1461
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1462
+
1463
+ # download it
1464
+ if download:
1465
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1466
+ else:
1467
+ gene_sets = gene_sets_name # 避免重复下载
1468
+ print(f"\ngene_sets get ready: {gene_sets_name}")
978
1469
 
1470
+ #! prerank
1471
+ try:
1472
+ pre_res = gp.prerank(
1473
+ rnk=rnk,
1474
+ gene_sets=gene_sets,
1475
+ threads=threads, # Number of CPU cores to use
1476
+ permutation_num=permutation_num, # Number of permutations for significance
1477
+ min_size=min_size, # Minimum gene set size
1478
+ max_size=max_size, # Maximum gene set size
1479
+ seed=seed, # Seed for reproducibility
1480
+ verbose=verbose, # Verbosity
1481
+ )
1482
+ except ValueError as e:
1483
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
1484
+ return None
1485
+ df_prerank = pre_res.res2d
1486
+ if plot_:
1487
+ #! gseaplot
1488
+ # # (1) easy way
1489
+ # terms = df_prerank.Term
1490
+ # axs = pre_res.plot(terms=terms[0])
1491
+ # (2) # to make more control on the plot, use
1492
+ terms = df_prerank.Term
1493
+ axs = pre_res.plot(
1494
+ terms=terms[:7],
1495
+ # legend_kws={"loc": (1.2, 0)}, # set the legend loc
1496
+ # show_ranking=True, # whether to show the second yaxis
1497
+ figsize=(3, 4),
1498
+ )
1499
+ ips.figsave(dir_save + f"prerank_gseaplot_{gene_sets}.pdf")
1500
+ #!dotplot
1501
+ from gseapy import dotplot
1502
+
1503
+ # to save your figure, make sure that ``ofname`` is not None
1504
+ ax = dotplot(
1505
+ df_prerank,
1506
+ column="NOM p-val", # FDR q-val",
1507
+ cmap=cmap,
1508
+ size=size,
1509
+ figsize=(10, 5),
1510
+ cutoff=cutoff,
1511
+ show_ring=show_ring,
1512
+ )
1513
+ ips.figsave(dir_save + f"prerank_dotplot_{gene_sets}.pdf")
1514
+
1515
+ #! network plot
1516
+ from gseapy import enrichment_map
1517
+ import networkx as nx
1518
+
1519
+ for top_term in range(5, 50):
1520
+ try:
1521
+ # return two dataframe
1522
+ nodes, edges = enrichment_map(
1523
+ df=df_prerank,
1524
+ columns="FDR q-val",
1525
+ cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1526
+ top_term=top_term,
1527
+ )
1528
+ # build graph
1529
+ G = nx.from_pandas_edgelist(
1530
+ edges,
1531
+ source="src_idx",
1532
+ target="targ_idx",
1533
+ edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
1534
+ )
1535
+ # to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
1536
+ if len(list(nodes.Hits_ratio)) == len(G.nodes):
1537
+ node_sizes = list(nodes.Hits_ratio * 1000)
1538
+ else:
1539
+ raise ValueError(
1540
+ "The size of node_size list does not match the number of nodes in the graph."
1541
+ )
1542
+
1543
+ layout = "circular"
1544
+ fig, ax = plt.subplots(figsize=(8, 8))
1545
+ if layout == "spring":
1546
+ pos = nx.layout.spring_layout(G)
1547
+ elif layout == "circular":
1548
+ pos = nx.layout.circular_layout(G)
1549
+ elif layout == "shell":
1550
+ pos = nx.layout.shell_layout(G)
1551
+ elif layout == "spectral":
1552
+ pos = nx.layout.spectral_layout(G)
1553
+
1554
+ # node_size = nx.get_node_attributes()
1555
+ # draw node
1556
+ nx.draw_networkx_nodes(
1557
+ G,
1558
+ pos=pos,
1559
+ cmap=plt.cm.RdYlBu,
1560
+ node_color=list(nodes.NES),
1561
+ node_size=list(nodes.Hits_ratio * 1000),
1562
+ )
1563
+ # draw node label
1564
+ nx.draw_networkx_labels(
1565
+ G,
1566
+ pos=pos,
1567
+ labels=nodes.Term.to_dict(),
1568
+ font_size=8,
1569
+ verticalalignment="bottom",
1570
+ )
1571
+ # draw edge
1572
+ edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
1573
+ nx.draw_networkx_edges(
1574
+ G,
1575
+ pos=pos,
1576
+ width=list(map(lambda x: x * 10, edge_weight)),
1577
+ edge_color="#CDDBD4",
1578
+ )
1579
+ ax.set_axis_off()
1580
+ print(f"{gene_sets}(top_term={top_term})")
1581
+ plot.figsets(title=f"{gene_sets}(top_term={top_term})")
1582
+ ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
1583
+ break
1584
+ except:
1585
+ print(f"not work {top_term}")
1586
+ return df_prerank
1587
+ def plot_prerank(
1588
+ results_df,
1589
+ kind="bar", # 'barplot', 'dotplot'
1590
+ cutoff=0.25,
1591
+ show_ring=False,
1592
+ xticklabels_rot=0,
1593
+ title=None, # 'KEGG'
1594
+ cmap="coolwarm",
1595
+ n_top=10,
1596
+ size=5, # when size is None in network, by "NES"
1597
+ facecolor=None,# default by "NES"
1598
+ linewidth=None,# default by "NES"
1599
+ linecolor=None,# default by "NES"
1600
+ linealpha=None, # default by "NES"
1601
+ alpha=None,# default by "NES"
1602
+ ax=None,
1603
+ **kwargs,
1604
+ ):
1605
+ kws_figsets = {}
1606
+ for k_arg, v_arg in kwargs.items():
1607
+ if "figset" in k_arg:
1608
+ kws_figsets = v_arg
1609
+ kwargs.pop(k_arg, None)
1610
+ break
1611
+ if isinstance(cmap, str):
1612
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1613
+ elif isinstance(cmap, list):
1614
+ palette = cmap
1615
+ if n_top < 5:
1616
+ height_ = 4
1617
+ elif 5 <= n_top < 10:
1618
+ height_ = 5
1619
+ elif 10 <= n_top < 15:
1620
+ height_ = 6
1621
+ elif 15 <= n_top < 20:
1622
+ height_ = 7
1623
+ elif 20 <= n_top < 30:
1624
+ height_ = 8
1625
+ elif 30 <= n_top < 40:
1626
+ height_ = int(n_top / 5)
1627
+ else:
1628
+ height_ = int(n_top / 6)
1629
+ results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
1630
+ results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
1631
+ #! barplot
1632
+ if "bar" in kind.lower():
1633
+ df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
1634
+ if ax is None:
1635
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1636
+ ax = plot.plotxy(
1637
+ data=df_.head(n_top),
1638
+ kind="barplot",
1639
+ x="-log10(Adjusted P-value)",
1640
+ y="Term",
1641
+ hue="Term",
1642
+ palette=palette,
1643
+ legend=None,
1644
+ )
1645
+ plot.figsets(ax=ax, **kws_figsets)
1646
+ return ax, df_
1647
+
1648
+ #! dotplot
1649
+ elif "dot" in kind.lower():
1650
+ #! dotplot
1651
+ cutoff_curr = cutoff
1652
+ step = 0.05
1653
+ cutoff_stop = 0.5
1654
+ while cutoff_curr <= cutoff_stop:
1655
+ try:
1656
+ if cutoff_curr != cutoff:
1657
+ plt.clf()
1658
+ ax = gp.dotplot(
1659
+ results_df,
1660
+ column="NOM p-val",
1661
+ show_ring=show_ring,
1662
+ xticklabels_rot=xticklabels_rot,
1663
+ title=title,
1664
+ cmap=cmap,
1665
+ cutoff=cutoff_curr,
1666
+ top_term=n_top,
1667
+ size=size,
1668
+ figsize=[10, height_],
1669
+ )
1670
+ if len(ax.collections) >= n_top:
1671
+ print(f"cutoff={cutoff_curr} done! ")
1672
+ break
1673
+ if cutoff_curr == cutoff_stop:
1674
+ break
1675
+ cutoff_curr += step
1676
+ except Exception as e:
1677
+ cutoff_curr += step
1678
+ print(
1679
+ f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
1680
+ )
1681
+ plot.figsets(ax=ax, **kws_figsets)
1682
+ return ax, results_df
1683
+
1684
+ #! barplot with counts
1685
+ elif "co" in kind.lower():
1686
+ if ax is None:
1687
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1688
+ # 从overlap中提取出个数
1689
+ df_ = results_df.sort_values(by="Count", ascending=False)
1690
+ ax = plot.plotxy(
1691
+ data=df_.head(n_top),
1692
+ kind="barplot",
1693
+ x="Count",
1694
+ y="Term",
1695
+ hue="Term",
1696
+ palette=palette,
1697
+ legend=None,
1698
+ ax=ax,
1699
+ **kwargs,
1700
+ )
1701
+
1702
+ plot.figsets(ax=ax, **kws_figsets)
1703
+ return ax, df_
1704
+ #! scatter with counts
1705
+ elif "sca" in kind.lower():
1706
+ if isinstance(cmap, str):
1707
+ palette = plot.get_color(n_top, cmap=cmap)
1708
+ elif isinstance(cmap, list):
1709
+ palette = cmap
1710
+ if ax is None:
1711
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1712
+ # 从overlap中提取出个数
1713
+ df_ = results_df.sort_values(by="Count", ascending=False)
1714
+ ax = plot.plotxy(
1715
+ data=df_.head(n_top),
1716
+ kind="scatter",
1717
+ x="Count",
1718
+ y="Term",
1719
+ hue="Count",
1720
+ size="Count",
1721
+ sizes=[10,50],
1722
+ palette=palette,
1723
+ legend=None,
1724
+ ax=ax,
1725
+ **kwargs,
1726
+ )
1727
+
1728
+ plot.figsets(ax=ax, **kws_figsets)
1729
+ return ax, df_
1730
+ elif "net" in kind.lower():
1731
+ #! network plot
1732
+ from gseapy import enrichment_map
1733
+ import networkx as nx
1734
+ from matplotlib import cm
1735
+ # try:
1736
+ if cutoff>=1 or cutoff is None:
1737
+ print(f"cutoff is {cutoff} => Without applying filter")
1738
+ nodes, edges = enrichment_map(
1739
+ df=results_df,
1740
+ columns="NOM p-val",
1741
+ cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1742
+ top_term=n_top,
1743
+ )
1744
+ else:
1745
+ cutoff_curr = cutoff
1746
+ step = 0.05
1747
+ cutoff_stop = 1.0
1748
+ while cutoff_curr <= cutoff_stop:
1749
+ try:
1750
+ # return two dataframe
1751
+ nodes, edges = enrichment_map(
1752
+ df=results_df,
1753
+ columns="NOM p-val",
1754
+ cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1755
+ top_term=n_top,
1756
+ )
1757
+
1758
+ if nodes.shape[0] >= n_top:
1759
+ print(f"cutoff={cutoff_curr} done! ")
1760
+ break
1761
+ if cutoff_curr == cutoff_stop:
1762
+ break
1763
+ cutoff_curr += step
1764
+ except Exception as e:
1765
+ cutoff_curr += step
1766
+ print(
1767
+ f"{e}: trying cutoff={cutoff_curr}"
1768
+ )
1769
+
1770
+ print("size: by 'NES'") if size is None else print("")
1771
+ print("linewidth: by 'NES'") if linewidth is None else print("")
1772
+ print("linecolor: by 'NES'") if linecolor is None else print("")
1773
+ print("linealpha: by 'NES'") if linealpha is None else print("")
1774
+ print("facecolor: by 'NES'") if facecolor is None else print("")
1775
+ print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
1776
+ edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
1777
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1778
+ G,ax=plot_ppi(
1779
+ interactions=edges,
1780
+ player1="src_name",
1781
+ player2="targ_name",
1782
+ weight="jaccard_coef",
1783
+ size=[
1784
+ node["NES"] * 300 for _, node in nodes.iterrows()
1785
+ ] if size is None else size, # size nodes by NES
1786
+ facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
1787
+ linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
1788
+ linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
1789
+ linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
1790
+ alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
1791
+ **kwargs
1792
+ )
1793
+ # except Exception as e:
1794
+ # print(f"not work {n_top},{e}")
1795
+ return ax, G, nodes, edges
1796
+
1797
+
979
1798
  #! https://string-db.org/help/api/
980
1799
 
981
1800
  import pandas as pd
@@ -1104,16 +1923,22 @@ def plot_ppi(
1104
1923
  n_rank=[5, 10], # Nodes in each rank for the concentric layout
1105
1924
  dist_node = 10, # Distance between each rank of circles
1106
1925
  layout="degree",
1107
- size='auto',#700,
1926
+ size=None,#700,
1927
+ sizes=(50,500),# min and max of size
1108
1928
  facecolor="skyblue",
1109
1929
  cmap='coolwarm',
1110
1930
  edgecolor="k",
1111
1931
  edgelinewidth=1.5,
1112
1932
  alpha=.5,
1933
+ alphas=(0.1, 1.0),# min and max of alpha
1113
1934
  marker="o",
1114
1935
  node_hideticks=True,
1115
1936
  linecolor="gray",
1937
+ line_cmap='coolwarm',
1116
1938
  linewidth=1.5,
1939
+ linewidths=(0.5,5),# min and max of linewidth
1940
+ linealpha=1.0,
1941
+ linealphas=(0.1,1.0),# min and max of linealpha
1117
1942
  linestyle="-",
1118
1943
  line_arrowstyle='-',
1119
1944
  fontsize=10,
@@ -1142,7 +1967,7 @@ def plot_ppi(
1142
1967
  for col in [player1, player2, weight]:
1143
1968
  if col not in interactions.columns:
1144
1969
  raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
1145
-
1970
+ interactions.sort_values(by=[weight], inplace=True)
1146
1971
  # Initialize Pyvis network
1147
1972
  net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
1148
1973
  net.force_atlas_2based(
@@ -1161,34 +1986,71 @@ def plot_ppi(
1161
1986
  G = nx.Graph()
1162
1987
  for _, row in interactions.iterrows():
1163
1988
  G.add_edge(row[player1], row[player2], weight=row[weight])
1989
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
1990
+
1164
1991
 
1165
1992
  # Calculate node degrees
1166
1993
  degrees = dict(G.degree())
1167
1994
  norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
1168
1995
  colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1169
1996
 
1997
+ if not ips.isa(facecolor, 'color'):
1998
+ print("facecolor: based on degrees")
1999
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
2000
+ num_nodes = G.number_of_nodes()
2001
+ #* size
1170
2002
  # Set properties based on degrees
1171
2003
  if not isinstance(size, (int,float,list)):
2004
+ print("size: based on degrees")
1172
2005
  size = [deg * 50 for deg in degrees.values()] # Scale sizes
1173
- if not ips.isa(facecolor, 'color'):
1174
- facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
1175
- if size is None:
1176
- size = [700] * G.number_of_nodes() # Default size for all nodes
1177
- elif isinstance(size, (int, float)):
1178
- size = [size] * G.number_of_nodes() # If a scalar, apply to all nodes
1179
- # else:
1180
- # size = size.tolist() # Ensure size is a list
1181
- if len(size)>G.number_of_nodes():
1182
- size=size[:G.number_of_nodes()]
1183
-
1184
- for node in G.nodes():
2006
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
2007
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
2008
+ # Normalize sizes
2009
+ min_size, max_size = sizes # Use sizes tuple for min and max values
2010
+ min_degree, max_degree = min(size), max(size)
2011
+ if max_degree > min_degree: # Avoid division by zero
2012
+ size = [
2013
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
2014
+ for sz in size
2015
+ ]
2016
+ else:
2017
+ # If all values are the same, set them to a default of the midpoint
2018
+ size = [(min_size + max_size) / 2] * len(size)
2019
+
2020
+ #* facecolor
2021
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
2022
+ # * facealpha
2023
+ if isinstance(alpha, list):
2024
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
2025
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
2026
+ if len(alpha) > 0:
2027
+ # Normalize alpha based on the specified min and max
2028
+ min_alpha, max_alpha = min(alpha), max(alpha)
2029
+ if max_alpha > min_alpha: # Avoid division by zero
2030
+ alpha = [
2031
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
2032
+ for ea in alpha
2033
+ ]
2034
+ else:
2035
+ # If all alpha values are the same, set them to the average of min and max
2036
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
2037
+ else:
2038
+ # Default to a full opacity if no edges are provided
2039
+ alpha = [1.0] * num_nodes
2040
+ else:
2041
+ # If alpha is a single value, convert it to a list and normalize it
2042
+ alpha = [alpha] * num_nodes # Adjust based on alphas
2043
+
2044
+ for i, node in enumerate(G.nodes()):
1185
2045
  net.add_node(
1186
2046
  node,
1187
2047
  label=node,
1188
- size=size[list(G.nodes()).index(node)] if isinstance(size,list) else size[0],
1189
- color=facecolor[list(G.nodes()).index(node)] if isinstance(facecolor,list) else facecolor,
2048
+ size=size[i],
2049
+ color=facecolor[i],
2050
+ alpha=alpha[i],
1190
2051
  font={"size": fontsize, "color": fontcolor},
1191
2052
  )
2053
+ print(f'nodes number: {i+1}')
1192
2054
 
1193
2055
  for edge in G.edges(data=True):
1194
2056
  net.add_edge(
@@ -1198,6 +2060,7 @@ def plot_ppi(
1198
2060
  color=edgecolor,
1199
2061
  width=edgelinewidth * edge[2]["weight"],
1200
2062
  )
2063
+
1201
2064
  layouts = [
1202
2065
  "spring",
1203
2066
  "circular",
@@ -1209,7 +2072,8 @@ def plot_ppi(
1209
2072
  "degree"
1210
2073
  ]
1211
2074
  layout = ips.strcmp(layout, layouts)[0]
1212
- print(layout)
2075
+ print(f"layout:{layout}, or select one in {layouts}")
2076
+
1213
2077
  # Choose layout
1214
2078
  if layout == "spring":
1215
2079
  pos = nx.spring_layout(G, k=k_value)
@@ -1235,7 +2099,9 @@ def plot_ppi(
1235
2099
  # Calculate node degrees and sort nodes by degree
1236
2100
  degrees = dict(G.degree())
1237
2101
  sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
1238
-
2102
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
2103
+ colormap = cm.get_cmap(cmap)
2104
+
1239
2105
  # Create positions for concentric circles based on n_layers and n_rank
1240
2106
  pos = {}
1241
2107
  n_layers=len(n_rank)+1 if n_layers is None else n_layers
@@ -1266,8 +2132,8 @@ def plot_ppi(
1266
2132
 
1267
2133
  # If ax is None, use plt.gca()
1268
2134
  if ax is None:
1269
- fig, ax = plt.subplots(1,1,figsize=figsize)
1270
-
2135
+ fig, ax = plt.subplots(1,1,figsize=figsize)
2136
+
1271
2137
  # Draw nodes, edges, and labels with customization options
1272
2138
  nx.draw_networkx_nodes(
1273
2139
  G,
@@ -1281,6 +2147,54 @@ def plot_ppi(
1281
2147
  hide_ticks=node_hideticks,
1282
2148
  node_shape=marker
1283
2149
  )
2150
+
2151
+ #* linewidth
2152
+ if not isinstance(linewidth, list):
2153
+ linewidth = [linewidth] * G.number_of_edges()
2154
+ else:
2155
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
2156
+ # Normalize linewidth if it is a list
2157
+ if isinstance(linewidth, list):
2158
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
2159
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
2160
+ if max_linewidth > min_linewidth: # Avoid division by zero
2161
+ # Scale between vmin and vmax
2162
+ linewidth = [
2163
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
2164
+ for lw in linewidth
2165
+ ]
2166
+ else:
2167
+ # If all values are the same, set them to a default of the midpoint
2168
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
2169
+ else:
2170
+ # If linewidth is a single value, convert it to a list of that value
2171
+ linewidth = [linewidth] * G.number_of_edges()
2172
+ #* linecolor
2173
+ if not isinstance(linecolor, str):
2174
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
2175
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
2176
+ colormap = cm.get_cmap(line_cmap)
2177
+ linecolor = [colormap(norm(weight)) for weight in weights]
2178
+ else:
2179
+ linecolor = [linecolor] * G.number_of_edges()
2180
+
2181
+ # * linealpha
2182
+ if isinstance(linealpha, list):
2183
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
2184
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
2185
+ if len(linealpha) > 0:
2186
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
2187
+ if max_linealpha > min_linealpha: # Avoid division by zero
2188
+ linealpha = [
2189
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
2190
+ for ea in linealpha
2191
+ ]
2192
+ else:
2193
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
2194
+ else:
2195
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
2196
+ else:
2197
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
1284
2198
  nx.draw_networkx_edges(
1285
2199
  G,
1286
2200
  pos,
@@ -1289,14 +2203,21 @@ def plot_ppi(
1289
2203
  width=linewidth,
1290
2204
  style=linestyle,
1291
2205
  arrowstyle=line_arrowstyle,
1292
- alpha=0.7
2206
+ alpha=linealpha
1293
2207
  )
2208
+
1294
2209
  nx.draw_networkx_labels(
1295
2210
  G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
1296
2211
  )
1297
2212
  plot.figsets(ax=ax,**kws_figsets)
1298
2213
  ax.axis("off")
1299
- net.write_html(dir_save)
2214
+ if dir_save:
2215
+ if not os.path.basename(dir_save):
2216
+ dir_save="_.html"
2217
+ net.write_html(dir_save)
2218
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
2219
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
2220
+ ips.figsave(dir_save.replace(".html",".pdf"))
1300
2221
  return G,ax
1301
2222
 
1302
2223