codeine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeine/__init__.py +15 -0
- codeine/constraints/banned.py +444 -0
- codeine/constraints/base.py +39 -0
- codeine/constraints/mutations.py +115 -0
- codeine/graph/base.py +267 -0
- codeine/graph/compile.py +489 -0
- codeine/graph/nodes.py +111 -0
- codeine/graph/view.py +781 -0
- codeine/motifs/restriction.py +105 -0
- codeine/motifs/validate.py +117 -0
- codeine/space/__init__.py +0 -0
- codeine/space/coding.py +490 -0
- codeine/space/mutation.py +512 -0
- codeine/translation/__init__.py +0 -0
- codeine/translation/data/__init__.py +0 -0
- codeine/translation/data/tables.json +2252 -0
- codeine/translation/data/weights.py +232 -0
- codeine/translation/tables.py +200 -0
- codeine/translation/weights.py +323 -0
- codeine/utils/__init__.py +0 -0
- codeine/utils/dict.py +23 -0
- codeine/utils/display.py +124 -0
- codeine/utils/sampling.py +90 -0
- codeine-0.1.0.dist-info/METADATA +162 -0
- codeine-0.1.0.dist-info/RECORD +28 -0
- codeine-0.1.0.dist-info/WHEEL +5 -0
- codeine-0.1.0.dist-info/licenses/LICENSE +21 -0
- codeine-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
from codeine.translation.tables import TranslationTable
|
|
4
|
+
from codeine.utils.dict import FrozenDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
WeightDict = Dict[str, Dict[str, Union[float, int]]]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CodonWeights:
|
|
11
|
+
"""
|
|
12
|
+
A class to store codon weights, for example codon usage information for a
|
|
13
|
+
specific organism.
|
|
14
|
+
|
|
15
|
+
Input weights are grouped by amino acid:
|
|
16
|
+
|
|
17
|
+
.. code-block:: python
|
|
18
|
+
|
|
19
|
+
{
|
|
20
|
+
'A': {'GCT': 1.0, 'GCC': 1.0, ...},
|
|
21
|
+
'R': {'CGT': 1.0, 'CGC': 1.0, ...},
|
|
22
|
+
...
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
Stored weights are flat:
|
|
26
|
+
|
|
27
|
+
.. code-block:: python
|
|
28
|
+
|
|
29
|
+
{
|
|
30
|
+
'GCT': 0.25,
|
|
31
|
+
'GCC': 0.25,
|
|
32
|
+
...
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
Weights are normalised per amino acid.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
weights: WeightDict,
|
|
41
|
+
rna: Optional[bool] = None,
|
|
42
|
+
table: Optional[TranslationTable] = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
weights
|
|
48
|
+
Codon weights grouped by amino acid, for codons in the TranslationTable.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
|
|
52
|
+
.. code-block:: python
|
|
53
|
+
|
|
54
|
+
{
|
|
55
|
+
'A': {'GCT': 1.0, 'GCC': 1.0, 'GCA': 1.0, 'GCG': 1.0},
|
|
56
|
+
'R': {'CGT': 1.0, 'CGC': 1.0, 'CGA': 1.0,
|
|
57
|
+
'CGG': 1.0, 'AGA': 1.0, 'AGG': 1.0},
|
|
58
|
+
...
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
rna
|
|
62
|
+
Whether to use rna (default is False, i.e. use DNA).
|
|
63
|
+
"""
|
|
64
|
+
self._locked = False
|
|
65
|
+
|
|
66
|
+
if table is None:
|
|
67
|
+
table = TranslationTable(rna=False if rna is None else rna)
|
|
68
|
+
|
|
69
|
+
elif rna is not None and table.rna != rna:
|
|
70
|
+
raise ValueError('table and rna specify inconsistent molecule types.')
|
|
71
|
+
|
|
72
|
+
expected_aa = set(table.aa_to_codons)
|
|
73
|
+
observed_aa = {aa.upper() for aa in weights}
|
|
74
|
+
|
|
75
|
+
missing_aa = expected_aa - observed_aa
|
|
76
|
+
if missing_aa:
|
|
77
|
+
raise ValueError(f'Missing weights for amino acid(s): {sorted(missing_aa)}')
|
|
78
|
+
|
|
79
|
+
extra_aa = observed_aa - expected_aa
|
|
80
|
+
if extra_aa:
|
|
81
|
+
raise ValueError(f'Unknown amino acid(s): {sorted(extra_aa)}')
|
|
82
|
+
|
|
83
|
+
weights_flat: Dict[str, float] = {}
|
|
84
|
+
aa_to_codons: Dict[str, Tuple[str, ...]] = {}
|
|
85
|
+
|
|
86
|
+
for aa, codon_weights in weights.items():
|
|
87
|
+
aa = aa.upper()
|
|
88
|
+
|
|
89
|
+
normalised_codon_weights = {
|
|
90
|
+
table.normalise_sequence(codon): weight
|
|
91
|
+
for codon, weight in codon_weights.items()
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
expected_codons = set(table.aa_to_codons[aa])
|
|
95
|
+
observed_codons = set(normalised_codon_weights.keys())
|
|
96
|
+
|
|
97
|
+
missing_codons = expected_codons - observed_codons
|
|
98
|
+
if missing_codons:
|
|
99
|
+
raise ValueError(f'Missing weights for codon(s) for amino acid {aa}: {sorted(missing_codons)}')
|
|
100
|
+
|
|
101
|
+
extra_codons = observed_codons - expected_codons
|
|
102
|
+
if extra_codons:
|
|
103
|
+
raise ValueError(f'Unknown codon(s) for amino acid {aa}: {sorted(extra_codons)}')
|
|
104
|
+
|
|
105
|
+
total = sum(normalised_codon_weights.values())
|
|
106
|
+
if total <= 0:
|
|
107
|
+
raise ValueError(f'Weights for amino acid {aa} must sum to > 0')
|
|
108
|
+
|
|
109
|
+
codons = []
|
|
110
|
+
for codon, weight in normalised_codon_weights.items():
|
|
111
|
+
if weight < 0:
|
|
112
|
+
raise ValueError(f'Weight for codon {codon} cannot be negative')
|
|
113
|
+
|
|
114
|
+
codons.append(codon)
|
|
115
|
+
weights_flat[codon] = float(weight) / total
|
|
116
|
+
|
|
117
|
+
aa_to_codons[aa] = tuple(codons)
|
|
118
|
+
|
|
119
|
+
self.rna = table.rna
|
|
120
|
+
self.aa_to_codons = FrozenDict(aa_to_codons)
|
|
121
|
+
self.weights = FrozenDict(weights_flat)
|
|
122
|
+
|
|
123
|
+
self._locked = True
|
|
124
|
+
|
|
125
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
126
|
+
if getattr(self, '_locked', False):
|
|
127
|
+
raise AttributeError(f'{type(self).__name__} is immutable')
|
|
128
|
+
super().__setattr__(name, value)
|
|
129
|
+
|
|
130
|
+
def __repr__(self) -> str:
|
|
131
|
+
molecule = 'RNA' if self.rna else 'DNA'
|
|
132
|
+
|
|
133
|
+
lines = [
|
|
134
|
+
f'CodonWeights',
|
|
135
|
+
f'Molecule type: {molecule}',
|
|
136
|
+
'',
|
|
137
|
+
'Weights:',
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
for aa in sorted(self.aa_to_codons):
|
|
141
|
+
weights = ' '.join(
|
|
142
|
+
f'{codon}={self.weights[codon]:.3f}'
|
|
143
|
+
for codon in self.aa_to_codons[aa]
|
|
144
|
+
)
|
|
145
|
+
lines.append(f' {aa}: {weights}')
|
|
146
|
+
|
|
147
|
+
return '\n'.join(lines)
|
|
148
|
+
|
|
149
|
+
def __getitem__(self, codon: str) -> float:
|
|
150
|
+
return self.weights[codon]
|
|
151
|
+
|
|
152
|
+
def by_aa(self, aa: str) -> Dict[str, float]:
|
|
153
|
+
"""
|
|
154
|
+
Return the codon weights corresponding to a particular AA.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
aa
|
|
159
|
+
The amino acid of interest.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
A set of codon weights keyed by codon.
|
|
164
|
+
"""
|
|
165
|
+
codons = self.aa_to_codons[aa.upper()]
|
|
166
|
+
weights = {codon: self.weights[codon] for codon in codons}
|
|
167
|
+
return weights
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def uniform(cls, table: Optional[TranslationTable] = None, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
171
|
+
"""
|
|
172
|
+
Construct a ``CodonWeights`` object with uniform codon weights for a given translation table.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
table
|
|
177
|
+
The reference table. If blank, use the standard genetic code.
|
|
178
|
+
|
|
179
|
+
rna
|
|
180
|
+
Whether to use RNA.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
A uniform CodonWeights object.
|
|
185
|
+
"""
|
|
186
|
+
if table is None:
|
|
187
|
+
table = TranslationTable(rna=False if rna is None else rna)
|
|
188
|
+
|
|
189
|
+
elif rna is not None and table.rna != rna:
|
|
190
|
+
raise ValueError('table and rna specify inconsistent molecule types.')
|
|
191
|
+
|
|
192
|
+
uniform_weights: WeightDict = {
|
|
193
|
+
aa: {codon: 1.0 for codon in codons}
|
|
194
|
+
for aa, codons in table.aa_to_codons.items()
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
return cls(uniform_weights, table=table)
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def ecoli(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
201
|
+
"""
|
|
202
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to E. coli.
|
|
203
|
+
|
|
204
|
+
Weights are based on codon usage counts from GenScript:
|
|
205
|
+
|
|
206
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
rna
|
|
211
|
+
Whether to use RNA.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
A ``CodonWeights`` object corresponding to E. Coli
|
|
216
|
+
"""
|
|
217
|
+
from codeine.translation.data.weights import ECOLI_WEIGHTS
|
|
218
|
+
return cls(ECOLI_WEIGHTS, rna=rna)
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def yeast(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
222
|
+
"""
|
|
223
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to 'yeast'.
|
|
224
|
+
|
|
225
|
+
Weights are based on codon usage counts from GenScript:
|
|
226
|
+
|
|
227
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
rna
|
|
232
|
+
Whether to use RNA.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
A ``CodonWeights`` object corresponding to S. cerevisiea
|
|
237
|
+
"""
|
|
238
|
+
from codeine.translation.data.weights import YEAST_WEIGHTS
|
|
239
|
+
return cls(YEAST_WEIGHTS, rna=rna)
|
|
240
|
+
|
|
241
|
+
@classmethod
|
|
242
|
+
def human(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
243
|
+
"""
|
|
244
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to Human.
|
|
245
|
+
|
|
246
|
+
Weights are based on codon usage counts from GenScript:
|
|
247
|
+
|
|
248
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
rna
|
|
253
|
+
Whether to use RNA.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
A ``CodonWeights`` object corresponding to Human.
|
|
258
|
+
"""
|
|
259
|
+
from codeine.translation.data.weights import HUMAN_WEIGHTS
|
|
260
|
+
return cls(HUMAN_WEIGHTS, rna=rna)
|
|
261
|
+
|
|
262
|
+
@classmethod
|
|
263
|
+
def mouse(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
264
|
+
"""
|
|
265
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to Mouse.
|
|
266
|
+
|
|
267
|
+
Weights are based on codon usage counts from GenScript:
|
|
268
|
+
|
|
269
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
rna
|
|
274
|
+
Whether to use RNA.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
A ``CodonWeights`` object corresponding to Mouse.
|
|
279
|
+
"""
|
|
280
|
+
from codeine.translation.data.weights import MOUSE_WEIGHTS
|
|
281
|
+
return cls(MOUSE_WEIGHTS, rna=rna)
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def arabidopsis(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
285
|
+
"""
|
|
286
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to A. thaliana.
|
|
287
|
+
|
|
288
|
+
Weights are based on codon usage counts from GenScript:
|
|
289
|
+
|
|
290
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
291
|
+
|
|
292
|
+
Parameters
|
|
293
|
+
----------
|
|
294
|
+
rna
|
|
295
|
+
Whether to use RNA.
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
A ``CodonWeights`` object corresponding to Arabidopsis thaliana.
|
|
300
|
+
"""
|
|
301
|
+
from codeine.translation.data.weights import ARABIDOPSIS_WEIGHTS
|
|
302
|
+
return cls(ARABIDOPSIS_WEIGHTS, rna=rna)
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def drosophila(cls, rna: Optional[bool] = None) -> 'CodonWeights':
|
|
306
|
+
"""
|
|
307
|
+
Construct a ``CodonWeights`` object with codon probabilities corresponding to D. melanogaster.
|
|
308
|
+
|
|
309
|
+
Weights are based on codon usage counts from GenScript:
|
|
310
|
+
|
|
311
|
+
https://www.genscript.com/tools/codon-frequency-table
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
rna
|
|
316
|
+
Whether to use RNA.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
A ``CodonWeights`` object corresponding to Drosophila melanogaster.
|
|
321
|
+
"""
|
|
322
|
+
from codeine.translation.data.weights import DROSOPHILA_WEIGHTS
|
|
323
|
+
return cls(DROSOPHILA_WEIGHTS, rna=rna)
|
|
File without changes
|
codeine/utils/dict.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, Iterator
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FrozenDict(Mapping):
|
|
6
|
+
"""
|
|
7
|
+
Immutable (at the shallow level) version of a dict.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, data: Mapping) -> None:
|
|
11
|
+
self._data = dict(data)
|
|
12
|
+
|
|
13
|
+
def __getitem__(self, key: Any) -> Any:
|
|
14
|
+
return self._data[key]
|
|
15
|
+
|
|
16
|
+
def __iter__(self) -> Iterator:
|
|
17
|
+
return iter(self._data)
|
|
18
|
+
|
|
19
|
+
def __len__(self) -> int:
|
|
20
|
+
return len(self._data)
|
|
21
|
+
|
|
22
|
+
def __repr__(self) -> str:
|
|
23
|
+
return repr(self._data)
|
codeine/utils/display.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from codeine.space.coding import ForbiddenMotif
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List, Sequence, Union
|
|
7
|
+
|
|
8
|
+
from codeine.motifs.restriction import RestrictionSite
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
CodonRestriction = Union[str, Sequence[str]]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def format_count(n: int) -> str:
|
|
15
|
+
"""
|
|
16
|
+
Format large counts in scientific notation.
|
|
17
|
+
"""
|
|
18
|
+
if n < 10**9:
|
|
19
|
+
return f'{n:,}'
|
|
20
|
+
|
|
21
|
+
exponent = len(str(n)) - 1
|
|
22
|
+
mantissa = n / (10 ** exponent)
|
|
23
|
+
return f'{mantissa:.3g} × 10^{exponent}'
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def format_restrictions(
|
|
27
|
+
restrictions: Dict[int, CodonRestriction],
|
|
28
|
+
label: str = 'positions',
|
|
29
|
+
max_lines: int = 8,
|
|
30
|
+
) -> List[str]:
|
|
31
|
+
"""
|
|
32
|
+
Format codon restrictions or pins for repr output.
|
|
33
|
+
"""
|
|
34
|
+
if not restrictions:
|
|
35
|
+
return [' None']
|
|
36
|
+
|
|
37
|
+
items = sorted(restrictions.items())
|
|
38
|
+
|
|
39
|
+
if len(items) <= max_lines:
|
|
40
|
+
lines = []
|
|
41
|
+
|
|
42
|
+
for pos, codons in items:
|
|
43
|
+
if isinstance(codons, str):
|
|
44
|
+
codons = [codons]
|
|
45
|
+
lines.append(f' {pos}: {" ".join(codons)}')
|
|
46
|
+
|
|
47
|
+
return lines
|
|
48
|
+
|
|
49
|
+
lines = [f' {len(items)} {label}']
|
|
50
|
+
|
|
51
|
+
for pos, codons in items[:max_lines]:
|
|
52
|
+
if isinstance(codons, str):
|
|
53
|
+
codons = [codons]
|
|
54
|
+
lines.append(f' {pos}: {" ".join(codons)}')
|
|
55
|
+
|
|
56
|
+
lines.append(f' ... {len(items) - max_lines} more')
|
|
57
|
+
return lines
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def format_forbidden_motifs(
|
|
61
|
+
sequences: Sequence[str],
|
|
62
|
+
max_lines: int = 8,
|
|
63
|
+
) -> List[str]:
|
|
64
|
+
"""
|
|
65
|
+
Format banned sequences for repr output.
|
|
66
|
+
"""
|
|
67
|
+
if not sequences:
|
|
68
|
+
return [' None']
|
|
69
|
+
|
|
70
|
+
if len(sequences) <= max_lines:
|
|
71
|
+
return [f' {sequence}' for sequence in sequences]
|
|
72
|
+
|
|
73
|
+
lines = [f' {len(sequences)} sequences']
|
|
74
|
+
for sequence in sequences[:max_lines]:
|
|
75
|
+
lines.append(f' {sequence}')
|
|
76
|
+
|
|
77
|
+
lines.append(f' ... {len(sequences) - max_lines} more')
|
|
78
|
+
return lines
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def normalise_motif(seq: str, rna: bool) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Normalise a motif by uppercasing and converting to RNA/DNA as specified
|
|
84
|
+
"""
|
|
85
|
+
seq = seq.upper()
|
|
86
|
+
return seq.replace('T', 'U') if rna else seq.replace('U', 'T')
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def format_forbidden_motif(motif: 'ForbiddenMotif', rna: bool) -> str:
|
|
90
|
+
"""
|
|
91
|
+
Format a forbidden motif for display.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
if isinstance(motif, RestrictionSite):
|
|
95
|
+
sequences = [normalise_motif(seq, rna) for seq in motif.motifs]
|
|
96
|
+
return f'{motif.name} ({", ".join(sequences)})'
|
|
97
|
+
|
|
98
|
+
return normalise_motif(motif, rna)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def format_positions(positions) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Format positions nicely.
|
|
104
|
+
"""
|
|
105
|
+
positions = sorted(positions)
|
|
106
|
+
if not positions:
|
|
107
|
+
return 'None'
|
|
108
|
+
|
|
109
|
+
ranges = []
|
|
110
|
+
start = prev = positions[0]
|
|
111
|
+
|
|
112
|
+
for pos in positions[1:]:
|
|
113
|
+
if pos == prev + 1:
|
|
114
|
+
prev = pos
|
|
115
|
+
else:
|
|
116
|
+
ranges.append((start, prev))
|
|
117
|
+
start = prev = pos
|
|
118
|
+
|
|
119
|
+
ranges.append((start, prev))
|
|
120
|
+
|
|
121
|
+
return ', '.join(
|
|
122
|
+
str(start) if start == end else f'{start}-{end}'
|
|
123
|
+
for start, end in ranges
|
|
124
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import bisect
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
from typing import Any, Sequence, Optional, Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
Seedable = Union[None, int, float, str, bytes, bytearray]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Sampler:
|
|
11
|
+
"""
|
|
12
|
+
A precomputed sampler to speed up weighted sampling.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self,
|
|
15
|
+
items: Sequence[Any],
|
|
16
|
+
weights: Sequence[Union[int, float]] = None,
|
|
17
|
+
seed: Optional[Seedable] = None,
|
|
18
|
+
rng: Optional[random.Random] = None,
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Constructor for the Sampler class.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
items
|
|
26
|
+
The items from which to sample.
|
|
27
|
+
weights
|
|
28
|
+
The weights assigned to the items.
|
|
29
|
+
seed
|
|
30
|
+
Seed used to initialise a random number generator on this Sampler.
|
|
31
|
+
rng
|
|
32
|
+
Pre-constructed random number generator to use for sampling.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
if len(items) == 0:
|
|
36
|
+
raise ValueError('Items cannot be empty.')
|
|
37
|
+
|
|
38
|
+
if weights is None:
|
|
39
|
+
weights = [1] * len(items)
|
|
40
|
+
|
|
41
|
+
if len(items) != len(weights):
|
|
42
|
+
raise ValueError('Items and weights must have same length.')
|
|
43
|
+
|
|
44
|
+
if any(weight < 0 for weight in weights):
|
|
45
|
+
raise ValueError('Weights cannot be negative.')
|
|
46
|
+
|
|
47
|
+
if seed is not None and rng is not None:
|
|
48
|
+
raise ValueError('Provide either seed or rng, not both.')
|
|
49
|
+
|
|
50
|
+
total = sum(weights)
|
|
51
|
+
if total <= 0:
|
|
52
|
+
raise ValueError('Weights must sum to a positive number.')
|
|
53
|
+
|
|
54
|
+
self.items = tuple(items)
|
|
55
|
+
self._single = len(items) == 1
|
|
56
|
+
|
|
57
|
+
if not self._single:
|
|
58
|
+
|
|
59
|
+
if rng is None:
|
|
60
|
+
if seed is not None:
|
|
61
|
+
rng = random.Random(seed)
|
|
62
|
+
else:
|
|
63
|
+
rng = random.Random()
|
|
64
|
+
|
|
65
|
+
self._rng = rng
|
|
66
|
+
|
|
67
|
+
cumulative = []
|
|
68
|
+
running = 0
|
|
69
|
+
|
|
70
|
+
for weight in weights:
|
|
71
|
+
running += weight
|
|
72
|
+
cumulative.append(running / total)
|
|
73
|
+
|
|
74
|
+
self._cumulative = cumulative
|
|
75
|
+
|
|
76
|
+
def sample(self):
|
|
77
|
+
"""
|
|
78
|
+
Sample the items according to the stored weights.
|
|
79
|
+
If there is only one item, just return that.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
The sampled item.
|
|
84
|
+
"""
|
|
85
|
+
if self._single:
|
|
86
|
+
return self.items[0]
|
|
87
|
+
else:
|
|
88
|
+
r = self._rng.random()
|
|
89
|
+
i = bisect.bisect_left(self._cumulative, r)
|
|
90
|
+
return self.items[i]
|