msasim 25.11.9__cp312-cp312-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
msasim/sailfish.py ADDED
@@ -0,0 +1,580 @@
1
+ import _Sailfish
2
+ import os, warnings, math, operator, time, profile, tempfile, pathlib
3
+ from functools import reduce
4
+ from typing import List, Optional, Dict
5
+ from re import split
6
+ from enum import Enum
7
+
8
+
9
+
10
+ MODEL_CODES = _Sailfish.modelCode
11
+
12
+ class SIMULATION_TYPE(Enum):
13
+ NOSUBS = 0
14
+ DNA = 1
15
+ PROTEIN = 2
16
+
17
+ class Distribution:
18
+ def set_dist(self, dist):
19
+ # sum should be "around" 1
20
+ epsilon = 10e-6
21
+ if abs(sum(dist)-1) > epsilon:
22
+ raise ValueError(f"Sum of the distribution should be 1 for a valid probability distribution. Input received is: {dist}, sum is {sum(dist)}")
23
+ for x in dist:
24
+ if x < 0 or x > 1:
25
+ raise ValueError(f"Each value of the probabilities should be between 0 to 1. Received a value of {x}")
26
+ self._dist = _Sailfish.DiscreteDistribution(dist)
27
+
28
+ # def draw_sample(self) -> int:
29
+ # return self._dist.draw_sample()
30
+
31
+ # def set_seed(self, seed: int) -> None:
32
+ # return self._dist.set_seed(seed)
33
+
34
+ # def get_table(self) -> List:
35
+ # return self._dist.get_table()
36
+
37
+ def _get_Sailfish_dist(self) -> _Sailfish.DiscreteDistribution:
38
+ return self._dist
39
+
40
+ class CustomDistribution(Distribution):
41
+ '''
42
+ Provide a custom discrete distribution to the model.
43
+ '''
44
+ def __init__(self, dist: List[float]):
45
+ self.set_dist(dist)
46
+
47
+ class GeometricDistribution(Distribution):
48
+ def __init__(self, p: float, truncation: int = 150):
49
+ """
50
+ Calculation of geoemtric moment
51
+ inputs:
52
+ p - p parameter of the geoemtric distribution
53
+ truncation - (optional, by default 150) maximal value of the distribution
54
+ """
55
+ self.p = p
56
+ self.truncation = truncation
57
+ PMF = lambda x: p*(1-p)**(x-1)
58
+ CDF = lambda x: 1-(1-p)**x
59
+ norm_factor = CDF(truncation) - CDF(0)
60
+
61
+ probabilities = [PMF(i)/norm_factor for i in range(1, truncation+1)]
62
+ # probabilities = probabilities / norm_factor
63
+
64
+ self.set_dist(probabilities)
65
+
66
+ def __repr__(self) -> str:
67
+ return f"Geometric distribution: (p={self.p}, truncation{self.truncation})"
68
+
69
+ class PoissonDistribution(Distribution):
70
+ def __init__(self, p: float, truncation: int = 150):
71
+ """
72
+ Calculation of geoemtric moment
73
+ inputs:
74
+ p - p parameter of the geoemtric distribution
75
+ truncation - (optional, by default 150) maximal value of the distribution
76
+ """
77
+ self.p = p
78
+ self.truncation = truncation
79
+
80
+ factorial = lambda z: reduce(operator.mul, [1, 1] if z == 0 else range(1,z+1))
81
+
82
+ PMF = lambda x: ((p**x)*(math.e**-p))*(1.0/factorial(x))
83
+ CDF = lambda x: (math.e**-p)*sum([(p**i)*(1.0/factorial(i)) for i in range(0,x+1)])
84
+
85
+ norm_factor = CDF(truncation) - CDF(0)
86
+
87
+ probabilities = [PMF(i)/norm_factor for i in range(1, truncation+1)]
88
+
89
+ self.set_dist(probabilities)
90
+
91
+ def __repr__(self) -> str:
92
+ return f"Poisson distribution: (p={self.p}, truncation{self.truncation})"
93
+
94
+ class ZipfDistribution(Distribution):
95
+ def __init__(self, p: float, truncation: int = 150):
96
+ """
97
+ Calculation of geoemtric moment
98
+ inputs:
99
+ p - p parameter of the geoemtric distribution
100
+ truncation - (optional, by default 150) maximal value of the distribution
101
+ """
102
+ self.p = p
103
+ self.truncation = truncation
104
+
105
+ norm_factor = sum([(i**-p) for i in range(1,truncation+1)])
106
+ probabilities = [(i**-p)/norm_factor for i in range(1, truncation+1)]
107
+
108
+ self.set_dist(probabilities)
109
+
110
+ def __repr__(self) -> str:
111
+ return f"Zipf distribution: (p={self.p}, truncation{self.truncation})"
112
+
113
+ def is_newick(tree: str):
114
+ # from: https://github.com/ila/Newick-validator/blob/master/Newick_Validator.py
115
+ # dividing the string into tokens, to check them singularly
116
+ tokens = split(r'([A-Za-z]+[^A-Za-z,)]+[A-Za-z]+|[0-9.]*[A-Za-z]+[0-9.]+|[0-9.]+\s+[0-9.]+|[0-9.]+|[A-za-z]+|\(|\)|;|:|,)', tree)
117
+
118
+ # removing spaces and empty strings (spaces within labels are still present)
119
+ parsed_tokens = list(filter(lambda x: not (x.isspace() or not x), tokens))
120
+
121
+ # checking whether the tree ends with ;
122
+ if parsed_tokens[-1] != ';':
123
+ raise ValueError(f"Tree without ; at the end. Tree received: {tree}")
124
+ return False
125
+ return True
126
+
127
+ # TODO, I think should be deleted
128
+ class Block:
129
+ '''
130
+ A single block of event.
131
+ Used to add insertions or deletions.
132
+ '''
133
+ def __init__(self, num1: int, num2: int):
134
+ self.block = _Sailfish.Block(num1, num2)
135
+
136
+ class BlockTree:
137
+ '''
138
+ Used to contain the events on a multiple branches (entire tree).
139
+ '''
140
+ def __init__(self, root_length: int):
141
+ self.blockTree = _Sailfish.BlockTree(root_length)
142
+
143
+ def print_tree(self) -> str:
144
+ return self.blockTree.print_tree()
145
+
146
+ def block_list(self) -> List:
147
+ return self.blockTree.block_list()
148
+
149
+ # TODO delete one of this (I think the above if not used)
150
+ class BlockTreePython:
151
+ '''
152
+ Used to contain the events on a multiple branches (entire tree).
153
+ '''
154
+ def __init__(self, branch_block_dict: Dict[str, _Sailfish.Block]):
155
+ self._branch_block_dict = branch_block_dict
156
+ # dictionary of {str: List of blocks}
157
+ self._branch_block_dict_python = {i: x for i, x in branch_block_dict.items()}
158
+
159
+ def _get_Sailfish_blocks(self) -> Dict[str, _Sailfish.Block]:
160
+ return self._branch_block_dict
161
+
162
+ def get_branches_str(self) -> str:
163
+ return {i: self._branch_block_dict[i].print_tree() for i in list(self._branch_block_dict.keys())}
164
+
165
+ def get_specific_branch(self, branch: str) -> str:
166
+ if not branch in self._branch_block_dict_python:
167
+ raise ValueError(f"branch not in the _branch_block, aviable branches are: {list(self._branch_block_dict_python.keys())}")
168
+ return self._branch_block_dict[branch].print_tree()
169
+
170
+ def print_branches(self) -> str:
171
+ for i in list(self._branch_block_dict.keys()):
172
+ print(f"branch = {i}")
173
+ print(self._branch_block_dict[i].print_tree())
174
+
175
+ def block_list(self) -> List:
176
+ if not branch in self._branch_block_dict_python:
177
+ raise ValueError(f"branch not in the _branch_block, aviable branches are: {list(self._branch_block_dict_python.keys())}")
178
+ return self._branch_block_dict_python[branch]
179
+
180
+ class Tree:
181
+ '''
182
+ The tree class for the simulator
183
+ '''
184
+ def __init__(self, input_str: str):
185
+ is_from_file = False
186
+ if os.path.isfile(input_str):
187
+ is_from_file = True
188
+ tree_str = open(input_str, 'r').read()
189
+ else:
190
+ tree_str = input_str
191
+ if not is_newick(tree_str):
192
+ if is_from_file:
193
+ raise ValueError(f"Failed to read tree from file. File path: {input_str}, content: {tree_str}")
194
+ else:
195
+ raise ValueError(f"Failed construct tree from string. String received: {tree_str}")
196
+ self._tree = _Sailfish.Tree(input_str, is_from_file)
197
+ self._tree_str = tree_str
198
+
199
+ def get_num_nodes(self) -> int:
200
+ return self._tree.num_nodes
201
+
202
+ def get_num_leaves(self) -> int:
203
+ return self._tree.root.num_leaves
204
+
205
+ def _get_Sailfish_tree(self) -> _Sailfish.Tree:
206
+ return self._tree
207
+
208
+ def __repr__(self) -> str:
209
+ return f"{self._tree_str}"
210
+
211
+ class SimProtocol:
212
+ '''
213
+ The simulator protocol, sets the different distribution, tree and root length.
214
+ '''
215
+ def __init__(self, tree = None,
216
+ root_seq_size: int = 100,
217
+ deletion_rate: float = 0.0,
218
+ insertion_rate: float = 0.0,
219
+ deletion_dist: Distribution = ZipfDistribution(1.7, 50),
220
+ insertion_dist: Distribution = ZipfDistribution(1.7, 50),
221
+ minimum_seq_size: int = 100,
222
+ seed: int = 0,
223
+ ):
224
+ if isinstance(tree, Tree):
225
+ self._tree = tree
226
+ elif isinstance(tree, str):
227
+ self._tree = Tree(tree)
228
+ else:
229
+ raise ValueError(f"please provide one of the following: (1) a newick format of a tree; (2) a path to a file containing a tree; (3) or a tree created by the Tree class")
230
+
231
+ self._num_branches = self._tree.get_num_nodes() - 1
232
+ self._sim = _Sailfish.SimProtocol(self._tree._get_Sailfish_tree())
233
+ self.set_seed(seed)
234
+ self.set_sequence_size(root_seq_size)
235
+ self._is_deletion_rate_zero = not deletion_rate
236
+ self._is_insertion_rate_zero = not insertion_rate
237
+ self.set_deletion_rates(deletion_rate=deletion_rate)
238
+ self.set_insertion_rates(insertion_rate=insertion_rate)
239
+ self.set_deletion_length_distributions(deletion_dist=deletion_dist)
240
+ self.set_insertion_length_distributions(insertion_dist=insertion_dist)
241
+ self.set_min_sequence_size(min_sequence_size=minimum_seq_size)
242
+
243
+ def get_tree(self) -> Tree:
244
+ return self._tree
245
+
246
+ def _get_Sailfish_tree(self) -> _Sailfish.Tree:
247
+ return self._tree._get_Sailfish_tree()
248
+
249
+ def _get_root(self):
250
+ return self._tree._get_Sailfish_tree().root
251
+
252
+ def get_num_branches(self) -> int:
253
+ return self._num_branches
254
+
255
+ def set_seed(self, seed: int) -> None:
256
+ self._seed = seed
257
+ self._sim.set_seed(seed)
258
+
259
+ def get_seed(self) -> int:
260
+ return self._seed
261
+
262
+ def set_sequence_size(self, sequence_size: int) -> None:
263
+ self._sim.set_sequence_size(sequence_size)
264
+ self._root_seq_size = sequence_size
265
+
266
+ def get_sequence_size(self) -> int:
267
+ return self._root_seq_size
268
+
269
+ def set_min_sequence_size(self, min_sequence_size: int) -> None:
270
+ self._sim.set_minimum_sequence_size(min_sequence_size)
271
+ self._min_seq_size = min_sequence_size
272
+
273
+
274
+ def set_insertion_rates(self, insertion_rate: Optional[float] = None, insertion_rates: Optional[List[float]] = None) -> None:
275
+ if insertion_rate is not None:
276
+ self.insertion_rates = [insertion_rate] * self._num_branches
277
+ if insertion_rate:
278
+ self._is_insertion_rate_zero = False
279
+ elif insertion_rates:
280
+ if not len(insertion_rates) == self._num_branches:
281
+ raise ValueError(f"The length of the insertaion rates should be equal to the number of branches in the tree. The insertion_rates length is {len(insertion_rates)} and the number of branches is {self._num_branches}. You can pass a single value as insertion_rate which will be used for all branches.")
282
+ self.insertion_rates = insertion_rates
283
+ for insertion_rate in insertion_rates:
284
+ if insertion_rate:
285
+ self._is_insertion_rate_zero = False
286
+ else:
287
+ raise ValueError(f"please provide one of the following: insertion_rate (a single value used for all branches), or a insertion_rates (a list of values, each corresponding to a different branch)")
288
+
289
+ self._sim.set_insertion_rates(self.insertion_rates)
290
+
291
+ def get_insertion_rate(self, branch_num: int) -> float:
292
+ if branch_num >= self._num_branches:
293
+ raise ValueError(f"The branch number should be between 0 to {self._num_branches} (not included). Received value of {branch_num}")
294
+ return self._sim.get_insertion_rate(branch_num)
295
+
296
+ def get_all_insertion_rates(self) -> Dict:
297
+ return {i: self.get_insertion_rate(i) for i in range(self._num_branches)}
298
+
299
+ def set_deletion_rates(self, deletion_rate: Optional[float] = None, deletion_rates: Optional[List[float]] = None) -> None:
300
+ if deletion_rate is not None:
301
+ self.deletion_rates = [deletion_rate] * self._num_branches
302
+ if deletion_rate:
303
+ self._is_deletion_rate_zero = False
304
+ elif deletion_rates:
305
+ if not len(deletion_rates) == self._num_branches:
306
+ raise ValueError(f"The length of the deletion rates should be equal to the number of branches in the tree. The deletion_rates length is {len(deletion_rates)} and the number of branches is {self._num_branches}. You can pass a single value as deletion_rate which will be used for all branches.")
307
+ self.deletion_rates = deletion_rates
308
+ for deletion_rate in deletion_rates:
309
+ if deletion_rate:
310
+ self._is_deletion_rate_zero = False
311
+ else:
312
+ raise ValueError(f"please provide one of the following: deletion_rate (a single value used for all branches), or a deletion_rates (a list of values, each corresponding to a different branch)")
313
+
314
+ self._sim.set_deletion_rates(self.deletion_rates)
315
+
316
+ def get_deletion_rate(self, branch_num: int) -> float:
317
+ if branch_num >= self._num_branches:
318
+ raise ValueError(f"The branch number should be between 0 to {self._num_branches} (not included). Received value of {branch_num}")
319
+ return self._sim.get_deletion_rate(branch_num)
320
+
321
+ def get_all_deletion_rates(self) -> Dict:
322
+ return {i: self.get_deletion_rate(i) for i in range(self._num_branches)}
323
+
324
+ def set_insertion_length_distributions(self, insertion_dist: Optional[Distribution] = None, insertion_dists: Optional[List[Distribution]] = None) -> None:
325
+ if insertion_dist:
326
+ self.insertion_dists = [insertion_dist] * self._num_branches
327
+ elif insertion_dists:
328
+ if not len(insertion_dists) == self._num_branches:
329
+ raise ValueError(f"The length of the insertion dists should be equal to the number of branches in the tree. The insertion_dists length is {len(insertion_dists)} and the number of branches is {self._num_branches}. You can pass a single value as insertion_dist which will be used for all branches.")
330
+ self.insertion_dists = insertion_dists
331
+ else:
332
+ raise ValueError(f"please provide one of the following: deletion_rate (a single value used for all branches), or a deletion_rates (a list of values, each corresponding to a different branch)")
333
+
334
+ self._sim.set_insertion_length_distributions([dist._get_Sailfish_dist() for dist in self.insertion_dists])
335
+
336
+ def get_insertion_length_distribution(self, branch_num: int) -> Distribution:
337
+ if branch_num >= self._num_branches:
338
+ raise ValueError(f"The branch number should be between 0 to {self._num_branches} (not included). Received value of {branch_num}")
339
+ return self.insertion_dists[branch_num]
340
+
341
+ def get_all_insertion_length_distribution(self) -> Dict:
342
+ return {i: self.get_insertion_length_distribution(i) for i in range(self._num_branches)}
343
+
344
+ def set_deletion_length_distributions(self, deletion_dist: Optional[Distribution] = None, deletion_dists: Optional[List[Distribution]] = None) -> None:
345
+ if deletion_dist:
346
+ self.deletion_dists = [deletion_dist] * self._num_branches
347
+ elif deletion_dists:
348
+ if not len(deletion_dists) == self._num_branches:
349
+ raise ValueError(f"The length of the deletion dists should be equal to the number of branches in the tree. The deletion_dists length is {len(deletion_dists)} and the number of branches is {self._num_branches}. You can pass a single value as deletion_dist which will be used for all branches.")
350
+ self.deletion_dists = deletion_dists
351
+ else:
352
+ raise ValueError(f"please provide one of the following: deletion_rate (a single value used for all branches), or a deletion_rates (a list of values, each corresponding to a different branch)")
353
+
354
+ self._sim.set_deletion_length_distributions([dist._get_Sailfish_dist() for dist in self.deletion_dists])
355
+
356
+ def get_deletion_length_distribution(self, branch_num: int) -> Distribution:
357
+ if branch_num >= self._num_branches:
358
+ raise ValueError(f"The branch number should be between 0 to {self._num_branches} (not included). Received value of {branch_num}")
359
+ return self.deletion_dists[branch_num]
360
+
361
+ def get_all_deletion_length_distribution(self) -> Dict:
362
+ return {i: self.get_deletion_length_distribution(i) for i in range(self._num_branches)}
363
+
364
+ class Msa:
365
+ '''
366
+ The MSA class from the simulator
367
+ '''
368
+ def __init__(self, species_dict: Dict[str, BlockTree], root_node, save_list: List[bool]):
369
+ self._msa = _Sailfish.Msa(species_dict, root_node, save_list)
370
+
371
+ def generate_msas(self, node):
372
+ self._msa.generate_msas(node)
373
+
374
+ def get_length(self) -> int:
375
+ return self._msa.length()
376
+
377
+ def get_num_sequences(self) -> int:
378
+ return self._msa.num_sequences()
379
+
380
+ def fill_substitutions(self, sequenceContainer) -> None:
381
+ self._msa.fill_substitutions(sequenceContainer)
382
+
383
+ def print_msa(self) -> str:
384
+ return self._msa.print_msa()
385
+
386
+ def print_indels(self) -> str:
387
+ return self._msa.print_indels()
388
+
389
+ def get_msa(self) -> str:
390
+ return self._msa.get_msa_string()
391
+
392
+ def write_msa(self, file_path) -> None:
393
+ self._msa.write_msa(file_path)
394
+
395
+ #def __repr__(self) -> str:
396
+ # return f"{self.get_msa()}"
397
+
398
+ class Simulator:
399
+ '''
400
+ Simulate MSAs based on SimProtocol
401
+ '''
402
+ def __init__(self, simProtocol: Optional[SimProtocol] = None, simulation_type: Optional[SIMULATION_TYPE] = None):
403
+ if not simProtocol:
404
+ warnings.warn(f"initalized a simulator without simProtocol -> using a default protocol with Tree = '(A:0.01,B:0.5,C:0.03);' and root length of 100")
405
+ # default simulation values
406
+ possion = PoissonDistribution(10, 100)
407
+ simProtocol = SimProtocol(tree="(A:0.01,B:0.5,C:0.03);")
408
+ simProtocol.set_insertion_length_distributions(possion)
409
+ simProtocol.set_deletion_length_distributions(possion)
410
+ simProtocol.set_insertion_rates(0.05)
411
+ simProtocol.set_deletion_rates(0.05)
412
+ simProtocol.set_sequence_size(100)
413
+ simProtocol.set_min_sequence_size(1)
414
+
415
+ # verify sim_protocol
416
+ if self._verify_sim_protocol(simProtocol):
417
+ self._simProtocol = simProtocol
418
+ if simulation_type == SIMULATION_TYPE.PROTEIN:
419
+ self._simulator = _Sailfish.AminoSimulator(self._simProtocol._sim)
420
+ else:
421
+ self._simulator = _Sailfish.NucleotideSimulator(self._simProtocol._sim)
422
+ else:
423
+ raise ValueError(f"failed to verify simProtocol")
424
+
425
+ if not simulation_type:
426
+ warnings.warn(f"simulation type not provided -> running indel only simulation")
427
+ simulation_type = SIMULATION_TYPE.NOSUBS
428
+
429
+ if simulation_type == SIMULATION_TYPE.PROTEIN:
430
+ self._alphabet = _Sailfish.alphabetCode.AMINOACID
431
+ elif simulation_type == SIMULATION_TYPE.DNA:
432
+ self._alphabet = _Sailfish.alphabetCode.NUCLEOTIDE
433
+ elif simulation_type == SIMULATION_TYPE.NOSUBS:
434
+ self._alphabet = _Sailfish.alphabetCode.NULLCODE
435
+ else:
436
+ raise ValueError(f"unknown simulation type, please provde one of the following: {[e.name for e in SIMULATION_TYPE]}")
437
+
438
+ self._simulation_type = simulation_type
439
+ self._is_sub_model_init = False
440
+
441
+ def _verify_sim_protocol(self, simProtocol) -> bool:
442
+ if not simProtocol.get_tree():
443
+ raise ValueError(f"protocol miss tree, please provide when initalizing the simProtocol")
444
+ if not simProtocol.get_sequence_size() or simProtocol.get_sequence_size() == 0:
445
+ raise ValueError(f"protocol miss root length, please provide -> simProtocol.set_sequence_size(int)")
446
+ if not simProtocol.get_insertion_length_distribution(0):
447
+ raise ValueError(f"protocol miss insertion length distribution, please provide -> simProtocol.set_insertion_length_distributions(float)")
448
+ if not simProtocol.get_deletion_length_distribution(0):
449
+ raise ValueError(f"protocol miss deletion length distribution, please provide -> simProtocol.set_deletion_length_distributions(float)")
450
+ if simProtocol.get_insertion_rate(0) < 0:
451
+ raise ValueError(f"please provide a non zero value for insertion rate, provided value of: {simProtocol.get_insertion_rate(0)} -> simProtocol.set_insertion_rate(float)")
452
+ if simProtocol.get_deletion_rate(0) < 0:
453
+ raise ValueError(f"please provide a non zero value for deletion rate, provided value of: {simProtocol.get_deletion_rate(0)} -> simProtocol.set_deletion_rate(float)")
454
+ return True
455
+
456
+ def reset_sim(self):
457
+ # TODO, complete
458
+ pass
459
+
460
+ def _init_sub_model(self) -> None:
461
+ self._model_factory = _Sailfish.modelFactory(self._simProtocol._get_Sailfish_tree())
462
+ self._model_factory.set_alphabet(self._alphabet)
463
+ if self._simulation_type == SIMULATION_TYPE.PROTEIN:
464
+ warnings.warn(f"replacement matrix not provided -> running with default parameters: WAG model")
465
+ self._model_factory.set_replacement_model(_Sailfish.modelCode.WAG)
466
+ else:
467
+ warnings.warn(f"replacement matrix not provided -> running with default parameters: JC model")
468
+ self._model_factory.set_replacement_model(_Sailfish.modelCode.NUCJC)
469
+ self._model_factory.set_gamma_parameters(1.0, 1)
470
+
471
+ self._simulator.init_substitution_sim(self._model_factory)
472
+ self._is_sub_model_init = True
473
+
474
+ def set_replacement_model(
475
+ self,
476
+ model: _Sailfish.modelCode,
477
+ amino_model_file: pathlib.Path = None,
478
+ model_parameters: List = None,
479
+ gamma_parameters_alpha : float = 1.0,
480
+ gamma_parameters_categories: int = 1,
481
+ invariant_sites_proportion: float = 0.0,
482
+ site_rate_correlation: float = 0.0,
483
+ ) -> None:
484
+ if not model:
485
+ raise ValueError(f"please provide a substitution model from the the following list: {_Sailfish.modelCode}")
486
+ if int(gamma_parameters_categories) != gamma_parameters_categories:
487
+ raise ValueError(f"gamma_parameters_catergories has to be a positive int value: received value of {gamma_parameters_categories}")
488
+ self._model_factory = _Sailfish.modelFactory(self._simProtocol._get_Sailfish_tree())
489
+
490
+ self._model_factory.set_alphabet(self._alphabet)
491
+ if self._simulation_type == SIMULATION_TYPE.PROTEIN:
492
+ if model_parameters:
493
+ raise ValueError(f"no model parameters are used in protein, recevied value of: {model_parameters}")
494
+ self._model_factory.set_replacement_model(model)
495
+ if model == MODEL_CODES.CUSTOM and amino_model_file:
496
+ self._model_factory.set_amino_replacement_model_file(str(amino_model_file))
497
+ else:
498
+ if model == MODEL_CODES.NUCJC and model_parameters:
499
+ raise ValueError(f"no model parameters in JC model, recevied value of: {model_parameters}")
500
+ self._model_factory.set_replacement_model(model)
501
+ if model == MODEL_CODES.NUCJC and not model_parameters:
502
+ pass
503
+ elif not model_parameters:
504
+ raise ValueError(f"please provide a model parameters")
505
+ else:
506
+ self._model_factory.set_model_parameters(model_parameters)
507
+
508
+ self._model_factory.set_gamma_parameters(gamma_parameters_alpha, gamma_parameters_categories)
509
+ self._model_factory.set_invariant_sites_proportion(invariant_sites_proportion)
510
+ self._model_factory.set_site_rate_correlation(site_rate_correlation)
511
+
512
+ self._simulator.init_substitution_sim(self._model_factory)
513
+
514
+ self._is_sub_model_init = True
515
+
516
+ def gen_indels(self) -> BlockTreePython:
517
+ return BlockTreePython(self._simulator.gen_indels())
518
+
519
+ def get_sequences_to_save(self) -> List[bool]:
520
+ return self._simulator.get_saved_nodes_mask()
521
+
522
+ def save_root_sequence(self):
523
+ self._simulator.save_root_sequence()
524
+
525
+ def save_all_nodes_sequences(self):
526
+ self._simulator.save_all_nodes_sequences()
527
+
528
+ def gen_substitutions(self, length: int):
529
+ if not self._is_sub_model_init:
530
+ self._init_sub_model()
531
+ return self._simulator.gen_substitutions(length)
532
+
533
+ # @profile
534
+ def simulate(self, times: int = 1) -> List[Msa]:
535
+ Msas = []
536
+ for _ in range(times):
537
+ if self._simProtocol._is_insertion_rate_zero and self._simProtocol._is_deletion_rate_zero:
538
+ msa = Msa(sum(self.get_sequences_to_save()),
539
+ self._simProtocol.get_sequence_size(),
540
+ self.get_sequences_to_save())
541
+ else:
542
+ blocktree = self.gen_indels()
543
+ msa = Msa(blocktree._get_Sailfish_blocks(),
544
+ self._simProtocol._get_root(),
545
+ self.get_sequences_to_save())
546
+
547
+ # sim.init_substitution_sim(mFac)
548
+ if self._simulation_type != SIMULATION_TYPE.NOSUBS:
549
+ substitutions = self.gen_substitutions(msa.get_length())
550
+ msa.fill_substitutions(substitutions)
551
+
552
+ Msas.append(msa)
553
+ return Msas
554
+
555
+ def simulate_low_memory(self, output_file_path: pathlib.Path) -> Msa:
556
+ if self._simProtocol._is_insertion_rate_zero and self._simProtocol._is_deletion_rate_zero:
557
+ pass
558
+ else:
559
+ blocktree = self.gen_indels()
560
+ msa = Msa(blocktree._get_Sailfish_blocks(),
561
+ self._simProtocol._get_root(),
562
+ self.get_sequences_to_save())
563
+ self._simulator.set_aligned_sequence_map(msa._msa)
564
+
565
+ # sim.init_substitution_sim(mFac)
566
+ if self._simulation_type != SIMULATION_TYPE.NOSUBS:
567
+ self._simulator.gen_substitutions_to_file(self._simProtocol.get_sequence_size(),
568
+ str(output_file_path))
569
+ else:
570
+ msa.write_msa(str(output_file_path))
571
+
572
+
573
+ def __call__(self) -> Msa:
574
+ return self.simulate(1)[0]
575
+
576
+ def save_rates(self, is_save: bool) -> None:
577
+ self._simulator.save_site_rates(is_save)
578
+
579
+ def get_rates(self) -> List[float]:
580
+ return self._simulator.get_site_rates()