gwaslab 3.4.45__py3-none-any.whl → 3.4.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -23,8 +23,8 @@ from gwaslab.viz_aux_reposition_text import adjust_text_position
23
23
  from gwaslab.viz_aux_annotate_plot import annotate_single
24
24
  from gwaslab.viz_plot_qqplot import _plot_qq
25
25
  from gwaslab.hm_harmonize_sumstats import auto_check_vcf_chr_dict
26
- from gwaslab.viz_plot_regionalplot import _plot_regional
27
- from gwaslab.viz_plot_regionalplot import process_vcf
26
+ from gwaslab.viz_plot_regional2 import _plot_regional
27
+ from gwaslab.viz_plot_regional2 import process_vcf
28
28
  from gwaslab.viz_aux_quickfix import _get_largenumber
29
29
  from gwaslab.viz_aux_quickfix import _quick_fix_p_value
30
30
  from gwaslab.viz_aux_quickfix import _quick_fix_pos
@@ -51,6 +51,9 @@ from gwaslab.bd_common_data import get_number_to_chr
51
51
  from gwaslab.bd_common_data import get_recombination_rate
52
52
  from gwaslab.bd_common_data import get_gtf
53
53
  from gwaslab.g_version import _get_version
54
+ from matplotlib.colors import ListedColormap
55
+ from matplotlib.colors import LinearSegmentedColormap
56
+ from matplotlib.colors import to_hex
54
57
  # 20230202 ######################################################################################################
55
58
 
