geney 1.4.40__py3-none-any.whl → 1.4.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geney/__init__.py CHANGED
@@ -4,12 +4,17 @@ from .engines import (
4
4
  sai_predict_probs,
5
5
  run_spliceai_seq,
6
6
  run_splicing_engine,
7
+ predict_splicing,
8
+ adjoin_splicing_outcomes,
7
9
  )
8
10
  from .transcripts import TranscriptLibrary
9
- from .splicing_table import adjoin_splicing_outcomes
10
11
  from .splice_graph import SpliceSimulator
11
- from .pipelines import oncosplice_pipeline_single_transcript
12
- from .samples import *
12
+ from .pipelines import (
13
+ oncosplice_pipeline,
14
+ oncosplice_top_isoform,
15
+ max_splicing_delta,
16
+ oncosplice_pipeline_single_transcript, # backwards compat
17
+ )
13
18
 
14
19
  __all__ = [
15
20
  "Mutation",
@@ -18,8 +23,16 @@ __all__ = [
18
23
  "sai_predict_probs",
19
24
  "run_spliceai_seq",
20
25
  "run_splicing_engine",
21
- "TranscriptLibrary",
26
+ "predict_splicing",
22
27
  "adjoin_splicing_outcomes",
28
+ "TranscriptLibrary",
23
29
  "SpliceSimulator",
30
+ "oncosplice_pipeline",
31
+ "oncosplice_top_isoform",
32
+ "max_splicing_delta",
24
33
  "oncosplice_pipeline_single_transcript",
25
- ]
34
+ ]
35
+
36
+
37
+ mut_id = 'KRAS:12:25227343:G:T'
38
+ epistasis_id = 'KRAS:12:25227343:G:T|KRAS:12:25227344:A:T'
geney/engines.py CHANGED
@@ -2,137 +2,102 @@
2
2
  from __future__ import annotations
3
3
 
4
4
  from typing import Dict, List, Tuple, Optional, Union
5
-
6
5
  import numpy as np
7
6
 
8
- # These are your existing helpers; keep them in separate modules if you want.
9
- # from ._spliceai_utils import one_hot_encode, sai_models # type: ignore
10
- # from ._pangolin_utils import pangolin_predict_probs, pang_models # type: ignore
11
- import torch
12
- from pkg_resources import resource_filename
13
- from pangolin.model import *
14
- import numpy as np
15
- import sys
7
+ # Lazy-loaded model containers (loaded automatically on first use)
8
+ _pang_models = None
9
+ _sai_models = None
10
+ _pang_device = None
11
+ _sai_device = None
12
+
16
13
 
17
- pang_model_nums = [0, 1, 2, 3, 4, 5, 6, 7]
18
- pang_models = []
14
+ def _get_torch_device():
15
+ """Get the best available device for PyTorch."""
16
+ import sys
17
+ import torch
19
18
 
20
- def get_best_device():
21
- """Get the best available device for computation."""
22
19
  if sys.platform == 'darwin' and torch.backends.mps.is_available():
23
20
  try:
24
- # Test MPS availability
25
21
  torch.tensor([1.0], device="mps")
26
22
  return torch.device("mps")
27
23
  except RuntimeError:
28
- print("Warning: MPS not available, falling back to CPU")
29
24
  return torch.device("cpu")
30
25
  elif torch.cuda.is_available():
31
26
  return torch.device("cuda")
32
- else:
33
- return torch.device("cpu")
27
+ return torch.device("cpu")
28
+
29
+
30
+ def _get_tensorflow_device():
31
+ """Get the best available TensorFlow device."""
32
+ import sys
33
+ import tensorflow as tf
34
+
35
+ try:
36
+ if tf.config.list_physical_devices('GPU'):
37
+ return '/GPU:0'
38
+ elif sys.platform == 'darwin' and tf.config.list_physical_devices('MPS'):
39
+ return '/device:GPU:0'
40
+ except Exception:
41
+ pass
42
+ return '/CPU:0'
43
+
44
+
45
+ def _load_pangolin_models():
46
+ """Lazy load Pangolin models."""
47
+ global _pang_models, _pang_device
48
+
49
+ if _pang_models is not None:
50
+ return _pang_models
51
+
52
+ import torch
53
+ from pkg_resources import resource_filename
54
+ from pangolin.model import Pangolin, L, W, AR
55
+
56
+ _pang_device = _get_torch_device()
57
+ print(f"Pangolin loading to {_pang_device}...")
34
58
 
