geney 1.2.22__py2.py3-none-any.whl → 1.2.24__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of geney might be problematic. Click here for more details.

geney/oncosplice.py CHANGED
@@ -1,1279 +1,12 @@
1
- from Bio.Seq import Seq
1
+ import copy
2
+
2
3
  from Bio import pairwise2
3
- from dataclasses import dataclass
4
- from copy import deepcopy
5
4
  import re
6
- import pandas as pd
7
- from pathlib import Path
8
5
  import numpy as np
9
- from geney import config_setup
10
- import networkx as nx
11
- import matplotlib.pyplot as plt
12
- from matplotlib.patches import Rectangle
13
- import seaborn as sns
14
- from collections import namedtuple
15
- print('hellp')
16
- from geney.utils import find_files_by_gene_name, reverse_complement, unload_pickle, contains, unload_json, dump_json #, is_monotonic
17
- from geney.Fasta_segment import Fasta_segment
18
-
19
- #### SpliceAI Modules
20
- import tensorflow as tf
21
- from keras.models import load_model
22
- from pkg_resources import resource_filename
23
- from spliceai.utils import one_hot_encode
24
-
25
- tf.config.threading.set_intra_op_parallelism_threads(1)
26
- tf.config.threading.set_inter_op_parallelism_threads(1)
27
-
28
- sai_paths = ('models/spliceai{}.h5'.format(x) for x in range(1, 6))
29
- sai_models = [load_model(resource_filename('spliceai', x)) for x in sai_paths]
30
-
31
- # Load models
32
- import torch
33
- from pkg_resources import resource_filename
34
- from pangolin.model import *
35
-
36
- pang_model_nums = [0, 1, 2, 3, 4, 5, 6]
37
- pang_models = []
38
- for i in pang_model_nums:
39
- for j in range(1, 6):
40
- model = Pangolin(L, W, AR)
41
- if torch.cuda.is_available():
42
- model.cuda()
43
- weights = torch.load(resource_filename("pangolin","models/final.%s.%s.3" % (j, i)))
44
- else:
45
- weights = torch.load(resource_filename("pangolin","models/final.%s.%s.3" % (j, i)),
46
- map_location=torch.device('cpu'))
47
- model.load_state_dict(weights)
48
- model.eval()
49
- pang_models.append(model)
50
-
51
-
52
- # def is_monotonic(A):
53
- # x, y = [], []
54
- # x.extend(A)
55
- # y.extend(A)
56
- # x.sort()
57
- # y.sort(reverse=True)
58
- # if (x == A or y == A):
59
- # return True
60
- # return False
61
-
62
-
63
- def is_monotonic(A):
64
- return all(x <= y for x, y in zip(A, A[1:])) or all(x >= y for x, y in zip(A, A[1:]))
65
-
66
-
67
- def sai_predict_probs(seq: str, models: list) -> list:
68
- '''
69
- Predicts the donor and acceptor junction probability of each
70
- NT in seq using SpliceAI.
71
-
72
- Let m:=2*sai_mrg_context + L be the input seq length. It is assumed
73
- that the input seq has the following structure:
74
-
75
- seq = |<sai_mrg_context NTs><L NTs><sai_mrg_context NTs>|
76
-
77
- The returned probability matrix is of size 2XL, where
78
- the first row is the acceptor probability and the second row
79
- is the donor probability. These probabilities corresponds to the
80
- middel <L NTs> NTs of the input seq.
81
- '''
82
- x = one_hot_encode(seq)[None, :]
83
- y = np.mean([models[m].predict(x, verbose=0) for m in range(5)], axis=0)
84
- return y[0, :, 1:].T
85
-
86
-
87
- ### Variant Modules
88
- class Mutation:
89
- def __init__(self, mid):
90
- '''
91
-
92
- :param mid: mutation id in the format of gene:chrom:pos:ref:alt
93
- Needs only to store the following properties for a given mutation
94
- gene: the name of the gene
95
- chrom: the chromosome refernece
96
- start: the position of the mutation
97
- file_identifier: some filename that can be used to store related data
98
- vartype: the variant type
99
-
100
- We want to be able to compare mutations based on location.
101
- '''
102
-
103
- self.mut_id = mid
104
-
105
- gene, chrom, pos, ref, alt = mid.split(':')
106
- self.gene = gene
107
- self.chrom = chrom.strip('chr')
108
- self.start = int(pos)
109
-
110
- self.file_identifier = self.mut_id.replace(':', '_')
111
- self.file_identifier_short = f'{self.start}_{ref[:6]}_{alt[:6]}'
112
-
113
- self.ref = ref if ref != '-' else ''
114
- self.alt = alt if alt != '-' else ''
115
-
116
- if len(self.ref) == len(self.alt) == 1:
117
- self.vartype = 'SNP'
118
-
119
- elif len(self.ref) == len(self.alt) > 1:
120
- self.vartype = 'SUB'
121
- elif self.ref and not self.alt:
122
- self.vartype = 'DEL'
123
- elif self.alt and not self.ref:
124
- self.vartype = 'INS'
125
- else:
126
- self.vartype = 'INDEL'
127
-
128
- def __str__(self):
129
- return self.mut_id
130
-
131
- def __repr__(self):
132
- return f"Mutation({self.mut_id})"
133
-
134
- def __lt__(self, other):
135
- return self.start < other.start
136
-
137
- class Variations:
138
- '''
139
- Unlike a mutation, here we have an epistatic set, or a series of mtuations that are separated by '|' characters
140
- For such events we want to store them
141
- '''
142
- def __init__(self, epistatic_set):
143
- self.variants = sorted([Mutation(m) for m in epistatic_set.split('|')])
144
- self.mut_id = epistatic_set
145
- self.start = self.variants[0].start
146
- self.positions = [v.start for v in self.variants]
147
- self.gene = self.variants[0].gene
148
- self.chrom = self.variants[0].chrom.strip('chr')
149
- self.file_identifier = f'{self.gene}_{self.chrom}' + '_' + '_'.join(
150
- [v.file_identifier_short for v in self.variants])
151
- self.range = max(self.positions) - min(self.positions)
152
-
153
- def __str__(self):
154
- return '|'.join([m.mut_id for m in self.variants])
155
-
156
- def __repr__(self):
157
- return f"Variation({', '.join([m.mut_id for m in self.variants])})"
158
-
159
- def __iter__(self):
160
- self.current_index = 0
161
- return self
162
-
163
- def __next__(self):
164
- if self.current_index < len(self.variants):
165
- x = self.variants[self.current_index]
166
- self.current_index += 1
167
- return x
168
- raise StopIteration
169
-
170
- @property
171
- def file_identifier_json(self):
172
- return Path(self.file_identifier + '.json')
173
-
174
- @property
175
- def as_dict(self):
176
- return {m.start: m.alt for m in self.variants}
177
-
178
- def verify(self):
179
- if len(set(self.positions)) != len(self.variants):
180
- return False
181
- return True
182
-
183
-
184
- def generate_mut_variant(seq: str, indices: list, mut: Mutation):
185
- offset = 1 if not mut.ref else 0
186
- check_indices = list(range(mut.start, mut.start + len(mut.ref) + offset))
187
- check1 = all([contains(list(filter((-1).__ne__, indices)), m) for m in check_indices])
188
- if not check1:
189
- print(
190
- f"Mutation {mut} not within transcript bounds: {min(list(filter((-1).__ne__, indices)))} - {max(indices)}.")
191
-
192
- return seq, indices
193
-
194
- rel_start, rel_end = indices.index(mut.start) + offset, indices.index(mut.start) + offset + len(mut.ref)
195
- acquired_seq = seq[rel_start:rel_end]
196
- check2 = acquired_seq == mut.ref
197
- if not check2:
198
- print(f'Reference allele ({mut.ref}) does not match genome_build allele ({acquired_seq}).')
199
-
200
- if len(mut.ref) == len(mut.alt) > 0:
201
- temp_indices = list(range(mut.start, mut.start + len(mut.ref)))
202
- # elif len(mut.ref) > 0 and len(mut.alt) > 0:
203
- # temp_indices = [indices[indices.index(mut.start)] + v / 1000 for v in list(range(0, len(mut.alt)))]
204
- else:
205
- temp_indices = [indices[indices.index(mut.start)] + v / 1000 for v in list(range(1, len(mut.alt) + 1))]
206
-
207
- new_indices = indices[:rel_start] + temp_indices + indices[rel_end:]
208
- new_seq = seq[:rel_start] + mut.alt + seq[rel_end:]
209
-
210
- assert len(new_seq) == len(new_indices), f'Error in preserving sequence lengths during variant modification: {mut}, {len(new_seq)}, {len(new_indices)}'
211
- assert is_monotonic(list(filter((-1).__ne__, new_indices))), f'Modified nucleotide indices are not monotonic.'
212
- return new_seq, new_indices
213
-
214
-
215
-
216
- class Gene:
217
- def __init__(self, gene_name, variation=None, organism='hg38'):
218
- self.gene_name = gene_name
219
- self.gene_id = ''
220
- self.rev = None
221
- self.chrm = ''
222
- self.gene_start = 0
223
- self.gene_end = 0
224
- self.transcripts = {}
225
- self.load_from_file(find_files_by_gene_name(gene_name, organism=organism))
226
- # print(f"In Gene: {variation}")
227
- self.variations = variation
228
- self.primary_tid = None
229
- self.organism = organism
230
- tids = [k for k, v in self.transcripts.items() if v['primary_transcript'] and v['transcript_biotype'] == 'protein_coding']
231
- if tids:
232
- self.primary_tid = tids[0]
233
- else:
234
- self.primary_tid = list(self.transcripts.keys())[0]
235
-
236
- def __repr__(self):
237
- return f'Gene(gene_name={self.gene_name})'
238
-
239
- def __len__(self):
240
- return len(self.transcripts)
241
-
242
- def __str__(self):
243
- return '{gname}, {ntranscripts} transcripts'.format(gname=self.gene_name, ntranscripts=self.__len__())
244
-
245
- def __copy__(self):
246
- cls = self.__class__
247
- result = cls.__new__(cls)
248
- result.__dict__.update(self.__dict__)
249
- return result
250
-
251
- def __deepcopy__(self, memo):
252
- cls = self.__class__
253
- result = cls.__new__(cls)
254
- memo[id(self)] = result
255
- for k, v in self.__dict__.items():
256
- setattr(result, k, deepcopy(v, memo))
257
- return result
258
-
259
- def __getitem__(self, index):
260
- return Transcript(list(self.transcripts.values())[index])
261
-
262
- def load_from_file(self, file_name):
263
- if not file_name.exists():
264
- raise FileNotFoundError(f"File '{file_name}' not found.")
265
- self.load_from_dict(dict_data=unload_pickle(file_name))
266
- return self
267
-
268
- def load_from_dict(self, dict_data=None):
269
- for k, v in dict_data.items():
270
- setattr(self, k, v)
271
- return self
272
-
273
- def transcript(self, tid=None):
274
- if tid is None:
275
- tid = self.primary_tid
276
-
277
- if tid not in self.transcripts:
278
- raise AttributeError(f"Transcript '{tid}' not found in gene '{self.gene_name}'.")
279
- return Transcript(self.transcripts[tid], organism=self.organism, variations=self.variations)
280
-
281
- def run_transcripts(self, primary_transcript=False, protein_coding=False):
282
- for tid, annotations in self.transcripts.items():
283
- if primary_transcript and not annotations['primary_transcript']:
284
- continue
285
- if protein_coding and annotations['transcript_biotype'] != 'protein_coding':
286
- continue
287
-
288
- yield Transcript(self.transcripts[tid], variations=self.variations, organism=self.organism)
289
-
290
-
291
- class Transcript:
292
- def __init__(self, d=None, variations=None, organism='hg38'):
293
- self.transcript_id = None
294
- self.transcript_start = None # transcription
295
- self.transcript_end = None # transcription
296
- self.transcript_upper = None
297
- self.transcript_lower = None
298
- self.transcript_biotype = None # metadata
299
- self.acceptors, self.donors = [], [] # splicing
300
- self.TIS, self.TTS = None, None # translation
301
- self.transcript_seq, self.transcript_indices = '', [] # sequence data
302
- self.rev = None # sequence data
303
- self.chrm = '' # sequence data
304
- self.pre_mrna = '' # sequence data
305
- self.orf = '' # sequence data
306
- self.protein = '' # sequence data
307
- self.log = '' # sequence data
308
- self.primary_transcript = None # sequence data
309
- self.cons_available = False # metadata
310
- self.cons_seq = ''
311
- self.cons_vector = ''
312
- self.variations = None
313
- self.organism = organism
314
- # print(f"Variations: {variations}")
315
- if variations:
316
- self.variations = Variations(variations)
317
-
318
- if d:
319
- self.load_from_dict(d)
320
-
321
-
322
- if self.transcript_biotype == 'protein_coding' and variations is None:
323
- self.generate_protein()
324
-
325
- else:
326
- self.generate_pre_mrna()
327
-
328
- if '*' in self.cons_seq:
329
- self.cons_seq = self.cons_seq.replace('*', '')
330
- self.cons_vector = np.array(self.cons_vector[:-1])
331
-
332
- if self.cons_seq == self.protein and len(self.cons_vector) == len(self.cons_seq):
333
- self.cons_available = True
334
-
335
- if self.cons_available == False:
336
- self.cons_vector = np.ones(len(self.protein))
337
-
338
-
339
- def __repr__(self):
340
- return 'Transcript(transcript_id={tid})'.format(tid=self.transcript_id)
341
-
342
- def __len__(self):
343
- return len(self.transcript_seq)
344
-
345
- def __str__(self):
346
- return 'Transcript {tid}, Transcript Type: ' \
347
- '{protein_coding}, Primary: {primary}'.format(
348
- tid=self.transcript_id, protein_coding=self.transcript_biotype.replace('_', ' ').title(),
349
- primary=self.primary_transcript)
350
-
351
- def __eq__(self, other):
352
- return self.transcript_seq == other.transcript_seq
353
-
354
- def __contains__(self, subvalue):
355
- '''
356
- :param subvalue: the substring to search for in the mature mrna transcript
357
- :return: wehether or not the substring is seen in the mature transcript or not
358
- '''
359
- if isinstance(subvalue, str):
360
- return subvalue in self.transcript_seq
361
- elif isinstance(subvalue, int):
362
- return subvalue in self.transcript_indices
363
- elif isinstance(subvalue, Variations):
364
- return any([self.transcript_lower <= p <= self.transcript_upper for p in subvalue.positions])
365
-
366
- else:
367
- print(
368
- "Pass an integer to check against the span of the gene's coordinates or a string to check against the "
369
- "pre-mRNA sequence.")
370
- return False
371
-
372
-
373
- def __deepcopy__(self, memo):
374
- cls = self.__class__
375
- result = cls.__new__(cls)
376
- memo[id(self)] = result
377
- for k, v in self.__dict__.items():
378
- setattr(result, k, deepcopy(v, memo))
379
- return result
380
-
381
- def load_from_dict(self, data):
382
- '''
383
- :param data: data is a dictionary containing the needed data to construct the transcript
384
- :return: itself
385
- '''
386
- for k, v in data.items(): # add a line here that ensure the dictionary key is a valid item
387
- setattr(self, k, v)
388
-
389
- self.transcript_upper, self.transcript_lower = max(self.transcript_start, self.transcript_end), min(self.transcript_start, self.transcript_end)
390
- self.__arrange_boundaries()#.generate_mature_mrna(inplace=True)
391
- return self
392
-
393
- @property
394
- def exons(self):
395
- '''
396
- :return: a list of tuples where the first position is the acceptor and the second position is the donor
397
- '''
398
- return list(zip([self.transcript_start] + self.acceptors, self.donors + [self.transcript_end]))
399
-
400
- @property
401
- def exons_pos(self):
402
- temp = self.exons
403
- if self.rev:
404
- temp = [(b, a) for a, b in temp[::-1]]
405
- return temp
406
-
407
- @property
408
- def introns(self):
409
- '''
410
- :return: a list of tuples where each first position is a bondary of the first intron, and the second position is the boundary of the end of the intron
411
- '''
412
- return list(zip([v for v in self.donors if v != self.transcript_end],
413
- [v for v in self.acceptors if v != self.transcript_start]))
414
-
415
- @property
416
- def introns_pos(self):
417
- temp = self.introns
418
- if self.rev:
419
- temp = [(b, a) for a, b in temp[::-1]]
420
- return temp
421
-
422
-
423
- def reset_acceptors(self, acceptors):
424
- '''
425
- :param acceptors: resetting and then reordering the list of acceptors or donors
426
- :return: itself
427
- '''
428
- self.acceptors = acceptors
429
- return self
430
-
431
- def reset_donors(self, donors):
432
- '''
433
- :param donors: resetting and then reordering the list of acceptors or donors
434
- :return: itself
435
- '''
436
- self.donors = donors
437
- return self
438
-
439
- def reset_transcription_start(self, pos):
440
- '''
441
- :param pos: resetting and then reordering the list of acceptors or donors
442
- :return: itself
443
- '''
444
- self.transcription_start = pos
445
- return self
446
-
447
-
448
- def reset_transcription_end(self, pos):
449
- '''
450
- :param pos: resetting and then reordering the list of acceptors or donors
451
- :return: itself
452
- '''
453
- self.transcription_end = pos
454
- return self
455
-
456
- def organize(self):
457
- '''
458
- In the case that transcript boundaries or exon boundaires are changed, this needs to be run to ensure the bluepritns are ordered the the mRNA is reobtained.
459
- :return:
460
- '''
461
- self.__arrange_boundaries().generate_mature_mrna(inplace=True)
462
- self.transcript_upper, self.transcript_lower = max(self.transcript_start, self.transcript_end), min(self.transcript_start, self.transcript_end)
463
-
464
- # if self.__exon_coverage_flag():
465
- # raise ValueError(f"Length of exon coverage does not match transcript length.")
466
- if self.__exon_intron_matchup_flag():
467
- raise ValueError(f"Unequal number of acceptors and donors.")
468
- if self.__exon_intron_order_flag():
469
- raise ValueError(f"Exons / intron order out of position.")
470
- if self.__transcript_boundary_flag():
471
- raise ValueError(f"Transcript boundaries must straddle acceptors and donors.")
472
- return self
473
-
474
- def __arrange_boundaries(self):
475
- # self.acceptors.append(self.transcript_start)
476
- # self.donors.append(self.transcript_end)
477
- self.acceptors = list(set(self.acceptors))
478
- self.donors = list(set(self.donors))
479
- self.acceptors.sort(reverse=self.rev)
480
- self.donors.sort(reverse=self.rev)
481
- return self
482
-
483
-
484
- def __exon_coverage_flag(self):
485
- if sum([abs(a - b) + 1 for a, b in self.exons]) != len(self):
486
- return True
487
- else:
488
- return False
489
-
490
- def __exon_intron_matchup_flag(self):
491
- if len(self.acceptors) != len(self.donors):
492
- return True
493
- else:
494
- return False
495
- def __exon_intron_order_flag(self):
496
- for b in self.exons_pos:
497
- if b[0] > b[1]:
498
- return True
499
- else:
500
- return False
501
- def __transcript_boundary_flag(self):
502
- if len(self.acceptors) == 0 and len(self.donors) == 0:
503
- return False
504
-
505
- if self.transcript_lower > min(self.acceptors + self.donors) or self.transcript_upper < max(self.acceptors + self.donors):
506
- return True
507
- else:
508
- return False
509
-
510
-
511
- @property
512
- def exonic_indices(self):
513
- return [lst for lsts in [list(range(a, b + 1)) for a, b in self.exons_pos] for lst in lsts]
514
-
515
-
516
- # Related to transcript seq generation
517
- def pull_pre_mrna_pos(self):
518
- fasta_obj = Fasta_segment()
519
- return fasta_obj.read_segment_endpoints(config_setup[self.organism]['CHROM_SOURCE'] / f'chr{self.chrm}.fasta',
520
- self.transcript_lower,
521
- self.transcript_upper)
522
-
523
- def generate_pre_mrna_pos(self):
524
- # *_pos functions do not set values into the object.
525
- seq, indices = self.pull_pre_mrna_pos()
526
- if self.variations:
527
- for mutation in self.variations.variants:
528
- # print(f"Implementing {mutation}")
529
- seq, indices = generate_mut_variant(seq, indices, mut=mutation)
530
- return seq, indices
531
-
532
- def generate_pre_mrna(self, inplace=True):
533
- pre_mrna, pre_indices = self.__pos2sense(*self.generate_pre_mrna_pos())
534
- self.pre_mrna = pre_mrna
535
- self.pre_indices = pre_indices
536
- if inplace:
537
- return self
538
- return pre_mrna, pre_indices
539
-
540
- def __pos2sense(self, mrna, indices):
541
- if self.rev:
542
- mrna = reverse_complement(mrna)
543
- indices = indices[::-1]
544
- return mrna, indices
545
-
546
- def __sense2pos(self, mrna, indices):
547
- if self.rev:
548
- mrna = reverse_complement(mrna)
549
- indices = indices[::-1]
550
- return mrna, indices
551
-
552
- def generate_mature_mrna_pos(self, reset=True):
553
- if reset:
554
- pre_seq_pos, pre_indices_pos = self.generate_pre_mrna_pos()
555
- self.pre_mrna, self.pre_indices = self.__pos2sense(pre_seq_pos, pre_indices_pos)
556
- else:
557
- pre_seq_pos, pre_indices_pos = self.__sense2pos(self.pre_mrna, self.pre_indices)
558
-
559
- mature_mrna_pos, mature_indices_pos = '', []
560
- for i, j in self.exons_pos:
561
- rel_start, rel_end = pre_indices_pos.index(i), pre_indices_pos.index(j)
562
- mature_mrna_pos += pre_seq_pos[rel_start:rel_end + 1]
563
- mature_indices_pos.extend(pre_indices_pos[rel_start:rel_end + 1])
564
- return mature_mrna_pos, mature_indices_pos
565
-
566
- def generate_mature_mrna(self, inplace=True):
567
- if inplace:
568
- self.transcript_seq, self.transcript_indices = self.__pos2sense(*self.generate_mature_mrna_pos())
569
- return self
570
- return self.__pos2sense(*self.generate_mature_mrna_pos())
571
-
572
- def generate_protein(self, inplace=True, reset=True):
573
- if reset:
574
- self.generate_mature_mrna()
575
-
576
- if not self.TIS or self.TIS not in self.transcript_indices:
577
- return ''
578
-
579
- rel_start = self.transcript_indices.index(self.TIS)
580
- orf = self.transcript_seq[rel_start:]
581
- first_stop_index = next((i for i in range(0, len(orf) - 2, 3) if orf[i:i + 3] in {"TAG", "TAA", "TGA"}), len(orf)-3)
582
- while first_stop_index % 3 != 0:
583
- first_stop_index -= 1
584
-
585
- orf = orf[:first_stop_index + 3]
586
- protein = str(Seq(orf).translate()).replace('*', '')
587
- if inplace:
588
- self.orf = orf
589
- self.protein = protein
590
- if self.protein != self.cons_seq:
591
- self.cons_available = False
592
- return self
593
- return protein
594
-
595
-
596
-
597
- ## Missplicing construction
598
- def develop_aberrant_splicing(transcript, aberrant_splicing):
599
- exon_starts = {v: 1 for v in transcript.acceptors}
600
- exon_starts.update({transcript.transcript_start: 1})
601
- exon_starts.update({s: v['absolute'] for s, v in aberrant_splicing['missed_acceptors'].items()})
602
- exon_starts.update({s: v['absolute'] for s, v in aberrant_splicing['discovered_acceptors'].items()})
603
-
604
- exon_ends = {v: 1 for v in transcript.donors}
605
- exon_ends.update({transcript.transcript_end: 1})
606
- exon_ends.update({s: v['absolute'] for s, v in aberrant_splicing['missed_donors'].items()})
607
- exon_ends.update({s: v['absolute'] for s, v in aberrant_splicing['discovered_donors'].items()})
608
-
609
- nodes = [SpliceSite(pos=pos, ss_type=0, prob=prob) for pos, prob in exon_ends.items()] + \
610
- [SpliceSite(pos=pos, ss_type=1, prob=prob) for pos, prob in exon_starts.items()]
611
-
612
- nodes = [s for s in nodes if s.prob > 0]
613
- nodes.sort(key=lambda x: x.pos, reverse=transcript.rev)
614
-
615
- G = nx.DiGraph()
616
- G.add_nodes_from([n.pos for n in nodes])
617
-
618
- for i in range(len(nodes)):
619
- trailing_prob, in_between = 0, []
620
- for j in range(i + 1, len(nodes)):
621
- curr_node, next_node = nodes[i], nodes[j]
622
- spread = curr_node.ss_type in in_between
623
- in_between.append(next_node.ss_type)
624
- if curr_node.ss_type != next_node.ss_type:
625
- if spread:
626
- new_prob = next_node.prob - trailing_prob
627
- if new_prob <= 0:
628
- break
629
- G.add_edge(curr_node.pos, next_node.pos)
630
- G.edges[curr_node.pos, next_node.pos]['weight'] = new_prob
631
- trailing_prob += next_node.prob
632
- else:
633
- G.add_edge(curr_node.pos, next_node.pos)
634
- G.edges[curr_node.pos, next_node.pos]['weight'] = next_node.prob
635
- trailing_prob += next_node.prob
636
-
637
- new_paths, prob_sum = {}, 0
638
- for i, path in enumerate(nx.all_simple_paths(G, transcript.transcript_start, transcript.transcript_end)):
639
- curr_prob = path_weight_mult(G, path, 'weight')
640
- prob_sum += curr_prob
641
- new_paths[i] = {
642
- 'acceptors': sorted([p for p in path if p in exon_starts.keys() and p != transcript.transcript_start],
643
- reverse=transcript.rev),
644
- 'donors': sorted([p for p in path if p in exon_ends.keys() and p != transcript.transcript_end],
645
- reverse=transcript.rev),
646
- 'path_weight': curr_prob}
647
-
648
- for i, path in enumerate(nx.all_simple_paths(G, transcript.transcript_end, transcript.transcript_start)):
649
- curr_prob = path_weight_mult(G, path, 'weight')
650
- prob_sum += curr_prob
651
- new_paths[i] = {
652
- 'acceptors': sorted([p for p in path if p in exon_starts.keys() and p != transcript.transcript_start],
653
- reverse=transcript.rev),
654
- 'donors': sorted([p for p in path if p in exon_ends.keys() and p != transcript.transcript_end],
655
- reverse=transcript.rev),
656
- 'path_weight': curr_prob}
657
-
658
-
659
- for i, d in new_paths.items():
660
- d['path_weight'] = round(d['path_weight'] / prob_sum, 3)
661
- new_paths = {k: v for k, v in new_paths.items() if v['path_weight'] > 0.01}
662
- return list(new_paths.values())
663
-
664
-
665
- def path_weight_mult(G, path, weight):
666
- multigraph = G.is_multigraph()
667
- cost = 1
668
- if not nx.is_path(G, path):
669
- raise nx.NetworkXNoPath("path does not exist")
670
- for node, nbr in nx.utils.pairwise(path):
671
- if multigraph:
672
- cost *= min(v[weight] for v in G[node][nbr].values())
673
- else:
674
- cost *= G[node][nbr][weight]
675
- return cost
676
-
677
- @dataclass
678
- class SpliceSite(object):
679
- pos: int
680
- ss_type: int
681
- prob: float
682
-
683
- def __post_init__(self):
684
- pass
685
-
686
- def __lt__(self, other):
687
- return self.pos < other.pos
688
-
689
- def __str__(self):
690
- print(f"({self.ss_type}, {self.pos}, {self.prob})")
691
-
692
-
693
- # Missplicing Detection
694
- def find_ss_changes(ref_dct, mut_dct, known_splice_sites, threshold=0.5):
695
- '''
696
- :param ref_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the reference sequence
697
- :param mut_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the mutated sequence
698
- :param known_splice_sites: the indices (by genomic position) that serve as known splice sites
699
- :param threshold: the threshold for detection (difference between reference and mutated probabilities)
700
- :return: two dictionaries; discovered_pos is a dictionary containing all the positions that meat the threshold for discovery
701
- and deleted_pos containing all the positions that meet the threshold for missing and the condition for missing
702
- '''
703
-
704
- new_dict = {v: mut_dct.get(v, 0) - ref_dct.get(v, 0) for v in
705
- list(set(list(ref_dct.keys()) + list(mut_dct.keys())))}
706
-
707
- discovered_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct[k]), 3)} for k, v in
708
- new_dict.items() if v >= threshold and k not in known_splice_sites} # if (k not in known_splice_sites and v >= threshold) or (v > 0.45)}
709
-
710
- deleted_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct.get(k, 0)), 3)} for k, v in
711
- new_dict.items() if -v >= threshold and k in known_splice_sites} #if k in known_splice_sites and v <= -threshold}
712
-
713
- return discovered_pos, deleted_pos
714
-
715
- def run_spliceai_seq(seq, indices, threshold=0):
716
- seq = 'N' * 5000 + seq + 'N' * 5000
717
- ref_seq_probs_temp = sai_predict_probs(seq, sai_models)
718
- ref_seq_acceptor_probs, ref_seq_donor_probs = ref_seq_probs_temp[0, :], ref_seq_probs_temp[1, :]
719
- acceptor_indices = {a: b for a, b in list(zip(indices, ref_seq_acceptor_probs)) if b >= threshold}
720
- donor_indices = {a: b for a, b in list(zip(indices, ref_seq_donor_probs)) if b >= threshold}
721
- return acceptor_indices, donor_indices
722
-
723
-
724
- def pang_one_hot_encode(seq):
725
- IN_MAP = np.asarray([[0, 0, 0, 0],
726
- [1, 0, 0, 0],
727
- [0, 1, 0, 0],
728
- [0, 0, 1, 0],
729
- [0, 0, 0, 1]])
730
- seq = seq.upper().replace('A', '1').replace('C', '2')
731
- seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
732
- seq = np.asarray(list(map(int, list(seq))))
733
- return IN_MAP[seq.astype('int8')]
734
-
735
-
736
- def get_pos_seq_indices(t):
737
- seq, indices = t.pre_mrna, t.pre_indices
738
- if t.rev:
739
- return reverse_complement(seq), indices[::-1]
740
- else:
741
- return seq, indices
742
-
743
-
744
- def pangolin_predict_probs(true_seq, models):
745
- # print(f"Running pangolin on: {true_seq}")
746
- model_nums = [0, 2, 4, 6]
747
- INDEX_MAP = {0: 1, 1: 2, 2: 4, 3: 5, 4: 7, 5: 8, 6: 10, 7: 11}
748
-
749
- seq = 'N'*5000 + true_seq + 'N'*5000
750
- acceptor_dinucleotide = np.array([true_seq[i - 2:i] == 'AG' for i in range(len(true_seq))])
751
- donor_dinucleotide = np.array([true_seq[i + 1:i + 3] == 'GT' for i in range(len(true_seq))])
752
-
753
- seq = pang_one_hot_encode(seq).T
754
- seq = torch.from_numpy(np.expand_dims(seq, axis=0)).float()
755
-
756
- if torch.cuda.is_available():
757
- seq = seq.to(torch.device("cuda"))
758
-
759
- scores = []
760
- for j, model_num in enumerate(model_nums):
761
- score = []
762
- # Average across 5 models
763
- for model in models[5 * j:5 * j + 5]:
764
- with torch.no_grad():
765
- score.append(model(seq)[0][INDEX_MAP[model_num], :].cpu().numpy())
766
-
767
- scores.append(np.mean(score, axis=0))
768
-
769
- splicing_pred = np.array(scores).max(axis=0)
770
- donor_probs = [splicing_pred[i] * donor_dinucleotide[i] for i in range(len(true_seq))]
771
- acceptor_probs = [splicing_pred[i] * acceptor_dinucleotide[i] for i in range(len(true_seq))]
772
- return donor_probs[5000:-5000], acceptor_probs[5000:-5000]
773
-
774
-
775
- def find_transcript_missplicing(mutations, ref_transcript, var_transcript, context=7500, threshold=0.5,
776
- engine='spliceai'):
777
- positions = mutations.positions
778
- end_positions = [m.start + len(m.ref) for m in mutations.variants]
779
- positions.extend(end_positions)
780
- center = int(np.mean(positions) // 1)
781
-
782
- seq_start_pos, seq_end_pos = center - context, center + context
783
- transcript_start, transcript_end, rev = ref_transcript.transcript_lower, ref_transcript.transcript_upper, ref_transcript.rev
784
-
785
- # Generate reference sequence data
786
- ref_seq, ref_indices = get_pos_seq_indices(ref_transcript)
787
- center_index = ref_indices.index(center)
788
- start_cutoff = ref_indices.index(seq_start_pos) if seq_start_pos in ref_indices else 0
789
- end_cutoff = ref_indices.index(seq_end_pos) if seq_end_pos in ref_indices else len(ref_indices)
790
- start_pad, end_pad = max(0, context - (center_index - start_cutoff)), max(0, context - (end_cutoff - center_index))
791
- ref_seq = 'N' * start_pad + ref_seq[start_cutoff:end_cutoff] + 'N' * end_pad
792
- ref_indices = [-1] * start_pad + ref_indices[start_cutoff:end_cutoff] + [-1] * end_pad
793
-
794
- # Generate mutation sequence data
795
- mut_seq, mut_indices = get_pos_seq_indices(var_transcript)
796
- start_cutoff = mut_indices.index(seq_start_pos) if seq_start_pos in mut_indices else 0
797
- end_cutoff = mut_indices.index(seq_end_pos) if seq_end_pos in mut_indices else len(mut_indices)
798
- start_pad, end_pad = max(0, context - (center_index - start_cutoff)), max(0, context - (end_cutoff - center_index))
799
- mut_seq = 'N' * start_pad + mut_seq[start_cutoff:end_cutoff] + 'N' * end_pad
800
- mut_indices = [-1] * start_pad + mut_indices[start_cutoff:end_cutoff] + [-1] * end_pad
801
- # print(f"Mut and Ref are equal: {mut_seq == ref_seq}")
802
-
803
- copy_mut_indices = mut_indices.copy()
804
-
805
- if rev:
806
- ref_seq = reverse_complement(ref_seq)
807
- mut_seq = reverse_complement(mut_seq)
808
- ref_indices = ref_indices[::-1]
809
- mut_indices = mut_indices[::-1]
810
-
811
- if engine == 'spliceai':
812
- ref_seq_probs_temp = sai_predict_probs(ref_seq, sai_models)
813
- mut_seq_probs_temp = sai_predict_probs(mut_seq, sai_models)
814
- ref_seq_acceptor_probs, ref_seq_donor_probs = ref_seq_probs_temp[0, :], ref_seq_probs_temp[1, :]
815
- mut_seq_acceptor_probs, mut_seq_donor_probs = mut_seq_probs_temp[0, :], mut_seq_probs_temp[1, :]
816
- ref_indices, mut_indices = ref_indices[5000:-5000], mut_indices[5000:-5000]
817
-
818
-
819
- elif engine == 'pangolin':
820
- ref_seq_donor_probs, ref_seq_acceptor_probs = pangolin_predict_probs(ref_seq, models=pang_models)
821
- mut_seq_donor_probs, mut_seq_acceptor_probs = pangolin_predict_probs(mut_seq, models=pang_models)
822
- ref_indices, mut_indices = ref_indices[5000:-5000], mut_indices[5000:-5000]
823
-
824
- else:
825
- raise ValueError(f"{engine} not implemented")
826
-
827
- visible_donors = np.intersect1d(ref_transcript.donors, ref_indices)
828
- visible_acceptors = np.intersect1d(ref_transcript.acceptors, ref_indices)
829
- # print(ref_indices.index(visible_donors[0]), ref_seq_donor_probs[ref_indices.index(visible_donors[0])], mut_seq_donor_probs[mut_indices.index(visible_donors[0])])
830
-
831
- # print(len(ref_seq_donor_probs), len(ref_seq_acceptor_probs), len(mut_seq_donor_probs), len(mut_seq_acceptor_probs), len(ref_indices), len(mut_indices))
832
- # print(ref_seq_donor_probs)
833
-
834
- assert len(ref_indices) == len(ref_seq_acceptor_probs), 'Reference pos not the same'
835
- assert len(mut_indices) == len(mut_seq_acceptor_probs), 'Mut pos not the same'
836
-
837
- iap, dap = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_acceptor_probs))},
838
- {p: v for p, v in list(zip(mut_indices, mut_seq_acceptor_probs))},
839
- visible_acceptors,
840
- threshold=threshold)
841
-
842
- assert len(ref_indices) == len(ref_seq_donor_probs), 'Reference pos not the same'
843
- assert len(mut_indices) == len(mut_seq_donor_probs), 'Mut pos not the same'
844
-
845
- idp, ddp = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_donor_probs))},
846
- {p: v for p, v in list(zip(mut_indices, mut_seq_donor_probs))},
847
- visible_donors,
848
- threshold=threshold)
849
-
850
- ref_acceptors = {a: b for a, b in list(zip(ref_indices, ref_seq_acceptor_probs))}
851
- ref_donors = {a: b for a, b in list(zip(ref_indices, ref_seq_donor_probs))}
852
-
853
- lost_acceptors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_acceptors[p]), 3)} for p in
854
- visible_acceptors if p not in mut_indices and p not in dap}
855
- lost_donors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_donors[p]), 3)} for p in visible_donors
856
- if p not in mut_indices and p not in ddp}
857
- dap.update(lost_acceptors)
858
- ddp.update(lost_donors)
859
-
860
- missplicing = {'missed_acceptors': dap, 'missed_donors': ddp, 'discovered_acceptors': iap, 'discovered_donors': idp}
861
- missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
862
- temp = {outk: {int(k) if k.is_integer() else k: v for k, v in outv.items()} for outk, outv in missplicing.items()}
863
- # print(temp)
864
- return temp
865
-
866
-
867
- # def run_spliceai_transcript(mutations, transcript_data, sai_mrg_context=5000, min_coverage=2500, sai_threshold=0.5, engine='spliceai'):
868
- # positions = mutations.positions
869
- # end_positions = [m.start + len(m.ref) for m in mutations.variants]
870
- # positions.extend(end_positions)
871
- #
872
- # seq_start_pos = min(positions) - sai_mrg_context - min_coverage
873
- # seq_end_pos = max(positions) + sai_mrg_context + min_coverage
874
- #
875
- # fasta_obj = Fasta_segment()
876
- # ref_seq, ref_indices = fasta_obj.read_segment_endpoints(
877
- # config_setup[transcript_data.organism]['CHROM_SOURCE'] / f'chr{mutations.chrom}.fasta',
878
- # seq_start_pos,
879
- # seq_end_pos)
880
- #
881
- # transcript_start, transcript_end, rev = transcript_data.transcript_lower, transcript_data.transcript_upper, transcript_data.rev
882
- #
883
- # # visible_donors = np.intersect1d(transcript_data.donors, ref_indices)
884
- # # visible_acceptors = np.intersect1d(transcript_data.acceptors, ref_indices)
885
- #
886
- # start_pad = ref_indices.index(transcript_start) if transcript_start in ref_indices else 0
887
- # end_cutoff = ref_indices.index(transcript_end) if transcript_end in ref_indices else len(ref_indices)
888
- # end_pad = len(ref_indices) - end_cutoff
889
- # ref_seq = 'N' * start_pad + ref_seq[start_pad:end_cutoff] + 'N' * end_pad
890
- # ref_indices = [-1] * start_pad + ref_indices[start_pad:end_cutoff] + [-1] * end_pad
891
- # mut_seq, mut_indices = ref_seq, ref_indices
892
- #
893
- # for mut in mutations:
894
- # mut_seq, mut_indices = generate_mut_variant(seq=mut_seq, indices=mut_indices, mut=mut)
895
- #
896
- # ref_indices = ref_indices[sai_mrg_context:-sai_mrg_context]
897
- # mut_indices = mut_indices[sai_mrg_context:-sai_mrg_context]
898
- # copy_mut_indices = mut_indices.copy()
899
- #
900
- # visible_donors = np.intersect1d(transcript_data.donors, ref_indices)
901
- # visible_acceptors = np.intersect1d(transcript_data.acceptors, ref_indices)
902
- #
903
- # if rev:
904
- # ref_seq = reverse_complement(ref_seq)
905
- # mut_seq = reverse_complement(mut_seq)
906
- # ref_indices = ref_indices[::-1]
907
- # mut_indices = mut_indices[::-1]
908
- #
909
- # if engine == 'spliceai':
910
- # ref_seq_probs_temp = sai_predict_probs(ref_seq, sai_models)
911
- # mut_seq_probs_temp = sai_predict_probs(mut_seq, sai_models)
912
- # ref_seq_acceptor_probs, ref_seq_donor_probs = ref_seq_probs_temp[0, :], ref_seq_probs_temp[1, :]
913
- # mut_seq_acceptor_probs, mut_seq_donor_probs = mut_seq_probs_temp[0, :], mut_seq_probs_temp[1, :]
914
- #
915
- # elif engine == 'pangolin':
916
- # ref_seq_donor_probs, ref_seq_acceptor_probs = pangolin_predict_probs(ref_seq, pangolin_models=pang_models)
917
- # mut_seq_donor_probs, mut_seq_acceptor_probs = pangolin_predict_probs(mut_seq, pangolin_models=pang_models)
918
- #
919
- # else:
920
- # raise ValueError(f"{engine} not implemented")
921
- #
922
- # assert len(ref_indices) == len(ref_seq_acceptor_probs), 'Reference pos not the same'
923
- # assert len(mut_indices) == len(mut_seq_acceptor_probs), 'Mut pos not the same'
924
- #
925
- # iap, dap = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_acceptor_probs))},
926
- # {p: v for p, v in list(zip(mut_indices, mut_seq_acceptor_probs))},
927
- # visible_acceptors,
928
- # threshold=sai_threshold)
929
- #
930
- # assert len(ref_indices) == len(ref_seq_donor_probs), 'Reference pos not the same'
931
- # assert len(mut_indices) == len(mut_seq_donor_probs), 'Mut pos not the same'
932
- #
933
- # idp, ddp = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_donor_probs))},
934
- # {p: v for p, v in list(zip(mut_indices, mut_seq_donor_probs))},
935
- # visible_donors,
936
- # threshold=sai_threshold)
937
- #
938
- # ref_acceptors = {a: b for a, b in list(zip(ref_indices, ref_seq_acceptor_probs))}
939
- # ref_donors = {a: b for a, b in list(zip(ref_indices, ref_seq_donor_probs))}
940
- #
941
- # lost_acceptors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_acceptors[p]), 3)} for p in visible_acceptors if p not in mut_indices and p not in dap}
942
- # lost_donors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_donors[p]), 3)} for p in visible_donors if p not in mut_indices and p not in ddp}
943
- # dap.update(lost_acceptors)
944
- # ddp.update(lost_donors)
945
- #
946
- # missplicing = {'missed_acceptors': dap, 'missed_donors': ddp, 'discovered_acceptors': iap, 'discovered_donors': idp}
947
- # missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
948
- # return {outk: {int(k) if k.is_integer() else k: v for k, v in outv.items()} for outk, outv in missplicing.items()}
949
-
950
-
951
- # def run_spliceai(mutations, gene_data, sai_mrg_context=5000, min_coverage=2500, sai_threshold=0.5):
952
- # positions = mutations.positions
953
- # seq_start_pos = min(positions) - sai_mrg_context - min_coverage
954
- # seq_end_pos = max(positions) + sai_mrg_context + min_coverage
955
- #
956
- # fasta_obj = Fasta_segment()
957
- # ref_seq, ref_indices = fasta_obj.read_segment_endpoints(
958
- # config_setup['CHROM_SOURCE'] / f'chr{mutations.chrom}.fasta',
959
- # seq_start_pos,
960
- # seq_end_pos)
961
- #
962
- # gene_start, gene_end, rev = gene_data.gene_start, gene_data.gene_end, gene_data.rev
963
- #
964
- # mrna_acceptors = sorted(list(set([lst for lsts in
965
- # [mrna.get('acceptors', []) for mrna in gene_data.transcripts.values() if
966
- # mrna['transcript_biotype'] == 'protein_coding'] for lst in lsts])))
967
- #
968
- # mrna_donors = sorted(list(set([lst for lsts in
969
- # [mrna.get('donors', []) for mrna in gene_data.transcripts.values() if
970
- # mrna['transcript_biotype'] == 'protein_coding'] for lst in lsts])))
971
- #
972
- # visible_donors = np.intersect1d(mrna_donors, ref_indices)
973
- # visible_acceptors = np.intersect1d(mrna_acceptors, ref_indices)
974
- #
975
- # start_pad = ref_indices.index(gene_start) if gene_start in ref_indices else 0
976
- # end_cutoff = ref_indices.index(gene_end) if gene_end in ref_indices else len(ref_indices) # - 1
977
- # end_pad = len(ref_indices) - end_cutoff
978
- # ref_seq = 'N' * start_pad + ref_seq[start_pad:end_cutoff] + 'N' * end_pad
979
- # ref_indices = [-1] * start_pad + ref_indices[start_pad:end_cutoff] + [-1] * end_pad
980
- # mut_seq, mut_indices = ref_seq, ref_indices
981
- #
982
- # for mut in mutations:
983
- # mut_seq, mut_indices = generate_mut_variant(seq=mut_seq, indices=mut_indices, mut=mut)
984
- #
985
- # ref_indices = ref_indices[sai_mrg_context:-sai_mrg_context]
986
- # mut_indices = mut_indices[sai_mrg_context:-sai_mrg_context]
987
- #
988
- # copy_mut_indices = mut_indices.copy()
989
- # if rev:
990
- # ref_seq = reverse_complement(ref_seq)
991
- # mut_seq = reverse_complement(mut_seq)
992
- # ref_indices = ref_indices[::-1]
993
- # mut_indices = mut_indices[::-1]
994
- #
995
- # ref_seq_probs_temp = sai_predict_probs(ref_seq, sai_models)
996
- # mut_seq_probs_temp = sai_predict_probs(mut_seq, sai_models)
997
- #
998
- # ref_seq_acceptor_probs, ref_seq_donor_probs = ref_seq_probs_temp[0, :], ref_seq_probs_temp[1, :]
999
- # mut_seq_acceptor_probs, mut_seq_donor_probs = mut_seq_probs_temp[0, :], mut_seq_probs_temp[1, :]
1000
- #
1001
- # assert len(ref_indices) == len(ref_seq_acceptor_probs), 'Reference pos not the same'
1002
- # assert len(mut_indices) == len(mut_seq_acceptor_probs), 'Mut pos not the same'
1003
- #
1004
- # iap, dap = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_acceptor_probs))},
1005
- # {p: v for p, v in list(zip(mut_indices, mut_seq_acceptor_probs))},
1006
- # visible_acceptors,
1007
- # threshold=sai_threshold)
1008
- #
1009
- # assert len(ref_indices) == len(ref_seq_donor_probs), 'Reference pos not the same'
1010
- # assert len(mut_indices) == len(mut_seq_donor_probs), 'Mut pos not the same'
1011
- #
1012
- # idp, ddp = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_donor_probs))},
1013
- # {p: v for p, v in list(zip(mut_indices, mut_seq_donor_probs))},
1014
- # visible_donors,
1015
- # threshold=sai_threshold)
1016
- #
1017
- # # lost_acceptors = {p: {'absolute': 0, 'delta': -1} for p in gene_data.acceptors if not contains(copy_mut_indices, p)}
1018
- # # lost_donors = {p: {'absolute': 0, 'delta': -1} for p in gene_data.donors if not contains(copy_mut_indices, p)}
1019
- # # dap.update(lost_acceptors)
1020
- # # ddp.update(lost_donors)
1021
- # missplicing = {'missed_acceptors': dap, 'missed_donors': ddp, 'discovered_acceptors': iap, 'discovered_donors': idp}
1022
- # missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
1023
- #
1024
- # return {outk: {int(k) if k.is_integer() else k: v for k, v in outv.items()} for outk, outv in missplicing.items()}
1025
-
1026
-
1027
- class PredictSpliceAI:
1028
- def __init__(self, mutation, gene_data,
1029
- threshold=0.5, force=False, save_results=False, sai_mrg_context=5000, min_coverage=2500, engine='spliceai', organism='hg38'):
1030
- self.modification = mutation
1031
- self.threshold = threshold
1032
- self.transcript_id = gene_data.transcript_id
1033
- self.spliceai_db = config_setup[gene_data.organism]['MISSPLICING_PATH'] / f'spliceai_epistatic'
1034
- self.missplicing = {}
1035
-
1036
- if self.prediction_file_exists() and not force: # need to do a check for the filename length
1037
- self.missplicing = self.load_sai_predictions()
1038
-
1039
- if not self.missplicing:
1040
- # else:
1041
- # if isinstance(gene_data, Gene):
1042
- # self.missplicing = run_spliceai(self.modification, gene_data=gene_data, sai_mrg_context=sai_mrg_context, min_coverage=min_coverage, sai_threshold=0.1)
1043
- # if save_results:
1044
- # self.save_sai_predictions()
1045
- #
1046
- # elif isinstance(gene_data, Transcript):
1047
-
1048
- # self.missplicing = run_spliceai_transcript(self.modification, transcript_data=gene_data, sai_mrg_context=sai_mrg_context, min_coverage=min_coverage, sai_threshold=0.1)
1049
- # print(f"RUNNING: {mutation.mut_id}")
1050
- ref_transcript, var_transcript = Gene(mutation.mut_id.split(':')[0], organism=organism).transcript(gene_data.transcript_id), Gene(mutation.mut_id.split(':')[0], mutation.mut_id, organism=organism).transcript(gene_data.transcript_id)
1051
- # print(f"Second check : {ref_transcript.pre_mrna == var_transcript.pre_mrna}")
1052
- self.missplicing = find_transcript_missplicing(self.modification, ref_transcript, var_transcript, context=sai_mrg_context+min_coverage, threshold=threshold,
1053
- engine=engine)
1054
- if save_results:
1055
- self.save_sai_predictions()
1056
-
1057
-
1058
- def __repr__(self):
1059
- return f'Missplicing({self.modification.mut_id}) --> {self.missplicing}'
1060
-
1061
- def __str__(self):
1062
- return self.aberrant_splicing
1063
- def __bool__(self):
1064
- for event, details in self.aberrant_splicing.items():
1065
- if details:
1066
- return True
1067
- return False
1068
-
1069
- def __eq__(self, alt_splicing):
1070
- flag, _ = check_splicing_difference(self.missplicing, alt_splicing, self.threshold)
1071
- return not flag
1072
-
1073
- def __iter__(self):
1074
- penetrances = [abs(d_in['delta']) for d in self.missplicing.values() for d_in in d.values()] + [0]
1075
- return iter(penetrances)
1076
-
1077
- @property
1078
- def aberrant_splicing(self):
1079
- return self.apply_sai_threshold(self.missplicing, self.threshold)
1080
-
1081
- @property
1082
- def prediction_file(self):
1083
- return self.spliceai_db / self.modification.gene / self.modification.file_identifier_json
1084
-
1085
- def prediction_file_exists(self):
1086
- return self.prediction_file.exists()
1087
-
1088
- def load_sai_predictions(self):
1089
- missplicing = unload_json(self.prediction_file)
1090
- if self.transcript_id in missplicing:
1091
- missplicing = missplicing[self.transcript_id]
1092
- else:
1093
- return {}
1094
-
1095
- missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
1096
- missplicing = {outk: {int(k) if k.is_integer() or 'missed' in outk else k: v for k, v in outv.items()} for
1097
- outk, outv in
1098
- missplicing.items()}
1099
- return missplicing
1100
-
1101
- def save_sai_predictions(self):
1102
- self.prediction_file.parent.mkdir(parents=True, exist_ok=True)
1103
- if self.prediction_file_exists():
1104
- missplicing = unload_json(self.prediction_file)
1105
- missplicing[self.transcript_id] = self.missplicing
1106
-
1107
- else:
1108
- missplicing = {self.transcript_id: self.missplicing}
1109
-
1110
- # print(missplicing)
1111
- dump_json(self.prediction_file, missplicing)
1112
-
1113
- def apply_sai_threshold(self, splicing_dict=None, threshold=None):
1114
- splicing_dict = self.missplicing if not splicing_dict else splicing_dict
1115
- threshold = self.threshold if not threshold else threshold
1116
- new_dict = {}
1117
- for event, details in splicing_dict.items():
1118
- for e, d in details.items():
1119
- if abs(d['delta']) >= threshold:
1120
- return splicing_dict
1121
- # new_dict[event] = {} #{k: v for k, v in details.items() if abs(v['delta']) >= threshold}
1122
- return new_dict
1123
-
1124
-
1125
- def apply_sai_threshold_primary(self, splicing_dict=None, threshold=None):
1126
- splicing_dict = self.missplicing if not splicing_dict else splicing_dict
1127
- threshold = self.threshold if not threshold else threshold
1128
- new_dict = {}
1129
- for event, details in splicing_dict.items():
1130
- new_dict_in = {}
1131
- for e, d in details.items():
1132
- if abs(d['delta']) >= threshold:
1133
- new_dict_in[e] = d
1134
- new_dict[event] = new_dict_in
1135
- return new_dict
1136
-
1137
- def get_max_missplicing_delta(self):
1138
- max_delta = 0
1139
- for event, details in self.missplicing.items():
1140
- for e, d in details.items():
1141
- if abs(d['delta']) > max_delta:
1142
- max_delta = abs(d['delta'])
1143
- return max_delta
1144
-
1145
-
1146
- def check_splicing_difference(missplicing1, missplicing2, threshold=None):
1147
- flag = False
1148
- true_differences = {}
1149
- for event in ['missed_acceptors', 'missed_donors']:
1150
- td = {}
1151
- dct1 = missplicing1[event]
1152
- dct2 = missplicing2[event]
1153
- for k in list(set(list(dct1.keys()) + list(dct2.keys()))):
1154
- diff = abs(dct1.get(k, {'delta': 0})['delta']) - abs(dct2.get(k, {'delta': 0})['delta'])
1155
- if abs(diff) >= threshold:
1156
- flag = True
1157
- td[k] = diff
1158
- true_differences[event] = td
1159
-
1160
- for event in ['discovered_acceptors', 'discovered_donors']:
1161
- td = {}
1162
- dct1 = missplicing1[event]
1163
- dct2 = missplicing2[event]
1164
- for k in list(set(list(dct1.keys()) + list(dct2.keys()))):
1165
- diff = abs(dct1.get(k, {'delta': 0})['delta']) - abs(dct2.get(k, {'delta': 0})['delta'])
1166
- if abs(diff) >= threshold:
1167
- flag = True
1168
- td[k] = diff
1169
- true_differences[event] = td
1170
-
1171
- return flag, true_differences
1172
-
1173
-
1174
- # Annotating
1175
- def OncospliceAnnotator(reference_transcript, variant_transcript, mut):
1176
- affected_exon, affected_intron, distance_from_5, distance_from_3 = find_splice_site_proximity(mut,
1177
- reference_transcript)
1178
-
1179
- report = {}
1180
- report['primary_transcript'] = reference_transcript.primary_transcript
1181
- report['transcript_id'] = reference_transcript.transcript_id
1182
- report['mut_id'] = mut.mut_id
1183
- report['cons_available'] = int(reference_transcript.cons_available)
1184
- report['protein_coding'] = reference_transcript.transcript_biotype
1185
-
1186
- report['reference_mrna'] = reference_transcript.transcript_seq
1187
- report['reference_cds_start'] = reference_transcript.TIS
1188
- report['reference_pre_mrna'] = reference_transcript.pre_mrna
1189
- report[
1190
- 'reference_orf'] = reference_transcript.orf # pre_mrna[reference_transcript.transcript_indices.index(reference_transcript.TIS):reference_transcript.transcript_indices.index(reference_transcript.TTS)]
1191
- report['reference_protein'] = reference_transcript.protein
1192
- report['reference_protein_length'] = len(reference_transcript.protein)
1193
-
1194
- report['variant_mrna'] = variant_transcript.transcript_seq
1195
- report['variant_cds_start'] = variant_transcript.TIS
1196
- report[
1197
- 'variant_pre_mrna'] = variant_transcript.pre_mrna # pre_mrna[variant_transcript.transcript_indices.index(variant_transcript.TIS):variant_transcript.transcript_indices.index(variant_transcript.TTS)]
1198
- report['variant_orf'] = variant_transcript.orf
1199
- report['variant_protein'] = variant_transcript.protein
1200
- report['variant_protein_length'] = len(variant_transcript.protein)
1201
-
1202
- descriptions = define_missplicing_events(reference_transcript, variant_transcript)
1203
- # print(descriptions)
1204
- report['exon_changes'] = '|'.join([v for v in descriptions if v])
1205
- report['splicing_codes'] = summarize_missplicing_event(*descriptions)
1206
- report['affected_exon'] = affected_exon
1207
- report['affected_intron'] = affected_intron
1208
- report['mutation_distance_from_5'] = distance_from_5
1209
- report['mutation_distance_from_3'] = distance_from_3
1210
- return report
1211
-
1212
-
1213
- def find_splice_site_proximity(mut, transcript):
1214
-
1215
- for i, (ex_start, ex_end) in enumerate(transcript.exons):
1216
- if min(ex_start, ex_end) <= mut.start <= max(ex_start, ex_end):
1217
- return i + 1, None, abs(mut.start - ex_start), abs(mut.start - ex_end)
1218
-
1219
- for i, (in_start, in_end) in enumerate(transcript.introns):
1220
- if min(in_start, in_end) <= mut.start <= max(in_start, in_end):
1221
- return None, i + 1, abs(mut.start - in_end), abs(mut.start - in_start)
1222
-
1223
- return None, None, np.inf, np.inf
1224
-
1225
- def define_missplicing_events(ref, var):
1226
- ref_introns, ref_exons = ref.introns, ref.exons
1227
- var_introns, var_exons = var.introns, var.exons
1228
-
1229
- num_ref_exons = len(ref_exons)
1230
- num_ref_introns = len(ref_introns)
1231
-
1232
- partial_exon_skipping = ','.join(
1233
- [f'Exon {exon_count + 1}/{num_ref_exons} truncated: {(t1, t2)} --> {(s1, s2)}' for (s1, s2) in var_exons for
1234
- exon_count, (t1, t2) in enumerate(ref_exons)
1235
- if (not ref.rev and ((s1 == t1 and s2 < t2) or (s1 > t1 and s2 == t2)))
1236
- or (ref.rev and ((s1 == t1 and s2 > t2) or (s1 < t1 and s2 == t2)))])
1237
-
1238
- partial_intron_retention = ','.join(
1239
- [f'Intron {intron_count + 1}/{num_ref_introns} partially retained: {(t1, t2)} --> {(s1, s2)}' for (s1, s2)
1240
- in var_introns for intron_count, (t1, t2) in enumerate(ref_introns)
1241
- if (not ref.rev and ((s1 == t1 and s2 < t2) or (s1 > t1 and s2 == t2)))
1242
- or (ref.rev and ((s1 == t1 and s2 > t2) or (s1 < t1 and s2 == t2)))])
1243
-
1244
- exon_skipping = ','.join(
1245
- [f'Exon {exon_count + 1}/{num_ref_exons} skipped: {(t1, t2)}' for exon_count, (t1, t2) in enumerate(ref_exons)
1246
- if t1 not in var.acceptors and t2 not in var.donors])
1247
-
1248
- novel_exons = ','.join([f'Novel Exon: {(t1, t2)}' for (t1, t2) in var_exons if
1249
- t1 not in ref.acceptors and t2 not in ref.donors])
1250
-
1251
- intron_retention = ','.join(
1252
- [f'Intron {intron_count + 1}/{num_ref_introns} retained: {(t1, t2)}' for intron_count, (t1, t2) in
1253
- enumerate(ref_introns)
1254
- if t1 not in var.donors and t2 not in var.acceptors])
1255
-
1256
- return partial_exon_skipping, partial_intron_retention, exon_skipping, novel_exons, intron_retention
1257
-
1258
-
1259
- def summarize_missplicing_event(pes, pir, es, ne, ir):
1260
- event = []
1261
- if pes:
1262
- event.append('PES')
1263
- if es:
1264
- event.append('ES')
1265
- if pir:
1266
- event.append('PIR')
1267
- if ir:
1268
- event.append('IR')
1269
- if ne:
1270
- event.append('NE')
1271
- if len(event) >= 1:
1272
- return ','.join(event)
1273
- # elif len(event) == 1:
1274
- # return event[0]
1275
- else:
1276
- return '-'
6
+ import pandas as pd
7
+ from .splicing_utils import find_transcript_missplicing, develop_aberrant_splicing, Missplicing
8
+ from .seqmat_utils import *
9
+ from .mutation_utils import *
1277
10
 
