geney 1.3.79__py2.py3-none-any.whl → 1.4.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,693 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from .Gene import Gene
5
+ from geney.utils.SeqMats import MutSeqMat
6
+ from collections import defaultdict
7
+
8
+
9
+ def generate_adjacency_list(acceptors, donors, transcript_start, transcript_end, max_distance=50, rev=False):
10
+ # Append the transcript end to donors to allow connection to the end point
11
+ donors.append((transcript_end, 1))
12
+ acceptors = sorted(acceptors, key=lambda x: (x[0], x[1] if not rev else -x[1]), reverse=rev)
13
+ donors = sorted(donors, key=lambda x: (x[0], x[1] if not rev else -x[1]), reverse=rev)
14
+
15
+ # Initialize adjacency list to store downstream connections
16
+ adjacency_list = defaultdict(list)
17
+
18
+ # Connect each donor to the nearest acceptor(s) within the distance threshold
19
+ for d_pos, d_prob in donors:
20
+ running_prob = 1
21
+ for a_pos, a_prob in acceptors:
22
+ correct_orientation = (a_pos > d_pos and not rev) or (a_pos < d_pos and rev)
23
+ distance_valid = abs(a_pos - d_pos) <= max_distance
24
+ if correct_orientation and distance_valid:
25
+ in_between_acceptors = sum([d_pos < a < a_pos for a, _ in acceptors]) if not rev else sum([a_pos < a < d_pos for a, _ in acceptors])
26
+ in_between_donors = sum([d_pos < d < a_pos for d, _ in donors]) if not rev else sum([a_pos < d < d_pos for d, _ in donors])
27
+ in_between_naturals = 0
28
+ if in_between_donors == 0 or in_between_acceptors == 0:
29
+ adjacency_list[(d_pos, 'donor')].append((a_pos, 'acceptor', a_prob))
30
+ running_prob -= a_prob
31
+
32
+ else:
33
+ if running_prob > 0:
34
+ adjacency_list[(d_pos, 'donor')].append((a_pos, 'acceptor', a_prob*running_prob))
35
+ running_prob -= a_prob
36
+ else:
37
+ break
38
+
39
+ # Connect each acceptor to the nearest donor(s) within the distance threshold
40
+ for a_pos, a_prob in acceptors:
41
+ running_prob = 1
42
+ for d_pos, d_prob in donors:
43
+ correct_orientation = (d_pos > a_pos and not rev) or (d_pos < a_pos and rev)
44
+ distance_valid = abs(d_pos - a_pos) <= max_distance
45
+ if correct_orientation and distance_valid:
46
+ in_between_acceptors = sum([a_pos < a < d_pos for a, _ in acceptors]) if not rev else sum([d_pos < a < a_pos for a, _ in acceptors])
47
+ in_between_donors = sum([a_pos < d < d_pos for d, _ in donors]) if not rev else sum([d_pos < d < a_pos for d, _ in donors])
48
+ in_between_naturals = 0
49
+ tag = 'donor' if d_pos != transcript_end else 'transcript_end'
50
+
51
+ if in_between_acceptors == 0:
52
+ adjacency_list[(a_pos, 'acceptor')].append((d_pos, tag, d_prob))
53
+ running_prob -= d_prob
54
+ else:
55
+ if running_prob > 0:
56
+ adjacency_list[(a_pos, 'acceptor')].append((d_pos, tag, d_prob*running_prob))
57
+ running_prob -= d_prob
58
+ else:
59
+ break
60
+
61
+ # Connect the transcript start to the nearest donor(s) within the distance threshold
62
+ running_prob = 1
63
+ for d_pos, d_prob in donors:
64
+ if ((d_pos > transcript_start and not rev) or (d_pos < transcript_start and rev)) and abs(
65
+ d_pos - transcript_start) <= max_distance:
66
+ adjacency_list[(transcript_start, 'transcript_start')].append((d_pos, 'donor', d_prob))
67
+ running_prob -= d_prob
68
+ if running_prob <= 0:
69
+ break
70
+
71
+ # Normalize probabilities to ensure they sum up to 1 for each list of connections
72
+ for k, next_nodes in adjacency_list.items():
73
+ prob_sum = sum([c for a, b, c in next_nodes])
74
+ adjacency_list[k] = [(a, b, round(c / prob_sum, 3)) for a, b, c in next_nodes] if prob_sum > 0 else next_nodes
75
+
76
+ return adjacency_list
77
+
78
+
79
+ def find_all_paths(graph, start, end, path=[], probability=1.0):
80
+ path = path + [start] # Add current node to the path
81
+ if start == end:
82
+ yield path, probability # If end is reached, yield the path and its cumulative probability
83
+ return
84
+ if start not in graph:
85
+ return # If the start node has no outgoing edges, return
86
+
87
+ for (next_node, node_type, prob) in graph[start]:
88
+ # Recur for each connected node, updating the probability
89
+ yield from find_all_paths(graph, (next_node, node_type), end, path, probability * prob)
90
+
91
+
92
+ def prepare_splice_sites(acceptors, donors, aberrant_splicing):
93
+ acceptors = {p: 1 for p in acceptors}
94
+ donors = {p: 1 for p in donors}
95
+
96
+ for p, v in aberrant_splicing[f'missed_donors'].items():
97
+ donors[p] = v['absolute']
98
+
99
+ for p, v in aberrant_splicing[f'discovered_donors'].items():
100
+ donors[p] = v['absolute']
101
+
102
+ for p, v in aberrant_splicing[f'missed_acceptors'].items():
103
+ acceptors[p] = v['absolute']
104
+
105
+ for p, v in aberrant_splicing[f'discovered_acceptors'].items():
106
+ acceptors[p] = v['absolute']
107
+
108
+ acceptors = {int(k): v for k, v in acceptors.items()}
109
+ donors = {int(k): v for k, v in donors.items()}
110
+ return list(acceptors.items()), list(donors.items())
111
+
112
+
113
+ def develop_aberrant_splicing(transcript, aberrant_splicing):
114
+ if not aberrant_splicing:
115
+ yield {'acceptors': transcript.acceptors, 'donors': transcript.donors, 'path_weight': 1}
116
+
117
+ else:
118
+ all_acceptors, all_donors = prepare_splice_sites(transcript.acceptors, transcript.donors, aberrant_splicing.missplicing)
119
+ adj_list = generate_adjacency_list(all_acceptors, all_donors, transcript_start=transcript.transcript_start,
120
+ transcript_end=transcript.transcript_end, rev=transcript.rev,
121
+ max_distance=100000)
122
+ end_node = (transcript.transcript_end, 'transcript_end')
123
+ start_node = (transcript.transcript_start, 'transcript_start')
124
+ for path, prob in find_all_paths(adj_list, start_node, end_node):
125
+ yield {'acceptors': [p[0] for p in path if p[1] == 'acceptor'],
126
+ 'donors': [p[0] for p in path if p[1] == 'donor'], 'path_weight': prob}
127
+
128
+
129
+
130
+ # Missplicing Detection
131
+ def find_ss_changes(ref_dct, mut_dct, known_splice_sites, threshold=0.5):
132
+ '''
133
+ :param ref_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the reference sequence
134
+ :param mut_dct: the spliceai probabilities for each nucleotide (by genomic position) as a dictionary for the mutated sequence
135
+ :param known_splice_sites: the indices (by genomic position) that serve as known splice sites
136
+ :param threshold: the threshold for detection (difference between reference and mutated probabilities)
137
+ :return: two dictionaries; discovered_pos is a dictionary containing all the positions that meat the threshold for discovery
138
+ and deleted_pos containing all the positions that meet the threshold for missing and the condition for missing
139
+ '''
140
+
141
+ new_dict = {v: mut_dct.get(v, 0) - ref_dct.get(v, 0) for v in
142
+ list(set(list(ref_dct.keys()) + list(mut_dct.keys())))}
143
+
144
+ discovered_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct[k]), 3), 'reference': round(ref_dct.get(k, 0), 3)} for k, v in
145
+ new_dict.items() if v >= threshold and k not in known_splice_sites} # if (k not in known_splice_sites and v >= threshold) or (v > 0.45)}
146
+
147
+ deleted_pos = {k: {'delta': round(float(v), 3), 'absolute': round(float(mut_dct.get(k, 0)), 3), 'reference': round(ref_dct.get(k, 0), 3)} for k, v in
148
+ new_dict.items() if -v >= threshold and k in known_splice_sites} #if k in known_splice_sites and v <= -threshold}
149
+
150
+ return discovered_pos, deleted_pos
151
+
152
+
153
+ from typing import Tuple, Dict
154
+
155
+ def run_splicing_engine(seq, engine='spliceai'):
156
+ match engine:
157
+ case 'spliceai':
158
+ from geney.utils.spliceai_utils import sai_predict_probs, sai_models
159
+ acceptor_probs, donor_probs = sai_predict_probs(seq, models=sai_models)
160
+
161
+ case 'pangolin':
162
+ from geney.utils.pangolin_utils import pangolin_predict_probs, pang_models
163
+ donor_probs, acceptor_probs = pangolin_predict_probs(seq, models=pang_models)
164
+
165
+ case _:
166
+ raise ValueError(f"{engine} not implemented")
167
+ return donor_probs, acceptor_probs
168
+
169
+
170
+ def find_transcript_splicing(transcript, engine: str = 'spliceai') -> Tuple[Dict[int, float], Dict[int, float]]:
171
+ """
172
+ Predict splice site probabilities for a given transcript using the specified engine.
173
+ This function uses a padding of 5000 'N's on each side of the transcript sequence
174
+ to align with the model's required context length.
175
+
176
+ Args:
177
+ transcript: An object representing a transcript, expected to have:
178
+ - an `indices` attribute that returns a sequence of positions.
179
+ - a `seq` attribute that returns the sequence string.
180
+ engine (str): The prediction engine to use. Supported: 'spliceai', 'pangolin'.
181
+
182
+ Returns:
183
+ (donor_probs, acceptor_probs) as two dictionaries keyed by position with probability values.
184
+
185
+ Raises:
186
+ ValueError: If an unsupported engine is provided.
187
+ AssertionError: If the length of predicted probabilities does not match the length of indices.
188
+ """
189
+ # Prepare reference sequence with padding
190
+ ref_indices = transcript.indices
191
+ # ref_seq = 'N' * 5000 + transcript.seq + 'N' * 5000
192
+ ref_seq = transcript.seq
193
+ ref_seq_donor_probs, ref_seq_acceptor_probs = run_splicing_engine(ref_seq, engine)
194
+ ref_seq, ref_indices = ref_seq[5000:-5000], ref_indices[5000:-5000]
195
+ # Verify lengths
196
+ assert len(ref_seq_donor_probs) == len(ref_indices), (
197
+ f"Donor probabilities length ({len(ref_seq_donor_probs)}) does not match "
198
+ f"indices length ({len(ref_indices)})."
199
+ )
200
+ assert len(ref_seq_acceptor_probs) == len(ref_indices), (
201
+ f"Acceptor probabilities length ({len(ref_seq_acceptor_probs)}) does not match "
202
+ f"indices length ({len(ref_indices)})."
203
+ )
204
+
205
+ # Create dictionaries and sort them by probability in descending order
206
+ donor_probs = dict(sorted(((i, p) for i, p in zip(ref_indices, ref_seq_donor_probs)),
207
+ key=lambda item: item[1], reverse=True))
208
+
209
+ acceptor_probs = dict(sorted(((i, p) for i, p in zip(ref_indices, ref_seq_acceptor_probs)),
210
+ key=lambda item: item[1], reverse=True))
211
+
212
+ return donor_probs, acceptor_probs
213
+
214
+
215
+ def missplicing_df(mut_id, **kwargs):
216
+ return find_transcript_missplicing(mut_id, **kwargs).max_delta
217
+
218
+
219
+ def find_transcript_missplicing(mut_id, transcript=None, threshold=0.5, engine='spliceai', organism='hg38'):
220
+ gene = Gene.from_file(mut_id.split(':')[0], organism=organism)
221
+ reference_transcript = gene.transcript(transcript) if transcript is not None else gene.transcript()
222
+ if reference_transcript is None:
223
+ return Missplicing()
224
+
225
+
226
+ variant_transcript = reference_transcript.clone()
227
+ mutations = [MutSeqMat.from_mutid(m) for m in mut_id.split('|')]
228
+ mutations = [m for m in mutations if m in reference_transcript]
229
+ if len(mutations) == 0:
230
+ return Missplicing()
231
+
232
+ center = int(np.mean([m.indices[0] for m in mutations]))
233
+ for mutation in mutations:
234
+ variant_transcript.mutate(mutation, inplace=True)
235
+
236
+ missplicing = find_transcript_missplicing_seqs(reference_transcript.pre_mrna.get_context(center, 7500, padding='N'), variant_transcript.pre_mrna.get_context(center, 7500, padding='N'), reference_transcript.donors, reference_transcript.acceptors, threshold=threshold, engine=engine)
237
+ return missplicing
238
+
239
+
240
+
241
+ def find_transcript_missplicing_seqs(ref_seq, var_seq, donors, acceptors, threshold=0.5, engine='spliceai'):
242
+ if ref_seq.seq == var_seq.seq:
243
+ return Missplicing({'missed_acceptors': {}, 'missed_donors': {}, 'discovered_acceptors': {}, 'discovered_donors': {}})
244
+
245
+ ref_seq_donor_probs, ref_seq_acceptor_probs = run_splicing_engine(ref_seq.seq, engine)
246
+ mut_seq_donor_probs, mut_seq_acceptor_probs = run_splicing_engine(var_seq.seq, engine)
247
+ ref_indices = ref_seq.indices[5000:-5000]
248
+ mut_indices = var_seq.indices[5000:-5000]
249
+ visible_donors = np.intersect1d(donors, ref_indices)
250
+ visible_acceptors = np.intersect1d(acceptors, ref_indices)
251
+
252
+ assert len(ref_indices) == len(
253
+ ref_seq_acceptor_probs), f'Reference pos ({len(ref_indices)}) not the same as probs ({len(ref_seq_acceptor_probs)})'
254
+ assert len(mut_indices) == len(
255
+ mut_seq_acceptor_probs), f'Mut pos ({len(mut_indices)}) not the same as probs ({len(mut_seq_acceptor_probs)})'
256
+
257
+ iap, dap = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_acceptor_probs))},
258
+ {p: v for p, v in list(zip(mut_indices, mut_seq_acceptor_probs))},
259
+ visible_acceptors,
260
+ threshold=0.1)
261
+
262
+ assert len(ref_indices) == len(ref_seq_donor_probs), 'Reference pos not the same'
263
+ assert len(mut_indices) == len(mut_seq_donor_probs), 'Mut pos not the same'
264
+
265
+ idp, ddp = find_ss_changes({p: v for p, v in list(zip(ref_indices, ref_seq_donor_probs))},
266
+ {p: v for p, v in list(zip(mut_indices, mut_seq_donor_probs))},
267
+ visible_donors,
268
+ threshold=0.1)
269
+
270
+ ref_acceptors = {a: b for a, b in list(zip(ref_indices, ref_seq_acceptor_probs))}
271
+ ref_donors = {a: b for a, b in list(zip(ref_indices, ref_seq_donor_probs))}
272
+
273
+ lost_acceptors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_acceptors[p]), 3)} for p in
274
+ visible_acceptors if p not in mut_indices and p not in dap}
275
+ lost_donors = {int(p): {'absolute': np.float64(0), 'delta': round(float(-ref_donors[p]), 3)} for p in
276
+ visible_donors
277
+ if p not in mut_indices and p not in ddp}
278
+ dap.update(lost_acceptors)
279
+ ddp.update(lost_donors)
280
+
281
+ missplicing = {'missed_acceptors': dap, 'missed_donors': ddp, 'discovered_acceptors': iap,
282
+ 'discovered_donors': idp}
283
+ missplicing = {outk: {float(k): v for k, v in outv.items()} for outk, outv in missplicing.items()}
284
+ missplicing = {outk: {int(k) if k.is_integer() else k: v for k, v in outv.items()} for outk, outv in
285
+ missplicing.items()}
286
+ return Missplicing(missplicing, threshold=threshold)
287
+
288
+
289
+
290
+ def process_pairwise_epistasis_explicit(mid: str, engine: str = 'spliceai') -> pd.DataFrame:
291
+ """Process pairwise epistasis for a given mutation identifier.
292
+
293
+ Args:
294
+ mid: Mutation identifier string in format "file:...:lower_pos:...:upper_pos:..."
295
+ engine: Splicing engine to use ('spliceai' or 'pangolin')
296
+
297
+ Returns:
298
+ DataFrame containing processed splicing probabilities and epistasis features
299
+ """
300
+ # Parse mutation ID and load gene
301
+ parts = mid.split(':')
302
+ gene_name, lower_pos, upper_pos = parts[0], int(parts[2]), int(parts[6])
303
+
304
+ g = Gene.from_file(gene_name).transcript()
305
+ if g is None:
306
+ return pd.DataFrame()
307
+
308
+ g.generate_pre_mrna()
309
+
310
+ # Calculate bounds with padding, handling reverse strand
311
+ factor = -1 if g.rev else 1
312
+ if g.rev:
313
+ lower_pos, upper_pos = upper_pos, lower_pos
314
+
315
+ lb = lower_pos - (factor * 7500)
316
+ ub = upper_pos + (factor * 7500)
317
+
318
+ # Ensure bounds are within transcript
319
+ if lb not in g.pre_mrna.indices:
320
+ lb = g.pre_mrna.indices.max() if g.rev else g.pre_mrna.indices.min()
321
+ if ub not in g.pre_mrna.indices:
322
+ ub = g.pre_mrna.indices.min() if g.rev else g.pre_mrna.indices.max()
323
+
324
+ # Process all mutation scenarios
325
+ scenarios = ['wild_type'] + mid.split('|') + [mid]
326
+ donor_probs, acceptor_probs = {}, {}
327
+
328
+ for m in scenarios:
329
+ transcript = g.clone().pre_mrna
330
+ if m != 'wild_type':
331
+ mutations = [MutSeqMat.from_mutid(cm) for cm in m.split('|')]
332
+ if g.rev:
333
+ mutations = [mutation.reverse_complement() for mutation in mutations]
334
+ for mutation in mutations:
335
+ if mutation in transcript:
336
+ transcript.mutate(mutation, inplace=True)
337
+
338
+ donors, acceptors = find_transcript_splicing(transcript[lb:ub], engine=engine)
339
+ donor_probs[m] = donors
340
+ acceptor_probs[m] = acceptors
341
+
342
+ # Convert to DataFrames and clean
343
+ acceptors_df = pd.DataFrame.from_dict(acceptor_probs, orient='index')
344
+ donors_df = pd.DataFrame.from_dict(donor_probs, orient='index')
345
+
346
+ # Apply rounding and thresholding
347
+ for df in [acceptors_df, donors_df]:
348
+ df[:] = df.map(
349
+ lambda x: 0 if isinstance(x, (int, float)) and abs(x) < 0.01
350
+ else round(x, 2) if isinstance(x, (int, float))
351
+ else x
352
+ ).round(2)
353
+
354
+ # Keep at least one column even if no variation
355
+ if (df.nunique() > 1).any():
356
+ df.drop(columns=df.columns[df.nunique() <= 1], inplace=True)
357
+ else:
358
+ df.drop(columns=df.columns[1:], inplace=True)
359
+
360
+ # Add epistasis features
361
+ for df in [donors_df, acceptors_df]:
362
+ if df.shape[1] > 0:
363
+ df.loc['residual'] = (df.iloc[3] - df.iloc[0]) - ((df.iloc[1] - df.iloc[0]) + (df.iloc[2] - df.iloc[0]))
364
+ df.loc['deviation1'] = df.iloc[1] - df.iloc[0]
365
+ df.loc['deviation2'] = df.iloc[2] - df.iloc[0]
366
+ df.loc['total_deviation'] = df.iloc[3] - df.iloc[0]
367
+
368
+ # Add site types and combine
369
+ if donors_df.shape[1] > 0:
370
+ donors_df.loc['site_type', :] = 0
371
+ if acceptors_df.shape[1] > 0:
372
+ acceptors_df.loc['site_type', :] = 1
373
+
374
+ df = pd.concat([acceptors_df, donors_df], axis=1)
375
+
376
+ # Add metadata and rename scenarios
377
+ df.loc['mut_id'] = mid
378
+ df.loc['engine'] = engine
379
+ df.loc['site'] = df.columns
380
+ df.rename({
381
+ mid: 'epistasis',
382
+ mid.split('|')[0]: 'cv1',
383
+ mid.split('|')[1]: 'cv2'
384
+ }, inplace=True)
385
+
386
+ return df.T
387
+
388
+
389
+ # def process_pairwise_epistasis_explicit(mid, engine='spliceai'):
390
+ # """
391
+ # Process pairwise epistasis for a given mutation identifier (mid).
392
+
393
+ # This function:
394
+ # 1. Parses the input 'mid' to extract positions and loads a gene/transcript.
395
+ # 2. Adjusts bounds based on strand orientation (reverse or forward).
396
+ # 3. Iterates over several mutation scenarios (wild type, individual mutations, and combined mutations),
397
+ # cloning and mutating the transcript as needed.
398
+ # 4. Computes splicing probabilities (donors and acceptors) for a transcript segment.
399
+ # 5. Stores these probabilities in dictionaries and converts them to DataFrames.
400
+ # 6. Applies rounding, thresholding (setting very small numbers to 0), and filters out columns with little variation.
401
+ # 7. Adds new features:
402
+ # - residual: difference between total change and the sum of two individual deviations.
403
+ # - deviation1: change from baseline (row 0) to row 1.
404
+ # - deviation2: change from baseline (row 0) to row 2.
405
+ # - total_deviation: change from baseline (row 0) to row 3.
406
+ # and filters columns with insignificant residual (absolute value <= 0.1).
407
+
408
+ # The new features persist in the returned DataFrames.
409
+
410
+ # Returns:
411
+ # acceptors_df (pd.DataFrame): Processed acceptor probabilities with extra features.
412
+ # donors_df (pd.DataFrame): Processed donor probabilities with extra features.
413
+ # """
414
+ # import pandas as pd
415
+
416
+ # donor_probs, acceptor_probs = {}, {}
417
+
418
+ # # Parse the mid string: assume the format is "file:...:lower_pos:...:upper_pos:..."
419
+ # parts = mid.split(':')
420
+ # lower_pos, upper_pos = int(parts[2]), int(parts[6])
421
+
422
+ # # Load gene and its transcript (as pre-mRNA)
423
+ # g = Gene.from_file(parts[0]).transcript()
424
+ # if g is None:
425
+ # return pd.DataFrame()
426
+
427
+ # g.generate_pre_mrna()
428
+
429
+ # # If gene is on the reverse strand, swap positions and set factor to -1.
430
+ # factor = 1
431
+ # if g.rev:
432
+ # lower_pos, upper_pos = upper_pos, lower_pos
433
+ # factor = -1
434
+
435
+ # # Define bounds with a 7500 bp padding on both sides.
436
+ # lb, ub = lower_pos - (factor * 7500), upper_pos + (factor * 7500)
437
+ # # Ensure lb and ub fall within the transcript indices.
438
+ # if lb not in g.pre_mrna.indices:
439
+ # lb = g.pre_mrna.indices.max() if g.rev else g.pre_mrna.indices.min()
440
+ # if ub not in g.pre_mrna.indices:
441
+ # ub = g.pre_mrna.indices.min() if g.rev else g.pre_mrna.indices.max()
442
+
443
+ # # Process each mutation scenario:
444
+ # # - 'wild_type' (no mutations)
445
+ # # - individual mutations (split by '|')
446
+ # # - a scenario with all mutations (mid itself)
447
+ # scenarios = ['wild_type'] + mid.split('|') + [mid]
448
+ # for m in scenarios:
449
+ # # Clone the transcript for independent mutation processing.
450
+ # transcript = g.clone().pre_mrna
451
+ # if m != 'wild_type':
452
+ # # Parse mutations from the scenario string.
453
+ # mutations = [MutSeqMat.from_mutid(cm) for cm in m.split('|')]
454
+ # # If the gene is reversed, get the reverse complement of each mutation.
455
+ # if g.rev:
456
+ # mutations = [mutation.reverse_complement() for mutation in mutations]
457
+ # # Apply each mutation (if present) to the transcript.
458
+ # for mutation in mutations:
459
+ # if mutation in transcript:
460
+ # transcript.mutate(mutation, inplace=True)
461
+
462
+ # # Calculate splicing probabilities on the transcript slice defined by lb:ub.
463
+ # donors, acceptors = find_transcript_splicing(transcript[lb:ub], engine=engine)
464
+ # donor_probs[m] = donors
465
+ # acceptor_probs[m] = acceptors
466
+
467
+ # # Convert the results to DataFrames (each scenario as a row)
468
+ # acceptors_df = pd.DataFrame.from_dict(acceptor_probs, orient='index')
469
+ # donors_df = pd.DataFrame.from_dict(donor_probs, orient='index')
470
+
471
+ # # Apply rounding and thresholding:
472
+ # # - For acceptors: set values < 0.01 to 0, else round to 2 decimals.
473
+ # # - For donors: use absolute value threshold.
474
+ # acceptors_df = acceptors_df.map(
475
+ # lambda x: 0 if isinstance(x, (int, float)) and x < 0.01 else round(x, 2) if isinstance(x, (int, float)) else x
476
+ # ).round(2)
477
+ # donors_df = donors_df.map(
478
+ # lambda x: 0 if isinstance(x, (int, float)) and abs(x) < 0.01 else round(x, 2) if isinstance(x,
479
+ # (int, float)) else x
480
+ # ).round(2)
481
+
482
+ # # Drop columns that do not vary (only one unique value).
483
+ # # acceptors_df = acceptors_df.loc[:, acceptors_df.nunique() > 1]
484
+ # # donors_df = donors_df.loc[:, donors_df.nunique() > 1]
485
+ # if (acceptors_df.nunique() > 1).any():
486
+ # acceptors_df = acceptors_df.loc[:, acceptors_df.nunique() > 1]
487
+ # else:
488
+ # acceptors_df = acceptors_df.iloc[:, [0]]
489
+
490
+ # # For donors_df:
491
+ # if (donors_df.nunique() > 1).any():
492
+ # donors_df = donors_df.loc[:, donors_df.nunique() > 1]
493
+ # else:
494
+ # donors_df = donors_df.iloc[:, [0]]
495
+
496
+ # # Further filter acceptors: keep only columns where the value in the second row is < 0.1.
497
+ # # (Assumes that the second row (iloc[1]) represents a specific measure you wish to threshold.)
498
+
499
+ # # Helper function: add new features (residual and deviations) and filter based on residual.
500
+ # def add_features_and_filter(df):
501
+ # if df.shape[1] == 0:
502
+ # return df # Nothing to process if no columns remain.
503
+ # df.loc['residual'] = (df.iloc[3] - df.iloc[0]) - ((df.iloc[1] - df.iloc[0]) + (df.iloc[2] - df.iloc[0]))
504
+ # # Compute deviations relative to the baseline (row 0)
505
+ # df.loc['deviation1'] = df.iloc[1] - df.iloc[0]
506
+ # df.loc['deviation2'] = df.iloc[2] - df.iloc[0]
507
+ # df.loc['total_deviation'] = df.iloc[3] - df.iloc[0]
508
+ # return df
509
+
510
+ # # Apply the feature computation to both donors and acceptors.
511
+ # donors_df = add_features_and_filter(donors_df)
512
+ # acceptors_df = add_features_and_filter(acceptors_df)
513
+
514
+ # # Return the processed dataframes with the new features persisting.
515
+ # if donors_df.shape[1] > 0:
516
+ # donors_df.loc['site_type', :] = 0
517
+ # if acceptors_df.shape[1] > 0:
518
+ # acceptors_df.loc['site_type', :] = 1
519
+
520
+ # df = pd.concat([acceptors_df, donors_df], axis=1)
521
+
522
+ # df.loc['mut_id'] = mid
523
+ # df.loc['engine'] = engine
524
+ # df.loc['site'] = df.columns
525
+ # df = df.rename({mid: 'epistasis', mid.split('|')[0]: 'cv1', mid.split('|')[1]: 'cv2'})
526
+ # df = df.T
527
+ # return df
528
+
529
+
530
+
531
+
532
+ class Missplicing:
533
+ def __init__(self, splicing_dict={'missed_acceptors': {}, 'missed_donors': {}, 'discovered_acceptors': {}, 'discovered_donors': {}}, threshold=0.5):
534
+ """
535
+ Initialize a Missplicing object.
536
+
537
+ Args:
538
+ splicing_dict (dict): Dictionary containing splicing events and their details.
539
+ Example:
540
+ {
541
+ "missed_acceptors": {100: {"absolute": 0.0, "delta": -0.3}, ...},
542
+ "missed_donors": { ... },
543
+ "discovered_acceptors": { ... },
544
+ "discovered_donors": { ... }
545
+ }
546
+ threshold (float): The threshold above which a delta is considered significant.
547
+ """
548
+ if splicing_dict is None:
549
+ splicing_dict = {'missed_acceptors': {}, 'missed_donors': {}, 'discovered_acceptors': {}, 'discovered_donors': {}}
550
+ self.missplicing = splicing_dict
551
+ self.threshold = threshold
552
+
553
+ def __str__(self):
554
+ import pprint
555
+ """String representation displays the filtered splicing events passing the threshold."""
556
+ return pprint.pformat(self.aberrant_splicing)
557
+
558
+ def __bool__(self):
559
+ """
560
+ Boolean evaluation: True if any event surpasses the threshold, False otherwise.
561
+ """
562
+ return self.first_significant_event() is not None
563
+
564
+ def __iter__(self):
565
+ """
566
+ Iterate over all delta values from all events. The first yielded value is 0 (for compatibility),
567
+ followed by all deltas in self.missplicing.
568
+ """
569
+ yield 0
570
+ for details in self.missplicing.values():
571
+ for d in details.values():
572
+ yield d['delta']
573
+
574
+ def __getitem__(self, key):
575
+ return self.missplicing[key]
576
+
577
+ @property
578
+ def aberrant_splicing(self):
579
+ """
580
+ Returns a filtered version of missplicing events that meet or exceed the current threshold.
581
+ """
582
+ return self.filter_by_threshold(self.threshold)
583
+
584
+ def filter_by_threshold(self, threshold=None):
585
+ """
586
+ Filter self.missplicing to only include events where abs(delta) >= threshold.
587
+
588
+ Args:
589
+ threshold (float, optional): The threshold to apply. Defaults to self.threshold.
590
+
591
+ Returns:
592
+ dict: A new dictionary with filtered events.
593
+ """
594
+ if threshold is None:
595
+ threshold = self.threshold
596
+ if threshold is None:
597
+ threshold = 0
598
+
599
+ return {
600
+ event: {
601
+ pos: detail for pos, detail in details.items()
602
+ if abs(detail['delta']) >= threshold
603
+ }
604
+ for event, details in self.missplicing.items()
605
+ }
606
+
607
+ def first_significant_event(self, splicing_dict=None, threshold=None):
608
+ """
609
+ Check if there is any event surpassing a given threshold and return the dictionary if found.
610
+
611
+ Args:
612
+ splicing_dict (dict, optional): Dictionary to check. Defaults to self.missplicing.
613
+ threshold (float, optional): Threshold to apply. Defaults to self.threshold.
614
+
615
+ Returns:
616
+ dict or None: Returns the dictionary if a delta surpasses the threshold, otherwise None.
617
+ """
618
+ if splicing_dict is None:
619
+ splicing_dict = self.missplicing
620
+ if threshold is None:
621
+ threshold = self.threshold
622
+
623
+ # Check if any event meets the threshold
624
+ if any(abs(detail['delta']) >= threshold for details in splicing_dict.values() for detail in details.values()):
625
+ return splicing_dict
626
+ return None
627
+
628
+ @property
629
+ def max_delta(self):
630
+ """
631
+ Returns the maximum absolute delta found in all events.
632
+
633
+ Returns:
634
+ float: The maximum absolute delta, or 0 if no events.
635
+ """
636
+ max_deltas = []
637
+ for k, v in self.missplicing.items():
638
+ max_deltas.extend([sv['delta'] for sv in v.values()])
639
+ return max(max_deltas, key=abs, default=0.0)
640
+
641
+
642
+ # def find_transcript_splicing(transcript, engine='spliceai'):
643
+ # ref_indices = transcript.indices
644
+ # ref_seq = 'N' * 5000 + transcript.seq + 'N' * 5000
645
+ # if engine == 'spliceai':
646
+ # from .spliceai_utils import sai_predict_probs, sai_models
647
+ # ref_seq_acceptor_probs, ref_seq_donor_probs = sai_predict_probs(ref_seq, sai_models)
648
+ #
649
+ # elif engine == 'pangolin':
650
+ # from .pangolin_utils import pangolin_predict_probs, pang_models
651
+ # ref_seq_donor_probs, ref_seq_acceptor_probs = pangolin_predict_probs(ref_seq, models=pang_models)
652
+ #
653
+ # else:
654
+ # raise ValueError(f"{engine} not implemented")
655
+ #
656
+ # assert len(ref_seq_donor_probs) == len(ref_indices), f'{len(ref_seq_donor_probs)} vs. {len(ref_indices)}'
657
+ # donor_probs = {i: p for i, p in list(zip(ref_indices, ref_seq_donor_probs))}
658
+ # donor_probs = dict(sorted(donor_probs.items(), key=lambda item: item[1], reverse=True))
659
+ #
660
+ # acceptor_probs = {i: p for i, p in list(zip(ref_indices, ref_seq_acceptor_probs))}
661
+ # acceptor_probs = dict(sorted(acceptor_probs.items(), key=lambda item: item[1], reverse=True))
662
+ # return donor_probs, acceptor_probs
663
+
664
+
665
+ def benchmark_splicing(gene, organism='hg38', engine='spliceai'):
666
+ gene = Gene(gene, organism=organism)
667
+ transcript = gene.transcript()
668
+ if transcript is None or len(transcript.introns) == 0:
669
+ return None, None
670
+
671
+ transcript.generate_pre_mrna()
672
+ predicted_donor_sites, predicted_acceptor_sites = find_transcript_splicing(transcript.pre_mrna, engine=engine)
673
+ num_introns = len(transcript.introns)
674
+ predicted_donors = list(predicted_donor_sites.keys())[:num_introns]
675
+ predicted_acceptors = list(predicted_acceptor_sites.keys())[:num_introns]
676
+ correct_donor_preds = [v for v in predicted_donors if v in transcript.donors]
677
+ correct_acceptor_preds = [v for v in predicted_acceptors if v in transcript.acceptors]
678
+ return len(correct_donor_preds) / num_introns, len(correct_acceptor_preds) / num_introns, len(transcript.introns)
679
+
680
+
681
+
682
+ def convert_numpy_to_native(obj):
683
+ """
684
+ Recursively convert NumPy data types to native Python types.
685
+ """
686
+ if isinstance(obj, dict):
687
+ return {key: convert_numpy_to_native(value) for key, value in obj.items()}
688
+ elif isinstance(obj, list):
689
+ return [convert_numpy_to_native(item) for item in obj]
690
+ elif isinstance(obj, np.generic): # Check for NumPy scalar types
691
+ return round(obj.item(), 3)
692
+ else:
693
+ return round(obj, 3)