35
- device = get_best_device()
36
- print(f"Pangolin loaded to {device}.")
59
+ _pang_models = []
60
+ pang_model_nums = [0, 1, 2, 3, 4, 5, 6, 7]
37
61
 
38
- # Initialize models with improved error handling
39
- try:
40
62
  for i in pang_model_nums:
41
63
  for j in range(1, 6):
42
64
  try:
43
- model = Pangolin(L, W, AR).to(device)
44
-
45
- # Load weights with proper device mapping
65
+ model = Pangolin(L, W, AR).to(_pang_device)
46
66
  model_path = resource_filename("pangolin", f"models/final.{j}.{i}.3")
47
- weights = torch.load(model_path, weights_only=True, map_location=device)
48
-
67
+ weights = torch.load(model_path, weights_only=True, map_location=_pang_device)
49
68
  model.load_state_dict(weights)
50
69
  model.eval()
51
- pang_models.append(model)
52
-
70
+ _pang_models.append(model)
53
71
  except Exception as e:
54
72
  print(f"Warning: Failed to load Pangolin model {j}.{i}: {e}")
55
- continue
56
-
57
- except Exception as e:
58
- print(f"Error initializing Pangolin models: {e}")
59
- pang_models = []
60
73
 
74
+ print(f"Pangolin loaded ({len(_pang_models)} models).")
75
+ return _pang_models
61
76
 
62
- def pang_one_hot_encode(seq: str) -> np.ndarray:
63
- """One-hot encode DNA sequence for Pangolin model.
64
-
65
- Args:
66
- seq: DNA sequence string
67
-
68
- Returns:
69
- One-hot encoded array of shape (len(seq), 4)
70
-
71
- Raises:
72
- ValueError: If sequence contains invalid characters
73
- """
74
- if not isinstance(seq, str):
75
- raise TypeError(f"Expected string, got {type(seq).__name__}")
76
-
77
- IN_MAP = np.asarray([[0, 0, 0, 0], # N or unknown
78
- [1, 0, 0, 0], # A
79
- [0, 1, 0, 0], # C
80
- [0, 0, 1, 0], # G
81
- [0, 0, 0, 1]]) # T
82
-
83
- # Validate sequence
84
- valid_chars = set('ACGTN')
85
- if not all(c.upper() in valid_chars for c in seq):
86
- raise ValueError("Sequence contains invalid characters (only A, C, G, T, N allowed)")
87
-
88
- # Convert to numeric representation
89
- seq = seq.upper().replace('A', '1').replace('C', '2')
90
- seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
91
-
92
- try:
93
- seq_array = np.asarray(list(map(int, list(seq))))
94
- return IN_MAP[seq_array.astype('int8')]
95
- except (ValueError, IndexError) as e:
96
- raise ValueError(f"Failed to encode sequence: {e}") from e
97
77
 
78
+ def _load_spliceai_models():
79
+ """Lazy load SpliceAI models."""
80
+ global _sai_models, _sai_device
98
81
 
82
+ if _sai_models is not None:
83
+ return _sai_models
99
84
 
85
+ import os
86
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
100
87
 
101
- import os
102
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
88
+ import sys
89
+ import tensorflow as tf
90
+ from keras.models import load_model
91
+ from importlib import resources
103
92
 
104
- import absl.logging
105
- absl.logging.set_verbosity(absl.logging.ERROR)
93
+ import absl.logging
94
+ absl.logging.set_verbosity(absl.logging.ERROR)
106
95
 
107
- import os
108
- import sys
109
- import tensorflow as tf
110
- import numpy as np
111
- from keras.models import load_model
112
- from importlib import resources
96
+ _sai_device = _get_tensorflow_device()
97
+ print(f"SpliceAI loading to {_sai_device}...")
113
98
 
