gwaslab 3.4.48__py3-none-any.whl → 3.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -9,6 +9,7 @@ from matplotlib.patches import Rectangle
9
9
  from adjustText import adjust_text
10
10
  from gwaslab.viz_aux_save_figure import save_figure
11
11
  from gwaslab.util_in_get_sig import getsig
12
+ from gwaslab.util_in_get_sig import annogene
12
13
  from gwaslab.g_Log import Log
13
14
  from gwaslab.util_in_correct_winnerscurse import wc_correct
14
15
  from gwaslab.util_in_correct_winnerscurse import wc_correct_test
@@ -59,6 +60,7 @@ def compare_effect(path1,
59
60
  xylabel_prefix="Per-allele effect size in ",
60
61
  helper_line_args=None,
61
62
  fontargs=None,
63
+ build="19",
62
64
  r_or_r2="r",
63
65
  #
64
66
  errargs=None,
@@ -77,10 +79,9 @@ def compare_effect(path1,
77
79
  scaled2 = True
78
80
  if is_q_mc=="fdr" or is_q_mc=="bon":
79
81
  is_q = True
80
-
81
82
  if is_q == True:
82
83
  if is_q_mc not in [False,"fdr","bon","non"]:
83
- raise ValueError("Please select either fdr or bon or non for is_q_mc.")
84
+ raise ValueError('Please select either "fdr" or "bon" or "non"/False for is_q_mc.')
84
85
  if save_args is None:
85
86
  save_args = {"dpi":300,"facecolor":"white"}
86
87
  if reg_box is None:
@@ -89,6 +90,8 @@ def compare_effect(path1,
89
90
  sep = ["\t","\t"]
90
91
  if get_lead_args is None:
91
92
  get_lead_args = {}
93
+ if anno=="GENENAME":
94
+ get_lead_args["anno"]=True
92
95
  if errargs is None:
93
96
  errargs={"ecolor":"#cccccc","elinewidth":1}
94
97
  if fontargs is None:
@@ -191,10 +194,12 @@ def compare_effect(path1,
191
194
  ######### 8.1 if a snplist is provided, use the snp list
192
195
  log.write(" -Extract variants in the given list from "+label[0]+"...")
193
196
  sig_list_1 = sumstats.loc[sumstats["SNPID"].isin(snplist),:].copy()
197
+ if anno=="GENENAME":
198
+ sig_list_1 = annogene(sumstats,"SNPID","CHR","POS", build=build, verbose=verbose,**get_lead_args)
194
199
  else:
195
- ######### 8,2 otherwise use the sutomatically detected lead SNPs
200
+ ######### 8,2 otherwise use the automatically detected lead SNPs
196
201
  log.write(" -Extract lead variants from "+label[0]+"...")
197
- sig_list_1 = getsig(sumstats,"SNPID","CHR","POS","P", verbose=verbose,sig_level=sig_level,**get_lead_args)
202
+ sig_list_1 = getsig(sumstats,"SNPID","CHR","POS","P", build=build, verbose=verbose,sig_level=sig_level,**get_lead_args)
198
203
 
199
204
  if drop==True:
200
205
  sig_list_1 = drop_duplicate_and_na(sig_list_1, sort_by="P", log=log ,verbose=verbose)
@@ -235,10 +240,12 @@ def compare_effect(path1,
235
240
  ######### 12.1 if a snplist is provided, use the snp list
236
241
  log.write(" -Extract snps in the given list from "+label[1]+"...")
237
242
  sig_list_2 = sumstats.loc[sumstats["SNPID"].isin(snplist),:].copy()
243
+ if anno=="GENENAME":
244
+ sig_list_2 = annogene(sumstats,"SNPID","CHR","POS", build=build, verbose=verbose,**get_lead_args)
238
245
  else:
239
246
  log.write(" -Extract lead snps from "+label[1]+"...")
240
247
  ######### 12.2 otherwise use the sutomatically detected lead SNPs
241
- sig_list_2 = getsig(sumstats,"SNPID","CHR","POS","P",
248
+ sig_list_2 = getsig(sumstats,"SNPID","CHR","POS","P",build=build,
242
249
  verbose=verbose,sig_level=sig_level,**get_lead_args)
243
250
  if drop==True:
244
251
  sig_list_2 = drop_duplicate_and_na(sig_list_2, sort_by="P", log=log ,verbose=verbose)
@@ -248,6 +255,10 @@ def compare_effect(path1,
248
255
  log.write("Merging snps from "+label[0]+" and "+label[1]+"...")
249
256
 
250
257
  sig_list_merged = pd.merge(sig_list_1,sig_list_2,left_on="SNPID",right_on="SNPID",how="outer",suffixes=('_1', '_2'))
258
+ if anno == "GENENAME":
259
+ sig_list_merged.loc[sig_list_merged["SNPID"].isin((sig_list_1["SNPID"])),"GENENAME"] = sig_list_merged.loc[sig_list_merged["SNPID"].isin((sig_list_1["SNPID"])),"GENE_1"]
260
+ sig_list_merged.loc[~sig_list_merged["SNPID"].isin((sig_list_1["SNPID"])),"GENENAME"] = sig_list_merged.loc[~sig_list_merged["SNPID"].isin((sig_list_1["SNPID"])),"GENE_2"]
261
+ sig_list_merged = sig_list_merged.drop(columns=["GENE_1","GENE_2","LOCATION_1","LOCATION_2"])
251
262
  # SNPID P_1 P_2
252
263
  #0 rs117986209 0.142569 0.394455
253
264
  #1 rs6704312 0.652104 0.143750
@@ -533,7 +544,7 @@ def compare_effect(path1,
533
544
 
534
545
  ########################## Het test############################################################
535
546
  ## heterogeneity test
536
- if (is_q is True):
547
+ if (is_q == True):
537
548
  log.write(" -Calculating Cochran's Q statistics and peform chisq test...", verbose=verbose)
538
549
  if mode=="beta" or mode=="BETA" or mode=="Beta":
539
550
  sig_list_merged = test_q(sig_list_merged,"EFFECT_1","SE_1","EFFECT_2_aligned","SE_2",q_level=q_level,is_q_mc=is_q_mc, log=log, verbose=verbose)
@@ -552,7 +563,7 @@ def compare_effect(path1,
552
563
  log.write(" -Exclude "+str(len(sig_list_merged) -sum(both_eaf_clear))+ " variants with maf <",maf_level, verbose=verbose)
553
564
  sig_list_merged = sig_list_merged.loc[both_eaf_clear,:]
554
565
  # heterogeneity summary
555
- if (is_q is True):
566
+ if (is_q == True):
556
567
  log.write(" -Significant het:" ,len(sig_list_merged.loc[sig_list_merged["HetP"]<0.05,:]), verbose=verbose)
557
568
  log.write(" -All sig:" ,len(sig_list_merged), verbose=verbose)
558
569
  log.write(" -Het rate:" ,len(sig_list_merged.loc[sig_list_merged["HetP"]<0.05,:])/len(sig_list_merged), verbose=verbose)
@@ -633,11 +644,11 @@ def compare_effect(path1,
633
644
  ax.scatter(both["OR_1"],both["OR_2_aligned"],label=label[2],zorder=2,color="#205be6",edgecolors=both["Edge_color"],marker="s",**scatterargs)
634
645
  legend_elements.append(label[2])
635
646
  ## annotation #################################################################################################################
636
- if anno==True:
647
+ if anno==True or anno=="GENENAME":
637
648
  sig_list_toanno = sig_list_merged.dropna(axis=0)
638
649
  if is_q==True and anno_het == True:
639
650
  sig_list_toanno = sig_list_toanno.loc[sig_list_toanno["Edge_color"]=="black",:]
640
-
651
+
641
652
  if mode=="beta":
642
653
  sig_list_toanno = sig_list_toanno.loc[sig_list_toanno["EFFECT_1"].abs() >=anno_min1 ,:]
643
654
  sig_list_toanno = sig_list_toanno.loc[sig_list_toanno["EFFECT_2_aligned"].abs() >=anno_min2 ,:]
@@ -651,22 +662,38 @@ def compare_effect(path1,
651
662
 
652
663
  texts_l=[]
653
664
  texts_r=[]
665
+
666
+ if anno==True:
667
+ log.write("Annotating variants using {}".format("SNPID"), verbose=verbose)
668
+ elif anno=="GENENAME":
669
+ log.write("Annotating variants using {}".format("GENENAME"), verbose=verbose)
670
+
654
671
  for index, row in sig_list_toanno.iterrows():
672
+ log.write("Annotating {}...".format(row), verbose=verbose)
673
+ if anno==True:
674
+ to_anno_text = index
675
+ elif type(anno) is str:
676
+ if not pd.isna(row[anno]):
677
+ to_anno_text = row[anno]
678
+ else:
679
+ to_anno_text = index
680
+
655
681
  if mode=="beta" or mode=="BETA" or mode=="Beta":
656
682
  if row["EFFECT_1"] < row["EFFECT_2_aligned"]:
657
- texts_l.append(plt.text(row["EFFECT_1"], row["EFFECT_2_aligned"],index,ha="right",va="bottom"))
683
+ texts_l.append(plt.text(row["EFFECT_1"], row["EFFECT_2_aligned"],to_anno_text,ha="right",va="bottom"))
658
684
  else:
659
- texts_r.append(plt.text(row["EFFECT_1"], row["EFFECT_2_aligned"],index,ha="left",va="top"))
685
+ texts_r.append(plt.text(row["EFFECT_1"], row["EFFECT_2_aligned"],to_anno_text,ha="left",va="top"))
660
686
  else:
661
687
  if row["OR_1"] < row["OR_2_aligned"]:
662
- texts_l.append(plt.text(row["OR_1"], row["OR_2_aligned"],index, ha='right', va='bottom'))
688
+ texts_l.append(plt.text(row["OR_1"], row["OR_2_aligned"],to_anno_text, ha='right', va='bottom'))
663
689
  else:
664
- texts_r.append(plt.text(row["OR_1"], row["OR_2_aligned"],index, ha='left', va='top'))
665
-
666
- adjust_text(texts_l,autoalign =False,precision =0.001,lim=1000, ha="right",va="bottom", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects=(0.8,0.8) ,arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
667
- adjust_text(texts_r,autoalign =False,precision =0.001,lim=1000, ha="left",va="top", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects =(0.8,0.8),arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
668
-
690
+ texts_r.append(plt.text(row["OR_1"], row["OR_2_aligned"],to_anno_text, ha='left', va='top'))
691
+ if len(texts_l)>0:
692
+ adjust_text(texts_l,autoalign =False,precision =0.001,lim=1000, ha="right",va="bottom", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects=(0.8,0.8) ,arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
693
+ if len(texts_r)>0:
694
+ adjust_text(texts_r,autoalign =False,precision =0.001,lim=1000, ha="left",va="top", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects =(0.8,0.8),arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
669
695
  elif type(anno) is dict:
696
+ sig_list_toanno = sig_list_merged.dropna(axis=0)
670
697
  # if input is a dict
671
698
  sig_list_toanno = sig_list_toanno.loc[sig_list_toanno.index.isin(list(anno.keys())),:]
672
699
  if is_q==True and anno_het == True:
@@ -696,9 +723,10 @@ def compare_effect(path1,
696
723
  texts_l.append(plt.text(row["OR_1"], row["OR_2_aligned"],anno[index], ha='right', va='bottom'))
697
724
  else:
698
725
  texts_r.append(plt.text(row["OR_1"], row["OR_2_aligned"],anno[index], ha='left', va='top'))
699
-
700
- adjust_text(texts_l,autoalign =False,precision =0.001,lim=1000, ha="right",va="bottom", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects=(0.8,0.8) ,arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
701
- adjust_text(texts_r,autoalign =False,precision =0.001,lim=1000, ha="left",va="top", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects =(0.8,0.8),arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
726
+ if len(texts_l)>0:
727
+ adjust_text(texts_l,autoalign =False,precision =0.001,lim=1000, ha="right",va="bottom", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects=(0.8,0.8) ,arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
728
+ if len(texts_r)>0:
729
+ adjust_text(texts_r,autoalign =False,precision =0.001,lim=1000, ha="left",va="top", expand_text=(1,1.8) , expand_objects=(0.1,0.1), expand_points=(1.8,1.8) ,force_objects =(0.8,0.8),arrowprops=dict(arrowstyle='-|>', color='grey'),ax=ax)
702
730
  #################################################################################################################################
703
731
 
704
732
  # plot x=0,y=0, and a 45 degree line
@@ -290,7 +290,11 @@ def plot_miami2(
290
290
 
291
291
 
292
292
  #####################################################################################################################
293
-
293
+ ax1l, ax1r = ax5.get_xlim()
294
+ ax5l, ax5r = ax1.get_xlim()
295
+ ax1.set_xlim([min(ax1l,ax5l), max(ax1r,ax5r)])
296
+ ax5.set_xlim([min(ax1l,ax5l), max(ax1r,ax5r)])
297
+ #####################################################################################################################
294
298
  ax5.set_xlabel("")
295
299
  #ax5.set_xticks(chrom_df)
296
300
  ax5.set_xticklabels([])
@@ -139,7 +139,7 @@ def mqqplot(insumstats,
139
139
  anno_gtf_path=None,
140
140
  anno_adjust=False,
141
141
  anno_max_iter=100,
142
- arm_offset=50,
142
+ arm_offset=None,
143
143
  arm_scale=1,
144
144
  anno_height=1,
145
145
  arm_scale_d=None,
@@ -291,7 +291,7 @@ def mqqplot(insumstats,
291
291
  if maf_bin_colors is None:
292
292
  maf_bin_colors = ["#f0ad4e","#5cb85c", "#5bc0de","#000042"]
293
293
  if save_args is None:
294
- save_args = {"dpi":300,"facecolor":"white"}
294
+ save_args = {"dpi":400,"facecolor":"white"}
295
295
  if highlight is None:
296
296
  highlight = list()
297
297
  if highlight_anno_args is None:
@@ -329,6 +329,20 @@ def mqqplot(insumstats,
329
329
  fig_args["dpi"]=72
330
330
  scatter_args["rasterized"]=True
331
331
  qq_scatter_args["rasterized"]=True
332
+ else:
333
+ fig_args["dpi"] = save_args["dpi"]
334
+
335
+ # configure dpi if saving the plot
336
+ fig_args, scatter_args, qq_scatter_args, save_args = _configure_fig_save_kwargs(save = save,
337
+ fig_args = fig_args,
338
+ scatter_args = scatter_args,
339
+ qq_scatter_args = qq_scatter_args,
340
+ save_args = save_args)
341
+
342
+
343
+ if len(anno_d) > 0 and arm_offset is None:
344
+ # in pixels
345
+ arm_offset = fig_args["dpi"] * repel_force * fig_args["figsize"][0]*0.5
332
346
 
333
347
  log.write("Start to create MQQ plot...{}:".format(_get_version()),verbose=verbose)
334
348
  log.write(" -Genomic coordinates version: {}...".format(build),verbose=verbose)
@@ -401,7 +415,7 @@ def mqqplot(insumstats,
401
415
  if mode=="b":
402
416
  sig_level=1,
403
417
  sig_line=False,
404
- windowsizekb = 100000000
418
+ #windowsizekb = 100000000
405
419
  mode="mb"
406
420
  scatter_args={"marker":"s"}
407
421
  marker_size= (marker_size[1],marker_size[1])
@@ -522,8 +536,12 @@ def mqqplot(insumstats,
522
536
  pos=pos,
523
537
  verbose=verbose,
524
538
  log=log)
539
+
540
+ lines_to_plot = pd.Series(lines_to_plot.to_list() + [bmean, bmedian])
541
+
525
542
  else:
526
543
  bmean, bmedian=0,0
544
+
527
545
  # P value conversion #####################################################################################################
528
546
 
529
547
  # add raw_P and scaled_P
@@ -956,7 +974,7 @@ def mqqplot(insumstats,
956
974
  ax1.set_title(mtitle,fontsize=title_fontsize,family=font_family)
957
975
  log.write("Finished processing figure arts.",verbose=verbose)
958
976
 
959
- # Add annotation arrows and texts
977
+ ## Add annotation arrows and texts
960
978
  log.write("Start to annotate variants...",verbose=verbose)
961
979
  ax1 = annotate_single(
962
980
  sumstats=sumstats,
@@ -1055,7 +1073,8 @@ def mqqplot(insumstats,
1055
1073
  fig.suptitle(title , fontsize = title_fontsize ,x=0.5, y=1.05)
1056
1074
  else:
1057
1075
  fig.suptitle(title , fontsize = title_fontsize, x=0.5,y=1)
1058
-
1076
+ ## Add annotation arrows and texts
1077
+
1059
1078
  # Saving figure
1060
1079
  save_figure(fig = fig, save = save, keyword=mode, save_args=save_args, log = log, verbose=verbose)
1061
1080
 
@@ -1069,7 +1088,31 @@ def mqqplot(insumstats,
1069
1088
 
1070
1089
  ##############################################################################################################################################################################
1071
1090
 
1091
+ def _configure_fig_save_kwargs(save=None,
1092
+ fig_args=None,
1093
+ scatter_args=None,
1094
+ qq_scatter_args=None,
1095
+ save_args=None):
1096
+ if fig_args is None:
1097
+ fig_args = dict()
1098
+ if scatter_args is None:
1099
+ scatter_args = dict()
1100
+ if qq_scatter_args is None:
1101
+ qq_scatter_args = dict()
1102
+ if save_args is None:
1103
+ save_args = dict()
1072
1104
 
1105
+ if save is not None:
1106
+ if type(save) is not bool:
1107
+ if len(save)>3:
1108
+ if save[-3:]=="pdf" or save[-3:]=="svg":
1109
+ # to save as vectorized plot
1110
+ fig_args["dpi"]=72
1111
+ scatter_args["rasterized"]=True
1112
+ qq_scatter_args["rasterized"]=True
1113
+ else:
1114
+ fig_args["dpi"] = save_args["dpi"]
1115
+ return fig_args, scatter_args, qq_scatter_args, save_args
1073
1116
 
1074
1117
 
1075
1118
  def _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats, pos, chrpad, xtight, log, verbose):
@@ -1104,12 +1147,6 @@ def _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats, pos, chrpad, xtight, l
1104
1147
  return ax1
1105
1148
 
1106
1149
 
1107
-
1108
-
1109
-
1110
-
1111
-
1112
-
1113
1150
  ##############################################################################################################################################################################
1114
1151
  def _configure_cols_to_use(insumstats, snpid, chrom, pos, ea, nea, eaf, p, mlog10p,scaled, mode,stratified,anno, anno_set, anno_alias,_chrom_df_for_i,highlight ,pinpoint,density_color):
1115
1152
  usecols=[]
@@ -1287,9 +1324,10 @@ def _process_density(sumstats, mode, bwindowsizekb, chrom, pos, verbose, log):
1287
1324
  else:
1288
1325
  break
1289
1326
  df = pd.DataFrame(stack,columns=["SNPID","TCHR+POS","DENSITY"])
1290
- sumstats["DENSITY"] = df["DENSITY"].values
1291
- bmean=sumstats["DENSITY"].mean()
1292
- bmedian=sumstats["DENSITY"].median()
1327
+ sumstats["DENSITY"] = df["DENSITY"].astype("Float64").values
1328
+
1329
+ bmean=sumstats.drop_duplicates(subset="SNPID")["DENSITY"].mean()
1330
+ bmedian=sumstats.drop_duplicates(subset="SNPID")["DENSITY"].median()
1293
1331
  elif "b" in mode and "DENSITY" in sumstats.columns:
1294
1332
  bmean=sumstats["DENSITY"].mean()
1295
1333
  bmedian=sumstats["DENSITY"].median()
@@ -1305,6 +1343,7 @@ def _process_line(ax1, sig_line, suggestive_sig_line, additional_line, lines_to_
1305
1343
  linestyle="--",
1306
1344
  color=sig_line_color,
1307
1345
  zorder=1)
1346
+
1308
1347
  if suggestive_sig_line is True:
1309
1348
  suggestive_sig_line = ax1.axhline(y=lines_to_plot[1],
1310
1349
  linewidth = sc_linewidth,
@@ -1312,15 +1351,20 @@ def _process_line(ax1, sig_line, suggestive_sig_line, additional_line, lines_to_
1312
1351
  color=suggestive_sig_line_color,
1313
1352
  zorder=1)
1314
1353
  if additional_line is not None:
1315
- for index, level in enumerate(lines_to_plot[2:].values):
1354
+ for index, level in enumerate(lines_to_plot[2:2+len(additional_line)].values):
1316
1355
  ax1.axhline(y=level,
1317
1356
  linewidth = sc_linewidth,
1318
1357
  linestyle="--",
1319
1358
  color=additional_line_color[index%len(additional_line_color)],
1320
1359
  zorder=1)
1321
- if "b" in mode:
1360
+ if "b" in mode:
1361
+ bmean = lines_to_plot.iat[-2]
1362
+ bmedian = lines_to_plot.iat[-1]
1322
1363
  # for brisbane plot, add median and mean line
1364
+ log.write(" -Plotting horizontal line ( mean DENISTY): y = {}".format(bmean),verbose=verbose)
1323
1365
  meanline = ax1.axhline(y=bmean, linewidth = sc_linewidth,linestyle="-",color=sig_line_color,zorder=1000)
1366
+
1367
+ log.write(" -Plotting horizontal line ( median DENISTY): y = {}".format(bmedian),verbose=verbose)
1324
1368
  medianline = ax1.axhline(y=bmedian, linewidth = sc_linewidth,linestyle="--",color=sig_line_color,zorder=1000)
1325
1369
  return ax1
1326
1370
 
@@ -1441,10 +1485,16 @@ def _process_layout(mode, figax, fig_args, mqqratio, region_hspace):
1441
1485
  ax2 = None
1442
1486
  plt.subplots_adjust(hspace=region_hspace)
1443
1487
  elif mode =="b" :
1444
- fig_args["figsize"] = (15,5)
1445
- fig, ax1 = plt.subplots(1, 1,**fig_args)
1446
- ax2 = None
1447
- ax3 = None
1488
+ if figax is not None:
1489
+ fig = figax[0]
1490
+ ax1 = figax[1]
1491
+ ax3 = None
1492
+ ax2 = None
1493
+ else:
1494
+ fig_args["figsize"] = (15,5)
1495
+ fig, ax1 = plt.subplots(1, 1,**fig_args)
1496
+ ax2 = None
1497
+ ax3 = None
1448
1498
  else:
1449
1499
  raise ValueError("Please select one from the 5 modes: mqq/qqm/m/qq/r/b")
1450
1500
  ax4=None
@@ -0,0 +1,260 @@
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ import numpy as np
5
+ import scipy as sp
6
+ from gwaslab.viz_aux_quickfix import _quick_assign_i_with_rank
7
+ from gwaslab.viz_aux_quickfix import _get_largenumber
8
+ from gwaslab.viz_aux_quickfix import _quick_fix_p_value
9
+ from gwaslab.viz_aux_quickfix import _quick_fix_pos
10
+ from gwaslab.viz_aux_quickfix import _quick_fix_chr
11
+ from gwaslab.viz_aux_quickfix import _quick_fix_eaf
12
+ from gwaslab.viz_aux_quickfix import _quick_fix_mlog10p
13
+ from gwaslab.viz_aux_quickfix import _dropna_in_cols
14
+ from gwaslab.viz_plot_mqqplot import _process_p_value
15
+ from gwaslab.viz_plot_mqqplot import _configure_fig_save_kwargs
16
+ from gwaslab.viz_plot_mqqplot import mqqplot
17
+ from gwaslab.viz_aux_save_figure import save_figure
18
+ from gwaslab.g_Log import Log
19
+ import copy
20
+ from gwaslab.bd_common_data import get_chr_to_number
21
+ from gwaslab.bd_common_data import get_number_to_chr
22
+ from gwaslab.g_version import _get_version
23
+
24
+ def _gwheatmap(
25
+ insumstats,
26
+ chrom="CHR",
27
+ pos="POS",
28
+ ref_chrom="REF_CHR",
29
+ ref_pos="REF_START",
30
+ p="P",
31
+ scaled=False,
32
+ sizes = (10,50),
33
+ alpha=0.5,
34
+ mlog10p="MLOG10P",
35
+ snpid="SNPID",
36
+ eaf=None,
37
+ group="CIS/TRANS",
38
+ ea="EA",
39
+ nea="NEA",
40
+ colors=None,
41
+ check = True,
42
+ chr_dict = None,
43
+ xchrpad = 0,
44
+ ychrpad=0,
45
+ use_rank = False,
46
+ xtick_chr_dict=None,
47
+ ytick_chr_dict=None,
48
+ fontsize=10,
49
+ add_b =False,
50
+ log=Log(),
51
+ fig_kwargs=None,
52
+ scatter_kwargs=None,
53
+ height_ratios=None,
54
+ hspace = 0.1,
55
+ font_family="Arial",
56
+ cis_windowsizekb=100,
57
+ verbose=True,
58
+ save=True,
59
+ save_kwargs=None,
60
+ grid_linewidth=1,
61
+ grid_linecolor="grey",
62
+ **mqq_kwargs
63
+ ):
64
+ log.write("Start to create genome-wide scatter plot...{}:".format(_get_version()),verbose=verbose)
65
+ if height_ratios is None:
66
+ height_ratios = [1, 2]
67
+ if xtick_chr_dict is None:
68
+ xtick_chr_dict = get_number_to_chr()
69
+ if ytick_chr_dict is None:
70
+ ytick_chr_dict = get_number_to_chr()
71
+ if chr_dict is None:
72
+ chr_dict = get_chr_to_number()
73
+ if colors is None:
74
+ colors=["#CB132D","#597FBD"]
75
+ if fig_kwargs is None:
76
+ fig_kwargs= dict(figsize=(15,15))
77
+ if save_kwargs is None:
78
+ save_kwargs = {"dpi":300,"facecolor":"white"}
79
+ if scatter_kwargs is None:
80
+ scatter_kwargs = {}
81
+
82
+ fig_kwargs, scatter_kwargs, qq_scatter_args, save_kwargs = _configure_fig_save_kwargs(save=save,
83
+ fig_args = fig_kwargs,
84
+ scatter_args = scatter_kwargs,
85
+ qq_scatter_args = dict(),
86
+ save_args = save_kwargs)
87
+
88
+ sumstats = insumstats.copy()
89
+
90
+ # Data QC and format
91
+ if check ==True:
92
+ sumstats[pos] = _quick_fix_pos(sumstats[pos])
93
+ sumstats[chrom] = _quick_fix_chr(sumstats[chrom], chr_dict=chr_dict)
94
+ sumstats[ref_pos] = _quick_fix_pos(sumstats[ref_pos])
95
+ sumstats[ref_chrom] = _quick_fix_chr(sumstats[ref_chrom], chr_dict=chr_dict)
96
+ sumstats = _dropna_in_cols(sumstats, [pos, chrom, ref_pos, ref_chrom], log=log, verbose=verbose)
97
+
98
+ # dropna
99
+ sumstats = sumstats.sort_values(by=group)
100
+
101
+ if scaled is True:
102
+ sumstats["raw_P"] = pd.to_numeric(sumstats[mlog10p], errors='coerce')
103
+ else:
104
+ sumstats["raw_P"] = sumstats[p].astype("float64")
105
+
106
+ sumstats = _process_p_value(sumstats=sumstats,
107
+ mode="m",
108
+ p=p,
109
+ mlog10p=mlog10p,
110
+ scaled=scaled,
111
+ log=log,
112
+ verbose=verbose )
113
+
114
+
115
+
116
+ if add_b ==False:
117
+ fig, ax1 = plt.subplots(**fig_kwargs)
118
+ else:
119
+ fig, (ax2, ax1) = plt.subplots( nrows=2 ,sharex=True, gridspec_kw={'height_ratios': height_ratios }, **fig_kwargs)
120
+ plt.subplots_adjust(hspace=hspace)
121
+
122
+ ## assign i for variants
123
+ sumstats, chrom_df_x = _quick_assign_i_with_rank(sumstats,
124
+ chrpad=xchrpad,
125
+ use_rank=use_rank,
126
+ chrom=chrom,
127
+ pos=pos,
128
+ verbose=verbose)
129
+ chrom_df_b = chrom_df_x
130
+ sumstats = sumstats.rename(columns={"i":"i_x"})
131
+ add_x_unique = list(sumstats["_ADD"].unique())
132
+
133
+ ## determine grouping methods for Y
134
+ ## assign i for Y group
135
+ sumstats, chrom_df_y = _quick_assign_i_with_rank(sumstats,
136
+ chrpad=ychrpad,
137
+ use_rank=use_rank,
138
+ chrom=ref_chrom,
139
+ pos=ref_pos,
140
+ verbose=verbose)
141
+
142
+ sumstats = sumstats.rename(columns={"i":"i_y"})
143
+ add_y_unique = list(sumstats["_ADD"].unique())
144
+
145
+ if add_b == True:
146
+ sumstats["i"] = sumstats["i_x"]
147
+ fig,log = mqqplot(sumstats,
148
+ chrom=chrom,
149
+ pos=pos,
150
+ p=p,
151
+ mlog10p=mlog10p,
152
+ snpid=snpid,
153
+ scaled=scaled,
154
+ log=log,
155
+ mode="b",
156
+ figax=(fig,ax2),
157
+ _chrom_df_for_i = chrom_df_b,
158
+ _invert=False,
159
+ _if_quick_qc=False,
160
+ **mqq_kwargs
161
+ )
162
+ ##
163
+ #min_xy = min(min(sumstats["i_x"]),min(sumstats["i_y"]))
164
+ #max_xy = max(max(sumstats["i_x"]),max(sumstats["i_y"]))
165
+
166
+ ## determine color
167
+
168
+ ## determine dot size
169
+
170
+ ## plot
171
+ legend = True
172
+ style=None
173
+ linewidth=0
174
+ edgecolor="black"
175
+
176
+ palette = sns.color_palette(colors,n_colors=sumstats[group].nunique())
177
+
178
+ #for index,g in enumerate(sumstats[group].unique()):
179
+ #
180
+ # palette = sns.color_palette("dark:{}".format(colors[index]), as_cmap=True)
181
+ #
182
+ # plot = sns.scatterplot(data=sumstats.loc[sumstats[group]==g,:], x='i_x', y='i_y',
183
+ # hue="scaled_P",
184
+ # palette=palette,
185
+ # size="scaled_P",
186
+ # alpha=alpha,
187
+ # sizes=sizes,
188
+ # legend=legend,
189
+ # style=style,
190
+ # linewidth=linewidth,
191
+ # edgecolor = edgecolor,
192
+ # zorder=2,
193
+ # ax=ax1)
194
+
195
+ plot = sns.scatterplot(data=sumstats, x='i_x', y='i_y',
196
+ hue=group,
197
+ palette=palette,
198
+ size="scaled_P",
199
+ alpha=alpha,
200
+ sizes=sizes,
201
+ legend=legend,
202
+ style=style,
203
+ linewidth=linewidth,
204
+ edgecolor = edgecolor,
205
+ zorder=2,
206
+ ax=ax1)
207
+
208
+ handles, labels = ax1.get_legend_handles_labels()
209
+ new_labels = []
210
+ ncol = len(labels)
211
+ for i in labels:
212
+ if i==group:
213
+ new_labels.append("Group")
214
+ elif i=="scaled_P":
215
+ new_labels.append("$-log_{10}(P)$")
216
+ else:
217
+ new_labels.append(i)
218
+
219
+ ax1.legend(labels = new_labels, handles=handles, loc="lower center", bbox_to_anchor=(.45, -0.17),
220
+ ncol=ncol, scatterpoints=2, title=None, frameon=False)
221
+
222
+ ## add vertical line
223
+ for i in add_x_unique:
224
+ ax1.axvline(x = i+0.5, linewidth = grid_linewidth,color=grid_linecolor,zorder=1000 )
225
+ for i in add_y_unique:
226
+ ax1.axhline(y = i+0.5, linewidth = grid_linewidth,color=grid_linecolor,zorder=1000 )
227
+
228
+
229
+ ## add X tick label
230
+ ax1 = _process_xtick(ax1, chrom_df_x, xtick_chr_dict, fontsize, font_family, log=log,verbose=True)
231
+ ## add Y tick label
232
+ ax1 = _process_ytick(ax1, chrom_df_y, ytick_chr_dict, fontsize, font_family, log=log,verbose=True)
233
+
234
+ ## set x y lim
235
+ ax1.set_ylim([0.5,sumstats["i_y"].max()+1])
236
+ ax1.set_xlim([0.5,sumstats["i_x"].max()+1])
237
+
238
+ ## set x y label
239
+
240
+ xlabel = "pQTL position"
241
+ ax1.set_xlabel(xlabel,fontsize=fontsize,family=font_family)
242
+ ylabel = "location of the gene encoding the target protein"
243
+ ax1.set_ylabel(ylabel,fontsize=fontsize,family=font_family)
244
+
245
+ save_figure(fig = fig, save = save, keyword="gwheatmap", save_args=save_kwargs, log = log, verbose=verbose)
246
+
247
+ return fig, log
248
+
249
+ ################################################################################################################
250
+ def _process_xtick(ax1, chrom_df, xtick_chr_dict, fontsize, font_family, log=Log(),verbose=True):
251
+ log.write(" -Processing X ticks...",verbose=verbose)
252
+ ax1.set_xticks(chrom_df.astype("float64"))
253
+ ax1.set_xticklabels(chrom_df.index.astype("Int64").map(xtick_chr_dict),fontsize=fontsize,family=font_family)
254
+ return ax1
255
+
256
+ def _process_ytick(ax1, chrom_df, ytick_chr_dict, fontsize, font_family, log=Log(),verbose=True):
257
+ log.write(" -Processing Y ticks...",verbose=verbose)
258
+ ax1.set_yticks(chrom_df.astype("float64"))
259
+ ax1.set_yticklabels(chrom_df.index.astype("Int64").map(ytick_chr_dict),fontsize=fontsize,family=font_family)
260
+ return ax1
@@ -95,6 +95,10 @@ def plot_stacked_mqq(objects,
95
95
  if "family" not in title_args.keys():
96
96
  title_args["family"] = "Arial"
97
97
  # create figure and axes ##################################################################################################################
98
+ #
99
+ # subplot_height : subplot height
100
+ # figsize : Width, height in inches
101
+
98
102
  if mode=="r":
99
103
  if len(vcfs)==1:
100
104
  vcfs = vcfs *len(sumstats_list)
@@ -107,14 +111,17 @@ def plot_stacked_mqq(objects,
107
111
  else:
108
112
  height_ratios = [1 for i in range(n_plot_plus_gene_track-1)]+[gene_track_height]
109
113
 
110
- fig_args["figsize"] = [16,subplot_height*n_plot_plus_gene_track]
114
+ if "figsize" not in fig_args.keys():
115
+ fig_args["figsize"] = [16,subplot_height*n_plot_plus_gene_track]
116
+
111
117
  fig, axes = plt.subplots(n_plot_plus_gene_track, 1, sharex=True,
112
118
  gridspec_kw={'height_ratios': height_ratios},
113
119
  **fig_args)
114
120
  plt.subplots_adjust(hspace=region_hspace)
115
121
  elif mode=="m":
116
122
  n_plot = len(sumstats_list)
117
- fig_args["figsize"] = [10,subplot_height*n_plot]
123
+ if "figsize" not in fig_args.keys():
124
+ fig_args["figsize"] = [10,subplot_height*n_plot]
118
125
  fig, axes = plt.subplots(n_plot, 1, sharex=True,
119
126
  gridspec_kw={'height_ratios': [1 for i in range(n_plot)]},
120
127
  **fig_args)
@@ -122,8 +129,8 @@ def plot_stacked_mqq(objects,
122
129
  vcfs = [None for i in range(n_plot)]
123
130
  elif mode=="mqq":
124
131
  n_plot = len(objects)
125
- #
126
- fig_args["figsize"] = [10,subplot_height*n_plot]
132
+ if "figsize" not in fig_args.keys():
133
+ fig_args["figsize"] = [10,subplot_height*n_plot]
127
134
  fig, axes = plt.subplots(n_plot, 2, sharex=True,
128
135
  gridspec_kw={'height_ratios': [1 for i in range(n_plot-1)],
129
136
  'width_ratios':[mqqratio,1]},