gwaslab 3.4.46__py3-none-any.whl → 3.4.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -0,0 +1,838 @@
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.ticker as ticker
4
+ import matplotlib.patches as patches
5
+ import seaborn as sns
6
+ import numpy as np
7
+ import copy
8
+ import re
9
+ import scipy as sp
10
+ from pyensembl import EnsemblRelease
11
+ from allel import GenotypeArray
12
+ from allel import read_vcf
13
+ from allel import rogers_huff_r_between
14
+ import matplotlib as mpl
15
+ from scipy import stats
16
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
17
+ from mpl_toolkits.axes_grid1.inset_locator import mark_inset
18
+ from adjustText import adjust_text
19
+ from gtfparse import read_gtf
20
+ from gwaslab.g_Log import Log
21
+ from gwaslab.bd_common_data import get_chr_to_number
22
+ from gwaslab.bd_common_data import get_number_to_chr
23
+ from gwaslab.bd_common_data import get_recombination_rate
24
+ from gwaslab.bd_common_data import get_gtf
25
+ from matplotlib.colors import ListedColormap
26
+ from matplotlib.colors import LinearSegmentedColormap
27
+ import matplotlib.colors
28
+ from matplotlib.colors import Normalize
29
+ from matplotlib.patches import Rectangle
30
+
31
+ def _plot_regional(
32
+ sumstats,
33
+ fig,
34
+ ax1,
35
+ ax3,
36
+ region,
37
+ vcf_path,
38
+ marker_size,
39
+ fontsize,
40
+ build,
41
+ chrom_df,
42
+ xtick_chr_dict,
43
+ cut_line_color,
44
+ vcf_chr_dict = None,
45
+ gtf_path="default",
46
+ gtf_chr_dict = get_number_to_chr(),
47
+ gtf_gene_name=None,
48
+ rr_path="default",
49
+ rr_header_dict=None,
50
+ rr_chr_dict = get_number_to_chr(),
51
+ rr_lim = (0,100),
52
+ rr_ylabel = True,
53
+ rr_title=None,
54
+ region_ld_legend=True,
55
+ region_title=None,
56
+ mode="mqq",
57
+ region_step = 21,
58
+ region_ref=None,
59
+ region_ref_index_dic = None,
60
+ #region_ref_second=None,
61
+ region_grid = False,
62
+ region_grid_line = {"linewidth": 2,"linestyle":"--"},
63
+ region_lead_grid = True,
64
+ region_lead_grid_line = {"alpha":0.5,"linewidth" : 2,"linestyle":"--","color":"#FF0000"},
65
+ region_title_args = None,
66
+ region_hspace=0.02,
67
+ region_ld_threshold = [0.2,0.4,0.6,0.8],
68
+ region_ld_colors = ["#E4E4E4","#020080","#86CEF9","#24FF02","#FDA400","#FF0000","#FF0000"],
69
+ region_marker_shapes=None,
70
+ palette=None,
71
+ region_recombination = True,
72
+ region_protein_coding=True,
73
+ region_flank_factor = 0.05,
74
+ track_font_family="Arial",
75
+ taf=[4,0,0.95,1,1],
76
+ # track_n, track_n_offset,font_ratio,exon_ratio,text_offset
77
+ tabix=None,
78
+ chrom="CHR",
79
+ pos="POS",
80
+ verbose=True,
81
+ log=Log()
82
+ ):
83
+
84
+ # x axix: use i to plot (there is a gap between i and pos)
85
+
86
+
87
+ # if regional plot : pinpoint lead , add color bar ##################################################
88
+ if (region is not None) :
89
+ # pinpoint lead
90
+ lead_ids = []
91
+
92
+ for index, region_ref_single in enumerate(region_ref):
93
+ ax1, lead_id_single = _pinpoint_lead(sumstats = sumstats,
94
+ ax1 = ax1,
95
+ region_ref=region_ref_single,
96
+ lead_color = palette[(index+1)*100 + len(region_ld_threshold)+2],
97
+ marker_size= marker_size,
98
+ region_marker_shapes=region_marker_shapes,
99
+ log=log,verbose=verbose)
100
+ #if lead_id_single is not None:
101
+ lead_ids.append(lead_id_single)
102
+
103
+ # update region_ref to variant rsID or variantID / skip NAs
104
+ new_region_ref = []
105
+ for i in range(len(lead_ids)):
106
+ if lead_ids[i] is None:
107
+ new_region_ref.append(region_ref[i])
108
+ continue
109
+ if region_ref[i] is None:
110
+ if "SNPID" in sumstats.columns:
111
+ new_name = sumstats.loc[lead_ids[i],"SNPID"]
112
+ elif "rsID" in sumstats.columns:
113
+ new_name = sumstats.loc[lead_ids[i],"rsID"]
114
+ else:
115
+ new_name = "chr{}:{}".format(sumstats.loc[lead_ids[i],"CHR"] , sumstats.loc[lead_ids[i],"POS"])
116
+ new_region_ref.append(new_name)
117
+ region_ref_index_dic[new_name] = region_ref_index_dic[region_ref[i]]
118
+ continue
119
+ else:
120
+ new_region_ref.append(region_ref[i])
121
+ region_ref = new_region_ref
122
+ ##########################################################################################################
123
+
124
+ ##########################################################################################################
125
+
126
+ if (vcf_path is not None) and region_ld_legend:
127
+ ## plot cbar
128
+ ax1, cbar = _add_ld_legend(sumstats=sumstats,
129
+ ax1=ax1,
130
+ region_ref=region_ref,
131
+ region_ld_threshold=region_ld_threshold,
132
+ region_ref_index_dic=region_ref_index_dic,
133
+ palette=palette)
134
+ else:
135
+ cbar=None
136
+
137
+ if region_title is not None:
138
+ ax1 = _add_region_title(region_title, ax1=ax1,region_title_args=region_title_args )
139
+
140
+ ## recombinnation rate ##################################################
141
+ if (region is not None) and (region_recombination is True):
142
+ ax4 = _plot_recombination_rate(sumstats = sumstats,
143
+ pos =pos,
144
+ region= region,
145
+ ax1 = ax1,
146
+ rr_path =rr_path,
147
+ rr_chr_dict = rr_chr_dict,
148
+ rr_header_dict =rr_header_dict,
149
+ build= build,
150
+ rr_lim=rr_lim,
151
+ rr_ylabel=rr_ylabel)
152
+
153
+ ## regional plot : gene track ######################################################################
154
+ # calculate offset
155
+ if (region is not None):
156
+ most_left_snp = sumstats["i"].idxmin()
157
+
158
+ # distance between leftmost variant position to region left bound
159
+ gene_track_offset = sumstats.loc[most_left_snp,pos] - region[1]
160
+
161
+ # rebase i to region[1] : the i value when POS=0
162
+ gene_track_start_i = sumstats.loc[most_left_snp,"i"] - gene_track_offset - region[1]
163
+
164
+ lead_snp_ys = []
165
+ lead_snp_is = []
166
+ lead_snp_is_colors = []
167
+ for i,lead_id_single in enumerate(lead_ids):
168
+ if lead_id_single is not None:
169
+ lead_snp_ys.append(sumstats.loc[lead_id_single,"scaled_P"] )
170
+ lead_snp_is.append(sumstats.loc[lead_id_single,"i"])
171
+ lead_color = palette[(region_ref_index_dic[region_ref[i]]+1)*100 + len(region_ld_threshold) +1] # consistent color
172
+ lead_snp_is_colors.append(lead_color)
173
+
174
+ if gtf_path is not None:
175
+ # load gtf
176
+ ax3, texts_to_adjust_middle =_plot_gene_track(
177
+ ax3=ax3,
178
+ fig=fig,
179
+ gtf_path=gtf_path,
180
+ region=region,
181
+ region_flank_factor=region_flank_factor,
182
+ region_protein_coding=region_protein_coding,
183
+ region_lead_grid=region_lead_grid,
184
+ region_lead_grid_line=region_lead_grid_line,
185
+ lead_snp_is=lead_snp_is,
186
+ gene_track_start_i=gene_track_start_i,
187
+ gtf_chr_dict=gtf_chr_dict,
188
+ gtf_gene_name=gtf_gene_name,
189
+ track_font_family=track_font_family,
190
+ taf=taf,
191
+ build=build,
192
+ verbose=verbose,
193
+ log=log)
194
+
195
+ ## regional plot - set X tick
196
+ if region is not None:
197
+ region_ticks = list(map('{:.3f}'.format,np.linspace(region[1], region[2], num=region_step).astype("int")/1000000))
198
+
199
+ # set x ticks for gene track
200
+ if "r" in mode:
201
+ if gtf_path is not None:
202
+ ax3.set_xticks(np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step))
203
+ ax3.set_xticklabels(region_ticks,rotation=45,fontsize=fontsize,family="sans-serif")
204
+
205
+ if region_grid==True:
206
+ for i in np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step):
207
+ ax1.axvline(x=i, color=cut_line_color,zorder=1,**region_grid_line)
208
+ ax3.axvline(x=i, color=cut_line_color,zorder=1,**region_grid_line)
209
+
210
+ if region_lead_grid==True:
211
+ for lead_snp_i, lead_snp_y, lead_snp_is_color in zip(lead_snp_is, lead_snp_ys , lead_snp_is_colors):
212
+ region_lead_grid_line["color"] = lead_snp_is_color
213
+ ax1.plot([lead_snp_i,lead_snp_i],[0,lead_snp_y], zorder=1,**region_lead_grid_line)
214
+ ax3.axvline(x=lead_snp_i, zorder=2,**region_lead_grid_line)
215
+
216
+ else:
217
+ # set x ticks m plot
218
+ ax1.set_xticks(np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step))
219
+ ax1.set_xticklabels(region_ticks,rotation=45,fontsize=fontsize,family="sans-serif")
220
+
221
+ ax1.set_xlim([gene_track_start_i+region[1], gene_track_start_i+region[2]])
222
+
223
+ # gene track (ax3) text adjustment
224
+ if (gtf_path is not None ) and ("r" in mode):
225
+ if len(texts_to_adjust_middle)>0:
226
+ adjust_text(texts_to_adjust_middle,
227
+ autoalign=False,
228
+ only_move={'points':'x', 'text':'x', 'objects':'x'},
229
+ ax=ax3,
230
+ precision=0,
231
+ force_text=(0.1,0),
232
+ expand_text=(1, 1),
233
+ expand_objects=(1,1),
234
+ expand_points=(1,1),
235
+ va="center",
236
+ ha='center',
237
+ avoid_points=False,
238
+ lim =1000)
239
+
240
+ return ax1, ax3, ax4, cbar, lead_snp_is, lead_snp_is_colors
241
+
242
+ # + ###########################################################################################################################################################################
243
+ def _get_lead_id(sumstats=None, region_ref=None, log=None, verbose=True):
244
+ region_ref_to_check = copy.copy(region_ref)
245
+ try:
246
+ if len(region_ref_to_check)>0 and type(region_ref_to_check) is not str:
247
+ region_ref_to_check = region_ref_to_check[0]
248
+ except:
249
+ pass
250
+
251
+ lead_id=None
252
+
253
+ if "rsID" in sumstats.columns:
254
+ lead_id = sumstats.index[sumstats["rsID"] == region_ref_to_check].to_list()
255
+
256
+ if lead_id is None and "SNPID" in sumstats.columns:
257
+ lead_id = sumstats.index[sumstats["SNPID"] == region_ref_to_check].to_list()
258
+
259
+ if type(lead_id) is list:
260
+ if len(lead_id)>0:
261
+ lead_id = int(lead_id[0])
262
+
263
+ if region_ref_to_check is not None:
264
+ if type(lead_id) is list:
265
+ if len(lead_id)==0 :
266
+ #try:
267
+ matched_snpid = re.match("(chr)?[0-9]+:[0-9]+:[ATCG]+:[ATCG]+", region_ref_to_check, re.IGNORECASE)
268
+ if matched_snpid is None:
269
+ pass
270
+ else:
271
+ lead_snpid = matched_snpid.group(0).split(":")
272
+ if len(lead_snpid)==4:
273
+ lead_chr= int(lead_snpid[0])
274
+ lead_pos= int(lead_snpid[1])
275
+ lead_ea= lead_snpid[2]
276
+ lead_nea= lead_snpid[3]
277
+ chrpos_match = (sumstats["CHR"] == lead_chr) & (sumstats["POS"] == lead_pos)
278
+ eanea_match = ((sumstats["EA"] == lead_ea) & (sumstats["NEA"] == lead_nea)) | ((sumstats["EA"] == lead_nea) & (sumstats["NEA"] == lead_ea))
279
+ if "rsID" in sumstats.columns:
280
+ lead_id = sumstats.index[chrpos_match&eanea_match].to_list()
281
+ if "SNPID" in sumstats.columns:
282
+ lead_id = sumstats.index[chrpos_match&eanea_match].to_list()
283
+ if type(lead_id) is list:
284
+ if len(lead_id)>0:
285
+ lead_id = int(lead_id[0])
286
+ log.warning("Trying matching variant {} using CHR:POS:EA:NEA to {}... ".format(region_ref_to_check,lead_id))
287
+
288
+ if type(lead_id) is list:
289
+ if len(lead_id)==0 :
290
+ log.warning("Extracting variant: {} not found in sumstats.. Skipping..".format(region_ref_to_check))
291
+ #lead_id = sumstats["scaled_P"].idxmax()
292
+ lead_id = None
293
+ return lead_id
294
+ else:
295
+ log.write(" -Reference variant ID: {} - {}".format(region_ref_to_check, lead_id), verbose=verbose)
296
+
297
+ if lead_id is None:
298
+ log.write(" -Extracting lead variant...", verbose=verbose)
299
+ lead_id = sumstats["scaled_P"].idxmax()
300
+
301
+ return lead_id
302
+
303
+ def _pinpoint_lead(sumstats,ax1,region_ref, lead_color, marker_size, log, verbose, region_marker_shapes):
304
+
305
+ if region_ref is None:
306
+ log.write(" -Extracting lead variant..." , verbose=verbose)
307
+ lead_id = sumstats["scaled_P"].idxmax()
308
+ else:
309
+ lead_id = _get_lead_id(sumstats, region_ref, log, verbose)
310
+
311
+ if lead_id is not None:
312
+ ax1.scatter(sumstats.loc[lead_id,"i"],sumstats.loc[lead_id,"scaled_P"],
313
+ color=lead_color,
314
+ zorder=3,
315
+ marker= region_marker_shapes[sumstats.loc[lead_id,"SHAPE"]-1],
316
+ s=marker_size[1]+2,
317
+ edgecolor="black")
318
+
319
+ return ax1, lead_id
320
+ # -############################################################################################################################################################################
321
+ def _add_region_title(region_title, ax1,region_title_args):
322
+ ax1.text(0.015,0.97, region_title, transform=ax1.transAxes, va="top", ha="left", region_ref=None, **region_title_args )
323
+ return ax1
324
+
325
+ def _add_ld_legend(sumstats, ax1, region_ld_threshold, region_ref,region_ref_index_dic,palette =None, position=1):
326
+
327
+ width_pct = "11%"
328
+ height_pct = "{}%".format( 14 + 7 * len(region_ref))
329
+ axins1 = inset_axes(ax1,
330
+ width=width_pct, # width = 50% of parent_bbox width
331
+ height=height_pct, # height : 5%
332
+ loc='upper right',axes_kwargs={"frameon":True,"facecolor":"white","zorder":999999})
333
+
334
+ ld_ticks = [0]+region_ld_threshold+[1]
335
+
336
+ for index, ld_threshold in enumerate(ld_ticks):
337
+ for group_index in range(len(region_ref)):
338
+ if index < len(ld_ticks)-1:
339
+ x=ld_threshold
340
+ y=0.2*group_index
341
+ width=0.2
342
+ height=ld_ticks[index+1]-ld_ticks[index]
343
+ hex_color = palette[(region_ref_index_dic[region_ref[group_index]]+1)*100 + index+1] # consistent color
344
+
345
+ a = Rectangle((x,y),width, height, fill = True, color = hex_color , linewidth = 2)
346
+ #patches.append(a)
347
+ axins1.add_patch(a)
348
+
349
+ # y snpid
350
+ yticks_position = 0.1 + 0.2 *np.arange(0,len(region_ref))
351
+ axins1.set_yticks(yticks_position, ["{}".format(x) for x in region_ref])
352
+ axins1.set_ylim(0,0.2*len(region_ref))
353
+
354
+ # x ld thresholds
355
+ axins1.set_xticks(ticks=ld_ticks)
356
+ axins1.set_xticklabels([str(i) for i in ld_ticks])
357
+ axins1.set_xlim(0,1)
358
+
359
+ axins1.set_aspect('equal', adjustable='box')
360
+ axins1.set_title('LD $r^{2}$ with variant',loc="center",y=-0.2)
361
+ cbar = axins1
362
+ return ax1, cbar
363
+
364
+ # -############################################################################################################################################################################
365
+ def _plot_recombination_rate(sumstats,pos, region, ax1, rr_path, rr_chr_dict, rr_header_dict, build,rr_lim,rr_ylabel=True):
366
+ ax4 = ax1.twinx()
367
+ most_left_snp = sumstats["i"].idxmin()
368
+
369
+ # the i value when pos=0
370
+ rc_track_offset = sumstats.loc[most_left_snp,"i"]-sumstats.loc[most_left_snp,pos]
371
+
372
+ if rr_path=="default":
373
+ if rr_chr_dict is not None:
374
+ rr_chr = rr_chr_dict[region[0]]
375
+ else:
376
+ rr_chr = str(region[0])
377
+ rc = get_recombination_rate(chrom=rr_chr,build=build)
378
+ else:
379
+ rc = pd.read_csv(rr_path,sep="\t")
380
+ if rr_header_dict is not None:
381
+ rc = rc.rename(columns=rr_header_dict)
382
+
383
+ rc = rc.loc[(rc["Position(bp)"]<region[2]) & (rc["Position(bp)"]>region[1]),:]
384
+ ax4.plot(rc_track_offset+rc["Position(bp)"],rc["Rate(cM/Mb)"],color="#5858FF",zorder=1)
385
+
386
+ ax1.set_zorder(ax4.get_zorder()+1)
387
+ ax1.patch.set_visible(False)
388
+
389
+ if rr_ylabel:
390
+ ax4.set_ylabel("Recombination rate(cM/Mb)")
391
+ if rr_lim!="max":
392
+ ax4.set_ylim(rr_lim[0],rr_lim[1])
393
+ else:
394
+ ax4.set_ylim(0, 1.05 * rc["Rate(cM/Mb)"].max())
395
+ ax4.spines["top"].set_visible(False)
396
+ ax4.spines["top"].set(zorder=1)
397
+ return ax4
398
+
399
+ # -############################################################################################################################################################################
400
+ def _plot_gene_track(
401
+ ax3,
402
+ fig,
403
+ gtf_path,
404
+ region,
405
+ region_flank_factor,
406
+ region_protein_coding,
407
+ region_lead_grid,
408
+ region_lead_grid_line,
409
+ lead_snp_is,
410
+ gene_track_start_i,
411
+ gtf_chr_dict,gtf_gene_name,
412
+ track_font_family,
413
+ taf,
414
+ build,
415
+ verbose=True,
416
+ log=Log()):
417
+
418
+ # load gtf
419
+ log.write(" -Loading gtf files from:" + gtf_path, verbose=verbose)
420
+ uniq_gene_region,exons = process_gtf( gtf_path = gtf_path ,
421
+ region = region,
422
+ region_flank_factor = region_flank_factor,
423
+ build=build,
424
+ region_protein_coding=region_protein_coding,
425
+ gtf_chr_dict=gtf_chr_dict,
426
+ gtf_gene_name=gtf_gene_name)
427
+
428
+ n_uniq_stack = uniq_gene_region["stack"].nunique()
429
+ stack_num_to_plot = max(taf[0],n_uniq_stack)
430
+ ax3.set_ylim((-stack_num_to_plot*2-taf[1]*2,2+taf[1]*2))
431
+ ax3.set_yticks([])
432
+ pixels_per_point = 72/fig.dpi
433
+ pixels_per_track = np.abs(ax3.transData.transform([0,0])[1] - ax3.transData.transform([0,1])[1])
434
+ font_size_in_pixels= taf[2] * pixels_per_track
435
+ font_size_in_points = font_size_in_pixels * pixels_per_point
436
+ linewidth_in_points= pixels_per_track * pixels_per_point
437
+ log.write(" -plotting gene track..", verbose=verbose)
438
+
439
+ sig_gene_name = "Undefined"
440
+ sig_gene_name2 = "Undefined"
441
+ texts_to_adjust_left = []
442
+ texts_to_adjust_middle = []
443
+ texts_to_adjust_right = []
444
+
445
+
446
+ sig_gene_names=[]
447
+ sig_gene_lefts=[]
448
+ sig_gene_rights=[]
449
+ for index,row in uniq_gene_region.iterrows():
450
+
451
+ gene_color="#020080"
452
+ #if row[6][0]=="+":
453
+ if row["strand"][0]=="+":
454
+ gene_anno = row["name"] + "->"
455
+ else:
456
+ gene_anno = "<-" + row["name"]
457
+
458
+
459
+
460
+ for lead_snp_i in lead_snp_is:
461
+ if region_lead_grid is True and lead_snp_i > gene_track_start_i+row["start"] and lead_snp_i < gene_track_start_i+row["end"] :
462
+ gene_color=region_lead_grid_line["color"]
463
+ sig_gene_names.append(row["name"])
464
+ sig_gene_lefts.append(gene_track_start_i+row["start"])
465
+ sig_gene_rights.append(gene_track_start_i+row["end"])
466
+
467
+ # plot gene line
468
+ ax3.plot((gene_track_start_i+row["start"],gene_track_start_i+row["end"]),
469
+ (row["stack"]*2,row["stack"]*2),color=gene_color,linewidth=linewidth_in_points/10)
470
+
471
+ # plot gene name
472
+ if row["end"] >= region[2]:
473
+ #right side
474
+ texts_to_adjust_right.append(ax3.text(x=gene_track_start_i+region[2],
475
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="right",va="center",color="black",style='italic', size=font_size_in_points,family=track_font_family))
476
+
477
+ elif row["start"] <= region[1] :
478
+ #left side
479
+ texts_to_adjust_left.append(ax3.text(x=gene_track_start_i+region[1],
480
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="left",va="center",color="black",style='italic', size=font_size_in_points,family=track_font_family))
481
+ else:
482
+ texts_to_adjust_middle.append(ax3.text(x=(gene_track_start_i+row["start"]+gene_track_start_i+row["end"])/2,
483
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="center",va="center",color="black",style='italic',size=font_size_in_points,family=track_font_family))
484
+
485
+ # plot exons
486
+ for index,row in exons.iterrows():
487
+ exon_color="#020080"
488
+ for sig_gene_name, sig_gene_left, sig_gene_right in zip(sig_gene_names,sig_gene_lefts,sig_gene_rights):
489
+
490
+ if not pd.isnull(row["name"]):
491
+ if (region_lead_grid is True) and row["name"]==sig_gene_name:
492
+ exon_color = region_lead_grid_line["color"]
493
+ else:
494
+ exon_color="#020080"
495
+ elif gene_track_start_i+row["starts"] > sig_gene_left and gene_track_start_i+row["end"] < sig_gene_right:
496
+ exon_color = region_lead_grid_line["color"]
497
+ else:
498
+ exon_color="#020080"
499
+
500
+ ax3.plot((gene_track_start_i+row["start"],gene_track_start_i+row["end"]),
501
+ (row["stack"]*2,row["stack"]*2),linewidth=linewidth_in_points*taf[3],color=exon_color,solid_capstyle="butt")
502
+
503
+ log.write(" -Finished plotting gene track..", verbose=verbose)
504
+
505
+ return ax3,texts_to_adjust_middle
506
+
507
+ # -############################################################################################################################################################################
508
+ # Helpers
509
+ # -############################################################################################################################################################################
510
+ def process_vcf(sumstats,
511
+ vcf_path,
512
+ region,
513
+ region_ref,
514
+ #region_ref_second,
515
+ log,
516
+ verbose,
517
+ pos ,
518
+ nea,
519
+ ea,
520
+ region_ld_threshold,
521
+ vcf_chr_dict,
522
+ tabix):
523
+
524
+ log.write("Start to load reference genotype...", verbose=verbose)
525
+ log.write(" -reference vcf path : "+ vcf_path, verbose=verbose)
526
+
527
+ # load genotype data of the targeted region
528
+ ref_genotype = read_vcf(vcf_path,region=vcf_chr_dict[region[0]]+":"+str(region[1])+"-"+str(region[2]),tabix=tabix)
529
+ if ref_genotype is None:
530
+ log.warning("No data was retrieved. Skipping ...")
531
+ ref_genotype=dict()
532
+ ref_genotype["variants/POS"]=np.array([],dtype="int64")
533
+ log.write(" -Retrieving index...", verbose=verbose)
534
+ log.write(" -Ref variants in the region: {}".format(len(ref_genotype["variants/POS"])), verbose=verbose)
535
+ # match sumstats pos and ref pos:
536
+ # get ref index for its first appearance of sumstats pos
537
+ #######################################################################################
538
+ def match_varaint(x):
539
+ # x: "POS,NEA,EA"
540
+ if np.any(ref_genotype["variants/POS"] == x.iloc[0]):
541
+ # position match
542
+ if len(np.where(ref_genotype["variants/POS"] == x.iloc[0] )[0])>1:
543
+ # multiple position matches
544
+ for j in np.where(ref_genotype["variants/POS"] == x.iloc[0])[0]:
545
+ # for each possible match, compare ref and alt
546
+ if x.iloc[1] == ref_genotype["variants/REF"][j]:
547
+ if x.iloc[2] in ref_genotype["variants/ALT"][j]:
548
+ return j
549
+ elif x.iloc[1] in ref_genotype["variants/ALT"][j]:
550
+ if x.iloc[2] == ref_genotype["variants/REF"][j]:
551
+ return j
552
+ return None
553
+ else:
554
+ # single match
555
+ return np.where(ref_genotype["variants/POS"] == x.iloc[0] )[0][0]
556
+ else:
557
+ # no position match
558
+ return None
559
+ log.write(" -Matching variants using POS, NEA, EA ...", verbose=verbose)
560
+ #############################################################################################
561
+ sumstats["REFINDEX"] = sumstats[[pos,nea,ea]].apply(lambda x: match_varaint(x),axis=1)
562
+ #############################################################################################
563
+
564
+ #for loop to add LD information
565
+ #############################################################################################
566
+ for ref_n, region_ref_single in enumerate(region_ref):
567
+
568
+ rsq = "RSQ_{}".format(ref_n)
569
+ ld_single = "LD_{}".format(ref_n)
570
+ lead = "LEAD_{}".format(ref_n)
571
+ sumstats[lead]= 0
572
+
573
+ # get lead variant id and pos
574
+ if region_ref_single is None:
575
+ # if not specified, use lead variant
576
+ lead_id = sumstats["scaled_P"].idxmax()
577
+ else:
578
+ # figure out lead variant
579
+ lead_id = _get_lead_id(sumstats, region_ref_single, log, verbose)
580
+
581
+
582
+ lead_series = None
583
+ if lead_id is None:
584
+
585
+ matched_snpid = re.match("(chr)?[0-9]+:[0-9]+:[ATCG]+:[ATCG]+",region_ref_single, re.IGNORECASE)
586
+
587
+ if matched_snpid is None:
588
+ sumstats[rsq] = None
589
+ sumstats[rsq] = sumstats[rsq].astype("float")
590
+ sumstats[ld_single] = 0
591
+ continue
592
+ else:
593
+
594
+ lead_snpid = matched_snpid.group(0).split(":")[1:]
595
+ lead_pos = int(lead_snpid[0])
596
+ lead_snpid[0]= int(lead_snpid[0])
597
+ lead_series = pd.Series(lead_snpid)
598
+ else:
599
+ lead_pos = sumstats.loc[lead_id,pos]
600
+
601
+
602
+ # if lead pos is available:
603
+ if lead_pos in ref_genotype["variants/POS"]:
604
+
605
+ # get ref index for lead snp
606
+ if lead_series is None:
607
+ lead_snp_ref_index = match_varaint(sumstats.loc[lead_id,[pos,nea,ea]])
608
+ #lead_snp_ref_index = np.where(ref_genotype["variants/POS"] == lead_pos)[0][0]
609
+ else:
610
+ log.warning("Computing LD: {} not found in sumstats but found in reference...Still Computing...".format(region_ref_single))
611
+ lead_snp_ref_index = match_varaint(lead_series)
612
+
613
+ # non-na other snp index
614
+ other_snps_ref_index = sumstats["REFINDEX"].dropna().astype("int").values
615
+ # get genotype
616
+
617
+ lead_snp_genotype = GenotypeArray([ref_genotype["calldata/GT"][lead_snp_ref_index]]).to_n_alt()
618
+ try:
619
+ if len(set(lead_snp_genotype[0]))==1:
620
+ log.warning("The variant is mono-allelic in reference VCF. LD can not be calculated.")
621
+ except:
622
+ pass
623
+ other_snp_genotype = GenotypeArray(ref_genotype["calldata/GT"][other_snps_ref_index]).to_n_alt()
624
+
625
+ log.write(" -Calculating Rsq...", verbose=verbose)
626
+
627
+ if len(other_snp_genotype)>1:
628
+ valid_r2= np.power(rogers_huff_r_between(lead_snp_genotype,other_snp_genotype)[0],2)
629
+ else:
630
+ valid_r2= np.power(rogers_huff_r_between(lead_snp_genotype,other_snp_genotype),2)
631
+ sumstats.loc[~sumstats["REFINDEX"].isna(),rsq] = valid_r2
632
+ else:
633
+ log.write(" -Lead SNP not found in reference...", verbose=verbose)
634
+ sumstats[rsq]=None
635
+
636
+ sumstats[rsq] = sumstats[rsq].astype("float")
637
+ sumstats[ld_single] = 0
638
+
639
+ for index,ld_threshold in enumerate(region_ld_threshold):
640
+ # No data,LD = 0
641
+ # 0, 0.2 LD = 1
642
+ # 1, 0.4 LD = 2
643
+ # 2, 0.6 LD = 3
644
+ # 3, 0.8 LD = 4
645
+ # 4, 1.0 LD = 5
646
+ # lead LD = 6
647
+
648
+ if index==0:
649
+ to_change_color = sumstats[rsq]>-1
650
+ sumstats.loc[to_change_color,ld_single] = 1
651
+ to_change_color = sumstats[rsq]>ld_threshold
652
+ sumstats.loc[to_change_color,ld_single] = index+2
653
+
654
+ if lead_series is None:
655
+ sumstats.loc[lead_id,ld_single] = len(region_ld_threshold)+2
656
+ sumstats.loc[lead_id,lead] = 1
657
+
658
+ ####################################################################################################
659
+ final_shape_col = "SHAPE"
660
+ final_ld_col = "LD"
661
+ final_rsq_col = "RSQ"
662
+
663
+ sumstats[final_ld_col] = 0
664
+ sumstats[final_shape_col] = 1
665
+ sumstats[final_rsq_col] = 0.0
666
+
667
+ for i in range(len(region_ref)):
668
+ ld_single = "LD_{}".format(i)
669
+ current_rsq = "RSQ_{}".format(i)
670
+ a_ngt_b = sumstats[final_rsq_col] < sumstats[current_rsq]
671
+ #set levels with interval=100
672
+ sumstats.loc[a_ngt_b, final_ld_col] = 100 * (i+1) + sumstats.loc[a_ngt_b, ld_single]
673
+ sumstats.loc[a_ngt_b, final_rsq_col] = sumstats.loc[a_ngt_b, current_rsq]
674
+ sumstats.loc[a_ngt_b, final_shape_col] = i + 1
675
+
676
+ ####################################################################################################
677
+ log.write("Finished loading reference genotype successfully!", verbose=verbose)
678
+ return sumstats
679
+
680
+ # -############################################################################################################################################################################
681
+
682
+ def process_gtf(gtf_path,
683
+ region,
684
+ region_flank_factor,
685
+ build,
686
+ region_protein_coding,
687
+ gtf_chr_dict,
688
+ gtf_gene_name):
689
+ #loading
690
+
691
+ # chr to string datatype using gtf_chr_dict
692
+ to_query_chrom = gtf_chr_dict[region[0]]
693
+
694
+ # loading gtf data
695
+ if gtf_path =="default" or gtf_path =="ensembl":
696
+
697
+ gtf = get_gtf(chrom=to_query_chrom, build=build, source="ensembl")
698
+
699
+ elif gtf_path =="refseq":
700
+
701
+ gtf = get_gtf(chrom=to_query_chrom, build=build, source="refseq")
702
+
703
+ else:
704
+ # if user-provided gtf
705
+ #gtf = pd.read_csv(gtf_path,sep="\t",header=None, comment="#", low_memory=False,dtype={0:"string"})
706
+ gtf = read_gtf(gtf_path)
707
+ gtf = gtf.loc[gtf["seqname"]==gtf_chr_dict[region[0]],:]
708
+
709
+ # filter in region
710
+ genes_1mb = gtf.loc[(gtf["seqname"]==to_query_chrom)&(gtf["start"]<region[2])&(gtf["end"]>region[1]),:].copy()
711
+
712
+ # extract biotype
713
+ #genes_1mb.loc[:,"gene_biotype"] = genes_1mb[8].str.extract(r'gene_biotype "([\w\.\_-]+)"')
714
+
715
+ # extract gene name
716
+ if gtf_gene_name is None:
717
+ if gtf_path=="refseq":
718
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"').astype("string")
719
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_id"]
720
+ elif gtf_path =="default" or gtf_path =="ensembl":
721
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_name "([\w\.-]+)"').astype("string")
722
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_name"]
723
+ else:
724
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"').astype("string")
725
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_id"]
726
+ else:
727
+ #pattern = r'{} "([\w\.-]+)"'.format(gtf_gene_name)
728
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(pattern).astype("string")
729
+ genes_1mb.loc[:,"name"] = genes_1mb[gtf_gene_name]
730
+
731
+ # extract gene
732
+ #genes_1mb.loc[:,"gene"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"')
733
+ genes_1mb.loc[:,"gene"] = genes_1mb["gene_id"]
734
+
735
+
736
+ # extract protein coding gene
737
+ if region_protein_coding is True:
738
+ #genes_1mb = genes_1mb.loc[genes_1mb["gene_biotype"]=="protein_coding",:].copy()
739
+ pc_genes_1mb_list = genes_1mb.loc[(genes_1mb["feature"]=="gene")& (genes_1mb["gene_biotype"]=="protein_coding") & (genes_1mb["name"]!=""),"name"].values
740
+ genes_1mb = genes_1mb.loc[(genes_1mb["feature"].isin(["exon","gene"])) & (genes_1mb["name"].isin(pc_genes_1mb_list)),:]
741
+ # extract exon
742
+ exons = genes_1mb.loc[genes_1mb["feature"]=="exon",:].copy()
743
+
744
+ #uniq genes
745
+ ## get all record with 2nd column == gene
746
+ #uniq_gene_region = genes_1mb.loc[genes_1mb[2]=="gene",:].copy()
747
+ uniq_gene_region = genes_1mb.loc[genes_1mb["feature"]=="gene",:].copy()
748
+
749
+ ## extract region + flank
750
+ flank = region_flank_factor * (region[2] - region[1])
751
+
752
+ ## get left and right boundary
753
+ #uniq_gene_region["left"] = uniq_gene_region[3]-flank
754
+ #uniq_gene_region["right"] = uniq_gene_region[4]+flank
755
+ #
756
+ uniq_gene_region["left"] = uniq_gene_region["start"]-flank
757
+ uniq_gene_region["right"] = uniq_gene_region["end"]+flank
758
+
759
+ # arrange gene track
760
+ stack_dic = assign_stack(uniq_gene_region.sort_values(["start"]).loc[:,["name","left","right"]])
761
+
762
+ # map gene to stack and add stack column : minus stack
763
+ uniq_gene_region["stack"] = -uniq_gene_region["name"].map(stack_dic)
764
+ exons.loc[:,"stack"] = -exons.loc[:,"name"].map(stack_dic)
765
+
766
+ # return uniq_gene_region (gene records with left and right boundary)
767
+ # return exon records with stack number
768
+ return uniq_gene_region, exons
769
+
770
+
771
+ # -############################################################################################################################################################################
772
+ def assign_stack(uniq_gene_region):
773
+
774
+ stacks=[] ## stack : gene track
775
+ stack_dic={} # mapping gene name to stack
776
+
777
+ for index,row in uniq_gene_region.iterrows():
778
+ if len(stacks)==0:
779
+ # add first entry
780
+ stacks.append([(row["left"],row["right"])])
781
+ stack_dic[row["name"]] = 0
782
+ else:
783
+ for i in range(len(stacks)):
784
+ for j in range(len(stacks[i])):
785
+ # if overlap
786
+ if (row["left"]>stacks[i][j][0] and row["left"]<stacks[i][j][1]) or (row["right"]>stacks[i][j][0] and row["right"]<stacks[i][j][1]):
787
+ # if not last stack : break
788
+ if i<len(stacks)-1:
789
+ break
790
+ # if last stack : add a new stack
791
+ else:
792
+ stacks.append([(row["left"],row["right"])])
793
+ stack_dic[row["name"]] = i+1
794
+ break
795
+ # if no overlap
796
+ else:
797
+ # not last in a stack
798
+ if j<len(stacks[i])-1:
799
+ #if in the middle
800
+ if row["left"]>stacks[i][j][1] and row["right"]<stacks[i][j+1][0]:
801
+ stacks[i].insert(j+1,(row["left"],row["right"]))
802
+ stack_dic[row["name"]] = i
803
+ break
804
+ #last one in a stack
805
+ elif row["left"]>stacks[i][j][1]:
806
+ stacks[i].append((row["left"],row["right"]))
807
+ stack_dic[row["name"]] = i
808
+ break
809
+ if row["name"] in stack_dic.keys():
810
+ break
811
+ return stack_dic
812
+
813
+ def closest_gene(x,data,chrom="CHR",pos="POS",maxiter=20000,step=50):
814
+ gene = data.gene_names_at_locus(contig=x[chrom], position=x[pos])
815
+ if len(gene)==0:
816
+ i=0
817
+ while i<maxiter:
818
+ distance = i*step
819
+ gene_u = data.gene_names_at_locus(contig=x[chrom], position=x[pos]-distance)
820
+ gene_d = data.gene_names_at_locus(contig=x[chrom], position=x[pos]+distance)
821
+ if len(gene_u)>0 and len(gene_d)>0:
822
+ for j in range(0,step,1):
823
+ distance = (i-1)*step
824
+ gene_u = data.gene_names_at_locus(contig=x[chrom], position=x[pos]-distance-j)
825
+ gene_d = data.gene_names_at_locus(contig=x[chrom], position=x[pos]+distance+j)
826
+ if len(gene_u)>0:
827
+ return -distance,",".join(gene_u)
828
+ else:
829
+ return distance,",".join(gene_d)
830
+ elif len(gene_u)>0:
831
+ return -distance,",".join(gene_u)
832
+ elif len(gene_d)>0:
833
+ return +distance,",".join(gene_d)
834
+ else:
835
+ i+=1
836
+ return distance,"intergenic"
837
+ else:
838
+ return 0,",".join(gene)