gwaslab 3.4.46__py3-none-any.whl → 3.4.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -23,8 +23,8 @@ from gwaslab.viz_aux_reposition_text import adjust_text_position
23
23
  from gwaslab.viz_aux_annotate_plot import annotate_single
24
24
  from gwaslab.viz_plot_qqplot import _plot_qq
25
25
  from gwaslab.hm_harmonize_sumstats import auto_check_vcf_chr_dict
26
- from gwaslab.viz_plot_regionalplot import _plot_regional
27
- from gwaslab.viz_plot_regionalplot import process_vcf
26
+ from gwaslab.viz_plot_regional2 import _plot_regional
27
+ from gwaslab.viz_plot_regional2 import process_vcf
28
28
  from gwaslab.viz_aux_quickfix import _get_largenumber
29
29
  from gwaslab.viz_aux_quickfix import _quick_fix_p_value
30
30
  from gwaslab.viz_aux_quickfix import _quick_fix_pos
@@ -51,6 +51,9 @@ from gwaslab.bd_common_data import get_number_to_chr
51
51
  from gwaslab.bd_common_data import get_recombination_rate
52
52
  from gwaslab.bd_common_data import get_gtf
53
53
  from gwaslab.g_version import _get_version
54
+ from matplotlib.colors import ListedColormap
55
+ from matplotlib.colors import LinearSegmentedColormap
56
+ from matplotlib.colors import to_hex
54
57
  # 20230202 ######################################################################################################
55
58
 
