codeine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codeine/graph/view.py ADDED
@@ -0,0 +1,781 @@
1
+ import math
2
+ import random
3
+
4
+ from itertools import islice
5
+ from typing import Dict, Generator, List, Optional, Sequence, Union
6
+
7
+ from codeine.constraints.base import PathConstraint
8
+ from codeine.constraints.banned import BannedSequenceTracker
9
+ from codeine.graph.base import CodonGraph, CodonRestriction
10
+ from codeine.graph.nodes import CodonNode
11
+ from codeine.translation.tables import TranslationTable
12
+ from codeine.translation.weights import CodonWeights
13
+ from codeine.utils.display import format_forbidden_motifs, format_count, format_restrictions
14
+ from codeine.utils.sampling import Seedable, Sampler
15
+ from codeine.graph.compile import ViewCompiler
16
+
17
+
18
+ class CodonGraphView:
19
+ """
20
+ View of a codon graph. The view allows optional temporary constraints such as pinned codons
21
+ and banned nucleotide sequences to be added without affecting the underlying codon graph.
22
+
23
+ It is on this object that most operations (counting, sampling, enumeration....) take place.
24
+ """
25
+
26
+ def __init__(self,
27
+ graph: CodonGraph,
28
+ banned_sequences: Optional[Sequence[str]] = None,
29
+ seed: Seedable = None,
30
+ ) -> None:
31
+ """
32
+ Constructor for the CodonGraphView
33
+
34
+ Parameters
35
+ ----------
36
+ graph
37
+ The underlying codon graph.
38
+ banned_sequences
39
+ Nucleotide sequences that are forbidden in this view.
40
+ seed
41
+ Seed used to initialise a random number generator, if not providing an RNG.
42
+ """
43
+ self._rng = random.Random(seed)
44
+
45
+ self.graph = graph
46
+ self.pinned_codons: Dict[int, List[str]] = {}
47
+ self.banned_sequences: List[str] = self._validate_banned_sequences(banned_sequences)
48
+ self.path_constraint: Optional[PathConstraint] = None
49
+
50
+ self.banned_tracker = BannedSequenceTracker(self.graph, self.banned_sequences)
51
+
52
+ self._compiled = None
53
+ self._requires_compile = True
54
+
55
+ self.initial_state_id = None
56
+ self.choices_by_state_id = ()
57
+ self.choice_results_by_state_id = ()
58
+ self.samplers_by_state_id = []
59
+
60
+ def __getitem__(self, index: Union[int, slice]) -> Union[str, List[str]]:
61
+ """
62
+ Return one valid sequence, or a list of valid sequences for a slice.
63
+
64
+ Parameters
65
+ ----------
66
+ index
67
+ Zero-based sequence index, or slice of sequence indices.
68
+
69
+ Returns
70
+ -------
71
+ str or list of str
72
+ The indexed valid coding sequence, or a list of valid coding sequences.
73
+ """
74
+ if isinstance(index, slice):
75
+ return self.sequences_at(index)
76
+
77
+ return self.sequence_at(index)
78
+
79
+ def __iter__(self) -> Generator[str, None, None]:
80
+ """
81
+ Iterate over all valid sequences in this graph view.
82
+
83
+ Yields
84
+ ----------
85
+ All valid sequences in the graph view, in order.
86
+ """
87
+ yield from self.enumerate()
88
+
89
+ def __contains__(self, seq: str) -> bool:
90
+ """
91
+ Does the given seq exist in this space?
92
+
93
+ Returns
94
+ ----------
95
+ True if and only if this is a valid sequence in this space.
96
+ """
97
+ return self.contains(seq)
98
+
99
+ def __repr__(self) -> str:
100
+ if self._requires_compile:
101
+ self.compile()
102
+
103
+ molecule = 'RNA' if self.graph.tt.rna else 'DNA'
104
+
105
+ lines = [
106
+ f'{type(self).__name__}',
107
+ '',
108
+ f'Translation table: {self.graph.tt.table_id} ({self.graph.tt.name})',
109
+ f'Molecule type: {molecule}',
110
+ '',
111
+ f'Amino acid sequence ({len(self.aa_seq)} aa)',
112
+ f'{self.aa_seq}',
113
+ ''
114
+ ]
115
+
116
+ if self.graph.codon_restrictions:
117
+ lines += [
118
+ 'Codon restrictions:',
119
+ *format_restrictions(
120
+ self.graph.codon_restrictions,
121
+ label='restricted positions',
122
+ ),
123
+ '',
124
+ ]
125
+
126
+ if self.banned_sequences:
127
+ lines += [
128
+ 'Banned sequences:',
129
+ *format_forbidden_motifs(
130
+ self.banned_sequences,
131
+ ),
132
+ '',
133
+ ]
134
+
135
+ if self.pinned_codons:
136
+ lines += [
137
+ 'Temporary pins:',
138
+ *format_restrictions(
139
+ self.pinned_codons,
140
+ label='pinned positions',
141
+ ),
142
+ '',
143
+ ]
144
+
145
+ lines.append(f'Num. valid coding sequences: {format_count(self.n_valid_sequences)}')
146
+
147
+ return '\n'.join(lines)
148
+
149
+ def contains(self, seq: str) -> bool:
150
+ """
151
+ Check whether a coding sequence is contained in this view.
152
+
153
+ Parameters
154
+ ----------
155
+ seq
156
+ The sequence to check
157
+
158
+ Returns
159
+ -------
160
+ True if and only if the sequence is contained in this coding space.
161
+ """
162
+ if self._requires_compile:
163
+ self.compile()
164
+
165
+ seq = self.translation_table.normalise_sequence(seq)
166
+
167
+ if len(seq) != len(self.graph.aa_seq) * 3:
168
+ return False
169
+
170
+ state_id = self.initial_state_id
171
+ choices_by_state_id = self.choices_by_state_id
172
+ states = self._compiled.states
173
+
174
+ while state_id is not None:
175
+
176
+ state = states[state_id]
177
+ node = state.node
178
+
179
+ if isinstance(node, CodonNode):
180
+ start = (node.pos - 1) * 3
181
+ choice = seq[start:start + 3]
182
+ else:
183
+ choice = node.sequence
184
+
185
+ result = choices_by_state_id[state_id].get(choice)
186
+
187
+ if result is None:
188
+ return False
189
+
190
+ state_id = result.next_state_id
191
+
192
+ return True
193
+
194
+ def sample(self, n: Optional[int] = None) -> Union[str, List[str]]:
195
+ """
196
+ Sample one or more coding sequences from this graph view.
197
+
198
+ Parameters
199
+ ----------
200
+ n
201
+ Number of sequences to sample. If omitted, return a single sequence.
202
+
203
+ Returns
204
+ -------
205
+ str or list of str
206
+ One sampled coding sequence, or a list of sampled coding sequences.
207
+ """
208
+ if self._requires_compile:
209
+ self.compile()
210
+
211
+ if self.n_valid_sequences == 0:
212
+ raise ValueError('Cannot sample from an empty coding space.')
213
+
214
+ if n is None:
215
+ return self._sample()
216
+
217
+ if n < 0:
218
+ raise ValueError('n must be non-negative.')
219
+
220
+ return [self._sample() for _ in range(n)]
221
+
222
+ def enumerate(self) -> Generator[str, None, None]:
223
+ """
224
+ Enumerate all valid sequences in this view.
225
+
226
+ Yields
227
+ ------
228
+ str
229
+ All valid coding sequences, one by one.
230
+ """
231
+ if self._requires_compile:
232
+ self.compile()
233
+
234
+ yield from self._iter_all_sequences()
235
+
236
+ def enumerate_range(self, start: int = 0, stop: Optional[int] = None) -> Generator[str, None, None]:
237
+ """
238
+ Enumerate valid sequences from start up to, but not including, stop.
239
+
240
+ Parameters
241
+ ----------
242
+ start
243
+ The zero-based start from which to begin enumeration
244
+ stop
245
+ The zero-based enumeration stop.
246
+
247
+ Yields
248
+ -------
249
+ str
250
+ Sequences in the range, one by one.
251
+ """
252
+ if self._requires_compile:
253
+ self.compile()
254
+
255
+ n_sequences = self.n_valid_sequences
256
+
257
+ if stop is None:
258
+ stop = n_sequences
259
+
260
+ if start < 0 or stop < start or stop > n_sequences:
261
+ raise IndexError('Enumeration range is out of bounds.')
262
+
263
+ if start == stop:
264
+ return
265
+
266
+ if start == 0 and stop == n_sequences:
267
+ yield from self._iter_all_sequences()
268
+ return
269
+
270
+ if start == 0:
271
+ yield from islice(self._iter_all_sequences(), stop)
272
+ return
273
+
274
+ yield from self._iter_sequence_range(start, stop)
275
+
276
+ def sequence_at(self, index: int) -> str:
277
+ """
278
+ Return the valid sequence at a given index.
279
+
280
+ Parameters
281
+ ----------
282
+ index
283
+ Zero-based sequence index.
284
+
285
+ Returns
286
+ -------
287
+ str
288
+ The indexed valid coding sequence.
289
+ """
290
+ if self._requires_compile:
291
+ self.compile()
292
+
293
+ if index < 0 or index >= self.n_valid_sequences:
294
+ raise IndexError(
295
+ f'Sequence index {index} out of range for '
296
+ f'{self.n_valid_sequences} valid sequences.'
297
+ )
298
+
299
+ return self._sequence_at(index)
300
+
301
+ def sequences_at(self, index_slice: slice) -> List[str]:
302
+ """
303
+ Return valid sequences from a slice.
304
+ """
305
+ if self._requires_compile:
306
+ self.compile()
307
+
308
+ n_sequences = self.n_valid_sequences
309
+ start, stop, step = index_slice.indices(n_sequences)
310
+
311
+ if start == stop:
312
+ return []
313
+
314
+ if step != 1:
315
+ return [self.sequence_at(index) for index in range(start, stop, step)]
316
+
317
+ if start == 0 and stop == n_sequences:
318
+ return [*self._iter_all_sequences()]
319
+
320
+ if start == 0:
321
+ return [*islice(self._iter_all_sequences(), stop)]
322
+
323
+ return [*self._iter_sequence_range(start, stop)]
324
+
325
+ def copy(self) -> 'CodonGraphView':
326
+ """
327
+ Copy this view and all its constraints and attributes.
328
+
329
+ Returns
330
+ -------
331
+ A copy of the view.
332
+ """
333
+ view = self.graph.view()
334
+ view._rng.setstate(self._rng.getstate())
335
+
336
+ view.pinned_codons = self.pinned_codons.copy()
337
+ view.banned_sequences = self.banned_sequences.copy()
338
+ view.path_constraint = self.path_constraint
339
+ view.banned_tracker = self.banned_tracker
340
+
341
+ view._compiled = self._compiled
342
+ view._requires_compile = self._requires_compile
343
+
344
+ view.initial_state_id = self.initial_state_id
345
+ view.choices_by_state_id = self.choices_by_state_id
346
+ view.choice_results_by_state_id = self.choice_results_by_state_id
347
+ view.samplers_by_state_id = [None] * len(self.samplers_by_state_id)
348
+
349
+ return view
350
+
351
+ def compile(self) -> None:
352
+ """
353
+ Calculate all graph properties that are derived from its structure plus constraints
354
+ such as pins and banned sequences.
355
+
356
+ Remember to do this after editing any constraints!
357
+ """
358
+ compiler = ViewCompiler(self)
359
+ compiled = compiler.compile()
360
+
361
+ self._compiled = compiled
362
+ self._requires_compile = False
363
+
364
+ self.initial_state_id = compiled.initial_state_id
365
+ self.choices_by_state_id = compiled.choices_by_state_id
366
+ self.choice_results_by_state_id = compiled.choice_results_by_state_id
367
+ self.samplers_by_state_id = [None] * len(compiled.states)
368
+
369
+ def pin_codons(self, pinned_codons: Dict[int, CodonRestriction]) -> None:
370
+ """
371
+ Pin (temporarily fix) a codon in this codon graph view
372
+
373
+ Parameters
374
+ ----------
375
+ pinned_codons
376
+ A dict specifying which codons to pin, by pos: codon.
377
+ """
378
+ pinned_codons = self.graph.validate_codon_restrictions(pinned_codons)
379
+ self.pinned_codons.update(pinned_codons)
380
+ self._requires_compile = True
381
+
382
+ def unpin_codons(self, positions: Sequence[int]) -> None:
383
+ """
384
+ Unpin codon nodes by pos.
385
+
386
+ Parameters
387
+ ----------
388
+ positions
389
+ A list of positions to unpin.
390
+ """
391
+ for pos in positions:
392
+ if pos < 1 or pos > len(self.graph.codon_nodes):
393
+ raise ValueError(f'Pinned codon position {pos} is out of range.')
394
+
395
+ self.pinned_codons.pop(pos, None)
396
+
397
+ self._requires_compile = True
398
+
399
+ def set_pinned_codons(self, pinned_codons: Dict[int, CodonRestriction]) -> None:
400
+ """
401
+ Pin (temporarily fix) a specified group codons, leaving all others unpinned.
402
+
403
+ Parameters
404
+ ----------
405
+ pinned_codons:
406
+ A dict specifying which codons to pin, by pos: codon
407
+ """
408
+ pinned_codons = self.graph.validate_codon_restrictions(pinned_codons)
409
+ self.pinned_codons = dict(pinned_codons)
410
+ self._requires_compile = True
411
+
412
+ def clear_pins(self) -> None:
413
+ """
414
+ Remove all codon pins from this graph view
415
+ """
416
+ self.pinned_codons.clear()
417
+ self._requires_compile = True
418
+
419
+ def set_banned_sequences(self, banned_sequences: Sequence[str]) -> None:
420
+ """
421
+ Set banned nucleotide sequences for this view.
422
+
423
+ Banned-sequence tracking depends only on the graph and banned sequences,
424
+ not on temporary pins, so it is rebuilt only when the banned list changes.
425
+ """
426
+ self.banned_sequences = self._validate_banned_sequences(banned_sequences)
427
+ self.banned_tracker = BannedSequenceTracker(self.graph, self.banned_sequences)
428
+ self._requires_compile = True
429
+
430
+ def clear_banned_sequences(self) -> None:
431
+ """
432
+ Remove all banned sequence restrictions from this view.
433
+ """
434
+ self.set_banned_sequences([])
435
+
436
+ def set_path_constraint(self, path_constraint: Optional[PathConstraint]) -> None:
437
+ """
438
+ Set an additional generic path constraint for this view.
439
+
440
+ Pass None to remove any path constraint.
441
+ """
442
+ self.path_constraint = path_constraint
443
+ self._requires_compile = True
444
+
445
+ def clear_path_constraint(self) -> None:
446
+ """
447
+ Remove the additional generic path constraint from this view.
448
+ """
449
+ self.set_path_constraint(None)
450
+
451
+ @property
452
+ def aa_seq(self) -> str:
453
+ """
454
+ The amino acid sequence.
455
+
456
+ Returns
457
+ -------
458
+ The aa seq.
459
+ """
460
+ return self.graph.aa_seq
461
+
462
+ @property
463
+ def translation_table(self) -> TranslationTable:
464
+ """
465
+ The translation table used by the codon graph.
466
+ """
467
+ return self.graph.tt
468
+
469
+ @property
470
+ def codon_weights(self) -> CodonWeights:
471
+ """
472
+ The codon weights used by the codon graph.
473
+ """
474
+ return self.graph.cw
475
+
476
+ @property
477
+ def codon_restrictions(self) -> Dict[int, CodonRestriction]:
478
+ """
479
+ Any hard-fixed codon restrictions on the codon graph.
480
+ """
481
+ return self.graph.codon_restrictions
482
+
483
+ @property
484
+ def context_l(self) -> str:
485
+ """
486
+ The left context sequence.
487
+ """
488
+ return self.graph.context_l
489
+
490
+ @property
491
+ def context_r(self) -> str:
492
+ """
493
+ The right context sequence.
494
+ """
495
+ return self.graph.context_r
496
+
497
+ @property
498
+ def n_valid_sequences(self) -> int:
499
+ """
500
+ Number of valid coding sequences in this view given all constraints.
501
+ """
502
+ if self._requires_compile:
503
+ self.compile()
504
+
505
+ return self._compiled.n_valid_sequences
506
+
507
+ def _validate_banned_sequences(self, banned_sequences: Optional[Sequence[str]]) -> List[str]:
508
+ """
509
+ Check the inputted banned sequences make sense, and return normalised versions of them.
510
+
511
+ Parameters
512
+ ----------
513
+ banned_sequences
514
+ The list of banned sequences.
515
+
516
+ Returns
517
+ -------
518
+ A normalised, de-duplicated list of banned sequences.
519
+ """
520
+ banned_sequences = banned_sequences or []
521
+
522
+ normalised = []
523
+ for sequence in banned_sequences:
524
+ sequence = self.translation_table.normalise_sequence(sequence)
525
+
526
+ if len(sequence) == 0:
527
+ raise ValueError('Banned sequences cannot be empty.')
528
+
529
+ normalised.append(sequence)
530
+
531
+ return sorted(set(normalised))
532
+
533
+ def _sampler_for_state_id(self, state_id: int) -> Optional[Sampler]:
534
+ """
535
+ Return the weighted sampler for one compiled traversal state.
536
+
537
+ Samplers are created lazily because many compiled states may never be
538
+ visited during sampling.
539
+
540
+ Parameters
541
+ ----------
542
+ state_id
543
+ The state ID.
544
+
545
+ Returns
546
+ -------
547
+ Sampler or None
548
+ The cached sampler for this state, or None if the state has no
549
+ valid choices.
550
+ """
551
+ sampler = self.samplers_by_state_id[state_id]
552
+
553
+ if sampler is not None:
554
+ return sampler
555
+
556
+ choice_results = self.choice_results_by_state_id[state_id]
557
+
558
+ if not choice_results:
559
+ return None
560
+
561
+ runtime_items = []
562
+ runtime_log_masses = []
563
+
564
+ for result in choice_results:
565
+ runtime_items.append((result.choice, result.is_coding, result.next_state_id))
566
+ runtime_log_masses.append(result.descendant_log_mass)
567
+
568
+ runtime_weights = self._convert_log_masses_to_sampler_weights(runtime_log_masses)
569
+ sampler = Sampler(runtime_items, runtime_weights, rng=self._rng)
570
+
571
+ self.samplers_by_state_id[state_id] = sampler
572
+
573
+ return sampler
574
+
575
+ def _convert_log_masses_to_sampler_weights(self, log_masses: List[float]) -> List[float]:
576
+ """
577
+ Convert subtree log masses into relative weights for sampling.
578
+
579
+ The returned weights are proportional to the true subtree probabilities but
580
+ are rescaled to avoid numerical underflow. Only the relative values matter
581
+ for weighted sampling.
582
+
583
+ Parameters
584
+ ----------
585
+ log_masses
586
+ Choice masses represented in log space.
587
+
588
+ Returns
589
+ -------
590
+ list of float
591
+ Relative non-log weights suitable for weighted sampling.
592
+ """
593
+ if not log_masses:
594
+ return log_masses
595
+
596
+ max_log_mass = max(log_masses)
597
+
598
+ if max_log_mass == -math.inf:
599
+ return [1.0] * len(log_masses)
600
+
601
+ return [math.exp(log_mass - max_log_mass) for log_mass in log_masses]
602
+
603
+ def _sample(self) -> str:
604
+ """
605
+ Sample one coding sequence from an already-compiled graph view.
606
+
607
+ Returns
608
+ -------
609
+ A sampled sequence.
610
+ """
611
+ state_id = self.initial_state_id
612
+ sequence = []
613
+
614
+ while state_id is not None:
615
+ sampler = self._sampler_for_state_id(state_id)
616
+
617
+ if sampler is None:
618
+ raise ValueError('Cannot sample from a state with no valid choices.')
619
+
620
+ choice, is_coding, state_id = sampler.sample()
621
+
622
+ if is_coding:
623
+ sequence.append(choice)
624
+
625
+ return ''.join(sequence)
626
+
627
+ def _sequence_at(self, index: int) -> str:
628
+ """
629
+ Return one valid sequence by directly descending through descendant counts.
630
+
631
+ Parameters
632
+ ----------
633
+ index
634
+ The index of the sequence in the graph.
635
+
636
+ Returns
637
+ -------
638
+ The sequence at the desired index.
639
+ """
640
+ state_id = self.initial_state_id
641
+ choice_results_by_state_id = self.choice_results_by_state_id
642
+ sequence_parts = []
643
+
644
+ while state_id is not None:
645
+ results = choice_results_by_state_id[state_id]
646
+
647
+ if not results:
648
+ break
649
+
650
+ if not results[0].is_coding:
651
+ state_id = results[0].next_state_id
652
+ continue
653
+
654
+ for result in results:
655
+ descendant_count = result.descendant_count
656
+
657
+ if index < descendant_count:
658
+ sequence_parts.append(result.choice)
659
+ state_id = result.next_state_id
660
+ break
661
+
662
+ index -= descendant_count
663
+ else:
664
+ raise RuntimeError('Invalid sequence index traversal state.')
665
+
666
+ return ''.join(sequence_parts)
667
+
668
+ def _iter_all_sequences(self) -> Generator[str, None, None]:
669
+ """
670
+ Iterate over all valid sequences. Faster than _iter_sequence_range when
671
+ we're starting at 0.
672
+
673
+ Yields
674
+ ------
675
+ str
676
+ All valid coding sequences, one by one
677
+ """
678
+ # Stack is:
679
+ # (
680
+ # state,
681
+ # coding sequence constructed so far,
682
+ # )
683
+ choice_results_by_state_id = self.choice_results_by_state_id
684
+ sequence_parts = [''] * len(self.graph.aa_seq)
685
+
686
+ stack = [(self.initial_state_id, 0, None)]
687
+
688
+ while stack:
689
+ state_id, codon_index, choice = stack.pop()
690
+
691
+ if choice is not None:
692
+ sequence_parts[codon_index - 1] = choice
693
+
694
+ if state_id is None:
695
+ yield ''.join(sequence_parts)
696
+ continue
697
+
698
+ results = choice_results_by_state_id[state_id]
699
+
700
+ if not results:
701
+ continue
702
+
703
+ if not results[0].is_coding:
704
+ stack.append((results[0].next_state_id, codon_index, None))
705
+ continue
706
+
707
+ next_codon_index = codon_index + 1
708
+
709
+ for result in reversed(results):
710
+ stack.append((result.next_state_id, next_codon_index, result.choice))
711
+
712
+ def _iter_sequence_range(
713
+ self,
714
+ start: int,
715
+ stop: int,
716
+ ) -> Generator[str, None, None]:
717
+ """
718
+ Iterate over valid sequences in a given index range.
719
+
720
+ Parameters
721
+ ----------
722
+ start
723
+ 0-based index of the first sequence.
724
+ stop
725
+ 0-based index one past the final sequence.
726
+
727
+ Yields
728
+ ------
729
+ str
730
+ Valid coding sequences in the requested range.
731
+ """
732
+ # Stack is:
733
+ # (
734
+ # state,
735
+ # sequence constructed so far,
736
+ # 0-based index of the first sequence reachable from that state.
737
+ # )
738
+ choice_results_by_state_id = self.choice_results_by_state_id
739
+ sequence_parts = [''] * len(self.graph.aa_seq)
740
+
741
+ stack = [(self.initial_state_id, 0, None, 0)]
742
+
743
+ while stack:
744
+ state_id, codon_index, choice, offset = stack.pop()
745
+
746
+ if choice is not None:
747
+ sequence_parts[codon_index - 1] = choice
748
+
749
+ if state_id is None:
750
+ if start <= offset < stop:
751
+ yield ''.join(sequence_parts)
752
+ continue
753
+
754
+ results = choice_results_by_state_id[state_id]
755
+
756
+ if not results:
757
+ continue
758
+
759
+ if not results[0].is_coding:
760
+ stack.append((results[0].next_state_id, codon_index, None, offset))
761
+ continue
762
+
763
+ next_codon_index = codon_index + 1
764
+ child_start = offset
765
+ push = []
766
+
767
+ for result in results:
768
+ child_stop = child_start + result.descendant_count
769
+
770
+ if child_stop > start and child_start < stop:
771
+ push.append((result, child_start))
772
+
773
+ child_start = child_stop
774
+
775
+ for result, child_start in reversed(push):
776
+ stack.append((
777
+ result.next_state_id,
778
+ next_codon_index,
779
+ result.choice,
780
+ child_start,
781
+ ))