99
+ _sai_models = []
114
100
 
115
- # Force device selection with error handling
116
- def get_best_tensorflow_device():
117
- """Get the best available TensorFlow device."""
118
- try:
119
- # Try GPU first
120
- if tf.config.list_physical_devices('GPU'):
121
- return '/GPU:0'
122
- # Try MPS on macOS
123
- elif sys.platform == 'darwin' and tf.config.list_physical_devices('MPS'):
124
- return '/device:GPU:0'
125
- else:
126
- return '/CPU:0'
127
- except Exception as e:
128
- print(f"Warning: Device selection failed, using CPU: {e}")
129
- return '/CPU:0'
130
-
131
- device = get_best_tensorflow_device()
132
-
133
- # Model loading paths with error handling
134
- def load_spliceai_models():
135
- """Load SpliceAI models with proper error handling."""
136
101
  try:
137
102
  if sys.platform == 'darwin':
138
103
  model_filenames = [f"models/spliceai{i}.h5" for i in range(1, 6)]
@@ -140,168 +105,312 @@ def load_spliceai_models():
140
105
  else:
141
106
  model_paths = [f"/tamir2/nicolaslynn/tools/SpliceAI/spliceai/models/spliceai{i}.h5"
142
107
  for i in range(1, 6)]
143
-
144
- # Load models onto correct device
145
- models = []
146
- with tf.device(device):
108
+
109
+ with tf.device(_sai_device):
147
110
  for i, model_path in enumerate(model_paths):
148
111
  try:
149
112
  model = load_model(str(model_path))
150
- models.append(model)
113
+ _sai_models.append(model)
151
114
  except Exception as e:
152
115
  print(f"Warning: Failed to load SpliceAI model {i+1}: {e}")
153
- continue
154
-
155
- if not models:
156
- raise RuntimeError("No SpliceAI models could be loaded")
157
-
158
- return models
159
-
160
116
  except Exception as e:
161
117
  print(f"Error loading SpliceAI models: {e}")
162
- return []
163
118
 
164
- sai_models = load_spliceai_models()
119
+ print(f"SpliceAI loaded ({len(_sai_models)} models).")
120
+ return _sai_models
121
+
122
+
123
+ def pang_one_hot_encode(seq: str) -> np.ndarray:
124
+ """One-hot encode DNA sequence for Pangolin model."""
125
+ if not isinstance(seq, str):
126
+ raise TypeError(f"Expected string, got {type(seq).__name__}")
127
+
128
+ IN_MAP = np.asarray([[0, 0, 0, 0], # N
129
+ [1, 0, 0, 0], # A
130
+ [0, 1, 0, 0], # C
131
+ [0, 0, 1, 0], # G
132
+ [0, 0, 0, 1]]) # T
133
+
134
+ valid_chars = set('ACGTN')
135
+ if not all(c.upper() in valid_chars for c in seq):
136
+ raise ValueError("Sequence contains invalid characters")
137
+
138
+ seq = seq.upper().replace('A', '1').replace('C', '2')
139
+ seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
165
140
 
141
+ seq_array = np.asarray(list(map(int, list(seq))))
142
+ return IN_MAP[seq_array.astype('int8')]
166
143
 
167
- print(f"SpliceAI loaded to {device}.")
168
144
 
169
145
  def one_hot_encode(seq: str) -> np.ndarray:
170
- """One-hot encode DNA sequence for SpliceAI model.
171
-
172
- Args:
173
- seq: DNA sequence string
174
-
175
- Returns:
176
- One-hot encoded array of shape (len(seq), 4)
177
-
178
- Raises:
179
- ValueError: If sequence contains invalid characters
180
- """
146
+ """One-hot encode DNA sequence for SpliceAI model."""
181
147
  if not isinstance(seq, str):
182
148
  raise TypeError(f"Expected string, got {type(seq).__name__}")
183
-
184
- # Validate sequence
149
+
185
150
  valid_chars = set('ACGTN')
186
151
  if not all(c.upper() in valid_chars for c in seq):
