phasegen 0.0.3b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
phasegen/__init__.py ADDED
@@ -0,0 +1,225 @@
1
+ """
2
+ PhaseGen package.
3
+ """
4
+
5
+ __author__ = "Janek Sendrowski"
6
+ __contact__ = "sendrowski.janek@gmail.com"
7
+ __date__ = "2023-04-09"
8
+
9
+ __version__ = '0.0.3-beta'
10
+
11
+ import logging
12
+ import os
13
+ import sys
14
+
15
+ import jsonpickle.ext.numpy as jsonpickle_numpy
16
+ from tqdm import tqdm
17
+
18
+ # lower the verbosity of TensorFlow
19
+ if 'TF_CPP_MIN_LOG_LEVEL' not in os.environ:
20
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
21
+
22
+ # register handlers
23
+ jsonpickle_numpy.register_handlers()
24
+
25
+
26
+ class TqdmLoggingHandler(logging.Handler):
27
+ """
28
+ A logging handler that uses TQDM to display log messages.
29
+ """
30
+
31
+ def __init__(self, level=logging.NOTSET):
32
+ """
33
+ Initialize the handler.
34
+
35
+ :param level:
36
+ """
37
+ super().__init__(level)
38
+
39
+ def emit(self, record):
40
+ """
41
+ Emit a record.
42
+ """
43
+ try:
44
+ msg = self.format(record)
45
+
46
+ # we write to stderr to avoid as the progress bar
47
+ # to make the two work together
48
+ tqdm.write(msg, file=sys.stderr)
49
+ self.flush()
50
+ except Exception:
51
+ self.handleError(record)
52
+
53
+
54
+ class ColoredFormatter(logging.Formatter):
55
+ """
56
+ Colored formatter.
57
+ """
58
+
59
+ def __init__(self, *args, **kwargs):
60
+ """
61
+ Initialize the formatter.
62
+ """
63
+ super().__init__(*args, **kwargs)
64
+
65
+ self.colors = {
66
+ "DEBUG": "\033[36m", # Cyan
67
+ "INFO": "\033[32m", # Green
68
+ "WARNING": "\033[33m", # Yellow
69
+ "ERROR": "\033[31m", # Red
70
+ "CRITICAL": "\033[31m", # Red
71
+ }
72
+
73
+ self.reset = "\033[0m"
74
+
75
+ def format(self, record):
76
+ """
77
+ Format the record.
78
+ """
79
+ color = self.colors.get(record.levelname, self.reset)
80
+
81
+ formatted = super().format(record)
82
+
83
+ # remove package name
84
+ formatted = formatted.replace(record.name, record.name.split('.')[-1])
85
+
86
+ return f"{color}{formatted}{self.reset}"
87
+
88
+
89
+ # configure logger
90
+ logger = logging.getLogger('phasegen')
91
+
92
+ # don't propagate to the root logger
93
+ logger.propagate = False
94
+
95
+ # set to INFO by default
96
+ logger.setLevel(logging.INFO)
97
+
98
+ # let TQDM handle the logging
99
+ handler = TqdmLoggingHandler()
100
+
101
+ # define a Formatter with colors
102
+ formatter = ColoredFormatter('%(levelname)s:%(name)s: %(message)s')
103
+
104
+ handler.setFormatter(formatter)
105
+ logger.addHandler(handler)
106
+
107
+ from .distributions import PhaseTypeDistribution
108
+
109
+ from .distributions import Coalescent
110
+
111
+ from .demography import (
112
+ Demography,
113
+ Epoch,
114
+ DiscreteRateChanges,
115
+ PopSizeChanges,
116
+ PopSizeChange,
117
+ MigrationRateChanges,
118
+ MigrationRateChange,
119
+ SymmetricMigrationRateChanges,
120
+ PopulationSplit,
121
+ DiscretizedRateChanges,
122
+ DiscretizedRateChange,
123
+ ExponentialPopSizeChanges,
124
+ ExponentialRateChanges
125
+ )
126
+
127
+ from .coalescent_models import (
128
+ CoalescentModel,
129
+ StandardCoalescent,
130
+ BetaCoalescent,
131
+ DiracCoalescent
132
+ )
133
+
134
+ from .state_space import (
135
+ StateSpace,
136
+ DefaultStateSpace,
137
+ BlockCountingStateSpace
138
+ )
139
+
140
+ from .rewards import (
141
+ Reward,
142
+ DefaultReward,
143
+ NonDefaultReward,
144
+ TreeHeightReward,
145
+ TotalTreeHeightReward,
146
+ TotalBranchLengthReward,
147
+ UnfoldedSFSReward,
148
+ FoldedSFSReward,
149
+ CustomReward,
150
+ ProductReward,
151
+ SumReward,
152
+ CombinedReward,
153
+ DemeReward
154
+ )
155
+
156
+ from .spectrum import (
157
+ SFS,
158
+ Spectra,
159
+ SFS2
160
+ )
161
+
162
+ from .inference import Inference
163
+
164
+ from .lineage import LineageConfig
165
+
166
+ from .locus import LocusConfig
167
+
168
+ from .norms import (
169
+ LNorm,
170
+ L1Norm,
171
+ L2Norm,
172
+ LInfNorm,
173
+ PoissonLikelihood
174
+ )
175
+
176
+ from .state_space_old import StateSpace as OldStateSpace
177
+
178
+ __all__ = [
179
+ 'PhaseTypeDistribution',
180
+ 'Coalescent',
181
+ 'Demography',
182
+ 'Epoch',
183
+ 'PopSizeChanges',
184
+ 'PopSizeChange',
185
+ 'MigrationRateChanges',
186
+ 'MigrationRateChange',
187
+ 'SymmetricMigrationRateChanges',
188
+ 'PopulationSplit',
189
+ 'ExponentialPopSizeChanges',
190
+ 'ExponentialRateChanges',
191
+ 'DiscreteRateChanges',
192
+ 'DiscretizedRateChange',
193
+ 'DiscretizedRateChanges',
194
+ 'StandardCoalescent',
195
+ 'BetaCoalescent',
196
+ 'DiracCoalescent',
197
+ 'SFS2',
198
+ 'SFS',
199
+ 'Spectra',
200
+ 'Inference',
201
+ 'LNorm',
202
+ 'L1Norm',
203
+ 'L2Norm',
204
+ 'LInfNorm',
205
+ 'PoissonLikelihood',
206
+ 'Reward',
207
+ 'TreeHeightReward',
208
+ 'TotalTreeHeightReward',
209
+ 'TotalBranchLengthReward',
210
+ 'UnfoldedSFSReward',
211
+ 'FoldedSFSReward',
212
+ 'CustomReward',
213
+ 'ProductReward',
214
+ 'SumReward',
215
+ 'DemeReward',
216
+ 'DefaultReward',
217
+ 'NonDefaultReward',
218
+ 'CombinedReward',
219
+ 'StateSpace',
220
+ 'DefaultStateSpace',
221
+ 'BlockCountingStateSpace',
222
+ 'CoalescentModel',
223
+ 'LineageConfig',
224
+ 'LocusConfig',
225
+ ]
@@ -0,0 +1,462 @@
1
+ """
2
+ Coalescent models.
3
+ """
4
+ import itertools
5
+ from abc import ABC, abstractmethod
6
+ from typing import List, Tuple, Sequence
7
+
8
+ import numpy as np
9
+ from scipy.special import comb, beta
10
+ from scipy.stats import binom
11
+
12
+
13
+ class CoalescentModel(ABC):
14
+ """
15
+ Abstract class for coalescent models.
16
+ """
17
+
18
+ def get_rate(self, s1: int, s2: int) -> float:
19
+ """
20
+ Get rate for a merger collapsing k1 lineages into k2 lineages.
21
+
22
+ :param s1: Number of lineages in the first state.
23
+ :param s2: Number of lineages in the second state.
24
+ :return: The rate.
25
+ """
26
+ # not possible
27
+ if s2 > s1:
28
+ return 0
29
+
30
+ return self._get_rate(b=s1, k=s1 + 1 - s2)
31
+
32
+ def get_rate_block_counting(self, n: int, s1: np.ndarray, s2: np.ndarray) -> float:
33
+ r"""
34
+ Get (positive) rate between two block counting states.
35
+ :math:`{ (a_1,...,a_n) \in \mathbb{Z}^+ : \sum_{i=1}^{n} a_i = n \}`.
36
+
37
+ :param n: Number of lineages.
38
+ :param s1: Block configuration 1, a vector of length n.
39
+ :param s2: Block configuration 2, a vector of length n.
40
+ :return: The rate.
41
+ """
42
+ diff = s2 - s1
43
+
44
+ # make sure only one class has one more lineage
45
+ if np.sum(diff == 1) == 1 and n == s1.shape[0]:
46
+
47
+ # get the index for the class that lost lineages
48
+ where_less = np.where(diff < 0)[0]
49
+
50
+ # only continue if there is a class that lost lineages
51
+ if len(where_less) > 0:
52
+
53
+ # get the number of lineages that were lost
54
+ diff_less = -diff[where_less]
55
+
56
+ # determine the index of the class that gained lineages
57
+ i_more = np.dot(where_less + 1, diff_less) - 1
58
+
59
+ # make sure that the class that gained lineages only gained one lineage
60
+ if diff[i_more] == 1:
61
+ # number of lineages before the merger
62
+ b = s1[where_less]
63
+
64
+ # determine number of lineages that coalesce
65
+ k = b - s2[where_less]
66
+
67
+ # get rate
68
+ rate = self._get_rate_block_counting(n=s1.sum(), b=b, k=k)
69
+ return rate
70
+
71
+ return 0
72
+
73
+ @abstractmethod
74
+ def _get_timescale(self, N: float) -> float:
75
+ """
76
+ Get the timescale.
77
+
78
+ :param N: The effective population size.
79
+ :return: The generation time.
80
+ """
81
+ pass
82
+
83
+ @abstractmethod
84
+ def _get_rate(self, b: int, k: int) -> float:
85
+ """
86
+ Get positive rate for a merger of k out of b lineages.
87
+ Negative rates will be inferred later
88
+
89
+ :param b: Number of lineages.
90
+ :param k: Number of lineages that merge.
91
+ :return: The rate.
92
+ """
93
+ pass
94
+
95
+ @abstractmethod
96
+ def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
97
+ """
98
+ Get positive rate for a merger of k_i out of b_i lineages for all i.
99
+ Negative rates will be inferred later
100
+
101
+ :param n: Number of lineages.
102
+ :param b: Number of lineages before merge for blocks that experience a merger.
103
+ :param k: Number of lineages that merge for blocks that experience a merger.
104
+ :return: The rate.
105
+ """
106
+ pass
107
+
108
+ @abstractmethod
109
+ def coalesce(self, n: int, blocks: np.ndarray) -> List[Tuple[np.ndarray, float]]:
110
+ """
111
+ Coalesce a state.
112
+
113
+ :param n: The total number of lineages.
114
+ :param blocks: The lineages in each block.
115
+ :return: List of coalesced states and their rates.
116
+ """
117
+ pass
118
+
119
+
120
+ class StandardCoalescent(CoalescentModel):
121
+ """
122
+ Standard (Kingman) coalescent model. Refer to
123
+ `Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
124
+ highlight=standard+coalescent#msprime.StandardCoalescent>`__
125
+ for more information.
126
+ """
127
+
128
+ def _get_timescale(self, N: float) -> float:
129
+ """
130
+ Get the timescale.
131
+
132
+ :param N: The effective population size.
133
+ :return: The generation time.
134
+ """
135
+ return N
136
+
137
+ def _get_rate(self, b: int, k: int) -> float:
138
+ """
139
+ Get positive rate for a merger of k out of b lineages.
140
+
141
+ :param b: Number of lineages.
142
+ :param k: Number of lineages that merge.
143
+ :return: The rate.
144
+ """
145
+ # two lineages can merge with a rate depending on b
146
+ if k == 2:
147
+ return b * (b - 1) / 2
148
+
149
+ # no other mergers can happen
150
+ return 0
151
+
152
+ def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
153
+ """
154
+ Get positive rate for a merger of k_i out of b_i lineages for all i.
155
+
156
+ :param n: Number of lineages.
157
+ :param b: Number of lineages before merge for blocks that experience a merger.
158
+ :param k: Number of lineages that merge for blocks that experience a merger.
159
+ :return: The rate.
160
+ """
161
+ # if we have a single class
162
+ if len(b) == 1:
163
+ return self._get_rate(b=b[0], k=k[0])
164
+
165
+ # if we have a merger from two classes
166
+ if len(b) == 2:
167
+ if k[0] == 1 and k[1] == 1:
168
+ # same as b[0] choose k[0] times b[1] choose k[1]
169
+ return b[0] * b[1]
170
+
171
+ # no other mergers possible
172
+ return 0
173
+
174
+ def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
175
+ """
176
+ Coalesce a state.
177
+
178
+ :param n: The total number of lineages.
179
+ :param blocks: The lineages in each block.
180
+ :return: List of coalesced states and their rates.
181
+ """
182
+ n_blocks = len(blocks)
183
+ states = []
184
+
185
+ # default state space
186
+ if n_blocks == 1:
187
+ if blocks[0] > 1:
188
+ states += [(np.array([blocks[0] - 1]), self._get_rate(b=blocks[0], k=2))]
189
+
190
+ return states
191
+
192
+ # block counting state space
193
+ for i, j in itertools.product(range(n_blocks), repeat=2):
194
+ if i == j:
195
+ if blocks[i] > 1:
196
+ new = blocks.copy()
197
+ new[i] -= 2
198
+ new[2 * (i + 1) - 1] += 1
199
+ states += [(new, self._get_rate_block_counting(n=n, b=[blocks[i]], k=[2]))]
200
+
201
+ elif i > j:
202
+ if blocks[i] > 0 and blocks[j] > 0:
203
+ new = blocks.copy()
204
+ new[i] -= 1
205
+ new[j] -= 1
206
+ new[i + j + 1] += 1
207
+
208
+ rate = self._get_rate_block_counting(n=n, b=[blocks[i], blocks[j]], k=[1, 1])
209
+
210
+ states += [(new, rate)]
211
+
212
+ return states
213
+
214
+ def __eq__(self, other):
215
+ """
216
+ Check if two coalescent models are equal.
217
+
218
+ :param other: The other coalescent model.
219
+ :return: Whether the two coalescent models are equal.
220
+ """
221
+ return isinstance(other, StandardCoalescent)
222
+
223
+
224
+ class MultipleMergerCoalescent(CoalescentModel, ABC):
225
+ """
226
+ Base class for multiple merger coalescent models.
227
+
228
+ :meta private:
229
+ """
230
+
231
+ def coalesce(self, n: int, blocks: np.ndarray[int]) -> List[Tuple[np.ndarray, float]]:
232
+ """
233
+ Coalesce a state.
234
+
235
+ :param n: The total number of lineages.
236
+ :param blocks: The lineages in each block.
237
+ :return: List of coalesced states and their rates.
238
+ """
239
+ n_blocks = len(blocks)
240
+ states = []
241
+
242
+ # default state space
243
+ if n_blocks == 1:
244
+ for k in range(1, blocks[0]):
245
+ states += [(np.array([blocks[0] - k]), self._get_rate(b=blocks[0], k=k + 1))]
246
+
247
+ return states
248
+
249
+ # block counting state space
250
+ for comb in itertools.product(*[list(range(blocks[i] + 1)) for i in range(n_blocks)]):
251
+ comb = np.array(comb)
252
+
253
+ if comb.sum() > 1:
254
+ new = blocks.copy()
255
+ new -= comb
256
+ new[comb.dot(np.arange(1, n_blocks + 1)) - 1] += 1
257
+
258
+ rate = self._get_rate_block_counting(n=blocks.sum(), b=blocks[comb > 0], k=comb[comb > 0])
259
+ states += [(new, rate)]
260
+
261
+ return states
262
+
263
+
264
+ class BetaCoalescent(MultipleMergerCoalescent):
265
+ """
266
+ Beta coalescent model. Refer to
267
+ `Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=beta+coalescent#msprime.BetaCoalescent>`__
268
+ for more information.
269
+ """
270
+
271
+ def __init__(self, alpha: float, scale_time: bool = True):
272
+ """
273
+ Initialize the beta coalescent model.
274
+
275
+ :param alpha: The alpha parameter of the beta coalescent model.
276
+ :param scale_time: Whether to scale coalescence time as described in
277
+ `Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
278
+ highlight=beta+coalescent#msprime.BetaCoalescent>`__. If ``False``, the timescale is set to N.
279
+ """
280
+ if alpha < 1 or alpha > 2:
281
+ raise ValueError("Alpha must be between 1 and 2.")
282
+
283
+ #: Whether to scale coalescence time
284
+ self.scale_time: bool = scale_time
285
+
286
+ #: The alpha parameter of the beta coalescent model.
287
+ self.alpha: float = alpha
288
+
289
+ def _get_base_rate(self, b: int, k: int) -> float:
290
+ """
291
+ Get base rate for a merger of k out of b lineages (without number of ways).
292
+
293
+ :param b: The number of lineages before the merger.
294
+ :param k: The number of lineages that merge.
295
+ :return: The rate.
296
+ """
297
+ rate = beta(k - self.alpha, b - k + self.alpha) / beta(self.alpha, 2 - self.alpha)
298
+
299
+ return rate
300
+
301
+ def _get_timescale(self, N: float) -> float:
302
+ """
303
+ Get the timescale.
304
+
305
+ :param N: The effective population size.
306
+ :return: The generation time.
307
+ """
308
+ if not self.scale_time:
309
+ return N
310
+
311
+ m = 1 + 1 / 2 ** (self.alpha - 1) / (self.alpha - 1)
312
+
313
+ scale = m ** self.alpha * N ** (self.alpha - 1) / self.alpha / beta(2 - self.alpha, self.alpha)
314
+
315
+ return scale
316
+
317
+ def _get_rate(self, b: int, k: int) -> float:
318
+ """
319
+ Get positive rate for a merger of k out of b lineages.
320
+ Negative rates will be filled in later.
321
+
322
+ :param b: The number of lineages before the merger.
323
+ :param k: The number of lineages that merge.
324
+ :return: The rate.
325
+ """
326
+ if k < 1 or k > b:
327
+ return 0
328
+
329
+ return comb(b, k, exact=True) * self._get_base_rate(b, k)
330
+
331
+ def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
332
+ """
333
+ Get positive rate for a merger of k_i out of b_i lineages for all i.
334
+
335
+ :param n: Number of lineages.
336
+ :param b: Number of lineages before merge for blocks that experience a merger.
337
+ :param k: Number of lineages that merge for blocks that experience a merger.
338
+ :return: The rate.
339
+ """
340
+ combinations = np.prod([comb(N=b_i, k=k_i, exact=True) for b_i, k_i in zip(b, k)])
341
+
342
+ return combinations * self._get_base_rate(b=n, k=sum(k))
343
+
344
+ def __eq__(self, other):
345
+ """
346
+ Check if two coalescent models are equal.
347
+
348
+ :param other: The other coalescent model.
349
+ :return: Whether the two coalescent models are equal.
350
+ """
351
+ return (
352
+ isinstance(other, BetaCoalescent) and
353
+ self.alpha == other.alpha and
354
+ self.scale_time == other.scale_time
355
+ )
356
+
357
+
358
+ class DiracCoalescent(MultipleMergerCoalescent):
359
+ """
360
+ Dirac coalescent model. Refer to
361
+ `Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?highlight=dirac+coalescent#msprime.DiracCoalescent>`__
362
+ for more information.
363
+ """
364
+
365
+ def __init__(self, psi: float, c: float, scale_time: bool = True):
366
+ """
367
+ Initialize the Dirac coalescent model.
368
+
369
+ :param psi: The fraction of the population replaced by offspring in one large reproduction event
370
+ :param c: The rate of potential multiple merger events.
371
+ :param scale_time: Whether to scale coalescence time as described in
372
+ `Msprime docs <https://tskit.dev/msprime/docs/stable/api.html?
373
+ highlight=dirac+coalescent#msprime.DiracCoalescent>`__. If ``False``, the timescale is set to N.
374
+ """
375
+ super().__init__()
376
+
377
+ if not 0 < psi < 1:
378
+ raise ValueError("Psi must be between 0 and 1.")
379
+
380
+ #: The fraction of the population replaced by offspring in one large reproduction event
381
+ self.psi = psi
382
+
383
+ #: The rate of potential multiple merger events.
384
+ self.c = c
385
+
386
+ #: Whether to scale coalescence time
387
+ self.scale_time: bool = scale_time
388
+
389
+ #: The standard coalescent model
390
+ self._standard = StandardCoalescent()
391
+
392
+ def _get_timescale(self, N: float) -> float:
393
+ """
394
+ Get the timescale.
395
+
396
+ :param N: The effective population size.
397
+ :return: The generation time.
398
+ """
399
+ if not self.scale_time:
400
+ return N
401
+
402
+ return N ** 2
403
+
404
+ def _get_rate(self, b: int, k: int) -> float:
405
+ """
406
+ Get positive rate for a merger of k out of b lineages.
407
+ Negative rates will be filled in later.
408
+
409
+ :param b: The number of lineages before the merger.
410
+ :param k: The number of lineages that merge.
411
+ :return: The rate.
412
+ """
413
+ # rate of binary merger
414
+ rate_binary = self._standard._get_rate(b=b, k=k)
415
+
416
+ # probability of multiple merger of k out of b lineages
417
+ p_psi = binom.pmf(k=k, n=b, p=self.psi)
418
+
419
+ # rate of multiple merger
420
+ rate_multi = p_psi * self.c
421
+
422
+ return rate_binary + rate_multi
423
+
424
+ def _get_rate_block_counting(self, n: int, b: Sequence[int], k: Sequence[int]) -> float:
425
+ """
426
+ Get positive rate for a merger of k_i out of b_i lineages for all i.
427
+
428
+ :param n: Number of lineages.
429
+ :param b: Number of lineages before merge for blocks that experience a merger.
430
+ :param k: Number of lineages that merge for blocks that experience a merger.
431
+ :return: The rate.
432
+ """
433
+ # rate of binary merger
434
+ rate_binary = self._standard._get_rate_block_counting(n=n, b=b, k=k)
435
+
436
+ # probability of multiple merger of k out of n lineages
437
+ # p_psi = binom.pmf(k=k.sum(), n=n, p=self.psi)
438
+ p_psi = np.prod([binom.pmf(k=k[i], n=b[i], p=self.psi) for i in range(len(k))])
439
+
440
+ if sum(b) < n:
441
+ p_psi *= binom.pmf(k=0, n=n - sum(b), p=self.psi)
442
+
443
+ # rate of multiple merger
444
+ rate_multi = p_psi * self.c
445
+
446
+ rate = rate_binary + rate_multi
447
+
448
+ return rate
449
+
450
+ def __eq__(self, other):
451
+ """
452
+ Check if two coalescent models are equal.
453
+
454
+ :param other: The other coalescent model.
455
+ :return: Whether the two coalescent models are equal.
456
+ """
457
+ return (
458
+ isinstance(other, DiracCoalescent) and
459
+ self.psi == other.psi and
460
+ self.c == other.c and
461
+ self.scale_time == other.scale_time
462
+ )