1278
11
 
1279
12
  ### Scoring
@@ -1518,49 +251,163 @@ def moving_average_conv(vector, window_size, factor=1):
1518
251
 
1519
252
  return np.convolve(vector, np.ones(window_size), mode='same') / window_size
1520
253
 
1521
- def oncosplice(mut_id, sai_threshold=0.5, protein_coding=True, primary_transcript=False, window_length=13, save_spliceai_results=False, force_spliceai=False, organism='hg38', engine='spliceai'):
1522
- mutation = Variations(mut_id)
1523
- # try:
1524
- reference_gene = Gene(mutation.gene, organism=organism)
1525
- # except FileNotFoundError:
1526
- # return pd.DataFrame()
1527
254
 
1528
- reference_gene_proteins = {g.protein: g.transcript_id for g in reference_gene.run_transcripts()}
1529
- mutated_gene = Gene(mutation.gene, mut_id, organism=organism)
255
+
256
+
257
+
258
+ def find_splice_site_proximity(pos, transcript):
259
+
260
+ for i, (ex_start, ex_end) in enumerate(transcript.exons):
261
+ if min(ex_start, ex_end) <= pos <= max(ex_start, ex_end):
262
+ return i + 1, None, abs(pos - ex_start), abs(pos - ex_end)
263
+
264
+ for i, (in_start, in_end) in enumerate(transcript.introns):
265
+ if min(in_start, in_end) <= pos <= max(in_start, in_end):
266
+ return None, i + 1, abs(pos - in_end), abs(pos - in_start)
267
+
268
+ return None, None, np.inf, np.inf
269
+
270
+ def define_missplicing_events(ref, var):
271
+ ref_introns, ref_exons = ref.introns, ref.exons
272
+ var_introns, var_exons = var.introns, var.exons
273
+
274
+ num_ref_exons = len(ref_exons)
275
+ num_ref_introns = len(ref_introns)
276
+
277
+ partial_exon_skipping = ','.join(
278
+ [f'Exon {exon_count + 1}/{num_ref_exons} truncated: {(t1, t2)} --> {(s1, s2)}' for (s1, s2) in var_exons for
279
+ exon_count, (t1, t2) in enumerate(ref_exons)
280
+ if (not ref.rev and ((s1 == t1 and s2 < t2) or (s1 > t1 and s2 == t2)))
281
+ or (ref.rev and ((s1 == t1 and s2 > t2) or (s1 < t1 and s2 == t2)))])
282
+
283
+ partial_intron_retention = ','.join(
284
+ [f'Intron {intron_count + 1}/{num_ref_introns} partially retained: {(t1, t2)} --> {(s1, s2)}' for (s1, s2)
285
+ in var_introns for intron_count, (t1, t2) in enumerate(ref_introns)
286
+ if (not ref.rev and ((s1 == t1 and s2 < t2) or (s1 > t1 and s2 == t2)))
287
+ or (ref.rev and ((s1 == t1 and s2 > t2) or (s1 < t1 and s2 == t2)))])
288
+
289
+ exon_skipping = ','.join(
290
+ [f'Exon {exon_count + 1}/{num_ref_exons} skipped: {(t1, t2)}' for exon_count, (t1, t2) in enumerate(ref_exons)
291
+ if t1 not in var.acceptors and t2 not in var.donors])
292
+
293
+ novel_exons = ','.join([f'Novel Exon: {(t1, t2)}' for (t1, t2) in var_exons if
294
+ t1 not in ref.acceptors and t2 not in ref.donors])
295
+
296
+ intron_retention = ','.join(
297
+ [f'Intron {intron_count + 1}/{num_ref_introns} retained: {(t1, t2)}' for intron_count, (t1, t2) in
298
+ enumerate(ref_introns)
299
+ if t1 not in var.donors and t2 not in var.acceptors])
300
+
301
+ return partial_exon_skipping, partial_intron_retention, exon_skipping, novel_exons, intron_retention
302
+
303
+
304
+ def summarize_missplicing_event(pes, pir, es, ne, ir):
305
+ event = []
306
+ if pes:
307
+ event.append('PES')
308
+ if es:
309
+ event.append('ES')
310
+ if pir:
311
+ event.append('PIR')
312
+ if ir:
313
+ event.append('IR')
314
+ if ne:
315
+ event.append('NE')
316
+ if len(event) >= 1:
317
+ return ','.join(event)
318
+ # elif len(event) == 1:
319
+ # return event[0]
320
+ else:
321
+ return '-'
322
+
323
+
324
+ # Annotating
325
+ def OncospliceAnnotator(reference_transcript, variant_transcript, mut):
326
+ affected_exon, affected_intron, distance_from_5, distance_from_3 = find_splice_site_proximity(mut.indices[0],
327
+ reference_transcript)
328
+
329
+ report = {}
330
+
331
+ report['primary_transcript'] = reference_transcript.primary_transcript
332
+ report['transcript_id'] = reference_transcript.transcript_id
333
+ # report['mut_id'] = mut.mut_id
334
+ report['cons_available'] = int(reference_transcript.cons_available)
335
+ # report['protein_coding'] = reference_transcript.transcript_biotype
336
+
337
+ # report['reference_mrna'] = reference_transcript.transcript_seq
338
+ report['reference_cds_start'] = reference_transcript.TIS
339
+ # report['reference_pre_mrna'] = reference_transcript.pre_mrna
340
+ # report[
341
+ # 'reference_orf'] = reference_transcript.orf # pre_mrna[reference_transcript.transcript_indices.index(reference_transcript.TIS):reference_transcript.transcript_indices.index(reference_transcript.TTS)]
342
+ report['reference_protein'] = reference_transcript.protein
343
+ report['reference_protein_length'] = len(reference_transcript.protein)
344
+
345
+ # report['variant_mrna'] = variant_transcript.transcript_seq
346
+ report['variant_cds_start'] = variant_transcript.TIS
347
+ # report[
348
+ # 'variant_pre_mrna'] = variant_transcript.pre_mrna # pre_mrna[variant_transcript.transcript_indices.index(variant_transcript.TIS):variant_transcript.transcript_indices.index(variant_transcript.TTS)]
349
+ # report['variant_orf'] = variant_transcript.orf
350
+ report['variant_protein'] = variant_transcript.protein
351
+ report['variant_protein_length'] = len(variant_transcript.protein)
352
+
353
+ descriptions = define_missplicing_events(reference_transcript, variant_transcript)
354
+ # print(descriptions)
355
+ report['exon_changes'] = '|'.join([v for v in descriptions if v])
356
+ report['splicing_codes'] = summarize_missplicing_event(*descriptions)
357
+ report['affected_exon'] = affected_exon
358
+ report['affected_intron'] = affected_intron
359
+ report['mutation_distance_from_5'] = distance_from_5
360
+ report['mutation_distance_from_3'] = distance_from_3
361
+ return report
362
+
363
+
364
+ def oncosplice(mut_id, splicing_threshold=0.5, protein_coding=True, primary_transcript=False, window_length=13, organism='hg38', engine='spliceai', domains=None):
365
+ gene = Gene(mut_id.split(':')[0], organism=organism)
366
+ mutation = get_mutation(mut_id, rev=gene.rev)
1530
367
 
