gwaslab 3.4.46__py3-none-any.whl → 3.4.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gwaslab might be problematic. Click here for more details.

@@ -0,0 +1,792 @@
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.ticker as ticker
4
+ import matplotlib.patches as patches
5
+ import seaborn as sns
6
+ import numpy as np
7
+ import copy
8
+ import scipy as sp
9
+ from pyensembl import EnsemblRelease
10
+ from allel import GenotypeArray
11
+ from allel import read_vcf
12
+ from allel import rogers_huff_r_between
13
+ import matplotlib as mpl
14
+ from scipy import stats
15
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
16
+ from mpl_toolkits.axes_grid1.inset_locator import mark_inset
17
+ from adjustText import adjust_text
18
+ from gtfparse import read_gtf
19
+ from gwaslab.g_Log import Log
20
+ from gwaslab.bd_common_data import get_chr_to_number
21
+ from gwaslab.bd_common_data import get_number_to_chr
22
+ from gwaslab.bd_common_data import get_recombination_rate
23
+ from gwaslab.bd_common_data import get_gtf
24
+ from matplotlib.colors import ListedColormap
25
+ from matplotlib.colors import LinearSegmentedColormap
26
+ import matplotlib.colors
27
+ from matplotlib.colors import Normalize
28
+ from matplotlib.patches import Rectangle
29
+
30
+ def _plot_regional(
31
+ sumstats,
32
+ fig,
33
+ ax1,
34
+ ax3,
35
+ region,
36
+ vcf_path,
37
+ marker_size,
38
+ fontsize,
39
+ build,
40
+ chrom_df,
41
+ xtick_chr_dict,
42
+ cut_line_color,
43
+ vcf_chr_dict = None,
44
+ gtf_path="default",
45
+ gtf_chr_dict = get_number_to_chr(),
46
+ gtf_gene_name=None,
47
+ rr_path="default",
48
+ rr_header_dict=None,
49
+ rr_chr_dict = get_number_to_chr(),
50
+ rr_lim = (0,100),
51
+ rr_ylabel = True,
52
+ rr_title=None,
53
+ region_ld_legend=True,
54
+ region_title=None,
55
+ mode="mqq",
56
+ region_step = 21,
57
+ region_ref=None,
58
+ region_ref_index_dic = None,
59
+ #region_ref_second=None,
60
+ region_grid = False,
61
+ region_grid_line = {"linewidth": 2,"linestyle":"--"},
62
+ region_lead_grid = True,
63
+ region_lead_grid_line = {"alpha":0.5,"linewidth" : 2,"linestyle":"--","color":"#FF0000"},
64
+ region_title_args = None,
65
+ region_hspace=0.02,
66
+ region_ld_threshold = [0.2,0.4,0.6,0.8],
67
+ region_ld_colors = ["#E4E4E4","#020080","#86CEF9","#24FF02","#FDA400","#FF0000","#FF0000"],
68
+ region_marker_shapes=None,
69
+ palette=None,
70
+ region_recombination = True,
71
+ region_protein_coding=True,
72
+ region_flank_factor = 0.05,
73
+ track_font_family="Arial",
74
+ taf=[4,0,0.95,1,1],
75
+ # track_n, track_n_offset,font_ratio,exon_ratio,text_offset
76
+ tabix=None,
77
+ chrom="CHR",
78
+ pos="POS",
79
+ verbose=True,
80
+ log=Log()
81
+ ):
82
+
83
+ # x axix: use i to plot (there is a gap between i and pos)
84
+
85
+
86
+ # if regional plot : pinpoint lead , add color bar ##################################################
87
+ if (region is not None) :
88
+ # pinpoint lead
89
+ lead_ids = []
90
+
91
+ for index, region_ref_single in enumerate(region_ref):
92
+ ax1, lead_id_single = _pinpoint_lead(sumstats = sumstats,
93
+ ax1 = ax1,
94
+ region_ref=region_ref_single,
95
+ lead_color = palette[(index+1)*100 + len(region_ld_threshold)+2],
96
+ marker_size= marker_size,
97
+ region_marker_shapes=region_marker_shapes,
98
+ log=log,verbose=verbose)
99
+ if lead_id_single is not None:
100
+ lead_ids.append(lead_id_single)
101
+
102
+ # update region_ref to variant rsID or variantID / skip NAs
103
+ new_region_ref = []
104
+ for i in range(len(lead_ids)):
105
+ if lead_ids[i] is None:
106
+ continue
107
+ if region_ref[i] is None:
108
+ if "rsID" in sumstats.columns:
109
+ new_name = sumstats.loc[lead_ids[i],"rsID"]
110
+ elif "SNPID" in sumstats.columns:
111
+ new_name = sumstats.loc[lead_ids[i],"SNPID"]
112
+ else:
113
+ new_name = "chr{}:{}".format(sumstats.loc[lead_ids[i],"CHR"] , sumstats.loc[lead_ids[i],"POS"])
114
+ new_region_ref.append(new_name)
115
+ region_ref_index_dic[new_name] = region_ref_index_dic[region_ref[i]]
116
+ continue
117
+ else:
118
+ new_region_ref.append(region_ref[i])
119
+ region_ref = new_region_ref
120
+ ##########################################################################################################
121
+
122
+ ##########################################################################################################
123
+
124
+ if (vcf_path is not None) and region_ld_legend:
125
+ ## plot cbar
126
+ ax1, cbar = _add_ld_legend(sumstats=sumstats,
127
+ ax1=ax1,
128
+ region_ref=region_ref,
129
+ region_ld_threshold=region_ld_threshold,
130
+ region_ref_index_dic=region_ref_index_dic,
131
+ palette=palette)
132
+ else:
133
+ cbar=None
134
+
135
+ if region_title is not None:
136
+ ax1 = _add_region_title(region_title, ax1=ax1,region_title_args=region_title_args )
137
+
138
+ ## recombinnation rate ##################################################
139
+ if (region is not None) and (region_recombination is True):
140
+ ax4 = _plot_recombination_rate(sumstats = sumstats,
141
+ pos =pos,
142
+ region= region,
143
+ ax1 = ax1,
144
+ rr_path =rr_path,
145
+ rr_chr_dict = rr_chr_dict,
146
+ rr_header_dict =rr_header_dict,
147
+ build= build,
148
+ rr_lim=rr_lim,
149
+ rr_ylabel=rr_ylabel)
150
+
151
+ ## regional plot : gene track ######################################################################
152
+ # calculate offset
153
+ if (region is not None):
154
+ most_left_snp = sumstats["i"].idxmin()
155
+
156
+ # distance between leftmost variant position to region left bound
157
+ gene_track_offset = sumstats.loc[most_left_snp,pos] - region[1]
158
+
159
+ # rebase i to region[1] : the i value when POS=0
160
+ gene_track_start_i = sumstats.loc[most_left_snp,"i"] - gene_track_offset - region[1]
161
+
162
+ lead_snp_ys = []
163
+ lead_snp_is = []
164
+ lead_snp_is_colors = []
165
+
166
+ for i,lead_id_single in enumerate(lead_ids):
167
+ if lead_id_single is not None:
168
+ lead_snp_ys.append(sumstats.loc[lead_id_single,"scaled_P"] )
169
+ lead_snp_is.append(sumstats.loc[lead_id_single,"i"])
170
+ lead_color = palette[(region_ref_index_dic[region_ref[i]]+1)*100 + len(region_ld_threshold) +1] # consistent color
171
+ lead_snp_is_colors.append(lead_color)
172
+
173
+ if gtf_path is not None:
174
+ # load gtf
175
+ ax3, texts_to_adjust_middle =_plot_gene_track(
176
+ ax3=ax3,
177
+ fig=fig,
178
+ gtf_path=gtf_path,
179
+ region=region,
180
+ region_flank_factor=region_flank_factor,
181
+ region_protein_coding=region_protein_coding,
182
+ region_lead_grid=region_lead_grid,
183
+ region_lead_grid_line=region_lead_grid_line,
184
+ lead_snp_is=lead_snp_is,
185
+ gene_track_start_i=gene_track_start_i,
186
+ gtf_chr_dict=gtf_chr_dict,
187
+ gtf_gene_name=gtf_gene_name,
188
+ track_font_family=track_font_family,
189
+ taf=taf,
190
+ build=build,
191
+ verbose=verbose,
192
+ log=log)
193
+
194
+ ## regional plot - set X tick
195
+ if region is not None:
196
+ region_ticks = list(map('{:.3f}'.format,np.linspace(region[1], region[2], num=region_step).astype("int")/1000000))
197
+
198
+ # set x ticks for gene track
199
+ if "r" in mode:
200
+ if gtf_path is not None:
201
+ ax3.set_xticks(np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step))
202
+ ax3.set_xticklabels(region_ticks,rotation=45,fontsize=fontsize,family="sans-serif")
203
+
204
+ if region_grid==True:
205
+ for i in np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step):
206
+ ax1.axvline(x=i, color=cut_line_color,zorder=1,**region_grid_line)
207
+ ax3.axvline(x=i, color=cut_line_color,zorder=1,**region_grid_line)
208
+
209
+ if region_lead_grid==True:
210
+ for lead_snp_i, lead_snp_y, lead_snp_is_color in zip(lead_snp_is, lead_snp_ys , lead_snp_is_colors):
211
+ region_lead_grid_line["color"] = lead_snp_is_color
212
+ ax1.plot([lead_snp_i,lead_snp_i],[0,lead_snp_y], zorder=1,**region_lead_grid_line)
213
+ ax3.axvline(x=lead_snp_i, zorder=2,**region_lead_grid_line)
214
+
215
+ else:
216
+ # set x ticks m plot
217
+ ax1.set_xticks(np.linspace(gene_track_start_i+region[1], gene_track_start_i+region[2], num=region_step))
218
+ ax1.set_xticklabels(region_ticks,rotation=45,fontsize=fontsize,family="sans-serif")
219
+
220
+ ax1.set_xlim([gene_track_start_i+region[1], gene_track_start_i+region[2]])
221
+
222
+ # gene track (ax3) text adjustment
223
+ if (gtf_path is not None ) and ("r" in mode):
224
+ if len(texts_to_adjust_middle)>0:
225
+ adjust_text(texts_to_adjust_middle,
226
+ autoalign=False,
227
+ only_move={'points':'x', 'text':'x', 'objects':'x'},
228
+ ax=ax3,
229
+ precision=0,
230
+ force_text=(0.1,0),
231
+ expand_text=(1, 1),
232
+ expand_objects=(1,1),
233
+ expand_points=(1,1),
234
+ va="center",
235
+ ha='center',
236
+ avoid_points=False,
237
+ lim =1000)
238
+
239
+ return ax1, ax3, ax4, cbar, lead_snp_is, lead_snp_is_colors
240
+
241
+ # + ###########################################################################################################################################################################
242
+ def _get_lead_id(sumstats=None, region_ref=None, log=None, verbose=True):
243
+ region_ref_to_check = copy.copy(region_ref)
244
+ try:
245
+ if len(region_ref_to_check)>0 and type(region_ref_to_check) is not str:
246
+ region_ref_to_check = region_ref_to_check[0]
247
+ except:
248
+ pass
249
+
250
+ lead_id=None
251
+
252
+ if "rsID" in sumstats.columns:
253
+ lead_id = sumstats.index[sumstats["rsID"] == region_ref_to_check].to_list()
254
+
255
+ if lead_id is None and "SNPID" in sumstats.columns:
256
+ lead_id = sumstats.index[sumstats["SNPID"] == region_ref_to_check].to_list()
257
+
258
+ if type(lead_id) is list:
259
+ if len(lead_id)>0:
260
+ lead_id = int(lead_id[0])
261
+
262
+ if region_ref_to_check is not None:
263
+ if type(lead_id) is list:
264
+ if len(lead_id)==0 :
265
+ log.warning("{} not found.. Skipping..".format(region_ref_to_check))
266
+ #lead_id = sumstats["scaled_P"].idxmax()
267
+ lead_id = None
268
+ return lead_id
269
+ else:
270
+ log.write(" -Reference variant ID: {} - {}".format(region_ref_to_check, lead_id), verbose=verbose)
271
+
272
+ if lead_id is None:
273
+ log.write(" -Extracting lead variant...", verbose=verbose)
274
+ lead_id = sumstats["scaled_P"].idxmax()
275
+
276
+ return lead_id
277
+
278
+ def _pinpoint_lead(sumstats,ax1,region_ref, lead_color, marker_size, log, verbose,region_marker_shapes):
279
+
280
+ if region_ref is None:
281
+ log.write(" -Extracting lead variant..." , verbose=verbose)
282
+ lead_id = sumstats["scaled_P"].idxmax()
283
+ else:
284
+ lead_id = _get_lead_id(sumstats, region_ref, log, verbose)
285
+
286
+ if lead_id is not None:
287
+ ax1.scatter(sumstats.loc[lead_id,"i"],sumstats.loc[lead_id,"scaled_P"],
288
+ color=lead_color,
289
+ zorder=3,
290
+ marker= region_marker_shapes[sumstats.loc[lead_id,"SHAPE"]-1],
291
+ s=marker_size[1]+2,
292
+ edgecolor="black")
293
+
294
+ return ax1, lead_id
295
+ # -############################################################################################################################################################################
296
+ def _add_region_title(region_title, ax1,region_title_args):
297
+ ax1.text(0.015,0.97, region_title, transform=ax1.transAxes, va="top", ha="left", region_ref=None, **region_title_args )
298
+ return ax1
299
+
300
+ def _add_ld_legend(sumstats, ax1, region_ld_threshold, region_ref,region_ref_index_dic,palette =None, position=1):
301
+
302
+ width_pct = "11%"
303
+ height_pct = "{}%".format( 14 + 7 * len(region_ref))
304
+ axins1 = inset_axes(ax1,
305
+ width=width_pct, # width = 50% of parent_bbox width
306
+ height=height_pct, # height : 5%
307
+ loc='upper right',axes_kwargs={"frameon":True,"facecolor":"white","zorder":999999})
308
+
309
+ ld_ticks = [0]+region_ld_threshold+[1]
310
+
311
+ for index, ld_threshold in enumerate(ld_ticks):
312
+ for group_index in range(len(region_ref)):
313
+ if index < len(ld_ticks)-1:
314
+ x=ld_threshold
315
+ y=0.2*group_index
316
+ width=0.2
317
+ height=ld_ticks[index+1]-ld_ticks[index]
318
+ hex_color = palette[(region_ref_index_dic[region_ref[group_index]]+1)*100 + index+1] # consistent color
319
+
320
+ a = Rectangle((x,y),width, height, fill = True, color = hex_color , linewidth = 2)
321
+ #patches.append(a)
322
+ axins1.add_patch(a)
323
+
324
+ # y snpid
325
+ yticks_position = 0.1 + 0.2 *np.arange(0,len(region_ref))
326
+ axins1.set_yticks(yticks_position, ["{}".format(x) for x in region_ref])
327
+ axins1.set_ylim(0,0.2*len(region_ref))
328
+
329
+ # x ld thresholds
330
+ axins1.set_xticks(ticks=ld_ticks)
331
+ axins1.set_xticklabels([str(i) for i in ld_ticks])
332
+ axins1.set_xlim(0,1)
333
+
334
+ axins1.set_aspect('equal', adjustable='box')
335
+ axins1.set_title('LD $r^{2}$ with variant',loc="center",y=-0.2)
336
+ cbar = axins1
337
+ return ax1, cbar
338
+
339
+ # -############################################################################################################################################################################
340
+ def _plot_recombination_rate(sumstats,pos, region, ax1, rr_path, rr_chr_dict, rr_header_dict, build,rr_lim,rr_ylabel=True):
341
+ ax4 = ax1.twinx()
342
+ most_left_snp = sumstats["i"].idxmin()
343
+
344
+ # the i value when pos=0
345
+ rc_track_offset = sumstats.loc[most_left_snp,"i"]-sumstats.loc[most_left_snp,pos]
346
+
347
+ if rr_path=="default":
348
+ if rr_chr_dict is not None:
349
+ rr_chr = rr_chr_dict[region[0]]
350
+ else:
351
+ rr_chr = str(region[0])
352
+ rc = get_recombination_rate(chrom=rr_chr,build=build)
353
+ else:
354
+ rc = pd.read_csv(rr_path,sep="\t")
355
+ if rr_header_dict is not None:
356
+ rc = rc.rename(columns=rr_header_dict)
357
+
358
+ rc = rc.loc[(rc["Position(bp)"]<region[2]) & (rc["Position(bp)"]>region[1]),:]
359
+ ax4.plot(rc_track_offset+rc["Position(bp)"],rc["Rate(cM/Mb)"],color="#5858FF",zorder=1)
360
+
361
+ ax1.set_zorder(ax4.get_zorder()+1)
362
+ ax1.patch.set_visible(False)
363
+
364
+ if rr_ylabel:
365
+ ax4.set_ylabel("Recombination rate(cM/Mb)")
366
+ if rr_lim!="max":
367
+ ax4.set_ylim(rr_lim[0],rr_lim[1])
368
+ else:
369
+ ax4.set_ylim(0, 1.05 * rc["Rate(cM/Mb)"].max())
370
+ ax4.spines["top"].set_visible(False)
371
+ ax4.spines["top"].set(zorder=1)
372
+ return ax4
373
+
374
+ # -############################################################################################################################################################################
375
+ def _plot_gene_track(
376
+ ax3,
377
+ fig,
378
+ gtf_path,
379
+ region,
380
+ region_flank_factor,
381
+ region_protein_coding,
382
+ region_lead_grid,
383
+ region_lead_grid_line,
384
+ lead_snp_is,
385
+ gene_track_start_i,
386
+ gtf_chr_dict,gtf_gene_name,
387
+ track_font_family,
388
+ taf,
389
+ build,
390
+ verbose=True,
391
+ log=Log()):
392
+
393
+ # load gtf
394
+ log.write(" -Loading gtf files from:" + gtf_path, verbose=verbose)
395
+ uniq_gene_region,exons = process_gtf( gtf_path = gtf_path ,
396
+ region = region,
397
+ region_flank_factor = region_flank_factor,
398
+ build=build,
399
+ region_protein_coding=region_protein_coding,
400
+ gtf_chr_dict=gtf_chr_dict,
401
+ gtf_gene_name=gtf_gene_name)
402
+
403
+ n_uniq_stack = uniq_gene_region["stack"].nunique()
404
+ stack_num_to_plot = max(taf[0],n_uniq_stack)
405
+ ax3.set_ylim((-stack_num_to_plot*2-taf[1]*2,2+taf[1]*2))
406
+ ax3.set_yticks([])
407
+ pixels_per_point = 72/fig.dpi
408
+ pixels_per_track = np.abs(ax3.transData.transform([0,0])[1] - ax3.transData.transform([0,1])[1])
409
+ font_size_in_pixels= taf[2] * pixels_per_track
410
+ font_size_in_points = font_size_in_pixels * pixels_per_point
411
+ linewidth_in_points= pixels_per_track * pixels_per_point
412
+ log.write(" -plotting gene track..", verbose=verbose)
413
+
414
+ sig_gene_name = "Undefined"
415
+ sig_gene_name2 = "Undefined"
416
+ texts_to_adjust_left = []
417
+ texts_to_adjust_middle = []
418
+ texts_to_adjust_right = []
419
+ for index,row in uniq_gene_region.iterrows():
420
+
421
+ gene_color="#020080"
422
+ #if row[6][0]=="+":
423
+ if row["strand"][0]=="+":
424
+ gene_anno = row["name"] + "->"
425
+ else:
426
+ gene_anno = "<-" + row["name"]
427
+
428
+
429
+ sig_gene_names=[]
430
+ sig_gene_lefts=[]
431
+ sig_gene_rights=[]
432
+ for lead_snp_i in lead_snp_is:
433
+ if region_lead_grid is True and lead_snp_i > gene_track_start_i+row["start"] and lead_snp_i < gene_track_start_i+row["end"] :
434
+ gene_color=region_lead_grid_line["color"]
435
+ sig_gene_names.append(row["name"])
436
+ sig_gene_lefts.append(gene_track_start_i+row["start"])
437
+ sig_gene_rights.append(gene_track_start_i+row["end"])
438
+
439
+ # plot gene line
440
+ ax3.plot((gene_track_start_i+row["start"],gene_track_start_i+row["end"]),
441
+ (row["stack"]*2,row["stack"]*2),color=gene_color,linewidth=linewidth_in_points/10)
442
+
443
+
444
+ # plot gene name
445
+ if row["end"] >= region[2]:
446
+ #right side
447
+ texts_to_adjust_right.append(ax3.text(x=gene_track_start_i+region[2],
448
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="right",va="center",color="black",style='italic', size=font_size_in_points,family=track_font_family))
449
+
450
+ elif row["start"] <= region[1] :
451
+ #left side
452
+ texts_to_adjust_left.append(ax3.text(x=gene_track_start_i+region[1],
453
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="left",va="center",color="black",style='italic', size=font_size_in_points,family=track_font_family))
454
+ else:
455
+ texts_to_adjust_middle.append(ax3.text(x=(gene_track_start_i+row["start"]+gene_track_start_i+row["end"])/2,
456
+ y=row["stack"]*2+taf[4],s=gene_anno,ha="center",va="center",color="black",style='italic',size=font_size_in_points,family=track_font_family))
457
+
458
+ # plot exons
459
+ for index,row in exons.iterrows():
460
+ exon_color="#020080"
461
+ for sig_gene_name, sig_gene_left, sig_gene_right in zip(sig_gene_names,sig_gene_lefts,sig_gene_rights):
462
+ if not pd.isnull(row["name"]):
463
+ if (region_lead_grid is True) and row["name"]==sig_gene_name:
464
+ exon_color = region_lead_grid_line["color"]
465
+ else:
466
+ exon_color="#020080"
467
+ elif gene_track_start_i+row["starts"] > sig_gene_left and gene_track_start_i+row["end"] < sig_gene_right:
468
+ exon_color = region_lead_grid_line["color"]
469
+ else:
470
+ exon_color="#020080"
471
+
472
+ ax3.plot((gene_track_start_i+row["start"],gene_track_start_i+row["end"]),
473
+ (row["stack"]*2,row["stack"]*2),linewidth=linewidth_in_points*taf[3],color=exon_color,solid_capstyle="butt")
474
+
475
+ log.write(" -Finished plotting gene track..", verbose=verbose)
476
+
477
+ return ax3,texts_to_adjust_middle
478
+
479
+ # -############################################################################################################################################################################
480
+ # Helpers
481
+ # -############################################################################################################################################################################
482
+ def process_vcf(sumstats,
483
+ vcf_path,
484
+ region,
485
+ region_ref,
486
+ #region_ref_second,
487
+ log,
488
+ verbose,
489
+ pos ,
490
+ nea,
491
+ ea,
492
+ region_ld_threshold,
493
+ vcf_chr_dict,
494
+ tabix):
495
+
496
+ log.write("Start to load reference genotype...", verbose=verbose)
497
+ log.write(" -reference vcf path : "+ vcf_path, verbose=verbose)
498
+
499
+ # load genotype data of the targeted region
500
+ ref_genotype = read_vcf(vcf_path,region=vcf_chr_dict[region[0]]+":"+str(region[1])+"-"+str(region[2]),tabix=tabix)
501
+ if ref_genotype is None:
502
+ log.warning("No data was retrieved. Skipping ...")
503
+ ref_genotype=dict()
504
+ ref_genotype["variants/POS"]=np.array([],dtype="int64")
505
+ log.write(" -Retrieving index...", verbose=verbose)
506
+ log.write(" -Ref variants in the region: {}".format(len(ref_genotype["variants/POS"])), verbose=verbose)
507
+ # match sumstats pos and ref pos:
508
+ # get ref index for its first appearance of sumstats pos
509
+ #######################################################################################
510
+ def match_varaint(x):
511
+ # x: "POS,NEA,EA"
512
+ if np.any(ref_genotype["variants/POS"] == x.iloc[0]):
513
+ # position match
514
+ if len(np.where(ref_genotype["variants/POS"] == x.iloc[0] )[0])>1:
515
+ # multiple position matches
516
+ for j in np.where(ref_genotype["variants/POS"] == x.iloc[0])[0]:
517
+ # for each possible match, compare ref and alt
518
+ if x.iloc[1] == ref_genotype["variants/REF"][j]:
519
+ if x.iloc[2] in ref_genotype["variants/ALT"][j]:
520
+ return j
521
+ elif x.iloc[1] in ref_genotype["variants/ALT"][j]:
522
+ if x.iloc[2] == ref_genotype["variants/REF"][j]:
523
+ return j
524
+ return None
525
+ else:
526
+ # single match
527
+ return np.where(ref_genotype["variants/POS"] == x.iloc[0] )[0][0]
528
+ else:
529
+ # no position match
530
+ return None
531
+ log.write(" -Matching variants using POS, NEA, EA ...", verbose=verbose)
532
+ #############################################################################################
533
+ sumstats["REFINDEX"] = sumstats[[pos,nea,ea]].apply(lambda x: match_varaint(x),axis=1)
534
+ #############################################################################################
535
+
536
+ #for loop to add LD information
537
+ #############################################################################################
538
+ for ref_n, region_ref_single in enumerate(region_ref):
539
+
540
+ rsq = "RSQ_{}".format(ref_n)
541
+ ld_single = "LD_{}".format(ref_n)
542
+ lead = "LEAD_{}".format(ref_n)
543
+ sumstats[lead]= 0
544
+
545
+ # get lead variant id and pos
546
+ if region_ref_single is None:
547
+ # if not specified, use lead variant
548
+ lead_id = sumstats["scaled_P"].idxmax()
549
+ else:
550
+ # figure out lead variant
551
+ lead_id = _get_lead_id(sumstats, region_ref_single, log, verbose)
552
+
553
+ if lead_id is None:
554
+ sumstats[rsq] = None
555
+ sumstats[rsq] = sumstats[rsq].astype("float")
556
+ sumstats[ld_single] = 0
557
+ continue
558
+
559
+ lead_pos = sumstats.loc[lead_id,pos]
560
+
561
+ # if lead pos is available:
562
+ if lead_pos in ref_genotype["variants/POS"]:
563
+
564
+ # get ref index for lead snp
565
+ lead_snp_ref_index = match_varaint(sumstats.loc[lead_id,[pos,nea,ea]])
566
+ #lead_snp_ref_index = np.where(ref_genotype["variants/POS"] == lead_pos)[0][0]
567
+
568
+ # non-na other snp index
569
+ other_snps_ref_index = sumstats["REFINDEX"].dropna().astype("int").values
570
+ # get genotype
571
+ lead_snp_genotype = GenotypeArray([ref_genotype["calldata/GT"][lead_snp_ref_index]]).to_n_alt()
572
+ try:
573
+ if len(set(lead_snp_genotype[0]))==1:
574
+ log.warning("The variant is mono-allelic in reference VCF. LD can not be calculated.")
575
+ except:
576
+ pass
577
+ other_snp_genotype = GenotypeArray(ref_genotype["calldata/GT"][other_snps_ref_index]).to_n_alt()
578
+
579
+ log.write(" -Calculating Rsq...", verbose=verbose)
580
+
581
+ if len(other_snp_genotype)>1:
582
+ valid_r2= np.power(rogers_huff_r_between(lead_snp_genotype,other_snp_genotype)[0],2)
583
+ else:
584
+ valid_r2= np.power(rogers_huff_r_between(lead_snp_genotype,other_snp_genotype),2)
585
+ sumstats.loc[~sumstats["REFINDEX"].isna(),rsq] = valid_r2
586
+ else:
587
+ log.write(" -Lead SNP not found in reference...", verbose=verbose)
588
+ sumstats[rsq]=None
589
+
590
+ sumstats[rsq] = sumstats[rsq].astype("float")
591
+ sumstats[ld_single] = 0
592
+
593
+ for index,ld_threshold in enumerate(region_ld_threshold):
594
+ # No data,LD = 0
595
+ # 0, 0.2 LD = 1
596
+ # 1, 0.4 LD = 2
597
+ # 2, 0.6 LD = 3
598
+ # 3, 0.8 LD = 4
599
+ # 4, 1.0 LD = 5
600
+ # lead LD = 6
601
+
602
+ if index==0:
603
+ to_change_color = sumstats[rsq]>-1
604
+ sumstats.loc[to_change_color,ld_single] = 1
605
+ to_change_color = sumstats[rsq]>ld_threshold
606
+ sumstats.loc[to_change_color,ld_single] = index+2
607
+
608
+ sumstats.loc[lead_id,ld_single] = len(region_ld_threshold)+2
609
+
610
+ sumstats.loc[lead_id,lead] = 1
611
+
612
+ ####################################################################################################
613
+ final_shape_col = "SHAPE"
614
+ final_ld_col = "LD"
615
+ final_rsq_col = "RSQ"
616
+
617
+ sumstats[final_ld_col] = 0
618
+ sumstats[final_shape_col] = 1
619
+ sumstats[final_rsq_col] = 0.0
620
+
621
+ for i in range(len(region_ref)):
622
+ ld_single = "LD_{}".format(i)
623
+ current_rsq = "RSQ_{}".format(i)
624
+ a_ngt_b = sumstats[final_rsq_col] < sumstats[current_rsq]
625
+ #set levels with interval=100
626
+ sumstats.loc[a_ngt_b, final_ld_col] = 100 * (i+1) + sumstats.loc[a_ngt_b, ld_single]
627
+ sumstats.loc[a_ngt_b, final_rsq_col] = sumstats.loc[a_ngt_b, current_rsq]
628
+ sumstats.loc[a_ngt_b, final_shape_col] = i + 1
629
+
630
+ ####################################################################################################
631
+ log.write("Finished loading reference genotype successfully!", verbose=verbose)
632
+ return sumstats
633
+
634
+ # -############################################################################################################################################################################
635
+
636
+ def process_gtf(gtf_path,
637
+ region,
638
+ region_flank_factor,
639
+ build,
640
+ region_protein_coding,
641
+ gtf_chr_dict,
642
+ gtf_gene_name):
643
+ #loading
644
+
645
+ # chr to string datatype using gtf_chr_dict
646
+ to_query_chrom = gtf_chr_dict[region[0]]
647
+
648
+ # loading gtf data
649
+ if gtf_path =="default" or gtf_path =="ensembl":
650
+
651
+ gtf = get_gtf(chrom=to_query_chrom, build=build, source="ensembl")
652
+
653
+ elif gtf_path =="refseq":
654
+
655
+ gtf = get_gtf(chrom=to_query_chrom, build=build, source="refseq")
656
+
657
+ else:
658
+ # if user-provided gtf
659
+ #gtf = pd.read_csv(gtf_path,sep="\t",header=None, comment="#", low_memory=False,dtype={0:"string"})
660
+ gtf = read_gtf(gtf_path)
661
+ gtf = gtf.loc[gtf["seqname"]==gtf_chr_dict[region[0]],:]
662
+
663
+ # filter in region
664
+ genes_1mb = gtf.loc[(gtf["seqname"]==to_query_chrom)&(gtf["start"]<region[2])&(gtf["end"]>region[1]),:].copy()
665
+
666
+ # extract biotype
667
+ #genes_1mb.loc[:,"gene_biotype"] = genes_1mb[8].str.extract(r'gene_biotype "([\w\.\_-]+)"')
668
+
669
+ # extract gene name
670
+ if gtf_gene_name is None:
671
+ if gtf_path=="refseq":
672
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"').astype("string")
673
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_id"]
674
+ elif gtf_path =="default" or gtf_path =="ensembl":
675
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_name "([\w\.-]+)"').astype("string")
676
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_name"]
677
+ else:
678
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"').astype("string")
679
+ genes_1mb.loc[:,"name"] = genes_1mb["gene_id"]
680
+ else:
681
+ #pattern = r'{} "([\w\.-]+)"'.format(gtf_gene_name)
682
+ #genes_1mb.loc[:,"name"] = genes_1mb[8].str.extract(pattern).astype("string")
683
+ genes_1mb.loc[:,"name"] = genes_1mb[gtf_gene_name]
684
+
685
+ # extract gene
686
+ #genes_1mb.loc[:,"gene"] = genes_1mb[8].str.extract(r'gene_id "([\w\.-]+)"')
687
+ genes_1mb.loc[:,"gene"] = genes_1mb["gene_id"]
688
+
689
+
690
+ # extract protein coding gene
691
+ if region_protein_coding is True:
692
+ #genes_1mb = genes_1mb.loc[genes_1mb["gene_biotype"]=="protein_coding",:].copy()
693
+ pc_genes_1mb_list = genes_1mb.loc[(genes_1mb["feature"]=="gene")& (genes_1mb["gene_biotype"]=="protein_coding") & (genes_1mb["name"]!=""),"name"].values
694
+ genes_1mb = genes_1mb.loc[(genes_1mb["feature"].isin(["exon","gene"])) & (genes_1mb["name"].isin(pc_genes_1mb_list)),:]
695
+ # extract exon
696
+ exons = genes_1mb.loc[genes_1mb["feature"]=="exon",:].copy()
697
+
698
+ #uniq genes
699
+ ## get all record with 2nd column == gene
700
+ #uniq_gene_region = genes_1mb.loc[genes_1mb[2]=="gene",:].copy()
701
+ uniq_gene_region = genes_1mb.loc[genes_1mb["feature"]=="gene",:].copy()
702
+
703
+ ## extract region + flank
704
+ flank = region_flank_factor * (region[2] - region[1])
705
+
706
+ ## get left and right boundary
707
+ #uniq_gene_region["left"] = uniq_gene_region[3]-flank
708
+ #uniq_gene_region["right"] = uniq_gene_region[4]+flank
709
+ #
710
+ uniq_gene_region["left"] = uniq_gene_region["start"]-flank
711
+ uniq_gene_region["right"] = uniq_gene_region["end"]+flank
712
+
713
+ # arrange gene track
714
+ stack_dic = assign_stack(uniq_gene_region.sort_values(["start"]).loc[:,["name","left","right"]])
715
+
716
+ # map gene to stack and add stack column : minus stack
717
+ uniq_gene_region["stack"] = -uniq_gene_region["name"].map(stack_dic)
718
+ exons.loc[:,"stack"] = -exons.loc[:,"name"].map(stack_dic)
719
+
720
+ # return uniq_gene_region (gene records with left and right boundary)
721
+ # return exon records with stack number
722
+ return uniq_gene_region, exons
723
+
724
+
725
+ # -############################################################################################################################################################################
726
+ def assign_stack(uniq_gene_region):
727
+
728
+ stacks=[] ## stack : gene track
729
+ stack_dic={} # mapping gene name to stack
730
+
731
+ for index,row in uniq_gene_region.iterrows():
732
+ if len(stacks)==0:
733
+ # add first entry
734
+ stacks.append([(row["left"],row["right"])])
735
+ stack_dic[row["name"]] = 0
736
+ else:
737
+ for i in range(len(stacks)):
738
+ for j in range(len(stacks[i])):
739
+ # if overlap
740
+ if (row["left"]>stacks[i][j][0] and row["left"]<stacks[i][j][1]) or (row["right"]>stacks[i][j][0] and row["right"]<stacks[i][j][1]):
741
+ # if not last stack : break
742
+ if i<len(stacks)-1:
743
+ break
744
+ # if last stack : add a new stack
745
+ else:
746
+ stacks.append([(row["left"],row["right"])])
747
+ stack_dic[row["name"]] = i+1
748
+ break
749
+ # if no overlap
750
+ else:
751
+ # not last in a stack
752
+ if j<len(stacks[i])-1:
753
+ #if in the middle
754
+ if row["left"]>stacks[i][j][1] and row["right"]<stacks[i][j+1][0]:
755
+ stacks[i].insert(j+1,(row["left"],row["right"]))
756
+ stack_dic[row["name"]] = i
757
+ break
758
+ #last one in a stack
759
+ elif row["left"]>stacks[i][j][1]:
760
+ stacks[i].append((row["left"],row["right"]))
761
+ stack_dic[row["name"]] = i
762
+ break
763
+ if row["name"] in stack_dic.keys():
764
+ break
765
+ return stack_dic
766
+
767
+ def closest_gene(x,data,chrom="CHR",pos="POS",maxiter=20000,step=50):
768
+ gene = data.gene_names_at_locus(contig=x[chrom], position=x[pos])
769
+ if len(gene)==0:
770
+ i=0
771
+ while i<maxiter:
772
+ distance = i*step
773
+ gene_u = data.gene_names_at_locus(contig=x[chrom], position=x[pos]-distance)
774
+ gene_d = data.gene_names_at_locus(contig=x[chrom], position=x[pos]+distance)
775
+ if len(gene_u)>0 and len(gene_d)>0:
776
+ for j in range(0,step,1):
777
+ distance = (i-1)*step
778
+ gene_u = data.gene_names_at_locus(contig=x[chrom], position=x[pos]-distance-j)
779
+ gene_d = data.gene_names_at_locus(contig=x[chrom], position=x[pos]+distance+j)
780
+ if len(gene_u)>0:
781
+ return -distance,",".join(gene_u)
782
+ else:
783
+ return distance,",".join(gene_d)
784
+ elif len(gene_u)>0:
785
+ return -distance,",".join(gene_u)
786
+ elif len(gene_d)>0:
787
+ return +distance,",".join(gene_d)
788
+ else:
789
+ i+=1
790
+ return distance,"intergenic"
791
+ else:
792
+ return 0,",".join(gene)