phasegen 0.0.3b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phasegen/__init__.py +225 -0
- phasegen/coalescent_models.py +462 -0
- phasegen/comparison.py +406 -0
- phasegen/demography.py +1066 -0
- phasegen/distributions.py +2928 -0
- phasegen/expm.py +77 -0
- phasegen/inference.py +740 -0
- phasegen/lineage.py +79 -0
- phasegen/locus.py +88 -0
- phasegen/norms.py +114 -0
- phasegen/rewards.py +540 -0
- phasegen/serialization.py +49 -0
- phasegen/spectrum.py +441 -0
- phasegen/state_space.py +924 -0
- phasegen/state_space_old.py +1601 -0
- phasegen/utils.py +45 -0
- phasegen/visualization.py +174 -0
- phasegen-0.0.3b0.dist-info/METADATA +36 -0
- phasegen-0.0.3b0.dist-info/RECORD +20 -0
- phasegen-0.0.3b0.dist-info/WHEEL +4 -0
phasegen/__init__.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PhaseGen package.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__author__ = "Janek Sendrowski"
|
|
6
|
+
__contact__ = "sendrowski.janek@gmail.com"
|
|
7
|
+
__date__ = "2023-04-09"
|
|
8
|
+
|
|
9
|
+
__version__ = '0.0.3-beta'
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
import jsonpickle.ext.numpy as jsonpickle_numpy
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
# lower the verbosity of TensorFlow
|
|
19
|
+
if 'TF_CPP_MIN_LOG_LEVEL' not in os.environ:
|
|
20
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
21
|
+
|
|
22
|
+
# register handlers
|
|
23
|
+
jsonpickle_numpy.register_handlers()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TqdmLoggingHandler(logging.Handler):
|
|
27
|
+
"""
|
|
28
|
+
A logging handler that uses TQDM to display log messages.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, level=logging.NOTSET):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the handler.
|
|
34
|
+
|
|
35
|
+
:param level:
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(level)
|
|
38
|
+
|
|
39
|
+
def emit(self, record):
|
|
40
|
+
"""
|
|
41
|
+
Emit a record.
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
msg = self.format(record)
|
|
45
|
+
|
|
46
|
+
# we write to stderr to avoid as the progress bar
|
|
47
|
+
# to make the two work together
|
|
48
|
+
tqdm.write(msg, file=sys.stderr)
|
|
49
|
+
self.flush()
|
|
50
|
+
except Exception:
|
|
51
|
+
self.handleError(record)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ColoredFormatter(logging.Formatter):
|
|
55
|
+
"""
|
|
56
|
+
Colored formatter.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, *args, **kwargs):
|
|
60
|
+
"""
|
|
61
|
+
Initialize the formatter.
|
|
62
|
+
"""
|
|
63
|
+
super().__init__(*args, **kwargs)
|
|
64
|
+
|
|
65
|
+
self.colors = {
|
|
66
|
+
"DEBUG": "\033[36m", # Cyan
|
|
67
|
+
"INFO": "\033[32m", # Green
|
|
68
|
+
"WARNING": "\033[33m", # Yellow
|
|
69
|
+
"ERROR": "\033[31m", # Red
|
|
70
|
+
"CRITICAL": "\033[31m", # Red
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
self.reset = "\033[0m"
|
|
74
|
+
|
|
75
|
+
def format(self, record):
|
|
76
|
+
"""
|
|
77
|
+
Format the record.
|
|
78
|
+
"""
|
|
79
|
+
color = self.colors.get(record.levelname, self.reset)
|
|
80
|
+
|
|
81
|
+
formatted = super().format(record)
|
|
82
|
+
|
|
83
|
+
# remove package name
|
|
84
|
+
formatted = formatted.replace(record.name, record.name.split('.')[-1])
|
|
85
|
+
|
|
86
|
+
return f"{color}{formatted}{self.reset}"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# configure logger
|
|
90
|
+
logger = logging.getLogger('phasegen')
|
|
91
|
+
|
|
92
|
+
# don't propagate to the root logger
|
|
93
|
+
logger.propagate = False
|
|
94
|
+
|
|
95
|
+
# set to INFO by default
|
|
96
|
+
logger.setLevel(logging.INFO)
|
|
97
|
+
|
|
98
|
+
# let TQDM handle the logging
|
|
99
|
+
handler = TqdmLoggingHandler()
|
|
100
|
+
|
|
101
|
+
# define a Formatter with colors
|
|
102
|
+
formatter = ColoredFormatter('%(levelname)s:%(name)s: %(message)s')
|
|
103
|
+
|
|
104
|
+
handler.setFormatter(formatter)
|
|
105
|
+
logger.addHandler(handler)
|
|
106
|
+
|
|
107
|
+
from .distributions import PhaseTypeDistribution
|
|
108
|
+
|
|
109
|
+
from .distributions import Coalescent
|
|
110
|
+
|
|
111
|
+
from .demography import (
|
|
112
|
+
Demography,
|
|
113
|
+
Epoch,
|
|
114
|
+
DiscreteRateChanges,
|
|
115
|
+
PopSizeChanges,
|
|
116
|
+
PopSizeChange,
|
|
117
|
+
MigrationRateChanges,
|
|
118
|
+
MigrationRateChange,
|
|
119
|
+
SymmetricMigrationRateChanges,
|
|
120
|
+
PopulationSplit,
|
|
121
|
+
DiscretizedRateChanges,
|
|
122
|
+
DiscretizedRateChange,
|
|
123
|
+
ExponentialPopSizeChanges,
|
|
124
|
+
ExponentialRateChanges
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
from .coalescent_models import (
|
|
128
|
+
CoalescentModel,
|
|
129
|
+
StandardCoalescent,
|
|
130
|
+
BetaCoalescent,
|
|
131
|
+
DiracCoalescent
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
from .state_space import (
|
|
135
|
+
StateSpace,
|
|
136
|
+
DefaultStateSpace,
|
|
137
|
+
BlockCountingStateSpace
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
from .rewards import (
|
|
141
|
+
Reward,
|
|
142
|
+
DefaultReward,
|
|
143
|
+
NonDefaultReward,
|
|
144
|
+
TreeHeightReward,
|
|
145
|
+
TotalTreeHeightReward,
|
|
146
|
+
TotalBranchLengthReward,
|
|
147
|
+
UnfoldedSFSReward,
|
|
148
|
+
FoldedSFSReward,
|
|
149
|
+
CustomReward,
|
|
150
|
+
ProductReward,
|
|
151
|
+
SumReward,
|
|
152
|
+
CombinedReward,
|
|
153
|
+
DemeReward
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
from .spectrum import (
|
|
157
|
+
SFS,
|
|
158
|
+
Spectra,
|
|
159
|
+
SFS2
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
from .inference import Inference
|
|
163
|
+
|
|
164
|
+
from .lineage import LineageConfig
|
|
165
|
+
|
|
166
|
+
from .locus import LocusConfig
|
|
167
|
+
|
|
168
|
+
from .norms import (
|
|
169
|
+
LNorm,
|
|
170
|
+
L1Norm,
|
|
171
|
+
L2Norm,
|
|
172
|
+
LInfNorm,
|
|
173
|
+
PoissonLikelihood
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
from .state_space_old import StateSpace as OldStateSpace
|
|
177
|
+
|
|
178
|
+
__all__ = [
|
|
179
|
+
'PhaseTypeDistribution',
|
|
180
|
+
'Coalescent',
|
|
181
|
+
'Demography',
|
|
182
|
+
'Epoch',
|
|
183
|
+
'PopSizeChanges',
|
|
184
|
+
'PopSizeChange',
|
|
185
|
+
'MigrationRateChanges',
|
|
186
|
+
'MigrationRateChange',
|
|
187
|
+
'SymmetricMigrationRateChanges',
|
|
188
|
+
'PopulationSplit',
|
|
189
|
+
'ExponentialPopSizeChanges',
|
|
190
|
+
'ExponentialRateChanges',
|
|
191
|
+
'DiscreteRateChanges',
|
|
192
|
+
'DiscretizedRateChange',
|
|
193
|
+
'DiscretizedRateChanges',
|
|
194
|
+
'StandardCoalescent',
|
|
195
|
+
'BetaCoalescent',
|
|
196
|
+
'DiracCoalescent',
|
|
197
|
+
'SFS2',
|
|
198
|
+
'SFS',
|
|
199
|
+
'Spectra',
|
|
200
|
+
'Inference',
|
|
201
|
+
'LNorm',
|
|
202
|
+
'L1Norm',
|
|
203
|
+
'L2Norm',
|
|
204
|
+
'LInfNorm',
|
|
205
|
+
'PoissonLikelihood',
|
|
206
|
+
'Reward',
|
|
207
|
+
'TreeHeightReward',
|
|
208
|
+
'TotalTreeHeightReward',
|
|
209
|
+
'TotalBranchLengthReward',
|
|
210
|
+
'UnfoldedSFSReward',
|
|
211
|
+
'FoldedSFSReward',
|
|
212
|
+
'CustomReward',
|
|
213
|
+
'ProductReward',
|
|
214
|
+
'SumReward',
|
|
215
|
+
'DemeReward',
|
|
216
|
+
'DefaultReward',
|
|
217
|
+
'NonDefaultReward',
|
|
218
|
+
'CombinedReward',
|
|
219
|
+
'StateSpace',
|
|
220
|
+
'DefaultStateSpace',
|
|
221
|
+
'BlockCountingStateSpace',
|
|
222
|
+
'CoalescentModel',
|
|
223
|
+
'LineageConfig',
|
|
224
|
+
'LocusConfig',
|
|
225
|
+
]
|
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Coalescent models.
|
|
3
|
+
"""
|
|
4
|
+
import itertools
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import List, Tuple, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.special import comb, beta
|
|
10
|
+
from scipy.stats import binom
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CoalescentModel(ABC):
|
|
14
|
+
"""
|
|
15
|
+
Abstract class for coalescent models.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def get_rate(self, s1: int, s2: int) -> float:
|
|
19
|
+
"""
|
|
20
|
+
Get rate for a merger collapsing k1 lineages into k2 lineages.
|
|
21
|
+
|
|
22
|
+
:param s1: Number of lineages in the first state.
|
|
23
|
+
:param s2: Number of lineages in the second state.
|
|
24
|
+
:return: The rate.
|
|
25
|
+
"""
|
|
26
|
+
# not possible
|
|
27
|
+
if s2 > s1:
|
|
28
|
+
return 0
|
|
29
|
+
|
|
30
|
+
return self._get_rate(b=s1, k=s1 + 1 - s2)
|
|
31
|
+
|
|
32
|
+
def get_rate_block_counting(self, n: int, s1: np.ndarray, s2: np.ndarray) -> float:
|
|
33
|
+
r"""
|
|
34
|
+
Get (positive) rate between two block counting states.
|
|
35
|
+
:math:`{ (a_1,...,a_n) \in \mathbb{Z}^+ : \sum_{i=1}^{n} a_i = n \}`.
|
|
36
|
+
|
|
37
|
+
:param n: Number of lineages.
|
|
38
|
+
:param s1: Block configuration 1, a vector of length n.
|
|
39
|
+
:param s2: Block configuration 2, a vector of length n.
|
|
40
|
+
:return: The rate.
|
|
41
|
+
"""
|
|
42
|
+
diff = s2 - s1
|
|
43
|
+
|
|
44
|
+
# make sure only one class has one more lineage
|
|
45
|
+
if np.sum(diff == 1) == 1 and n == s1.shape[0]:
|
|
46
|
+
|
|
47
|
+
# get the index for the class that lost lineages
|
|
48
|
+
where_less = np.where(diff < 0)[0]
|
|
49
|
+
|
|
50
|
+
# only continue if there is a class that lost lineages
|
|
51
|
+
if len(where_less) > 0:
|
|
52
|
+
|
|
53
|
+
# get the number of lineages that were lost
|
|
54
|
+
diff_less = -diff[where_less]
|
|
55
|
+
|
|
56
|
+
# determine the index of the class that gained lineages
|
|
57
|
+
i_more = np.dot(where_less + 1, diff_less) - 1
|
|
58
|
+
|
|
59
|
+
# make sure that the class that gained lineages only gained one lineage
|
|
60
|
+
if diff[i_more] == 1:
|
|
61
|
+
# number of lineages before the merger
|
|
62
|
+
b = s1[where_less]
|
|
63
|
+
|
|
64
|
+
# determine number of lineages that coalesce
|
|
65
|
+
k = b - s2[where_less]
|
|
66
|
+
|
|
67
|
+
# get rate
|
|
68
|
+
rate = self._get_rate_block_counting(n=s1.sum(), b=b, k=k)
|
|
69
|
+
return rate
|
|
70
|
+
|
|
71
|
+
return 0
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def _get_timescale(self, N: float) -> float:
|
|
75
|
+
"""
|
|
76
|
+
Get the timescale.
|
|
77
|
+
|
|
78
|
+
:param N: The effective population size.
|
|
79
|
+
:return: The generation time.
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def _get_rate(self, b: int, k: int) -> float:
|
|
85
|
+
"""
|
|
86
|
+
Get positive rate for a merger of k out of b lineages.
|
|
87
|
+
Negative rates will be inferred later
|
|
88
|
+
|
|
89
|
+
:param b: Number of lineages.
|
|
90
|
+
:param k: Number of lineages that merge.
|
|
91
|
+
:return: The rate.
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
|
|
97
|
+
"""
|
|
98
|
+
Get positive rate for a merger of k_i out of b_i lineages for all i.
|
|
99
|
+
Negative rates will be inferred later
|
|
100
|
+
|
|
101
|
+
:param n: Number of lineages.
|
|
102
|
+
:param b: Number of lineages before merge for blocks that experience a merger.
|
|
103
|
+
:param k: Number of lineages that merge for blocks that experience a merger.
|
|
104
|
+
:return: The rate.
|
|
105
|
+
"""
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def coalesce(self, n: int, blocks: np.ndarray) -> List[Tuple[np.ndarray, float]]:
|
|
110
|
+
"""
|
|
111
|
+
Coalesce a state.
|
|
112
|
+
|
|
113
|
+
:param n: The total number of lineages.
|
|
114
|
+
:param blocks: The lineages in each block.
|
|
115
|
+
:return: List of coalesced states and their rates.
|
|
116
|
+
"""
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class StandardCoalescent(CoalescentModel):
|
|
121
|
+
"""
|
|
122
|
+
Standard (Kingman) coalescent model. Refer to
|
|
123
|
+
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
|
|
124
|
+
highlight=standard+coalescent#msprime.StandardCoalescent>`__
|
|
125
|
+
for more information.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def _get_timescale(self, N: float) -> float:
|
|
129
|
+
"""
|
|
130
|
+
Get the timescale.
|
|
131
|
+
|
|
132
|
+
:param N: The effective population size.
|
|
133
|
+
:return: The generation time.
|
|
134
|
+
"""
|
|
135
|
+
return N
|
|
136
|
+
|
|
137
|
+
def _get_rate(self, b: int, k: int) -> float:
|
|
138
|
+
"""
|
|
139
|
+
Get positive rate for a merger of k out of b lineages.
|
|
140
|
+
|
|
141
|
+
:param b: Number of lineages.
|
|
142
|
+
:param k: Number of lineages that merge.
|
|
143
|
+
:return: The rate.
|
|
144
|
+
"""
|
|
145
|
+
# two lineages can merge with a rate depending on b
|
|
146
|
+
if k == 2:
|
|
147
|
+
return b * (b - 1) / 2
|
|
148
|
+
|
|
149
|
+
# no other mergers can happen
|
|
150
|
+
return 0
|
|
151
|
+
|
|
152
|
+
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
|
|
153
|
+
"""
|
|
154
|
+
Get positive rate for a merger of k_i out of b_i lineages for all i.
|
|
155
|
+
|
|
156
|
+
:param n: Number of lineages.
|
|
157
|
+
:param b: Number of lineages before merge for blocks that experience a merger.
|
|
158
|
+
:param k: Number of lineages that merge for blocks that experience a merger.
|
|
159
|
+
:return: The rate.
|
|
160
|
+
"""
|
|
161
|
+
# if we have a single class
|
|
162
|
+
if len(b) == 1:
|
|
163
|
+
return self._get_rate(b=b[0], k=k[0])
|
|
164
|
+
|
|
165
|
+
# if we have a merger from two classes
|
|
166
|
+
if len(b) == 2:
|
|
167
|
+
if k[0] == 1 and k[1] == 1:
|
|
168
|
+
# same as b[0] choose k[0] times b[1] choose k[1]
|
|
169
|
+
return b[0] * b[1]
|
|
170
|
+
|
|
171
|
+
# no other mergers possible
|
|
172
|
+
return 0
|
|
173
|
+
|
|
174
|
+
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
|
|
175
|
+
"""
|
|
176
|
+
Coalesce a state.
|
|
177
|
+
|
|
178
|
+
:param n: The total number of lineages.
|
|
179
|
+
:param blocks: The lineages in each block.
|
|
180
|
+
:return: List of coalesced states and their rates.
|
|
181
|
+
"""
|
|
182
|
+
n_blocks = len(blocks)
|
|
183
|
+
states = []
|
|
184
|
+
|
|
185
|
+
# default state space
|
|
186
|
+
if n_blocks == 1:
|
|
187
|
+
if blocks[0] > 1:
|
|
188
|
+
states += [(np.array([blocks[0] - 1]), self._get_rate(b=blocks[0], k=2))]
|
|
189
|
+
|
|
190
|
+
return states
|
|
191
|
+
|
|
192
|
+
# block counting state space
|
|
193
|
+
for i, j in itertools.product(range(n_blocks), repeat=2):
|
|
194
|
+
if i == j:
|
|
195
|
+
if blocks[i] > 1:
|
|
196
|
+
new = blocks.copy()
|
|
197
|
+
new[i] -= 2
|
|
198
|
+
new[2 * (i + 1) - 1] += 1
|
|
199
|
+
states += [(new, self._get_rate_block_counting(n=n, b=[blocks[i]], k=[2]))]
|
|
200
|
+
|
|
201
|
+
elif i > j:
|
|
202
|
+
if blocks[i] > 0 and blocks[j] > 0:
|
|
203
|
+
new = blocks.copy()
|
|
204
|
+
new[i] -= 1
|
|
205
|
+
new[j] -= 1
|
|
206
|
+
new[i + j + 1] += 1
|
|
207
|
+
|
|
208
|
+
rate = self._get_rate_block_counting(n=n, b=[blocks[i], blocks[j]], k=[1, 1])
|
|
209
|
+
|
|
210
|
+
states += [(new, rate)]
|
|
211
|
+
|
|
212
|
+
return states
|
|
213
|
+
|
|
214
|
+
def __eq__(self, other):
|
|
215
|
+
"""
|
|
216
|
+
Check if two coalescent models are equal.
|
|
217
|
+
|
|
218
|
+
:param other: The other coalescent model.
|
|
219
|
+
:return: Whether the two coalescent models are equal.
|
|
220
|
+
"""
|
|
221
|
+
return isinstance(other, StandardCoalescent)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class MultipleMergerCoalescent(CoalescentModel, ABC):
|
|
225
|
+
"""
|
|
226
|
+
Base class for multiple merger coalescent models.
|
|
227
|
+
|
|
228
|
+
:meta private:
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
|
|
232
|
+
"""
|
|
233
|
+
Coalesce a state.
|
|
234
|
+
|
|
235
|
+
:param n: The total number of lineages.
|
|
236
|
+
:param blocks: The lineages in each block.
|
|
237
|
+
:return: List of coalesced states and their rates.
|
|
238
|
+
"""
|
|
239
|
+
n_blocks = len(blocks)
|
|
240
|
+
states = []
|
|
241
|
+
|
|
242
|
+
# default state space
|
|
243
|
+
if n_blocks == 1:
|
|
244
|
+
for k in range(1, blocks[0]):
|
|
245
|
+
states += [(np.array([blocks[0] - k]), self._get_rate(b=blocks[0], k=k + 1))]
|
|
246
|
+
|
|
247
|
+
return states
|
|
248
|
+
|
|
249
|
+
# block counting state space
|
|
250
|
+
for comb in itertools.product(*[list(range(blocks[i] + 1)) for i in range(n_blocks)]):
|
|
251
|
+
comb = np.array(comb)
|
|
252
|
+
|
|
253
|
+
if comb.sum() > 1:
|
|
254
|
+
new = blocks.copy()
|
|
255
|
+
new -= comb
|
|
256
|
+
new[comb.dot(np.arange(1, n_blocks + 1)) - 1] += 1
|
|
257
|
+
|
|
258
|
+
rate = self._get_rate_block_counting(n=blocks.sum(), b=blocks[comb > 0], k=comb[comb > 0])
|
|
259
|
+
states += [(new, rate)]
|
|
260
|
+
|
|
261
|
+
return states
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class BetaCoalescent(MultipleMergerCoalescent):
|
|
265
|
+
"""
|
|
266
|
+
Beta coalescent model. Refer to
|
|
267
|
+
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=beta+coalescent#msprime.BetaCoalescent>`__
|
|
268
|
+
for more information.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(self, alpha: float, scale_time: bool = True):
|
|
272
|
+
"""
|
|
273
|
+
Initialize the beta coalescent model.
|
|
274
|
+
|
|
275
|
+
:param alpha: The alpha parameter of the beta coalescent model.
|
|
276
|
+
:param scale_time: Whether to scale coalescence time as described in
|
|
277
|
+
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
|
|
278
|
+
highlight=beta+coalescent#msprime.BetaCoalescent>`__. If ``False``, the timescale is set to N.
|
|
279
|
+
"""
|
|
280
|
+
if alpha < 1 or alpha > 2:
|
|
281
|
+
raise ValueError("Alpha must be between 1 and 2.")
|
|
282
|
+
|
|
283
|
+
#: Whether to scale coalescence time
|
|
284
|
+
self.scale_time: bool = scale_time
|
|
285
|
+
|
|
286
|
+
#: The alpha parameter of the beta coalescent model.
|
|
287
|
+
self.alpha: float = alpha
|
|
288
|
+
|
|
289
|
+
def _get_base_rate(self, b: int, k: int) -> float:
|
|
290
|
+
"""
|
|
291
|
+
Get base rate for a merger of k out of b lineages (without number of ways).
|
|
292
|
+
|
|
293
|
+
:param b: The number of lineages before the merger.
|
|
294
|
+
:param k: The number of lineages that merge.
|
|
295
|
+
:return: The rate.
|
|
296
|
+
"""
|
|
297
|
+
rate = beta(k - self.alpha, b - k + self.alpha) / beta(self.alpha, 2 - self.alpha)
|
|
298
|
+
|
|
299
|
+
return rate
|
|
300
|
+
|
|
301
|
+
def _get_timescale(self, N: float) -> float:
|
|
302
|
+
"""
|
|
303
|
+
Get the timescale.
|
|
304
|
+
|
|
305
|
+
:param N: The effective population size.
|
|
306
|
+
:return: The generation time.
|
|
307
|
+
"""
|
|
308
|
+
if not self.scale_time:
|
|
309
|
+
return N
|
|
310
|
+
|
|
311
|
+
m = 1 + 1 / 2 ** (self.alpha - 1) / (self.alpha - 1)
|
|
312
|
+
|
|
313
|
+
scale = m ** self.alpha * N ** (self.alpha - 1) / self.alpha / beta(2 - self.alpha, self.alpha)
|
|
314
|
+
|
|
315
|
+
return scale
|
|
316
|
+
|
|
317
|
+
def _get_rate(self, b: int, k: int) -> float:
|
|
318
|
+
"""
|
|
319
|
+
Get positive rate for a merger of k out of b lineages.
|
|
320
|
+
Negative rates will be filled in later.
|
|
321
|
+
|
|
322
|
+
:param b: The number of lineages before the merger.
|
|
323
|
+
:param k: The number of lineages that merge.
|
|
324
|
+
:return: The rate.
|
|
325
|
+
"""
|
|
326
|
+
if k < 1 or k > b:
|
|
327
|
+
return 0
|
|
328
|
+
|
|
329
|
+
return comb(b, k, exact=True) * self._get_base_rate(b, k)
|
|
330
|
+
|
|
331
|
+
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
|
|
332
|
+
"""
|
|
333
|
+
Get positive rate for a merger of k_i out of b_i lineages for all i.
|
|
334
|
+
|
|
335
|
+
:param n: Number of lineages.
|
|
336
|
+
:param b: Number of lineages before merge for blocks that experience a merger.
|
|
337
|
+
:param k: Number of lineages that merge for blocks that experience a merger.
|
|
338
|
+
:return: The rate.
|
|
339
|
+
"""
|
|
340
|
+
combinations = np.prod([comb(N=b_i, k=k_i, exact=True) for b_i, k_i in zip(b, k)])
|
|
341
|
+
|
|
342
|
+
return combinations * self._get_base_rate(b=n, k=sum(k))
|
|
343
|
+
|
|
344
|
+
def __eq__(self, other):
|
|
345
|
+
"""
|
|
346
|
+
Check if two coalescent models are equal.
|
|
347
|
+
|
|
348
|
+
:param other: The other coalescent model.
|
|
349
|
+
:return: Whether the two coalescent models are equal.
|
|
350
|
+
"""
|
|
351
|
+
return (
|
|
352
|
+
isinstance(other, BetaCoalescent) and
|
|
353
|
+
self.alpha == other.alpha and
|
|
354
|
+
self.scale_time == other.scale_time
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class DiracCoalescent(MultipleMergerCoalescent):
|
|
359
|
+
"""
|
|
360
|
+
Dirac coalescent model. Refer to
|
|
361
|
+
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=dirac+coalescent#msprime.DiracCoalescent>`__
|
|
362
|
+
for more information.
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def __init__(self, psi: float, c: float, scale_time: bool = True):
|
|
366
|
+
"""
|
|
367
|
+
Initialize the Dirac coalescent model.
|
|
368
|
+
|
|
369
|
+
:param psi: The fraction of the population replaced by offspring in one large reproduction event
|
|
370
|
+
:param c: The rate of potential multiple merger events.
|
|
371
|
+
:param scale_time: Whether to scale coalescence time as described in
|
|
372
|
+
`Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
|
|
373
|
+
highlight=dirac+coalescent#msprime.DiracCoalescent>`__. If ``False``, the timescale is set to N.
|
|
374
|
+
"""
|
|
375
|
+
super().__init__()
|
|
376
|
+
|
|
377
|
+
if not 0 < psi < 1:
|
|
378
|
+
raise ValueError("Psi must be between 0 and 1.")
|
|
379
|
+
|
|
380
|
+
#: The fraction of the population replaced by offspring in one large reproduction event
|
|
381
|
+
self.psi = psi
|
|
382
|
+
|
|
383
|
+
#: The rate of potential multiple merger events.
|
|
384
|
+
self.c = c
|
|
385
|
+
|
|
386
|
+
#: Whether to scale coalescence time
|
|
387
|
+
self.scale_time: bool = scale_time
|
|
388
|
+
|
|
389
|
+
#: The standard coalescent model
|
|
390
|
+
self._standard = StandardCoalescent()
|
|
391
|
+
|
|
392
|
+
def _get_timescale(self, N: float) -> float:
|
|
393
|
+
"""
|
|
394
|
+
Get the timescale.
|
|
395
|
+
|
|
396
|
+
:param N: The effective population size.
|
|
397
|
+
:return: The generation time.
|
|
398
|
+
"""
|
|
399
|
+
if not self.scale_time:
|
|
400
|
+
return N
|
|
401
|
+
|
|
402
|
+
return N ** 2
|
|
403
|
+
|
|
404
|
+
def _get_rate(self, b: int, k: int) -> float:
|
|
405
|
+
"""
|
|
406
|
+
Get positive rate for a merger of k out of b lineages.
|
|
407
|
+
Negative rates will be filled in later.
|
|
408
|
+
|
|
409
|
+
:param b: The number of lineages before the merger.
|
|
410
|
+
:param k: The number of lineages that merge.
|
|
411
|
+
:return: The rate.
|
|
412
|
+
"""
|
|
413
|
+
# rate of binary merger
|
|
414
|
+
rate_binary = self._standard._get_rate(b=b, k=k)
|
|
415
|
+
|
|
416
|
+
# probability of multiple merger of k out of b lineages
|
|
417
|
+
p_psi = binom.pmf(k=k, n=b, p=self.psi)
|
|
418
|
+
|
|
419
|
+
# rate of multiple merger
|
|
420
|
+
rate_multi = p_psi * self.c
|
|
421
|
+
|
|
422
|
+
return rate_binary + rate_multi
|
|
423
|
+
|
|
424
|
+
def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
|
|
425
|
+
"""
|
|
426
|
+
Get positive rate for a merger of k_i out of b_i lineages for all i.
|
|
427
|
+
|
|
428
|
+
:param n: Number of lineages.
|
|
429
|
+
:param b: Number of lineages before merge for blocks that experience a merger.
|
|
430
|
+
:param k: Number of lineages that merge for blocks that experience a merger.
|
|
431
|
+
:return: The rate.
|
|
432
|
+
"""
|
|
433
|
+
# rate of binary merger
|
|
434
|
+
rate_binary = self._standard._get_rate_block_counting(n=n, b=b, k=k)
|
|
435
|
+
|
|
436
|
+
# probability of multiple merger of k out of n lineages
|
|
437
|
+
# p_psi = binom.pmf(k=k.sum(), n=n, p=self.psi)
|
|
438
|
+
p_psi = np.prod([binom.pmf(k=k[i], n=b[i], p=self.psi) for i in range(len(k))])
|
|
439
|
+
|
|
440
|
+
if sum(b) < n:
|
|
441
|
+
p_psi *= binom.pmf(k=0, n=n - sum(b), p=self.psi)
|
|
442
|
+
|
|
443
|
+
# rate of multiple merger
|
|
444
|
+
rate_multi = p_psi * self.c
|
|
445
|
+
|
|
446
|
+
rate = rate_binary + rate_multi
|
|
447
|
+
|
|
448
|
+
return rate
|
|
449
|
+
|
|
450
|
+
def __eq__(self, other):
|
|
451
|
+
"""
|
|
452
|
+
Check if two coalescent models are equal.
|
|
453
|
+
|
|
454
|
+
:param other: The other coalescent model.
|
|
455
|
+
:return: Whether the two coalescent models are equal.
|
|
456
|
+
"""
|
|
457
|
+
return (
|
|
458
|
+
isinstance(other, DiracCoalescent) and
|
|
459
|
+
self.psi == other.psi and
|
|
460
|
+
self.c == other.c and
|
|
461
|
+
self.scale_time == other.scale_time
|
|
462
|
+
)
|