56
59
  def mqqplot(insumstats,
@@ -98,13 +101,13 @@ def mqqplot(insumstats,
98
101
  region_ld_threshold = None,
99
102
  region_ld_legend = True,
100
103
  region_ld_colors = None,
101
- region_ld_colors1 = None,
102
- region_ld_colors2 = None,
104
+ region_ld_colors_m = None,
103
105
  region_recombination = True,
104
106
  region_protein_coding = True,
105
107
  region_flank_factor = 0.05,
106
108
  region_anno_bbox_args = None,
107
- cbar_title='LD $r^{2}$',
109
+ region_marker_shapes=None,
110
+ cbar_title='LD $r^{2}$ with variant',
108
111
  cbar_fontsize = None,
109
112
  cbar_font_family = None,
110
113
  track_n=4,
@@ -138,6 +141,7 @@ def mqqplot(insumstats,
138
141
  anno_max_iter=100,
139
142
  arm_offset=50,
140
143
  arm_scale=1,
144
+ anno_height=1,
141
145
  arm_scale_d=None,
142
146
  cut=0,
143
147
  skip=0,
@@ -177,6 +181,7 @@ def mqqplot(insumstats,
177
181
  xpad=None,
178
182
  xpadl=None,
179
183
  xpadr=None,
184
+ xtight=False,
180
185
  chrpad=0.03,
181
186
  drop_chr_start=False,
182
187
  title =None,
@@ -238,20 +243,39 @@ def mqqplot(insumstats,
238
243
  anno_args=dict()
239
244
  if colors is None:
240
245
  colors=["#597FBD","#74BAD3"]
241
- if region_ref2 is not None:
242
- region_ref_second = copy.copy(region_ref2),
246
+
247
+ if region is not None:
248
+ if marker_size == (5,20):
249
+ marker_size=(45,65)
250
+
251
+ # make region_ref a list of ref variants
252
+ if pd.api.types.is_list_like(region_ref):
253
+ if len(region_ref) == 0 :
254
+ region_ref.append(None)
255
+ if region_ref_second is not None:
256
+ region_ref.append(region_ref_second)
257
+ else:
258
+ region_ref = [region_ref]
259
+ if region_ref_second is not None:
260
+ region_ref.append(region_ref_second)
261
+ region_ref_index_dic = {value: index for index,value in enumerate(region_ref)}
262
+
263
+ if region_marker_shapes is None:
264
+ # 9 shapes
265
+ region_marker_shapes = ['o', 's','^','D','*','P','X','h','8']
243
266
  if region_grid_line is None:
244
267
  region_grid_line = {"linewidth": 2,"linestyle":"--"}
245
268
  if region_lead_grid_line is None:
246
269
  region_lead_grid_line = {"alpha":0.5,"linewidth" : 2,"linestyle":"--","color":"#FF0000"}
247
270
  if region_ld_threshold is None:
248
271
  region_ld_threshold = [0.2,0.4,0.6,0.8]
272
+
249
273
  if region_ld_colors is None:
250
274
  region_ld_colors = ["#E4E4E4","#020080","#86CEF9","#24FF02","#FDA400","#FF0000","#FF0000"]
251
- if region_ld_colors1 is None:
252
- region_ld_colors1 = ["#E4E4E4","#F8CFCF","#F5A2A5","#F17474","#EB4445","#E51819","#E51819"]
253
- if region_ld_colors2 is None:
254
- region_ld_colors2 = ["#E4E4E4","#D8E2F2","#AFCBE3","#86B3D4","#5D98C4","#367EB7","#367EB7"]
275
+
276
+ # 7 colors
277
+ region_ld_colors_m = ["#E51819","#367EB7","green","#F07818","#AD5691","yellow","purple"]
278
+
255
279
  if region_title_args is None:
256
280
  region_title_args = {"size":10}
257
281
  if cbar_fontsize is None:
@@ -353,6 +377,7 @@ def mqqplot(insumstats,
353
377
  lines_to_plot = -np.log10(lines_to_plot)
354
378
 
355
379
  vcf_chr_dict = auto_check_vcf_chr_dict(vcf_path, vcf_chr_dict, verbose, log)
380
+
356
381
 
357
382
  # Plotting mode selection : layout ####################################################################
358
383
  # ax1 : manhattanplot / brisbane plot
@@ -435,7 +460,7 @@ def mqqplot(insumstats,
435
460
  region_chr = region[0]
436
461
  region_start = region[1]
437
462
  region_end = region[2]
438
- marker_size=(25,45)
463
+
439
464
  log.write(" -Extract SNPs in region : chr{}:{}-{}...".format(region_chr, region[1], region[2]),verbose=verbose)
440
465
 
441
466
  in_region_snp = (sumstats[chrom]==region_chr) & (sumstats[pos]<region_end) & (sumstats[pos]>region_start)
@@ -529,7 +554,8 @@ def mqqplot(insumstats,
529
554
  cut_log = cut_log,
530
555
  verbose =verbose,
531
556
  lines_to_plot=lines_to_plot,
532
- log = log)
557
+ log = log
558
+ )
533
559
  except:
534
560
  log.warning("No valid data! Please check the input.")
535
561
  return None
@@ -548,7 +574,6 @@ def mqqplot(insumstats,
548
574
  vcf_path=vcf_path,
549
575
  region=region,
550
576
  region_ref=region_ref,
551
- region_ref_second=region_ref_second,
552
577
  log=log ,
553
578
  pos=pos,
554
579
  ea=ea,
@@ -574,35 +599,48 @@ def mqqplot(insumstats,
574
599
  sumstats.loc[sumstats["scaled_P"]>-np.log10(sig_level_plot),"s"]=4
575
600
  sumstats["chr_hue"]=sumstats[chrom].astype("string")
576
601
 
577
- if vcf_path is not None:
602
+ if "r" in mode:
603
+ if vcf_path is None:
604
+ sumstats["LD"]=100
605
+ sumstats["SHAPE"]=1
578
606
  sumstats["chr_hue"]=sumstats["LD"]
607
+
579
608
  ## default seetings
580
609
 
581
610
  palette = sns.color_palette(colors,n_colors=sumstats[chrom].nunique())
582
-
583
611
 
584
612
  legend = None
585
613
  style=None
586
614
  linewidth=0
587
615
  edgecolor="black"
588
616
  # if regional plot assign colors
589
- if vcf_path is not None:
617
+ if "r" in mode:
618
+ #if vcf_path is not None:
590
619
  legend=None
591
620
  linewidth=1
592
- palette = { i:region_ld_colors[i] for i in range(len(region_ld_colors))}
593
- if region_ref_second is not None:
594
- palette = {}
595
- for i in range(len(region_ld_colors)):
596
- palette[i]=region_ld_colors1[i]
597
- palette[100+i]=region_ld_colors2[i]
621
+ if len(region_ref) == 1:
622
+ palette = {100+i:region_ld_colors[i] for i in range(len(region_ld_colors))}
623
+ else:
624
+ palette = {}
625
+ region_color_maps = []
626
+ for group_index, colorgroup in enumerate(region_ld_colors_m):
627
+ color_map_len = len(region_ld_threshold)+2 # default 6
628
+ rgba = LinearSegmentedColormap.from_list("custom", ["white",colorgroup], color_map_len)(range(1,color_map_len)) # skip white
629
+ output_hex_colors=[]
630
+ for i in range(len(rgba)):
631
+ output_hex_colors.append(to_hex(rgba[i]))
632
+ # 1 + 5 + 1
633
+ region_ld_colors_single = [region_ld_colors[0]] + output_hex_colors + [output_hex_colors[-1]]
634
+ region_color_maps.append(region_ld_colors_single)
635
+ # gradient colors
636
+ for i, hex_colors in enumerate(region_color_maps):
637
+ for j, hex_color in enumerate(hex_colors):
638
+ palette[(i+1)*100 + j ] = hex_color
639
+
598
640
  edgecolor="none"
599
- if sumstats["SHAPE"].nunique() >1:
600
- scatter_args["markers"]=['o', 's']
601
- else:
602
- scatter_args["markers"]=['o']
641
+ scatter_args["markers"]= {(i+1):m for i,m in enumerate(region_marker_shapes[:len(region_ref)])}
603
642
  style="SHAPE"
604
-
605
-
643
+
606
644
 
607
645
  ## if highlight
608
646
  highlight_i = pd.DataFrame()
@@ -649,6 +687,7 @@ def mqqplot(insumstats,
649
687
 
650
688
  ## if not highlight
651
689
  else:
690
+ ## density plot
652
691
  if density_color == True:
653
692
  hue = "DENSITY_hue"
654
693
  s = "DENSITY"
@@ -676,6 +715,7 @@ def mqqplot(insumstats,
676
715
  linewidth=linewidth,
677
716
  zorder=2,ax=ax1,edgecolor=edgecolor,**scatter_args)
678
717
  else:
718
+ # major / regional
679
719
  s = "s"
680
720
  hue = 'chr_hue'
681
721
  hue_norm=None
@@ -720,7 +760,7 @@ def mqqplot(insumstats,
720
760
  # if regional plot : pinpoint lead , add color bar ##################################################
721
761
  if (region is not None) and ("r" in mode):
722
762
 
723
- ax1, ax3, ax4, cbar, lead_snp_i, lead_snp_i2 =_plot_regional(
763
+ ax1, ax3, ax4, cbar, lead_snp_is, lead_snp_is_color =_plot_regional(
724
764
  sumstats=sumstats,
725
765
  fig=fig,
726
766
  ax1=ax1,
@@ -744,8 +784,8 @@ def mqqplot(insumstats,
744
784
  rr_ylabel=rr_ylabel,
745
785
  mode=mode,
746
786
  region_step = region_step,
747
- region_ref=region_ref,
748
- region_ref_second=region_ref_second,
787
+ region_ref = region_ref,
788
+ region_ref_index_dic = region_ref_index_dic,
749
789
  region_grid = region_grid,
750
790
  region_grid_line = region_grid_line,
751
791
  region_lead_grid = region_lead_grid,
@@ -756,8 +796,8 @@ def mqqplot(insumstats,
756
796
  region_ld_legend = region_ld_legend,
757
797
  region_ld_threshold = region_ld_threshold,
758
798
  region_ld_colors = region_ld_colors,
759
- region_ld_colors1=region_ld_colors1,
760
- region_ld_colors2=region_ld_colors2,
799
+ palette = palette,
800
+ region_marker_shapes = region_marker_shapes,
761
801
  region_recombination = region_recombination,
762
802
  region_protein_coding=region_protein_coding,
763
803
  region_flank_factor =region_flank_factor,
@@ -771,8 +811,8 @@ def mqqplot(insumstats,
771
811
  )
772
812
 
773
813
  else:
774
- lead_snp_i= None
775
- lead_snp_i2=None
814
+ lead_snp_is =[]
815
+ lead_snp_is_color = []
776
816
 
777
817
  log.write("Finished creating MQQ plot successfully!",verbose=verbose)
778
818
 
@@ -884,8 +924,8 @@ def mqqplot(insumstats,
884
924
  # regional plot cbar
885
925
  if cbar is not None:
886
926
  cbar = _process_cbar(cbar,
887
- cbar_fontsize=fontsize,
888
- cbar_font_family=font_family,
927
+ cbar_fontsize=cbar_fontsize,
928
+ cbar_font_family=cbar_font_family,
889
929
  cbar_title=cbar_title,
890
930
  log=log,
891
931
  verbose=verbose)
@@ -943,6 +983,7 @@ def mqqplot(insumstats,
943
983
  region=region,
944
984
  region_anno_bbox_args=region_anno_bbox_args,
945
985
  skip=skip,
986
+ anno_height=anno_height,
946
987
  snpid=snpid,
947
988
  chrom=chrom,
948
989
  pos=pos,
@@ -1006,7 +1047,7 @@ def mqqplot(insumstats,
1006
1047
  if "qq" in mode:
1007
1048
  ax2.set_ylim(ylim)
1008
1049
 
1009
- ax1 = _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats)
1050
+ ax1 = _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats, pos, chrpad, xtight, log = log, verbose=verbose)
1010
1051
 
1011
1052
  # Titles
1012
1053
  if title and anno and len(to_annotate)>0:
@@ -1021,7 +1062,7 @@ def mqqplot(insumstats,
1021
1062
  garbage_collect.collect()
1022
1063
  # Return matplotlib figure object #######################################################################################
1023
1064
  if _get_region_lead==True:
1024
- return fig, log, lead_snp_i, lead_snp_i2
1065
+ return fig, log, lead_snp_is, lead_snp_is_color
1025
1066
 
1026
1067
  log.write("Finished creating plot successfully!",verbose=verbose)
1027
1068
  return fig, log
@@ -1031,20 +1072,34 @@ def mqqplot(insumstats,
1031
1072
 
1032
1073
 
1033
1074
 
1034
- def _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats):
1075
+ def _add_pad_to_x_axis(ax1, xpad, xpadl, xpadr, sumstats, pos, chrpad, xtight, log, verbose):
1035
1076
 
1036
- if ax1 is not None:
1037
- xmin, xmax = ax1.get_xlim()
1038
-
1039
- if xpad is not None:
1040
- pad = xpad* sumstats["i"].max()
1041
- ax1.set_xlim([xmin - pad, xmin + pad])
1042
- if xpadl is not None:
1043
- pad = xpadl* sumstats["i"].max()
1044
- ax1.set_xlim([xmin - pad,xmax])
1045
- if xpadr is not None:
1046
- pad = xpadr* sumstats["i"].max()
1047
- ax1.set_xlim([xmin, xmax + pad])
1077
+ if xtight==True:
1078
+ log.write(" -Adjusting X padding on both side : tight mode", verbose=verbose)
1079
+ xmax = sumstats["i"].max()
1080
+ xmin= sumstats["i"].min()
1081
+ ax1.set_xlim([xmin, xmax])
1082
+
1083
+ else:
1084
+ chrpad_to_remove = sumstats[pos].max()*chrpad
1085
+ if ax1 is not None:
1086
+ xmin, xmax = ax1.get_xlim()
1087
+ length = xmax - xmin
1088
+
1089
+ if xpad is not None:
1090
+ log.write(" -Adjusting X padding on both side: {}".format(xpad), verbose=verbose)
1091
+ pad = xpad* length #sumstats["i"].max()
1092
+ ax1.set_xlim([xmin - pad + chrpad_to_remove, xmax + pad - chrpad_to_remove])
1093
+ if xpad is None and xpadl is not None:
1094
+ log.write(" -Adjusting X padding on left side: {}".format(xpadl), verbose=verbose)
1095
+ xmin, xmax = ax1.get_xlim()
1096
+ pad = xpadl*length # sumstats["i"].max()
1097
+ ax1.set_xlim([xmin - pad + chrpad_to_remove ,xmax])
1098
+ if xpad is None and xpadr is not None:
1099
+ log.write(" -Adjusting X padding on right side: {}".format(xpadr), verbose=verbose)
1100
+ xmin, xmax = ax1.get_xlim()
1101
+ pad = xpadr*length # sumstats["i"].max()
1102
+ ax1.set_xlim([xmin, xmax + pad - chrpad_to_remove])
1048
1103
 
1049
1104
  return ax1
1050
1105
 
@@ -1271,15 +1326,19 @@ def _process_line(ax1, sig_line, suggestive_sig_line, additional_line, lines_to_
1271
1326
 
1272
1327
  def _process_cbar(cbar, cbar_fontsize, cbar_font_family, cbar_title, log=Log(),verbose=True):
1273
1328
  log.write(" -Processing color bar...",verbose=verbose)
1274
- if type(cbar) == list:
1275
- for cbar_single in cbar:
1276
- cbar_yticklabels = cbar_single.ax.get_yticklabels()
1277
- cbar_single.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1278
- cbar_single.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1279
- else:
1280
- cbar_yticklabels = cbar.ax.get_yticklabels()
1281
- cbar.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1282
- cbar.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1329
+ #if type(cbar) == list:
1330
+ # for cbar_single in cbar:
1331
+ # cbar_yticklabels = cbar_single.ax.get_yticklabels()
1332
+ # cbar_single.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1333
+ # cbar_single.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1334
+ #else:
1335
+
1336
+ cbar_yticklabels = cbar.get_yticklabels()
1337
+ cbar.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1338
+ cbar_xticklabels = cbar.get_xticklabels()
1339
+ cbar.set_xticklabels(cbar_xticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1340
+
1341
+ cbar.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center", y=1.00 )
1283
1342
  return cbar
1284
1343
 
1285
1344
  def _process_xtick(ax1, chrom_df, xtick_chr_dict, fontsize, font_family, log=Log(),verbose=True):