1531
368
  results = []
1532
- for variant in mutated_gene.run_transcripts(protein_coding=protein_coding, primary_transcript=primary_transcript):
1533
- reference = reference_gene.transcript(variant.transcript_id)
1534
- if mutation not in reference or reference.protein == '' or len(reference.protein) < window_length:
369
+ for tid, transcript in gene.run_transcripts(protein_coding=protein_coding, primary_transcript=primary_transcript):
370
+ if not transcript.cons_available:
371
+ continue
372
+
373
+ if mutation not in transcript:
374
+ results.append({'transcript_id': transcript.transcript_id})
1535
375
  continue
1536
376
 
1537
- cons_vector = transform_conservation_vector(reference.cons_vector, window=window_length)
1538
- missplicing_obj = PredictSpliceAI(mutation, reference, threshold=sai_threshold, force=force_spliceai, save_results=save_spliceai_results, engine=engine, organism=organism)
1539
- missplicing = missplicing_obj.apply_sai_threshold_primary(threshold=sai_threshold)
377
+ transcript.generate_pre_mrna()
378
+ transcript.cons_vector = transform_conservation_vector(transcript.cons_vector, window=window_length)
379
+ transcript.generate_mature_mrna().generate_protein(inplace=True, domains=domains)
380
+ ref_protein, cons_vector = transcript.protein, transcript.cons_vector
381
+ reference_transcript = copy.deepcopy(transcript)
382
+
383
+ assert len(ref_protein) == len(cons_vector), f"Protein ({len(ref_protein)}) and conservation vector ({len(cons_vector)}) must be same length. {ref_protein}, \n>{cons_vector}\n>{transcript.cons_seq}"
1540
384
 
