codeine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,489 @@
1
+ import math
2
+
3
+ from typing import Dict, NamedTuple, Tuple, List, Optional, TYPE_CHECKING
4
+
5
+ from codeine.constraints.banned import BannedTrackerStateId, BannedTrackerAdvanceResult
6
+ from codeine.constraints.base import ConstraintState
7
+ from codeine.graph.nodes import CodonNode, Node, ContextNode
8
+
9
+ if TYPE_CHECKING:
10
+ from codeine.graph.view import CodonGraphView
11
+
12
+
13
+ class TraversalState(NamedTuple):
14
+ """
15
+ The traversal state consists of the current node, plus a summary of the relevant
16
+ parts of how we got there. For example, it tracks whether we have seen parts of
17
+ banned sequences, and can track nucleotide or codon properties.
18
+
19
+ Different graph traversal histories that produce the same traversal state are
20
+ collapsed and are equivalent under this framework.
21
+ """
22
+ node: Node
23
+ banned_tracker_state_id: BannedTrackerStateId
24
+ constraint_state: ConstraintState
25
+
26
+
27
+ class ChoiceResult(NamedTuple):
28
+ """
29
+ Cached result of taking one graph choice from one compiled state. The
30
+ "choice" is the graph edge label, i.e. a codon or a context sequence.
31
+
32
+ Each ChoiceResult is specific to its location in the graph. The descendant
33
+ counts and log mass are calculated iteratively by summing the values of
34
+ downstream nodes.
35
+ """
36
+ choice: str
37
+ descendant_count: int
38
+ descendant_log_mass: float
39
+ next_state_id: Optional[int]
40
+ is_coding: bool
41
+
42
+
43
+ class CompiledView(NamedTuple):
44
+ """
45
+ Cached data for a compiled CodonGraphView, to speed up sampling and enumeration.
46
+ """
47
+ initial_state: TraversalState
48
+ initial_state_id: int
49
+ states: Tuple[TraversalState, ...]
50
+
51
+ # Compiled graph choices (lookup):
52
+ # state ID -> choice -> ChoiceResult
53
+ # Used for fast sequence validation and graph traversal in the graph view.
54
+ choices_by_state_id: Tuple[Dict[str, ChoiceResult], ...]
55
+
56
+ # Compiled graph choices (iteration):
57
+ # state ID -> ChoiceResults in graph order
58
+ # Used for fast sampling and sequence enumeration in the graph view.
59
+ choice_results_by_state_id: Tuple[Tuple[ChoiceResult, ...], ...]
60
+
61
+ n_valid_sequences: int
62
+
63
+
64
+ class ViewCompiler:
65
+ """
66
+ Compile a CodonGraphView into cached choice, count, and sampling data.
67
+ """
68
+
69
+ def __init__(self, view: 'CodonGraphView') -> None:
70
+ self.view = view
71
+ self.graph = view.graph
72
+
73
+ self.banned_tracker = view.banned_tracker
74
+ self.has_banned_tracker = not self.banned_tracker.is_trivial
75
+
76
+ self.path_constraint = view.path_constraint
77
+ self.has_path_constraint = self.path_constraint is not None
78
+
79
+ self.state_ids: Dict[TraversalState, int] = {}
80
+ self.states: List[TraversalState] = []
81
+
82
+ # Dynamic-programming totals:
83
+ # state ID -> (descendant count, descendant log mass)
84
+ # Avoids repeatedly recomputing subtree sizes and probability masses.
85
+ self.totals_by_state_id: List[Optional[Tuple[int, float]]] = []
86
+
87
+ # Banned-sequence tracker transitions:
88
+ # (node, tracker state, choice) -> tracker result
89
+ # Avoids recomputing tracker advances during compilation.
90
+ self.banned_advance_cache: Dict[Tuple[Node, BannedTrackerStateId, str], BannedTrackerAdvanceResult] = {}
91
+
92
+ # Traversal transition table:
93
+ # state ID -> [(choice, child state ID), ...]
94
+ # Avoids rediscovering successor states during compilation.
95
+ self.child_results_by_state_id: List[Optional[List[Tuple[str, int]]]] = []
96
+
97
+ # Compiled graph choices (lookup):
98
+ # state ID -> choice -> ChoiceResult
99
+ # Used for fast sequence validation and graph traversal in the graph view.
100
+ self.choices_by_state_id: List[Optional[Dict[str, ChoiceResult]]] = []
101
+
102
+ # The cached log-ified codon weights, to avoid repeated log calculations
103
+ self.log_codon_weights = {
104
+ codon: math.log(weight)
105
+ for codon, weight in self.graph.cw.weights.items()
106
+ if weight > 0
107
+ }
108
+
109
+ # Cached version of the choices available at each node, taking into account fixed codons & pins
110
+ self.choices_by_node = {
111
+ node: tuple(self._get_choices_for_node(node))
112
+ for node in self.graph.nodes
113
+ if node is not self.graph.final_node
114
+ }
115
+
116
+ def compile(self) -> CompiledView:
117
+ """
118
+ Compile descendant counts, graph choices, and samplers.
119
+
120
+ Returns
121
+ -------
122
+ CompiledView
123
+ A compiled view.
124
+ """
125
+ initial_state = self._initial_state()
126
+ initial_state_id = self._get_or_register_state_id(initial_state)
127
+
128
+ self._compile_from(initial_state_id)
129
+
130
+ choices_by_state_id = tuple(
131
+ choices or {}
132
+ for choices in self.choices_by_state_id
133
+ )
134
+
135
+ choice_results_by_state_id = tuple(
136
+ tuple(choices.values()) if choices else ()
137
+ for choices in self.choices_by_state_id
138
+ )
139
+
140
+ initial_total = self.totals_by_state_id[initial_state_id]
141
+ assert initial_total is not None
142
+
143
+ return CompiledView(
144
+ initial_state=initial_state,
145
+ initial_state_id=initial_state_id,
146
+ states=tuple(self.states),
147
+ n_valid_sequences=initial_total[0],
148
+ choices_by_state_id=choices_by_state_id,
149
+ choice_results_by_state_id=choice_results_by_state_id,
150
+ )
151
+
152
+ def _get_or_register_state_id(self, state: TraversalState) -> int:
153
+ """
154
+ Return the stable integer ID for a traversal state, creating one if needed.
155
+ """
156
+ state_id = self.state_ids.get(state)
157
+
158
+ if state_id is None:
159
+ state_id = len(self.states)
160
+ self.state_ids[state] = state_id
161
+ self.states.append(state)
162
+ self.totals_by_state_id.append(None)
163
+ self.choices_by_state_id.append(None)
164
+ self.child_results_by_state_id.append(None)
165
+
166
+ return state_id
167
+
168
+ def _initial_state(self) -> TraversalState:
169
+ """
170
+ Return the starting traversal state.
171
+
172
+ The initial state starts at the graph's initial node, with fresh banned-
173
+ sequence and path-constraint states if those systems are active.
174
+
175
+ Returns
176
+ -------
177
+ TraversalState
178
+ Starting state for graph compilation.
179
+ """
180
+ if self.has_path_constraint:
181
+ constraint_state = self.path_constraint.initial_state
182
+ else:
183
+ constraint_state = ()
184
+
185
+ banned_tracker_state_id = 0
186
+
187
+ return TraversalState(self.graph.initial_node, banned_tracker_state_id, constraint_state)
188
+
189
+ def _compile_from(self, initial_state_id: int) -> None:
190
+ """
191
+ Compile every reachable traversal state starting from an initial state ID.
192
+
193
+ Uses an explicit depth-first stack so that each non-final state is compiled
194
+ only after its child states have been compiled.
195
+
196
+ Parameters
197
+ ----------
198
+ initial_state_id
199
+ ID of the state from which graph compilation should begin.
200
+ """
201
+ stack = [(initial_state_id, False)]
202
+
203
+ while stack:
204
+ state_id, expanded = stack.pop()
205
+ state = self.states[state_id]
206
+ node = state.node
207
+
208
+ if self.totals_by_state_id[state_id] is not None:
209
+ continue
210
+
211
+ if node is self.graph.final_node:
212
+ self._compile_final_state(state_id)
213
+ continue
214
+
215
+ if not expanded:
216
+ stack.append((state_id, True))
217
+ stack.extend(self._uncompiled_children(state_id))
218
+ continue
219
+
220
+ self._compile_state(state_id)
221
+
222
+ def _compile_final_state(self, state_id: int) -> None:
223
+ """
224
+ Compile a terminal traversal state.
225
+
226
+ By the time a terminal state is reached, all graph choices have already
227
+ been processed, including the right context. Choices rejected by the
228
+ banned-sequence tracker or path constraint would not have reached this
229
+ state.
230
+
231
+ The only remaining decision is whether the final path-constraint state is
232
+ acceptable. If so, this terminal state contributes one complete sequence
233
+ with log mass 0.0. Otherwise, it contributes no sequences.
234
+
235
+ Parameters
236
+ ----------
237
+ state_id
238
+ ID of the terminal traversal state being compiled.
239
+ """
240
+ state = self.states[state_id]
241
+
242
+ if not self.has_path_constraint or self.path_constraint.is_satisfied(state.constraint_state):
243
+ total = (1, 0.0)
244
+ else:
245
+ total = (0, -math.inf)
246
+
247
+ self.totals_by_state_id[state_id] = total
248
+ self.choices_by_state_id[state_id] = {}
249
+
250
+ def _compile_state(self, state_id: int) -> None:
251
+ """
252
+ Compile one non-final traversal state.
253
+
254
+ For each outgoing graph choice, combine the previously compiled child state
255
+ with the contribution from the current node to produce a ChoiceResult.
256
+ The total descendant count and log mass are then cached for the current state.
257
+
258
+ Parameters
259
+ ----------
260
+ state_id
261
+ ID of the traversal state being compiled.
262
+ """
263
+ state = self.states[state_id]
264
+ node = state.node
265
+ choice_results = {}
266
+ descendant_count = 0
267
+ descendant_log_masses = []
268
+ is_coding = isinstance(node, CodonNode)
269
+ child_results = self.child_results_by_state_id[state_id] or ()
270
+
271
+ for choice, child_id in child_results:
272
+ child_state = self.states[child_id]
273
+ child = child_state.node
274
+ child_total = self.totals_by_state_id[child_id]
275
+
276
+ if child_total is None:
277
+ continue
278
+
279
+ child_count, subtree_log_mass = child_total
280
+
281
+ if child_count == 0:
282
+ continue
283
+
284
+ choice_log_mass = self._accumulate_log_mass(node, choice, subtree_log_mass)
285
+
286
+ if choice_log_mass == -math.inf:
287
+ continue
288
+
289
+ result = ChoiceResult(
290
+ choice=choice,
291
+ descendant_count=child_count,
292
+ descendant_log_mass=choice_log_mass,
293
+ next_state_id=None if child is self.graph.final_node else child_id,
294
+ is_coding=is_coding,
295
+ )
296
+
297
+ choice_results[choice] = result
298
+ descendant_count += child_count
299
+ descendant_log_masses.append(choice_log_mass)
300
+
301
+ descendant_log_mass = self._sum_log_masses(descendant_log_masses)
302
+
303
+ total = (descendant_count, descendant_log_mass)
304
+
305
+ self.choices_by_state_id[state_id] = choice_results
306
+ self.totals_by_state_id[state_id] = total
307
+
308
+ def _uncompiled_children(self, state_id: int) -> List[Tuple[int, bool]]:
309
+ """
310
+ Return child state IDs reached by taking each outgoing graph choice.
311
+
312
+ Choices rejected by the banned-sequence tracker or path constraint are skipped.
313
+ Only child states that have not yet been compiled are returned.
314
+
315
+ Parameters
316
+ ----------
317
+ state_id
318
+ ID of the traversal state whose children should be discovered.
319
+
320
+ Returns
321
+ -------
322
+ list of tuple
323
+ Stack entries for child state IDs still needing compilation.
324
+ """
325
+ children = []
326
+ child_results = []
327
+
328
+ state = self.states[state_id]
329
+ node = state.node
330
+
331
+ for choice in self.choices_by_node[node]:
332
+ child = node.transitions.get(choice)
333
+
334
+ if child is None:
335
+ continue
336
+
337
+ advance = self._advance_banned_tracker(state.banned_tracker_state_id, node, choice)
338
+
339
+ if advance.banned:
340
+ continue
341
+
342
+ if self.has_path_constraint:
343
+ next_constraint_state = self.path_constraint.advance(state.constraint_state, node.pos, choice)
344
+
345
+ if next_constraint_state is None:
346
+ continue
347
+ else:
348
+ next_constraint_state = ()
349
+
350
+ child_state = TraversalState(child, advance.state_id, next_constraint_state)
351
+ child_id = self._get_or_register_state_id(child_state)
352
+ child_results.append((choice, child_id))
353
+
354
+ if self.totals_by_state_id[child_id] is None:
355
+ children.append((child_id, False))
356
+
357
+ self.child_results_by_state_id[state_id] = child_results
358
+
359
+ return children
360
+
361
+ def _get_choices_for_node(self, node: Node) -> List[str]:
362
+ """
363
+ Return the graph choices available from a node in this view.
364
+
365
+ Codon nodes respect any pinned codons defined by the view. Context nodes
366
+ always have a single fixed sequence.
367
+
368
+ Parameters
369
+ ----------
370
+ node
371
+ Graph node whose available choices are required.
372
+
373
+ Returns
374
+ -------
375
+ list of str
376
+ The choices that may be taken from this node in the current view.
377
+ """
378
+ if isinstance(node, CodonNode):
379
+ if node.pos in self.view.pinned_codons:
380
+ return self.view.pinned_codons[node.pos]
381
+ else:
382
+ return node.codons
383
+
384
+ elif isinstance(node, ContextNode):
385
+ return [node.sequence]
386
+
387
+ def _advance_banned_tracker(
388
+ self,
389
+ banned_tracker_state_id: BannedTrackerStateId,
390
+ node: Node,
391
+ choice: str,
392
+ ) -> BannedTrackerAdvanceResult:
393
+ """
394
+ Advance the banned-sequence tracker after taking one graph choice.
395
+
396
+ Results are cached because the same tracker transition may be encountered
397
+ from many traversal states during compilation.
398
+
399
+ Parameters
400
+ ----------
401
+ banned_tracker_state_id
402
+ Current banned-sequence tracker state.
403
+ node
404
+ Current graph node.
405
+ choice
406
+ Graph choice taken from the current node.
407
+
408
+ Returns
409
+ -------
410
+ BannedTrackerAdvanceResult
411
+ Whether the choice enters a banned state and the resulting tracker
412
+ state.
413
+ """
414
+ key = (node, banned_tracker_state_id, choice)
415
+
416
+ if key in self.banned_advance_cache:
417
+ return self.banned_advance_cache[key]
418
+
419
+ if not self.has_banned_tracker:
420
+ result = BannedTrackerAdvanceResult(banned=False, state_id=banned_tracker_state_id)
421
+ else:
422
+ step = (node.pos, choice)
423
+ result = self.banned_tracker.advance(step, banned_tracker_state_id)
424
+
425
+ self.banned_advance_cache[key] = result
426
+ return result
427
+
428
+ def _accumulate_log_mass(
429
+ self,
430
+ node: Node,
431
+ choice: str,
432
+ subtree_log_mass: float,
433
+ ) -> float:
434
+ """
435
+ Accumulate the log probability mass contributed by one graph choice.
436
+
437
+ The subtree log mass has already been computed for the child state. Codon
438
+ nodes contribute the log of their codon weight, whereas context nodes
439
+ contribute no additional mass.
440
+
441
+ Parameters
442
+ ----------
443
+ node
444
+ The graph node from which the choice is taken.
445
+ choice
446
+ The outgoing graph choice.
447
+ subtree_log_mass
448
+ The total log mass reachable from the child state.
449
+
450
+ Returns
451
+ -------
452
+ float
453
+ The total log mass reachable after taking this choice.
454
+ """
455
+ if isinstance(node, CodonNode):
456
+ codon_log_weight = self.log_codon_weights.get(choice)
457
+
458
+ if codon_log_weight is None:
459
+ return -math.inf
460
+
461
+ return codon_log_weight + subtree_log_mass
462
+
463
+ return subtree_log_mass
464
+
465
+ def _sum_log_masses(self, log_masses: List[float]) -> float:
466
+ """
467
+ Combine several subtree log masses into a single log mass.
468
+
469
+ The calculation is performed using the log-sum-exp trick to avoid numerical
470
+ underflow when the subtree probabilities are extremely small.
471
+
472
+ Parameters
473
+ ----------
474
+ log_masses
475
+ Log-space masses to sum.
476
+
477
+ Returns
478
+ -------
479
+ float
480
+ The log of the summed masses, or -inf if no finite masses exist.
481
+ """
482
+ if not log_masses:
483
+ return -math.inf
484
+
485
+ max_log_mass = max(log_masses)
486
+
487
+ total_relative_mass = sum(math.exp(log_mass - max_log_mass) for log_mass in log_masses)
488
+
489
+ return max_log_mass + math.log(total_relative_mass)
codeine/graph/nodes.py ADDED
@@ -0,0 +1,111 @@
1
+ from typing import Sequence
2
+
3
+
4
+ class Node:
5
+ """
6
+ Basic CodonGraph node.
7
+ """
8
+
9
+ def __init__(self) -> None:
10
+ """
11
+ Constructor for the Node class.
12
+ """
13
+ self.parents = set()
14
+ self.transitions = {}
15
+ self.pos = None
16
+
17
+
18
+ class ContextNode(Node):
19
+ """
20
+ Basic class representing a sequence context node on the codon graph.
21
+ This refers to the sequence either to the left of or to the right of
22
+ the coding sequence, and can be empty.
23
+ """
24
+
25
+ def __init__(self, pos: int, sequence: str) -> None:
26
+ """
27
+ Constructor for the ContextNode class.
28
+
29
+ Parameters
30
+ ----------
31
+ pos
32
+ The graph position. Left context is 0; right context is len(aa_seq) + 1.
33
+ sequence
34
+ The context sequence contained on this node.
35
+ """
36
+ super().__init__()
37
+
38
+ # Basic info.
39
+ self.pos = pos
40
+ self.sequence = sequence
41
+
42
+ # Set an ID for this node.
43
+ self.id = f'context-{pos}'
44
+
45
+ def __repr__(self) -> str:
46
+ return (
47
+ f'ContextNode('
48
+ f'id={self.id}'
49
+ f', pos={self.pos}'
50
+ f')'
51
+ )
52
+
53
+
54
+ class CodonNode(Node):
55
+ """
56
+ Basic class representing a codon node on the codon graph.
57
+ """
58
+
59
+ def __init__(self, pos: int, aa: str, codons: Sequence[str]) -> None:
60
+ """
61
+ Constructor for the CodonNode class.
62
+
63
+ Parameters
64
+ ----------
65
+ pos
66
+ The aa position. Positioning is 1-based.
67
+ aa
68
+ The aa identity.
69
+ codons
70
+ The possible codons for this node.
71
+ """
72
+ super().__init__()
73
+
74
+ # Basic info. Positioning is 1-based.
75
+ self.pos = pos
76
+ self.aa = aa
77
+
78
+ # Set an ID for this node.
79
+ self.id = f'{aa}{pos}'
80
+
81
+ # Initialise the basic attributes.
82
+ self.codons = tuple(codons)
83
+
84
+ def __repr__(self) -> str:
85
+ codons = ','.join(self.codons)
86
+
87
+ return (
88
+ f'CodonNode('
89
+ f'id={self.id}'
90
+ f', pos={self.pos}'
91
+ #f', codons=[{codons}]'
92
+ f')'
93
+ )
94
+
95
+
96
+ class EndNode(Node):
97
+ """
98
+ Final node for the codon graph.
99
+
100
+ Marks successful completion of a graph walk.
101
+ """
102
+
103
+ def __init__(self) -> None:
104
+ """
105
+ Constructor for the EndNode class.
106
+ """
107
+ super().__init__()
108
+ self.id = 'end'
109
+
110
+ def __repr__(self) -> str:
111
+ return (f'EndNode(id={self.id})')