56
59
  def mqqplot(insumstats,
@@ -98,13 +101,13 @@ def mqqplot(insumstats,
98
101
  region_ld_threshold = None,
99
102
  region_ld_legend = True,
100
103
  region_ld_colors = None,
101
- region_ld_colors1 = None,
102
- region_ld_colors2 = None,
104
+ region_ld_colors_m = None,
103
105
  region_recombination = True,
104
106
  region_protein_coding = True,
105
107
  region_flank_factor = 0.05,
106
108
  region_anno_bbox_args = None,
107
- cbar_title='LD $r^{2}$',
109
+ region_marker_shapes=None,
110
+ cbar_title='LD $r^{2}$ with variant',
108
111
  cbar_fontsize = None,
109
112
  cbar_font_family = None,
110
113
  track_n=4,
@@ -133,6 +136,7 @@ def mqqplot(insumstats,
133
136
  anno_style="right",
134
137
  anno_fixed_arm_length=None,
135
138
  anno_source = "ensembl",
139
+ anno_gtf_path=None,
136
140
  anno_adjust=False,
137
141
  anno_max_iter=100,
138
142
  arm_offset=50,
@@ -237,20 +241,39 @@ def mqqplot(insumstats,
237
241
  anno_args=dict()
238
242
  if colors is None:
239
243
  colors=["#597FBD","#74BAD3"]
240
- if region_ref2 is not None:
241
- region_ref_second = copy.copy(region_ref2),
244
+
245
+ if region is not None:
246
+ if marker_size == (5,20):
247
+ marker_size=(45,65)
248
+
249
+ # make region_ref a list of ref variants
250
+ if pd.api.types.is_list_like(region_ref):
251
+ if len(region_ref) == 0 :
252
+ region_ref.append(None)
253
+ if region_ref_second is not None:
254
+ region_ref.append(region_ref_second)
255
+ else:
256
+ region_ref = [region_ref]
257
+ if region_ref_second is not None:
258
+ region_ref.append(region_ref_second)
259
+ region_ref_index_dic = {value: index for index,value in enumerate(region_ref)}
260
+
261
+ if region_marker_shapes is None:
262
+ # 9 shapes
263
+ region_marker_shapes = ['o', 's','^','D','*','P','X','h','8']
242
264
  if region_grid_line is None:
243
265
  region_grid_line = {"linewidth": 2,"linestyle":"--"}
244
266
  if region_lead_grid_line is None:
245
267
  region_lead_grid_line = {"alpha":0.5,"linewidth" : 2,"linestyle":"--","color":"#FF0000"}
246
268
  if region_ld_threshold is None:
247
269
  region_ld_threshold = [0.2,0.4,0.6,0.8]
270
+
248
271
  if region_ld_colors is None:
249
272
  region_ld_colors = ["#E4E4E4","#020080","#86CEF9","#24FF02","#FDA400","#FF0000","#FF0000"]
250
- if region_ld_colors1 is None:
251
- region_ld_colors1 = ["#E4E4E4","#F8CFCF","#F5A2A5","#F17474","#EB4445","#E51819","#E51819"]
252
- if region_ld_colors2 is None:
253
- region_ld_colors2 = ["#E4E4E4","#D8E2F2","#AFCBE3","#86B3D4","#5D98C4","#367EB7","#367EB7"]
273
+
274
+ # 7 colors
275
+ region_ld_colors_m = ["#E51819","#367EB7","green","#F07818","#AD5691","yellow","purple"]
276
+
254
277
  if region_title_args is None:
255
278
  region_title_args = {"size":10}
256
279
  if cbar_fontsize is None:
@@ -352,6 +375,7 @@ def mqqplot(insumstats,
352
375
  lines_to_plot = -np.log10(lines_to_plot)
353
376
 
354
377
  vcf_chr_dict = auto_check_vcf_chr_dict(vcf_path, vcf_chr_dict, verbose, log)
378
+
355
379
 
356
380
  # Plotting mode selection : layout ####################################################################
357
381
  # ax1 : manhattanplot / brisbane plot
@@ -434,7 +458,7 @@ def mqqplot(insumstats,
434
458
  region_chr = region[0]
435
459
  region_start = region[1]
436
460
  region_end = region[2]
437
- marker_size=(25,45)
461
+
438
462
  log.write(" -Extract SNPs in region : chr{}:{}-{}...".format(region_chr, region[1], region[2]),verbose=verbose)
439
463
 
440
464
  in_region_snp = (sumstats[chrom]==region_chr) & (sumstats[pos]<region_end) & (sumstats[pos]>region_start)
@@ -547,7 +571,6 @@ def mqqplot(insumstats,
547
571
  vcf_path=vcf_path,
548
572
  region=region,
549
573
  region_ref=region_ref,
550
- region_ref_second=region_ref_second,
551
574
  log=log ,
552
575
  pos=pos,
553
576
  ea=ea,
@@ -588,17 +611,27 @@ def mqqplot(insumstats,
588
611
  if vcf_path is not None:
589
612
  legend=None
590
613
  linewidth=1
591
- palette = { i:region_ld_colors[i] for i in range(len(region_ld_colors))}
592
- if region_ref_second is not None:
593
- palette = {}
594
- for i in range(len(region_ld_colors)):
595
- palette[i]=region_ld_colors1[i]
596
- palette[100+i]=region_ld_colors2[i]
614
+ if len(region_ref) == 1:
615
+ palette = {100+i:region_ld_colors[i] for i in range(len(region_ld_colors))}
616
+ else:
617
+ palette = {}
618
+ region_color_maps = []
619
+ for group_index, colorgroup in enumerate(region_ld_colors_m):
620
+ color_map_len = len(region_ld_threshold)+2 # default 6
621
+ rgba = LinearSegmentedColormap.from_list("custom", ["white",colorgroup], color_map_len)(range(1,color_map_len)) # skip white
622
+ output_hex_colors=[]
623
+ for i in range(len(rgba)):
624
+ output_hex_colors.append(to_hex(rgba[i]))
625
+ # 1 + 5 + 1
626
+ region_ld_colors_single = [region_ld_colors[0]] + output_hex_colors + [output_hex_colors[-1]]
627
+ region_color_maps.append(region_ld_colors_single)
628
+ # gradient colors
629
+ for i, hex_colors in enumerate(region_color_maps):
630
+ for j, hex_color in enumerate(hex_colors):
631
+ palette[(i+1)*100 + j ] = hex_color
632
+
597
633
  edgecolor="none"
598
- if sumstats["SHAPE"].nunique() >1:
599
- scatter_args["markers"]=['o', 's']
600
- else:
601
- scatter_args["markers"]=['o']
634
+ scatter_args["markers"]= region_marker_shapes[:len(region_ref)]
602
635
  style="SHAPE"
603
636
 
604
637
 
@@ -648,6 +681,7 @@ def mqqplot(insumstats,
648
681
 
649
682
  ## if not highlight
650
683
  else:
684
+ ## density plot
651
685
  if density_color == True:
652
686
  hue = "DENSITY_hue"
653
687
  s = "DENSITY"
@@ -675,6 +709,7 @@ def mqqplot(insumstats,
675
709
  linewidth=linewidth,
676
710
  zorder=2,ax=ax1,edgecolor=edgecolor,**scatter_args)
677
711
  else:
712
+ # major / regional
678
713
  s = "s"
679
714
  hue = 'chr_hue'
680
715
  hue_norm=None
@@ -719,7 +754,7 @@ def mqqplot(insumstats,
719
754
  # if regional plot : pinpoint lead , add color bar ##################################################
720
755
  if (region is not None) and ("r" in mode):
721
756
 
722
- ax1, ax3, ax4, cbar, lead_snp_i, lead_snp_i2 =_plot_regional(
757
+ ax1, ax3, ax4, cbar, lead_snp_is, lead_snp_is_color =_plot_regional(
723
758
  sumstats=sumstats,
724
759
  fig=fig,
725
760
  ax1=ax1,
@@ -743,8 +778,8 @@ def mqqplot(insumstats,
743
778
  rr_ylabel=rr_ylabel,
744
779
  mode=mode,
745
780
  region_step = region_step,
746
- region_ref=region_ref,
747
- region_ref_second=region_ref_second,
781
+ region_ref = region_ref,
782
+ region_ref_index_dic = region_ref_index_dic,
748
783
  region_grid = region_grid,
749
784
  region_grid_line = region_grid_line,
750
785
  region_lead_grid = region_lead_grid,
@@ -755,8 +790,8 @@ def mqqplot(insumstats,
755
790
  region_ld_legend = region_ld_legend,
756
791
  region_ld_threshold = region_ld_threshold,
757
792
  region_ld_colors = region_ld_colors,
758
- region_ld_colors1=region_ld_colors1,
759
- region_ld_colors2=region_ld_colors2,
793
+ palette = palette,
794
+ region_marker_shapes = region_marker_shapes,
760
795
  region_recombination = region_recombination,
761
796
  region_protein_coding=region_protein_coding,
762
797
  region_flank_factor =region_flank_factor,
@@ -770,8 +805,8 @@ def mqqplot(insumstats,
770
805
  )
771
806
 
772
807
  else:
773
- lead_snp_i= None
774
- lead_snp_i2=None
808
+ lead_snp_is =[]
809
+ lead_snp_is_color = []
775
810
 
776
811
  log.write("Finished creating MQQ plot successfully!",verbose=verbose)
777
812
 
@@ -816,6 +851,7 @@ def mqqplot(insumstats,
816
851
  log=log,
817
852
  build=build,
818
853
  source=anno_source,
854
+ gtf_path=anno_gtf_path,
819
855
  verbose=verbose).rename(columns={"GENE":"Annotation"})
820
856
  log.write("Finished extracting variants for annotation...",verbose=verbose)
821
857
 
@@ -882,8 +918,8 @@ def mqqplot(insumstats,
882
918
  # regional plot cbar
883
919
  if cbar is not None:
884
920
  cbar = _process_cbar(cbar,
885
- cbar_fontsize=fontsize,
886
- cbar_font_family=font_family,
921
+ cbar_fontsize=cbar_fontsize,
922
+ cbar_font_family=cbar_font_family,
887
923
  cbar_title=cbar_title,
888
924
  log=log,
889
925
  verbose=verbose)
@@ -1019,7 +1055,7 @@ def mqqplot(insumstats,
1019
1055
  garbage_collect.collect()
1020
1056
  # Return matplotlib figure object #######################################################################################
1021
1057
  if _get_region_lead==True:
1022
- return fig, log, lead_snp_i, lead_snp_i2
1058
+ return fig, log, lead_snp_is, lead_snp_is_color
1023
1059
 
1024
1060
  log.write("Finished creating plot successfully!",verbose=verbose)
1025
1061
  return fig, log
@@ -1269,15 +1305,19 @@ def _process_line(ax1, sig_line, suggestive_sig_line, additional_line, lines_to_
1269
1305
 
1270
1306
  def _process_cbar(cbar, cbar_fontsize, cbar_font_family, cbar_title, log=Log(),verbose=True):
1271
1307
  log.write(" -Processing color bar...",verbose=verbose)
1272
- if type(cbar) == list:
1273
- for cbar_single in cbar:
1274
- cbar_yticklabels = cbar_single.ax.get_yticklabels()
1275
- cbar_single.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1276
- cbar_single.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1277
- else:
1278
- cbar_yticklabels = cbar.ax.get_yticklabels()
1279
- cbar.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1280
- cbar.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1308
+ #if type(cbar) == list:
1309
+ # for cbar_single in cbar:
1310
+ # cbar_yticklabels = cbar_single.ax.get_yticklabels()
1311
+ # cbar_single.ax.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1312
+ # cbar_single.ax.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center",y=-0.2 )
1313
+ #else:
1314
+
1315
+ cbar_yticklabels = cbar.get_yticklabels()
1316
+ cbar.set_yticklabels(cbar_yticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1317
+ cbar_xticklabels = cbar.get_xticklabels()
1318
+ cbar.set_xticklabels(cbar_xticklabels, fontsize=cbar_fontsize, family=cbar_font_family )
1319
+
1320
+ cbar.set_title(cbar_title, fontsize=cbar_fontsize, family=cbar_font_family, loc="center", y=1.00 )
1281
1321
  return cbar
1282
1322
 
1283
1323
  def _process_xtick(ax1, chrom_df, xtick_chr_dict, fontsize, font_family, log=Log(),verbose=True):