1541
- for i, new_boundaries in enumerate(develop_aberrant_splicing(variant, missplicing)):
1542
- variant_isoform = deepcopy(variant)
1543
- variant_isoform.reset_acceptors(acceptors=new_boundaries['acceptors']).reset_donors(donors=new_boundaries['donors']).organize().generate_protein()
1544
- alignment = get_logical_alignment(reference.protein, variant_isoform.protein)
385
+ missplicing = Missplicing(find_transcript_missplicing(transcript, mutation, engine=engine), threshold=splicing_threshold)
386
+ transcript.pre_mrna += mutation
387
+
388
+ for i, new_boundaries in enumerate(develop_aberrant_splicing(transcript, missplicing.aberrant_splicing)):
389
+ transcript.acceptors = new_boundaries['acceptors']
390
+ transcript.donors = new_boundaries['donors']
391
+ transcript.generate_mature_mrna().generate_protein()
392
+
393
+ alignment = get_logical_alignment(reference_transcript.protein, transcript.protein)
1545
394
  deleted, inserted = find_indels_with_mismatches_as_deletions(alignment.seqA, alignment.seqB)
1546
- modified_positions = find_modified_positions(len(reference.protein), deleted, inserted)
395
+ modified_positions = find_modified_positions(len(ref_protein), deleted, inserted)
1547
396
  temp_cons = np.convolve(cons_vector * modified_positions, np.ones(window_length)) / window_length