187
- raise ValueError("Sequence contains invalid characters (only A, C, G, T, N allowed)")
188
-
189
- encoding_map = np.asarray([[0, 0, 0, 0], # N or unknown
152
+ raise ValueError("Sequence contains invalid characters")
153
+
154
+ encoding_map = np.asarray([[0, 0, 0, 0], # N
190
155
  [1, 0, 0, 0], # A
191
156
  [0, 1, 0, 0], # C
192
157
  [0, 0, 1, 0], # G
193
158
  [0, 0, 0, 1]]) # T
194
159
 
195
- # Convert to numeric representation
196
160
  seq = seq.upper().replace('A', '\x01').replace('C', '\x02')
197
161
  seq = seq.replace('G', '\x03').replace('T', '\x04').replace('N', '\x00')
198
162
 
199
- try:
200
- return encoding_map[np.frombuffer(seq.encode('latin1'), np.int8) % 5]
201
- except Exception as e:
202
- raise ValueError(f"Failed to encode sequence: {e}") from e
163
+ return encoding_map[np.frombuffer(seq.encode('latin1'), np.int8) % 5]
203
164
 
204
165
 
205
- def sai_predict_probs(seq: str, models: list) -> tuple[np.ndarray, np.ndarray]:
206
- """
207
- Predict donor and acceptor probabilities for each nt in seq using SpliceAI.
208
- Returns (acceptor_probs, donor_probs) as np.ndarray of shape (L,).
166
+ def pangolin_predict_probs(seq: str, models: list = None) -> Tuple[List[float], List[float]]:
167
+ """Predict splice site probabilities using Pangolin.
168
+
169
+ Pangolin outputs shape (1, 12, seq_len) where:
170
+ - 12 channels = 4 tissues × 3 prediction types
171
+ - For each tissue: [site_usage, acceptor_gain, donor_gain] or similar
172
+
173
+ We aggregate by taking max across tissues.
209
174
  """
175
+ import torch
176
+
177
+ if models is None:
178
+ models = _load_pangolin_models()
179
+
210
180
  if not models:
211
- raise ValueError("No SpliceAI models loaded")
181
+ raise ValueError("No Pangolin models loaded")
212
182
 
213
- if not isinstance(seq, str):
214
- raise TypeError(f"Expected string, got {type(seq).__name__}")
183
+ x = pang_one_hot_encode(seq)
184
+ x = torch.tensor(x.T[None, :, :], dtype=torch.float32, device=_pang_device)
185
+
186
+ preds = []
187
+ with torch.no_grad():
188
+ for model in models:
189
+ pred = model(x)
190
+ preds.append(pred.cpu().numpy())
191
+
192
+ y = np.mean(preds, axis=0) # Shape: (1, 12, seq_len)
193
+
194
+ # Pangolin has 12 channels organized as:
195
+ # Indices 0,3,6,9: site usage scores for 4 tissues
196
+ # Indices 1,4,7,10: acceptor gain scores for 4 tissues
197
+ # Indices 2,5,8,11: donor gain scores for 4 tissues
198
+ # Take max across the 4 tissues for each type
199
+
200
+ # Acceptor: max of channels 1, 4, 7, 10
201
+ acceptor_channels = y[0, [1, 4, 7, 10], :] # (4, seq_len)
202
+ acceptor_probs = np.max(acceptor_channels, axis=0).tolist()
203
+
204
+ # Donor: max of channels 2, 5, 8, 11
205
+ donor_channels = y[0, [2, 5, 8, 11], :] # (4, seq_len)
206
+ donor_probs = np.max(donor_channels, axis=0).tolist()
207
+
208
+ return donor_probs, acceptor_probs
209
+
210
+
211
+ def sai_predict_probs(seq: str, models: list = None) -> Tuple[np.ndarray, np.ndarray]:
212
+ """Predict donor and acceptor probabilities using SpliceAI."""
213
+ if models is None:
214
+ models = _load_spliceai_models()
215
+
216
+ if not models:
217
+ raise ValueError("No SpliceAI models loaded")
215
218
 
216
219
  if len(seq) < 1000:
217
220
  raise ValueError(f"Sequence too short: {len(seq)} (expected >= 1000)")
218
221
 
219
- try:
220
- x = one_hot_encode(seq)[None, :]
221
- preds = []
222
- for i, model in enumerate(models):
223
- try:
224
- pred = model.predict(x, verbose=0)
225
- preds.append(pred)
226
- except Exception as e:
227
- print(f"Warning: SpliceAI model {i+1} failed: {e}")
228
- if not preds:
229
- raise RuntimeError("All SpliceAI model predictions failed")
222
+ x = one_hot_encode(seq)[None, :].astype(np.float32)
230
223
 
231
- y = np.mean(preds, axis=0) # (1, L, 3)
232
- y = y[0, :, 1:].T # (2, L) -> [acceptor, donor]
233
- return y[0, :], y[1, :]
224
+ # Use direct model call instead of .predict() to avoid Jupyter kernel issues
225
+ preds = []
226
+ for model in models:
227
+ pred = model(x, training=False)
228
+ if hasattr(pred, 'numpy'):
229
+ pred = pred.numpy()
230
+ preds.append(pred)
234
231
 
235
- except Exception as e:
236
- raise RuntimeError(f"SpliceAI prediction failed: {e}") from e
232
+ y = np.mean(preds, axis=0)
233
+ y = y[0, :, 1:].T
234
+ return y[0, :], y[1, :]
237
235
 
238
236
 
239
237
  def run_spliceai_seq(
240
238
  seq: str,
241
239
  indices: Union[List[int], np.ndarray],
242
240
  threshold: float = 0.0,
243
- ) -> tuple[Dict[int, float], Dict[int, float]]:
244
- """
245
- Run SpliceAI on seq and return donor / acceptor sites above threshold.
246
- Returns (donor_indices, acceptor_indices) as dict[pos -> prob]
247
- """
248
- if not isinstance(seq, str):
249
- raise TypeError(f"Expected string sequence, got {type(seq).__name__}")
250
-
251
- if not isinstance(indices, (list, np.ndarray)):
252
- raise TypeError(f"Expected list or array for indices, got {type(indices).__name__}")
253
-
241
+ ) -> Tuple[Dict[int, float], Dict[int, float]]:
242
+ """Run SpliceAI on seq and return donor/acceptor sites above threshold."""
254
243
  if len(indices) != len(seq):
255
244
  raise ValueError(f"indices length ({len(indices)}) must match sequence length ({len(seq)})")
256
245
 
257
- if not isinstance(threshold, (int, float)):
258
- raise TypeError(f"Threshold must be numeric, got {type(threshold).__name__}")
246
+ acc_probs, don_probs = sai_predict_probs(seq)
247
+ acceptor = {pos: p for pos, p in zip(indices, acc_probs) if p >= threshold}
248
+ donor = {pos: p for pos, p in zip(indices, don_probs) if p >= threshold}
249
+ return donor, acceptor
259
250
 
260
- try:
261
- acc_probs, don_probs = sai_predict_probs(seq, models=sai_models)
262
- acceptor = {pos: p for pos, p in zip(indices, acc_probs) if p >= threshold}
263
- donor = {pos: p for pos, p in zip(indices, don_probs) if p >= threshold}
264
- return donor, acceptor
265
- except Exception as e:
266
- raise RuntimeError(f"SpliceAI sequence analysis failed: {e}") from e
251
+
252
+ def _generate_random_sequence(length: int) -> str:
253
+ """Generate a random DNA sequence of given length."""
254
+ import random
255
+ return ''.join(random.choices('ACGT', k=length))
267
256
 
268
257
 
269
258
  def run_splicing_engine(
270
259
  seq: Optional[str] = None,
271
260
  engine: str = "spliceai",
272
261
  ) -> Tuple[List[float], List[float]]:
262
+ """Run specified splicing engine to predict splice site probabilities."""
263
+ if seq is None:
264
+ seq = _generate_random_sequence(15_001)
265
+
266
+ if not isinstance(seq, str) or not seq:
267
+ raise ValueError("Sequence must be a non-empty string")
268
+
269
+ valid_chars = set("ACGTN")
270
+ if not all(c.upper() in valid_chars for c in seq):
271
+ raise ValueError("Sequence contains invalid nucleotides")
272
+
273
+ match engine:
274
+ case "spliceai":
275
+ acc, don = sai_predict_probs(seq)
276
+ return don.tolist(), acc.tolist()
277
+ case "pangolin":
278
+ return pangolin_predict_probs(seq)
279
+ case _:
280
+ raise ValueError(f"Engine '{engine}' not implemented. Available: 'spliceai', 'pangolin'")
281
+
282
+
283
+ # ------------------------------------------------------------------------------
284
+ # Higher-level prediction utilities (formerly in splicing_table.py)
285
+ # ------------------------------------------------------------------------------
286
+
287
+ def predict_splicing(s, position: int, engine: str = 'spliceai', context: int = 7500):
273
288
  """
274
- Run specified splicing engine to predict splice site probabilities.
289
+ Predict splicing probabilities at a given position using the specified engine.
290
+
291
+ Args:
292
+ s: Sequence object with .seq, .index, .clone(), .rev attributes
293
+ position: The genomic position to predict splicing probabilities for.
294
+ engine: The prediction engine to use. Supported: 'spliceai', 'pangolin'.
295
+ context: The length of the target central region (default: 7500).
275
296
 
276
297
  Returns:
277
- (donor_probs, acceptor_probs) as lists
298
+ pd.DataFrame with position index and columns: donor_prob, acceptor_prob, nucleotides
278
299
  """
279
- from .utils import generate_random_sequence # type: ignore
300
+ import pandas as pd
280
301
 
281
- if seq is None:
282
- seq = generate_random_sequence(15_001)
302
+ if position < s.index.min() or position > s.index.max():
303
+ raise ValueError(f"Position {position} is outside sequence bounds [{s.index.min()}, {s.index.max()}]")
283
304
 
284
- if not isinstance(seq, str):
285
- raise TypeError(f"Sequence must be string, got {type(seq).__name__}")
286
- if not seq:
287
- raise ValueError("Sequence cannot be empty")
305
+ target = s.clone(position - context, position + context)
306
+
307
+ if len(target.seq) == 0:
308
+ raise ValueError(f"No sequence data found around position {position} with context {context}")
309
+
310
+ seq, indices = target.seq, target.index
311
+
312
+ if len(indices) == 0:
313
+ raise ValueError(f"No indices found in sequence around position {position}")
314
+
315
+ rel_pos = np.abs(indices - position).argmin()
316
+ left_missing, right_missing = max(0, context - rel_pos), max(0, context - (len(seq) - rel_pos))
317
+
318
+ if left_missing > 0 or right_missing > 0:
319
+ step = -1 if s.rev else 1
320
+
321
+ if left_missing > 0:
322
+ left_pad = np.arange(indices[0] - step * left_missing, indices[0], step)
323
+ else:
324
+ left_pad = np.array([], dtype=indices.dtype)
325
+
326
+ if right_missing > 0:
327
+ right_pad = np.arange(indices[-1] + step, indices[-1] + step * (right_missing + 1), step)
328
+ else:
329
+ right_pad = np.array([], dtype=indices.dtype)
330
+
331
+ seq = 'N' * left_missing + seq + 'N' * right_missing
332
+ indices = np.concatenate([left_pad, indices, right_pad])
333
+
334
+ donor_probs, acceptor_probs = run_splicing_engine(seq=seq, engine=engine)
335
+
336
+ seq = seq[5000:-5000]
337
+ indices = indices[5000:-5000]
338
+ expected_len = len(seq)
339
+
340
+ if len(donor_probs) != expected_len:
341
+ if len(donor_probs) > expected_len:
342
+ offset = (len(donor_probs) - expected_len) // 2
343
+ donor_probs = donor_probs[offset:offset + expected_len]
344
+ acceptor_probs = acceptor_probs[offset:offset + expected_len]
345
+ else:
346
+ pad_len = expected_len - len(donor_probs)
347
+ donor_probs = donor_probs + [0.0] * pad_len
348
+ acceptor_probs = acceptor_probs + [0.0] * pad_len
349
+
350
+ df = pd.DataFrame({
351
+ 'position': indices,
352
+ 'donor_prob': donor_probs,
353
+ 'acceptor_prob': acceptor_probs,
354
+ 'nucleotides': list(seq)
355
+ }).set_index('position').round(3)
356
+
357
+ df.attrs['name'] = s.name
358
+ return df
288
359
 
289
- valid_chars = set("ACGTN")
290
- if not all(c.upper() in valid_chars for c in seq):
291
- raise ValueError("Sequence contains invalid nucleotides (only A, C, G, T, N allowed)")
360
+
361
+ def adjoin_splicing_outcomes(
362
+ splicing_predictions: Dict[str, 'pd.DataFrame'],
363
+ transcript: Optional[object] = None,
364
+ ) -> 'pd.DataFrame':
365
+ """
366
+ Combine splicing predictions for multiple mutations into a multi-index DataFrame.
367
+
368
+ Args:
369
+ splicing_predictions: {label -> DF with 'donor_prob','acceptor_prob','nucleotides'}
370
+ transcript: optional transcript (must have .acceptors, .donors, .rev)
371
+ """
372
+ import pandas as pd
373
+
374
+ if not splicing_predictions:
375
+ raise ValueError("splicing_predictions cannot be empty")
376
+
377
+ dfs = []
378
+ for label, df in splicing_predictions.items():
379
+ if not isinstance(df, pd.DataFrame):
380
+ raise TypeError(f"Expected DataFrame for '{label}', got {type(df).__name__}")
381
+
382
+ required_cols = ["donor_prob", "acceptor_prob", "nucleotides"]
383
+ missing = [c for c in required_cols if c not in df.columns]
384
+ if missing:
385
+ raise ValueError(f"DataFrame for '{label}' missing required columns: {missing}")
386
+
387
+ var_df = df.rename(
388
+ columns={
389
+ "donor_prob": ("donors", f"{label}_prob"),
390
+ "acceptor_prob": ("acceptors", f"{label}_prob"),
391
+ "nucleotides": ("nts", f"{label}"),
392
+ }
393
+ )
394
+ dfs.append(var_df)
292
395
 
293
396
  try:
294
- match engine:
295
- case "spliceai":
296
- acc, don = sai_predict_probs(seq, models=sai_models)
297
- donor_probs, acceptor_probs = don.tolist(), acc.tolist()
298
- case "spliceai-pytorch":
299
- raise ValueError("spliceai-pytorch engine has been removed. Use 'spliceai' instead.")
300
- case "pangolin":
301
- donor_probs, acceptor_probs = pangolin_predict_probs(seq, models=pang_models)
302
- case _:
303
- raise ValueError(f"Engine '{engine}' not implemented. Available: 'spliceai', 'pangolin'")
304
- except ImportError as e:
305
- raise ImportError(f"Failed to import engine '{engine}': {e}") from e
306
-
307
- return donor_probs, acceptor_probs
397
+ full_df = pd.concat(dfs, axis=1)
398
+ except Exception as e:
399
+ raise ValueError(f"Failed to concatenate DataFrames: {e}") from e
400
+
401
+ if not isinstance(full_df.columns, pd.MultiIndex):
402
+ full_df.columns = pd.MultiIndex.from_tuples(full_df.columns)
403
+
404
+ if transcript is not None:
405
+ full_df[("acceptors", "annotated")] = full_df.apply(
406
+ lambda row: row.name in transcript.acceptors, axis=1
407
+ )
408
+ full_df[("donors", "annotated")] = full_df.apply(
409
+ lambda row: row.name in transcript.donors, axis=1
410
+ )
411
+ full_df.sort_index(axis=1, level=0, inplace=True)
412
+ full_df.sort_index(ascending=not transcript.rev, inplace=True)
413
+ else:
414
+ full_df.sort_index(axis=1, level=0, inplace=True)
415
+
416
+ return full_df