1548
397
  affected_cons_scores = max(temp_cons)
1549
398
  percentile = (
1550
399
  sorted(cons_vector).index(next(x for x in sorted(cons_vector) if x >= affected_cons_scores)) / len(
1551
400
  cons_vector))
1552
401
 
1553
- report = OncospliceAnnotator(reference, variant_isoform, mutation)
1554
- report['original_cons'] = reference.cons_vector
402
+ report = OncospliceAnnotator(reference_transcript, transcript, mutation)
403
+ report['mut_id'] = mut_id
1555
404
  report['oncosplice_score'] = affected_cons_scores
1556
405
  report['percentile'] = percentile
1557
- report['modified_positions'] = modified_positions
1558
- report['cons_vector'] = cons_vector
1559
406
  report['isoform_id'] = i
1560
407
  report['isoform_prevalence'] = new_boundaries['path_weight']
1561
- report['full_missplicing'] = missplicing
1562
- report['missplicing'] = max(missplicing_obj)
1563
- report['reference_resemblance'] = reference_gene_proteins.get(variant_isoform.protein, None)
408
+ report['full_missplicing'] = missplicing.aberrant_splicing
409
+ report['missplicing'] = max(missplicing)
410
+ # report['reference_resemblance'] = reference_gene_proteins.get(variant_isoform.protein, None)
1564
411
  results.append(report)
1565
412
 
1566
413
  report = pd.DataFrame(results)
@@ -1568,264 +415,71 @@ def oncosplice(mut_id, sai_threshold=0.5, protein_coding=True, primary_transcrip
1568
415
 
1569
416
 
1570
417
 
1571
- ### Graphical Stuff
1572
- def create_figure_story(epistasis, to_file=None):
1573
- g = epistasis.split(':')[0]
1574
- out = oncosplice(epistasis, annotate=True)
1575
- out = out[out.cons_available==1]
1576
-
1577
- for _, row in out.iterrows():
1578
- max_length = 0
1579
- pos = 0
1580
- for i, k in row.deletions.items():
1581
- if len(k) > max_length:
1582
- pos = i
1583
- max_length = len(k)
1584
418
 
1585
- if max_length > 5:
1586
- del_reg = [pos, pos + max_length]
1587
- else:
1588
- del_reg = None
419
+ def oncosplice_prototype(mut_id, splicing_threshold=0.5, protein_coding=True, primary_transcript=False, window_length=13, organism='hg38', engine='spliceai', domains=None):
420
+ import requests
421
+ import threading
1589
422
 
1590
- if row.oncosplice_score == 0:
1591
- mutation_loc = None
1592
- else:
1593
- mutation_loc = pos
423
+ def background_request(url, result):
424
+ return {'data': 'success'}
1594
425
 
1595
- plot_conservation(tid=row.transcript_id,
1596
- gene=f'{g}, {row.transcript_id}.{row.isoform}',
1597
- mutation_loc=mutation_loc,
1598
- target_region=del_reg, mut_name='Epistasis',
1599
- domain_annotations=get_annotations(row.transcript_id, 300),
1600
- to_file=to_file)
426
+ gene = Gene(mut_id.split(':')[0], organism=organism)
1601
427
 
428
+ domains = {}
429
+ request_thread = threading.Thread(target=background_request, args=(gene.transcript_ids, domains))
430
+ request_thread.start()
1602
431
 
432
+ mutation = get_mutation(mut_id, rev=gene.rev)
1603
433
 
1604
- def plot_conservation(gene_name, tid, gene='', mutation_loc=None, target_region=None, mut_name='Mutation', domain_annotations=[]):
1605
- """
1606
- Plots conservation vectors with protein domain visualization and Rate4Site scores.
1607
-
1608
- Parameters:
1609
- tid (str): Transcript identifier.
1610
- gene (str): Gene name.
1611
- mutation_loc (int): Position of the mutation.
1612
- target_region (tuple): Start and end positions of the target region.
1613
- mut_name (str): Name of the mutation.
1614
- domain_annotations (list): List of tuples for domain annotations (start, end, label).
1615
- """
1616
- # Access conservation data
1617
- _, cons_vec = unload_pickle(gene_name)['tid']['cons_vector']
1618
-
1619
- if not cons_vec:
1620
- raise ValueError("The conservation vector is empty.")
1621
-
1622
- sns.set_theme(style="white")
1623
- fig, ax = plt.subplots(figsize=(max(15, len(cons_vec)/10), 3)) # Dynamic figure size
1624
-
1625
- # Plotting the conservation vectors in the main plot
1626
- plot_conservation_vectors(ax, cons_vec)
1627
-
1628
- # Setting up primary axis for the main plot
1629
- setup_primary_axis(ax, gene, len(cons_vec))
1630
-
1631
- # Create a separate axes for protein domain visualization
1632
- domain_ax = create_domain_axes(fig, len(cons_vec))
1633
-
1634
- # Draw protein domains
1635
- plot_protein_domains(domain_ax, domain_annotations, len(cons_vec))
1636
-
1637
- # Plotting Rate4Site scores on secondary y-axis
1638
- plot_rate4site_scores(ax, cons_vec)
1639
-
1640
- # Plotting mutation location and target region, if provided
1641
- plot_mutation_and_target_region(ax, mutation_loc, target_region, mut_name)
1642
-
1643
- plt.show()
1644
-
1645
- def plot_conservation_vectors(ax, cons_vec):
1646
- """Plots transformed conservation vectors."""
1647
- temp = transform_conservation_vector(cons_vec, 76) # Larger window
1648
- temp /= max(temp)
1649
- ax.plot(list(range(len(temp))), temp, c='b', label='Estimated Functional Residues')
1650
-
1651
- temp = transform_conservation_vector(cons_vec, 6) # Smaller window
1652
- temp /= max(temp)
1653
- ax.plot(list(range(len(temp))), temp, c='k', label='Estimated Functional Domains')
1654
-
1655
- def setup_primary_axis(ax, gene, length):
1656
- """Configures the primary axis of the plot."""
1657
- ax.set_xlabel(f'AA Position - {gene}', weight='bold')
1658
- ax.set_xlim(0, length)
1659
- ax.set_ylim(0, 1)
1660
- ax.set_ylabel('Relative Importance', weight='bold')
1661
- ax.tick_params(axis='y')
1662
- ax.spines['right'].set_visible(False)
1663
- ax.spines['top'].set_visible(False)
1664
-
1665
- def create_domain_axes(fig, length):
1666
- """Creates an axis for protein domain visualization."""
1667
- domain_ax_height = 0.06
1668
- domain_ax = fig.add_axes([0.125, 0.95, 0.775, domain_ax_height])
1669
- domain_ax.set_xlim(0, length)
1670
- domain_ax.set_xticks([])
1671
- domain_ax.set_yticks([])
1672
- for spine in domain_ax.spines.values():
1673
- spine.set_visible(False)
1674
- return domain_ax
1675
-
1676
- def plot_protein_domains(ax, domain_annotations, length):
1677
- """Plots protein domain annotations."""
1678
- ax.add_patch(Rectangle((0, 0), length, 0.9, facecolor='lightgray', edgecolor='none'))
1679
- for domain in domain_annotations:
1680
- start, end, label = domain
1681
- ax.add_patch(Rectangle((start, 0), end - start, 0.9, facecolor='orange', edgecolor='none', alpha=0.5))
1682
- ax.text((start + end) / 2, 2.1, label, ha='center', va='center', color='black', size=8)
1683
-
1684
- def plot_rate4site_scores(ax, cons_vec):
1685
- """Plots Rate4Site scores on a secondary y-axis."""
1686
- ax2 = ax.twinx()
1687
- c = np.array(cons_vec)
1688
- c = c + abs(min(c))
1689
- c = c/max(c)
1690
- ax2.set_ylim(min(c), max(c)*1.1)
1691
- ax2.scatter(list(range(len(c))), c, color='green', label='Rate4Site Scores', alpha=0.4)
1692
- ax2.set_ylabel('Rate4Site Normalized', color='green', weight='bold')
1693
- ax2.tick_params(axis='y', labelcolor='green')
1694
- ax2.spines['right'].set_visible(True)
1695
- ax2.spines['top'].set_visible(False)
1696
-
1697
- def plot_mutation_and_target_region(ax, mutation_loc, target_region, mut_name):
1698
- """Highlights mutation location and target region, if provided."""
1699
- if mutation_loc is not None:
1700
- ax.axvline(x=mutation_loc, ymax=1, color='r', linestyle='--', alpha=0.7)
1701
- ax.text(mutation_loc, 1.04, mut_name, color='r', weight='bold', ha='center')
1702
-
1703
- if target_region is not None:
1704
- ax.add_patch(Rectangle((target_region[0], 0), target_region[1] - target_region[0], 1, alpha=0.25, facecolor='gray'))
1705
- center_loc = target_region[0] + 0.5 * (target_region[1] - target_region[0])
1706
- ax.text(center_loc, 1.04, 'Deleted Region', ha='center', va='center', color='gray', weight='bold')
1707
-
1708
-
1709
- def merge_overlapping_regions(df):
1710
- """
1711
- Merges overlapping regions in a DataFrame.
1712
-
1713
- Parameters:
1714
- df (pd.DataFrame): DataFrame with columns 'start', 'end', 'name'
1715
-
1716
- Returns:
1717
- List: List of merged regions as namedtuples (start, end, combined_name)
1718
- """
1719
- if df.empty:
1720
- return []
434
+ results = []
435
+ for tid, transcript in gene.run_transcripts(protein_coding=protein_coding, primary_transcript=primary_transcript):
436
+ if not transcript.cons_available:
437
+ continue
1721
438
 
1722
- Region = namedtuple('Region', ['start', 'end', 'combined_name'])
1723
- df = df.sort_values(by='start')
1724
- merged_regions = []
1725
- current_region = None
439
+ if mutation not in transcript:
440
+ results.append({'transcript_id': transcript.transcript_id})
441
+ continue
1726
442
 
1727
- for _, row in df.iterrows():
1728
- start, end, name = row['start'], row['end'], row['name'].replace('_', ' ')
1729
- if current_region is None:
1730
- current_region = Region(start, end, [name])
1731
- elif start <= current_region.end:
1732
- current_region = Region(current_region.start, max(current_region.end, end), current_region.combined_name + [name])
1733
- else:
1734
- merged_regions.append(current_region._replace(combined_name=', '.join(current_region.combined_name)))
1735
- current_region = Region(start, end, [name])
443
+ transcript.generate_pre_mrna()
444
+ transcript.cons_vector = transform_conservation_vector(transcript.cons_vector, window=window_length)
445
+ transcript.generate_mature_mrna().generate_protein(inplace=True, domains=domains)
446
+ ref_protein, cons_vector = transcript.protein, transcript.cons_vector
447
+ reference_transcript = copy.deepcopy(transcript)
1736
448
 
1737
- if current_region:
1738
- merged_regions.append(current_region._replace(combined_name=', '.join(current_region.combined_name)))
449
+ assert len(ref_protein) == len(cons_vector), f"Protein ({len(ref_protein)}) and conservation vector ({len(cons_vector)} must be same length."
1739
450
 
1740
- # Assuming split_text is a function that splits the text appropriately.
1741
- merged_regions = [Region(a, b, split_text(c, 35)) for a, b, c in merged_regions]
1742
- return merged_regions
451
+ missplicing = Missplicing(find_transcript_missplicing(transcript, mutation, engine=engine), threshold=splicing_threshold)
452
+ transcript.pre_mrna += mutation
1743
453
 
454
+ for i, new_boundaries in enumerate(develop_aberrant_splicing(transcript, missplicing.aberrant_splicing)):
455
+ transcript.acceptors = new_boundaries['acceptors']
456
+ transcript.donors = new_boundaries['donors']
457
+ transcript.generate_mature_mrna().generate_protein()
1744
458
 
1745
- def split_text(text, width):
1746
- """
1747
- Splits a text into lines with a maximum specified width.
459
+ alignment = get_logical_alignment(reference_transcript.protein, transcript.protein)
460
+ deleted, inserted = find_indels_with_mismatches_as_deletions(alignment.seqA, alignment.seqB)
461
+ modified_positions = find_modified_positions(len(ref_protein), deleted, inserted)
462
+ temp_cons = np.convolve(cons_vector * modified_positions, np.ones(window_length)) / window_length
463
+ affected_cons_scores = max(temp_cons)
464
+ percentile = (
465
+ sorted(cons_vector).index(next(x for x in sorted(cons_vector) if x >= affected_cons_scores)) / len(
466
+ cons_vector))
1748
467
 
1749
- Parameters:
1750
- text (str): The text to be split.
1751
- width (int): Maximum width of each line.
468
+ report = OncospliceAnnotator(reference_transcript, transcript, mutation)
469
+ report['mut_id'] = mut_id
470
+ report['oncosplice_score'] = affected_cons_scores
471
+ report['percentile'] = percentile
472
+ report['isoform_id'] = i
473
+ report['isoform_prevalence'] = new_boundaries['path_weight']
474
+ report['full_missplicing'] = missplicing.aberrant_splicing
475
+ report['missplicing'] = max(missplicing)
476
+ # report['reference_resemblance'] = reference_gene_proteins.get(variant_isoform.protein, None)
477
+ results.append(report)
1752
478
 
1753
- Returns:
1754
- str: The text split into lines of specified width.
1755
- """
1756
- lines = re.findall('.{1,' + str(width) + '}', text)
1757
- return '\n'.join(lines)
479
+ report = pd.DataFrame(results)
480
+ return report
1758
481
 
1759
- def get_annotations(target_gene, w=500):
1760
- PROTEIN_ANNOTATIONS = {}
1761
- temp = PROTEIN_ANNOTATIONS[(PROTEIN_ANNOTATIONS['Transcript stable ID'] == PROTEIN_ANNOTATIONS[target_gene]) & (PROTEIN_ANNOTATIONS.length < w)].drop_duplicates(subset=['Interpro Short Description'], keep='first')
1762
- return merge_overlapping_regions(temp)
1763
482
 
1764
483
 
1765
- # def plot_conservation(tid, gene='', mutation_loc=None, target_region=None, mut_name='Mutation', domain_annotations=[], to_file=None):
1766
- # _, cons_vec = access_conservation_data(tid)
1767
- #
1768
- # sns.set_theme(style="white")
1769
- # fig, ax = plt.subplots(figsize=(15, 3)) # Adjusted figure size for better layout
1770
- #
1771
- # # Plotting the conservation vectors in the main plot
1772
- # temp = transform_conservation_vector(cons_vec, 76)
1773
- # temp /= max(temp)
1774
- # ax.plot(list(range(len(temp))), temp, c='b', label='Estimated Functional Residues')
1775
- # temp = transform_conservation_vector(cons_vec, 6)
1776
- # temp /= max(temp)
1777
- # ax.plot(list(range(len(temp))), temp, c='k', label='Estimated Functional Domains')
1778
- #
1779
- # # Setting up primary axis for the main plot
1780
- # ax.set_xlabel(f'AA Position - {gene}', weight='bold')
1781
- # ax.set_xlim(0, len(cons_vec))
1782
- # ax.set_ylim(0, 1) # Set y-limit to end at 1
1783
- # ax.set_ylabel('Relative Importance', weight='bold')
1784
- # ax.tick_params(axis='y')
1785
- # ax.spines['right'].set_visible(False)
1786
- # ax.spines['top'].set_visible(False)
1787
- #
1788
- # # Create a separate axes for protein domain visualization above the main plot
1789
- # domain_ax_height = 0.06 # Adjust for thinner protein diagram
1790
- # domain_ax = fig.add_axes([0.125, 0.95, 0.775, domain_ax_height]) # Position higher above the main plot
1791
- # domain_ax.set_xlim(0, len(cons_vec))
1792
- # domain_ax.set_xticks([])
1793
- # domain_ax.set_yticks([])
1794
- # domain_ax.spines['top'].set_visible(False)
1795
- # domain_ax.spines['right'].set_visible(False)
1796
- # domain_ax.spines['left'].set_visible(False)
1797
- # domain_ax.spines['bottom'].set_visible(False)
1798
- #
1799
- # # Draw the full-length protein as a base rectangle
1800
- # domain_ax.add_patch(Rectangle((0, 0), len(cons_vec), 0.9, facecolor='lightgray', edgecolor='none'))
1801
- #
1802
- # # Overlay domain annotations
1803
- # for domain in domain_annotations:
1804
- # start, end, label = domain
1805
- # domain_ax.add_patch(Rectangle((start, 0), end - start, 0.9, facecolor='orange', edgecolor='none', alpha=0.5))
1806
- # domain_ax.text((start + end) / 2, 2.1, label, ha='center', va='center', color='black', size=8)
1807
- #
1808
- # # Plotting Rate4Site scores on secondary y-axis
1809
- # ax2 = ax.twinx()
1810
- # c = np.array(cons_vec)
1811
- # c = c + abs(min(c))
1812
- # c = c/max(c)
1813
- # ax2.set_ylim(min(c), max(c)*1.1)
1814
- # ax2.scatter(list(range(len(c))), c, color='green', label='Rate4Site Scores', alpha=0.4)
1815
- # ax2.set_ylabel('Rate4Site Normalized', color='green', weight='bold')
1816
- # ax2.tick_params(axis='y', labelcolor='green')
1817
- # ax2.spines['right'].set_visible(True)
1818
- # ax2.spines['top'].set_visible(False)
1819
- #
1820
- # # Plotting mutation location and target region
1821
- # if mutation_loc is not None:
1822
- # ax.axvline(x=mutation_loc, ymax=1,color='r', linestyle='--', alpha=0.7)
1823
- # ax.text(mutation_loc, 1.04, mut_name, color='r', weight='bold', ha='center')
1824
- #
1825
- # if target_region is not None:
1826
- # ax.add_patch(Rectangle((target_region[0], 0), target_region[1] - target_region[0], 1, alpha=0.25, facecolor='gray'))
1827
- # center_loc = target_region[0] + 0.5 * (target_region[1] - target_region[0])
1828
- # ax.text(center_loc, 1.04, 'Deleted Region', ha='center', va='center', color='gray', weight='bold')
1829
- #
1830
- # plt.show()
1831
- #
484
+ if __name__ == '__main__':
485
+ pass