LineageTree 1.8.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +27 -2
- LineageTree/legacy/export_csv.py +70 -0
- LineageTree/legacy/to_lineajea.py +30 -0
- LineageTree/legacy/to_motile.py +36 -0
- LineageTree/lineageTree.py +2268 -1467
- LineageTree/lineageTreeManager.py +749 -70
- LineageTree/loaders.py +942 -864
- LineageTree/test/test_lineageTree.py +634 -0
- LineageTree/test/test_uted.py +233 -0
- LineageTree/tree_approximation.py +488 -0
- LineageTree/utils.py +103 -103
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/METADATA +30 -34
- lineagetree-2.0.1.dist-info/RECORD +16 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/WHEEL +1 -1
- LineageTree/tree_styles.py +0 -334
- LineageTree-1.8.0.dist-info/RECORD +0 -11
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info/licenses}/LICENSE +0 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,17 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import pickle as pkl
|
3
5
|
import warnings
|
6
|
+
from collections.abc import Callable
|
4
7
|
from functools import partial
|
8
|
+
from typing import TYPE_CHECKING, Literal
|
5
9
|
|
10
|
+
import matplotlib.colors as mcolors
|
6
11
|
import numpy as np
|
12
|
+
from matplotlib import colormaps
|
13
|
+
|
14
|
+
from .tree_approximation import tree_style
|
7
15
|
|
8
16
|
try:
|
9
17
|
from edist import uted
|
@@ -12,51 +20,97 @@ except ImportError:
|
|
12
20
|
"No edist installed therefore you will not be able to compute the tree edit distance.",
|
13
21
|
stacklevel=2,
|
14
22
|
)
|
23
|
+
import matplotlib.pyplot as plt
|
24
|
+
from edist import uted
|
25
|
+
|
15
26
|
from LineageTree import lineageTree
|
27
|
+
from LineageTree.tree_approximation import TreeApproximationTemplate
|
28
|
+
|
29
|
+
from .utils import convert_style_to_number
|
16
30
|
|
17
|
-
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from edist.alignment import Alignment
|
18
33
|
|
19
34
|
|
20
35
|
class lineageTreeManager:
|
36
|
+
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
|
37
|
+
|
21
38
|
def __init__(self):
|
22
39
|
self.lineagetrees = {}
|
23
40
|
self.lineageTree_counter = 0
|
24
41
|
self.registered = {}
|
42
|
+
self._comparisons = {}
|
25
43
|
|
26
44
|
def __next__(self):
|
27
45
|
self.lineageTree_counter += 1
|
28
46
|
return self.lineageTree_counter - 1
|
29
47
|
|
48
|
+
def __len__(self):
|
49
|
+
"""Returns how many lineagetrees are in the manager.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
int
|
54
|
+
The number of trees inside the manager
|
55
|
+
"""
|
56
|
+
return len(self.lineagetrees)
|
57
|
+
|
58
|
+
def __iter__(
|
59
|
+
self,
|
60
|
+
):
|
61
|
+
yield from self.lineagetrees.items()
|
62
|
+
|
63
|
+
def __getitem__(self, key):
|
64
|
+
if key in self.lineagetrees:
|
65
|
+
return self.lineagetrees[key]
|
66
|
+
else:
|
67
|
+
raise KeyError(f"'{key}' not found in the manager")
|
68
|
+
|
30
69
|
@property
|
31
|
-
def gcd(self):
|
32
|
-
|
70
|
+
def gcd(self) -> int:
|
71
|
+
"""Calculates the greatesτ common divisor between all lineagetree resolutions in the manager.
|
72
|
+
|
73
|
+
Returns
|
74
|
+
-------
|
75
|
+
int
|
76
|
+
The overall greatest common divisor.
|
77
|
+
"""
|
78
|
+
if len(self) > 1:
|
33
79
|
all_time_res = [
|
34
80
|
embryo._time_resolution
|
35
81
|
for embryo in self.lineagetrees.values()
|
36
82
|
]
|
37
83
|
return np.gcd.reduce(all_time_res)
|
84
|
+
elif len(self):
|
85
|
+
return 1
|
86
|
+
else:
|
87
|
+
raise ValueError(
|
88
|
+
"You cannot calculate the greatest common divisor of time resolutions with an empty manager."
|
89
|
+
)
|
38
90
|
|
39
|
-
def add(
|
40
|
-
self, other_tree: lineageTree, name: str = "", classification: str = ""
|
41
|
-
):
|
91
|
+
def add(self, other_tree: lineageTree, name: str = ""):
|
42
92
|
"""Function that adds a new lineagetree object to the class.
|
43
93
|
Can be added either by .add or by using the + operator. If a name is
|
44
94
|
specified it will also add it as this specific name, otherwise it will
|
45
95
|
use the already existing name of the lineagetree.
|
46
96
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
97
|
+
Parameters
|
98
|
+
----------
|
99
|
+
other_tree : LineageTree
|
100
|
+
Thelineagetree to be added.
|
101
|
+
name : str, default=""
|
102
|
+
Then name of the lineagetree to be added, defaults to ''.
|
103
|
+
(Usually lineageTrees have the name of the path they are read from,
|
104
|
+
so this is going to be the name most of the times.)
|
51
105
|
"""
|
52
|
-
if isinstance(other_tree, lineageTree)
|
106
|
+
if isinstance(other_tree, lineageTree):
|
53
107
|
for tree in self.lineagetrees.values():
|
54
108
|
if tree == other_tree:
|
55
109
|
return False
|
56
110
|
if name:
|
57
111
|
self.lineagetrees[name] = other_tree
|
58
112
|
else:
|
59
|
-
if
|
113
|
+
if other_tree.name:
|
60
114
|
name = other_tree.name
|
61
115
|
self.lineagetrees[name] = other_tree
|
62
116
|
else:
|
@@ -74,8 +128,10 @@ class lineageTreeManager:
|
|
74
128
|
def write(self, fname: str):
|
75
129
|
"""Saves the manager
|
76
130
|
|
77
|
-
|
78
|
-
|
131
|
+
Parameters
|
132
|
+
----------
|
133
|
+
fname : str
|
134
|
+
The path and name of the file that is to be saved.
|
79
135
|
"""
|
80
136
|
if os.path.splitext(fname)[-1] != ".ltM":
|
81
137
|
fname = os.path.extsep.join((fname, "ltM"))
|
@@ -86,83 +142,119 @@ class lineageTreeManager:
|
|
86
142
|
def remove_embryo(self, key):
|
87
143
|
"""Removes the embryo from the manager.
|
88
144
|
|
89
|
-
|
90
|
-
|
145
|
+
Parameters
|
146
|
+
----------
|
147
|
+
key : str
|
148
|
+
The name of the lineagetree to be removed
|
91
149
|
|
92
|
-
Raises
|
93
|
-
|
150
|
+
Raises
|
151
|
+
------
|
152
|
+
IndexError
|
153
|
+
If there is no such lineagetree
|
94
154
|
"""
|
95
155
|
self.lineagetrees.pop(key, None)
|
96
156
|
|
97
157
|
@classmethod
|
98
|
-
def load(cls, fname: str):
|
99
|
-
"""
|
100
|
-
Loading a lineage tree Manager from a ".ltm" file.
|
158
|
+
def load(cls, fname: str) -> lineageTreeManager:
|
159
|
+
"""Loading a lineage tree Manager from a ".ltm" file.
|
101
160
|
|
102
|
-
|
103
|
-
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
fname : str
|
164
|
+
path to and name of the file to read
|
104
165
|
|
105
|
-
Returns
|
106
|
-
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
lineageTreeManager
|
169
|
+
loaded file
|
107
170
|
"""
|
108
171
|
with open(fname, "br") as f:
|
109
172
|
ltm = pkl.load(f)
|
110
173
|
f.close()
|
111
174
|
return ltm
|
112
175
|
|
113
|
-
def
|
176
|
+
def __cross_lineage_edit_backtrace(
|
114
177
|
self,
|
115
178
|
n1: int,
|
116
|
-
embryo_1,
|
117
|
-
end_time1: int,
|
179
|
+
embryo_1: str,
|
118
180
|
n2: int,
|
119
|
-
embryo_2,
|
120
|
-
|
121
|
-
|
181
|
+
embryo_2: str,
|
182
|
+
end_time1: int | None = None,
|
183
|
+
end_time2: int | None = None,
|
184
|
+
style: (
|
185
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
186
|
+
| type[TreeApproximationTemplate]
|
187
|
+
) = "simple",
|
188
|
+
norm: Literal["max", "sum", None] = "max",
|
122
189
|
downsample: int = 2,
|
123
|
-
|
124
|
-
|
190
|
+
) -> dict[
|
191
|
+
str,
|
192
|
+
Alignment
|
193
|
+
| tuple[TreeApproximationTemplate, TreeApproximationTemplate],
|
194
|
+
]:
|
125
195
|
"""Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
126
196
|
by two nodes `n1` from lineagetree1 and `n2` lineagetree2. The topology of the trees
|
127
197
|
are compared and the matching cost is given by the function delta (see edist doc for
|
128
198
|
more information).The distance is normed by the function norm that takes the two list
|
129
199
|
of nodes spawned by the trees `n1` and `n2`.
|
130
200
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
201
|
+
Parameters
|
202
|
+
----------
|
203
|
+
n1 : int
|
204
|
+
Node of the first Lineagetree
|
205
|
+
embryo_1 : str
|
206
|
+
The key/name of the first Lineagetree
|
207
|
+
n2 : int
|
208
|
+
The key/name of the first Lineagetree
|
209
|
+
embryo_2 : str
|
210
|
+
Node of the second Lineagetree
|
211
|
+
end_time1 : int, optional
|
212
|
+
The final time point the comparison algorithm will take into account for the first dataset.
|
213
|
+
If None or not provided all nodes will be taken into account.
|
214
|
+
end_time2 : int, optional
|
215
|
+
The final time point the comparison algorithm will take into account for the second dataset.
|
216
|
+
If None or not provided all nodes will be taken into account.
|
217
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
218
|
+
The approximation used to calculate the tree.
|
219
|
+
norm : {"max","sum", "None"}, default="max"
|
220
|
+
The normalization method used (Not important for this function)
|
221
|
+
downsample : int, default==2
|
222
|
+
The downsample factor for the downsampled tree approximation.
|
223
|
+
Used only when `style="downsampled"`.
|
140
224
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
]
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
|
228
|
+
- 'alignment'
|
229
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
|
230
|
+
- 'trees'
|
231
|
+
A list of the two trees that have been mapped to each other.
|
232
|
+
"""
|
233
|
+
if (
|
234
|
+
self[embryo_1].time_resolution <= 0
|
235
|
+
or self[embryo_2].time_resolution <= 0
|
236
|
+
):
|
237
|
+
raise Warning("Resolution cannot be <=0 ")
|
238
|
+
parameters = (
|
239
|
+
(end_time1, end_time2),
|
240
|
+
convert_style_to_number(style, downsample),
|
241
|
+
)
|
242
|
+
n1_embryo, n2_embryo = sorted(
|
243
|
+
((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
|
244
|
+
)
|
245
|
+
self._comparisons.setdefault(parameters, {})
|
246
|
+
if isinstance(style, str):
|
247
|
+
tree = tree_style[style].value
|
248
|
+
elif issubclass(style, TreeApproximationTemplate):
|
249
|
+
tree = style
|
160
250
|
else:
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
]
|
165
|
-
|
251
|
+
raise Warning("Use a valid approximation.")
|
252
|
+
time_res = tree.handle_resolutions(
|
253
|
+
time_resolution1=self.lineagetrees[embryo_1]._time_resolution,
|
254
|
+
time_resolution2=self.lineagetrees[embryo_2]._time_resolution,
|
255
|
+
gcd=self.gcd,
|
256
|
+
downsample=downsample,
|
257
|
+
)
|
166
258
|
tree1 = tree(
|
167
259
|
lT=self.lineagetrees[embryo_1],
|
168
260
|
downsample=downsample,
|
@@ -184,8 +276,11 @@ class lineageTreeManager:
|
|
184
276
|
nodes1, adj1, corres1 = tree1.edist
|
185
277
|
nodes2, adj2, corres2 = tree2.edist
|
186
278
|
if len(nodes1) == len(nodes2) == 0:
|
187
|
-
|
188
|
-
|
279
|
+
self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
|
280
|
+
"alignment": (),
|
281
|
+
"trees": (),
|
282
|
+
}
|
283
|
+
return self._comparisons[parameters][(n1_embryo, n2_embryo)]
|
189
284
|
delta_tmp = partial(
|
190
285
|
delta,
|
191
286
|
corres1=corres1,
|
@@ -193,6 +288,590 @@ class lineageTreeManager:
|
|
193
288
|
corres2=corres2,
|
194
289
|
times2=times2,
|
195
290
|
)
|
196
|
-
|
197
|
-
|
291
|
+
btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
|
292
|
+
|
293
|
+
self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
|
294
|
+
"alignment": btrc,
|
295
|
+
"trees": (tree1, tree2),
|
296
|
+
}
|
297
|
+
return self._comparisons[parameters][(n1_embryo, n2_embryo)]
|
298
|
+
|
299
|
+
def __calculate_distance_of_sub_tree(
|
300
|
+
self,
|
301
|
+
node1: int,
|
302
|
+
lT1: lineageTree,
|
303
|
+
node2: int,
|
304
|
+
lT2: lineageTree,
|
305
|
+
alignment: Alignment,
|
306
|
+
corres1: dict,
|
307
|
+
corres2: dict,
|
308
|
+
delta_tmp: Callable,
|
309
|
+
norm: Callable,
|
310
|
+
norm1: int | float,
|
311
|
+
norm2: int | float,
|
312
|
+
) -> float:
|
313
|
+
"""Calculates the distance of the subtree of a node matched in a comparison.
|
314
|
+
DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
|
315
|
+
|
316
|
+
TODO ITS BOUND TO CHANGE
|
317
|
+
|
318
|
+
Parameters
|
319
|
+
----------
|
320
|
+
node1 : int
|
321
|
+
The root of the first subtree
|
322
|
+
lT1 : lineageTree
|
323
|
+
The dataset the first lineage exists
|
324
|
+
node2 : int
|
325
|
+
The root of the first subtree
|
326
|
+
lT2 : lineageTree
|
327
|
+
The dataset the second lineage exists
|
328
|
+
alignment : Alignment
|
329
|
+
The alignment of the subtree
|
330
|
+
corres1 : dict
|
331
|
+
The correspndance dictionary of the first lineage
|
332
|
+
corres2 : dict
|
333
|
+
The correspondance dictionary of the second lineage
|
334
|
+
delta_tmp : Callable
|
335
|
+
The delta function for the comparisons
|
336
|
+
norm : Callable
|
337
|
+
How should the lineages be normalized
|
338
|
+
norm1 : int or float
|
339
|
+
The result of the normalization of the first tree
|
340
|
+
norm2 : int or float
|
341
|
+
The result of the normalization of the second tree
|
342
|
+
|
343
|
+
Returns
|
344
|
+
-------
|
345
|
+
float
|
346
|
+
The result of the comparison of the subtree
|
347
|
+
"""
|
348
|
+
sub_tree_1 = set(lT1.get_subtree_nodes(node1))
|
349
|
+
sub_tree_2 = set(lT2.get_subtree_nodes(node2))
|
350
|
+
res = 0
|
351
|
+
for m in alignment:
|
352
|
+
if (
|
353
|
+
corres1.get(m._left, -1) in sub_tree_1
|
354
|
+
or corres2.get(m._right, -1) in sub_tree_2
|
355
|
+
):
|
356
|
+
res += delta_tmp(
|
357
|
+
m._left if m._left != -1 else None,
|
358
|
+
m._right if m._right != -1 else None,
|
359
|
+
)
|
360
|
+
return res / norm([norm1, norm2])
|
361
|
+
|
362
|
+
def clear_comparisons(self):
|
363
|
+
self._comparisons.clear()
|
364
|
+
|
365
|
+
def cross_lineage_edit_distance(
|
366
|
+
self,
|
367
|
+
n1: int,
|
368
|
+
embryo_1: str,
|
369
|
+
n2: int,
|
370
|
+
embryo_2: str,
|
371
|
+
end_time1: int | None = None,
|
372
|
+
end_time2: int | None = None,
|
373
|
+
norm: Literal["max", "sum", None] = "max",
|
374
|
+
style: (
|
375
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
376
|
+
| type[TreeApproximationTemplate]
|
377
|
+
) = "simple",
|
378
|
+
downsample: int = 2,
|
379
|
+
return_norms: bool = False,
|
380
|
+
) -> float | tuple[float, tuple[float, float]]:
|
381
|
+
"""
|
382
|
+
Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
|
383
|
+
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
384
|
+
cost is given by the function delta (see edist doc for more information). There are
|
385
|
+
5 styles available (tree approximations) and the user may add their own.
|
386
|
+
|
387
|
+
Parameters
|
388
|
+
----------
|
389
|
+
n1 : int
|
390
|
+
id of the first node to compare
|
391
|
+
embryo_1 : str
|
392
|
+
the name of the first embryo to be used. (from lTm.lineagetrees.keys())
|
393
|
+
n2 : int
|
394
|
+
id of the second node to compare
|
395
|
+
embryo_2 : str
|
396
|
+
the name of the second embryo to be used. (from lTm.lineagetrees.keys())
|
397
|
+
end_time_1 : int, optional
|
398
|
+
the final time point the comparison algorithm will take into account for the first dataset.
|
399
|
+
If None or not provided all nodes will be taken into account.
|
400
|
+
end_time_2 : int, optional
|
401
|
+
the final time point the comparison algorithm will take into account for the second dataset.
|
402
|
+
If None or not provided all nodes will be taken into account.
|
403
|
+
norm : {"max", "sum"}, default="max"
|
404
|
+
The normalization method to use, defaults to 'max'.
|
405
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
406
|
+
Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
|
407
|
+
downsample : int, default=2
|
408
|
+
The downsample factor for the downsampled tree approximation.
|
409
|
+
Used only when `style="downsampled"`.
|
410
|
+
return_norms : bool
|
411
|
+
Decide if the norms will be returned explicitly (mainly used for the napari plugin)
|
412
|
+
|
413
|
+
Returns
|
414
|
+
-------
|
415
|
+
Alignment
|
416
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
|
417
|
+
--
|
418
|
+
ΟΡ
|
419
|
+
--
|
420
|
+
|
421
|
+
Alignment
|
422
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
|
423
|
+
tuple(tree,tree)
|
424
|
+
The two trees that have been mapped to each other.
|
425
|
+
"""
|
426
|
+
|
427
|
+
parameters = (
|
428
|
+
(end_time1, end_time2),
|
429
|
+
convert_style_to_number(style, downsample),
|
430
|
+
)
|
431
|
+
n1_embryo, n2_embryo = sorted(
|
432
|
+
((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
|
433
|
+
)
|
434
|
+
self._comparisons.setdefault(parameters, {})
|
435
|
+
if self._comparisons[parameters].get((n1, n2)):
|
436
|
+
tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
|
437
|
+
else:
|
438
|
+
tmp = self.__cross_lineage_edit_backtrace(
|
439
|
+
n1,
|
440
|
+
embryo_1,
|
441
|
+
n2,
|
442
|
+
embryo_2,
|
443
|
+
end_time1,
|
444
|
+
end_time2,
|
445
|
+
style,
|
446
|
+
norm,
|
447
|
+
downsample,
|
448
|
+
)
|
449
|
+
if len(self._comparisons) > 100:
|
450
|
+
warnings.warn(
|
451
|
+
"More than 100 comparisons are saved, use clear_comparisons() to delete them.",
|
452
|
+
stacklevel=2,
|
453
|
+
)
|
454
|
+
btrc = tmp["alignment"]
|
455
|
+
tree1, tree2 = tmp["trees"]
|
456
|
+
_, times1 = tree1.tree
|
457
|
+
_, times2 = tree2.tree
|
458
|
+
(
|
459
|
+
nodes1,
|
460
|
+
adj1,
|
461
|
+
corres1,
|
462
|
+
) = tree1.edist
|
463
|
+
(
|
464
|
+
nodes2,
|
465
|
+
adj2,
|
466
|
+
corres2,
|
467
|
+
) = tree2.edist
|
468
|
+
if len(nodes1) == len(nodes2) == 0:
|
469
|
+
self._comparisons[hash(frozenset(parameters))] = {
|
470
|
+
"alignment": (),
|
471
|
+
"trees": (),
|
472
|
+
}
|
473
|
+
return self._comparisons[hash(frozenset(parameters))]
|
474
|
+
delta_tmp = partial(
|
475
|
+
tree1.delta,
|
476
|
+
corres1=corres1,
|
477
|
+
corres2=corres2,
|
478
|
+
times1=times1,
|
479
|
+
times2=times2,
|
480
|
+
)
|
481
|
+
if norm not in self.norm_dict:
|
482
|
+
raise ValueError(
|
483
|
+
"Select a viable normalization method (max, sum, None)"
|
484
|
+
)
|
485
|
+
cost = btrc.cost(nodes1, nodes2, delta_tmp)
|
486
|
+
norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
|
487
|
+
if return_norms:
|
488
|
+
return cost, norm_values
|
489
|
+
return cost / self.norm_dict[norm](norm_values)
|
490
|
+
|
491
|
+
def plot_tree_distance_graphs(
|
492
|
+
self,
|
493
|
+
n1: int,
|
494
|
+
embryo_1: str,
|
495
|
+
n2: int,
|
496
|
+
embryo_2: str,
|
497
|
+
end_time1: int | None = None,
|
498
|
+
end_time2: int | None = None,
|
499
|
+
norm: Literal["max", "sum", None] = "max",
|
500
|
+
style: (
|
501
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
502
|
+
| type[TreeApproximationTemplate]
|
503
|
+
) = "simple",
|
504
|
+
downsample: int = 2,
|
505
|
+
colormap: str = "cool",
|
506
|
+
default_color: str = "black",
|
507
|
+
size: float = 10,
|
508
|
+
lw: float = 0.3,
|
509
|
+
ax: np.ndarray | None = None,
|
510
|
+
) -> tuple[plt.figure, plt.Axes]:
|
511
|
+
"""
|
512
|
+
Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
|
513
|
+
|
514
|
+
Parameters
|
515
|
+
----------
|
516
|
+
n1 : int
|
517
|
+
id of the first node to compare
|
518
|
+
embryo_1 : str
|
519
|
+
the name of the first embryo
|
520
|
+
n2 : int
|
521
|
+
id of the second node to compare
|
522
|
+
embryo_2 : str
|
523
|
+
the name of the second embryo
|
524
|
+
end_time1 : int, optional
|
525
|
+
the final time point the comparison algorithm will take into account for the first dataset.
|
526
|
+
If None or not provided all nodes will be taken into account.
|
527
|
+
end_time2 : int, optional
|
528
|
+
the final time point the comparison algorithm will take into account for the second dataset.
|
529
|
+
If None or not provided all nodes will be taken into account.
|
530
|
+
norm : {"max", "sum"}, default="max"
|
531
|
+
The normalization method to use.
|
532
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
533
|
+
Which tree approximation is going to be used for the comparisons.
|
534
|
+
downsample : int, default=2
|
535
|
+
The downsample factor for the downsampled tree approximation.
|
536
|
+
Used only when `style="downsampled"`.
|
537
|
+
colormap : str, default="cool"
|
538
|
+
The colormap used for matched nodes, defaults to "cool"
|
539
|
+
default_color : str
|
540
|
+
The color of the unmatched nodes, defaults to "black"
|
541
|
+
size : float
|
542
|
+
The size of the nodes, defaults to 10
|
543
|
+
lw : float
|
544
|
+
The width of the edges, defaults to 0.3
|
545
|
+
ax : np.ndarray, optional
|
546
|
+
The axes used, if not provided another set of axes is produced, defaults to None
|
547
|
+
|
548
|
+
Returns
|
549
|
+
-------
|
550
|
+
plt.Figure
|
551
|
+
The matplotlib figure
|
552
|
+
plt.Axes
|
553
|
+
The matplotlib axes
|
554
|
+
"""
|
555
|
+
|
556
|
+
parameters = (
|
557
|
+
(end_time1, end_time2),
|
558
|
+
convert_style_to_number(style, downsample),
|
559
|
+
)
|
560
|
+
n1_embryo, n2_embryo = sorted(
|
561
|
+
((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
|
562
|
+
)
|
563
|
+
self._comparisons.setdefault(parameters, {})
|
564
|
+
if self._comparisons[parameters].get((n1, n2)):
|
565
|
+
tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
|
566
|
+
else:
|
567
|
+
tmp = self.__cross_lineage_edit_backtrace(
|
568
|
+
n1,
|
569
|
+
embryo_1,
|
570
|
+
n2,
|
571
|
+
embryo_2,
|
572
|
+
end_time1,
|
573
|
+
end_time2,
|
574
|
+
style,
|
575
|
+
norm,
|
576
|
+
downsample,
|
577
|
+
)
|
578
|
+
btrc = tmp["alignment"]
|
579
|
+
tree1, tree2 = tmp["trees"]
|
580
|
+
_, times1 = tree1.tree
|
581
|
+
_, times2 = tree2.tree
|
582
|
+
(
|
583
|
+
*_,
|
584
|
+
corres1,
|
585
|
+
) = tree1.edist
|
586
|
+
(
|
587
|
+
*_,
|
588
|
+
corres2,
|
589
|
+
) = tree2.edist
|
590
|
+
delta_tmp = partial(
|
591
|
+
tree1.delta,
|
592
|
+
corres1=corres1,
|
593
|
+
corres2=corres2,
|
594
|
+
times1=times1,
|
595
|
+
times2=times2,
|
596
|
+
)
|
597
|
+
if norm not in self.norm_dict:
|
598
|
+
raise Warning(
|
599
|
+
"Select a viable normalization method (max, sum, None)"
|
600
|
+
)
|
601
|
+
matched_right = []
|
602
|
+
matched_left = []
|
603
|
+
colors1 = {}
|
604
|
+
colors2 = {}
|
605
|
+
if style not in ("full", "downsampled"):
|
606
|
+
for m in btrc:
|
607
|
+
if m._left != -1 and m._right != -1:
|
608
|
+
cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
|
609
|
+
if len(cyc1) > 1:
|
610
|
+
node_1, *_, l_node_1 = cyc1
|
611
|
+
matched_left.append(node_1)
|
612
|
+
matched_left.append(l_node_1)
|
613
|
+
elif len(cyc1) == 1:
|
614
|
+
node_1 = l_node_1 = cyc1.pop()
|
615
|
+
matched_left.append(node_1)
|
616
|
+
|
617
|
+
cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
|
618
|
+
if len(cyc2) > 1:
|
619
|
+
node_2, *_, l_node_2 = cyc2
|
620
|
+
matched_right.append(node_2)
|
621
|
+
matched_right.append(l_node_2)
|
622
|
+
|
623
|
+
elif len(cyc2) == 1:
|
624
|
+
node_2 = l_node_2 = cyc2.pop()
|
625
|
+
matched_right.append(node_2)
|
626
|
+
|
627
|
+
colors1[node_1] = self.__calculate_distance_of_sub_tree(
|
628
|
+
node_1,
|
629
|
+
tree1.lT,
|
630
|
+
node_2,
|
631
|
+
tree2.lT,
|
632
|
+
btrc,
|
633
|
+
corres1,
|
634
|
+
corres2,
|
635
|
+
delta_tmp,
|
636
|
+
self.norm_dict[norm],
|
637
|
+
tree1.get_norm(node_1),
|
638
|
+
tree2.get_norm(node_2),
|
639
|
+
)
|
640
|
+
colors2[node_2] = colors1[node_1]
|
641
|
+
colors1[l_node_1] = colors1[node_1]
|
642
|
+
colors2[l_node_2] = colors2[node_2]
|
643
|
+
|
644
|
+
else:
|
645
|
+
for m in btrc:
|
646
|
+
if m._left != -1 and m._right != -1:
|
647
|
+
node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0]
|
648
|
+
node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0]
|
649
|
+
if (
|
650
|
+
tree1.lT.get_chain_of_node(node_1)[0] == node_1
|
651
|
+
or tree2.lT.get_chain_of_node(node_2)[0] == node_2
|
652
|
+
and (node_1 not in colors1 or node_2 not in colors2)
|
653
|
+
):
|
654
|
+
matched_left.append(node_1)
|
655
|
+
l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1]
|
656
|
+
matched_left.append(l_node_1)
|
657
|
+
matched_right.append(node_2)
|
658
|
+
l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1]
|
659
|
+
matched_right.append(l_node_2)
|
660
|
+
colors1[node_1] = (
|
661
|
+
self.__calculate_distance_of_sub_tree(
|
662
|
+
node_1,
|
663
|
+
tree1.lT,
|
664
|
+
node_2,
|
665
|
+
tree2.lT,
|
666
|
+
btrc,
|
667
|
+
corres1,
|
668
|
+
corres2,
|
669
|
+
delta_tmp,
|
670
|
+
self.norm_dict[norm],
|
671
|
+
tree1.get_norm(node_1),
|
672
|
+
tree2.get_norm(node_2),
|
673
|
+
)
|
674
|
+
)
|
675
|
+
colors2[node_2] = colors1[node_1]
|
676
|
+
colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = (
|
677
|
+
colors1[node_1]
|
678
|
+
)
|
679
|
+
colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = (
|
680
|
+
colors2[node_2]
|
681
|
+
)
|
682
|
+
|
683
|
+
if tree1.lT.get_chain_of_node(node_1)[-1] != node_1:
|
684
|
+
matched_left.append(
|
685
|
+
tree1.lT.get_chain_of_node(node_1)[-1]
|
686
|
+
)
|
687
|
+
if tree2.lT.get_chain_of_node(node_2)[-1] != node_2:
|
688
|
+
matched_right.append(
|
689
|
+
tree2.lT.get_chain_of_node(node_2)[-1]
|
690
|
+
)
|
691
|
+
if ax is None:
|
692
|
+
fig, ax = plt.subplots(nrows=1, ncols=2)
|
693
|
+
cmap = colormaps[colormap]
|
694
|
+
c_norm = mcolors.Normalize(0, 1)
|
695
|
+
colors1 = {c: cmap(c_norm(v)) for c, v in colors1.items()}
|
696
|
+
colors2 = {c: cmap(c_norm(v)) for c, v in colors2.items()}
|
697
|
+
tree1.lT.plot_subtree(
|
698
|
+
tree1.lT.get_ancestor_at_t(n1),
|
699
|
+
end_time=end_time1,
|
700
|
+
size=size,
|
701
|
+
color_of_nodes=colors1,
|
702
|
+
color_of_edges=colors1,
|
703
|
+
default_color=default_color,
|
704
|
+
lw=lw,
|
705
|
+
ax=ax[0],
|
706
|
+
)
|
707
|
+
tree2.lT.plot_subtree(
|
708
|
+
tree2.lT.get_ancestor_at_t(n2),
|
709
|
+
end_time=end_time2,
|
710
|
+
size=size,
|
711
|
+
color_of_nodes=colors2,
|
712
|
+
color_of_edges=colors2,
|
713
|
+
default_color=default_color,
|
714
|
+
lw=lw,
|
715
|
+
ax=ax[1],
|
198
716
|
)
|
717
|
+
return ax[0].get_figure(), ax
|
718
|
+
|
719
|
+
def labelled_mappings(
|
720
|
+
self,
|
721
|
+
n1: int,
|
722
|
+
embryo_1: str,
|
723
|
+
n2: int,
|
724
|
+
embryo_2: str,
|
725
|
+
end_time1: int | None = None,
|
726
|
+
end_time2: int | None = None,
|
727
|
+
norm: Literal["max", "sum", None] = "max",
|
728
|
+
style: (
|
729
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
730
|
+
| type[TreeApproximationTemplate]
|
731
|
+
) = "simple",
|
732
|
+
downsample: int = 2,
|
733
|
+
) -> dict[str, list[str]]:
|
734
|
+
"""
|
735
|
+
Returns the labels or IDs of all the nodes in the subtrees compared.
|
736
|
+
|
737
|
+
Parameters
|
738
|
+
----------
|
739
|
+
n1 : int
|
740
|
+
id of the first node to compare
|
741
|
+
embryo_1 : str
|
742
|
+
the name of the first lineage
|
743
|
+
n2 : int
|
744
|
+
id of the second node to compare
|
745
|
+
embryo_2: str
|
746
|
+
the name of the second lineage
|
747
|
+
end_time1 : int, optional
|
748
|
+
the final time point the comparison algorithm will take into account for the first dataset.
|
749
|
+
If None or not provided all nodes will be taken into account.
|
750
|
+
end_time2 : int, optional
|
751
|
+
the final time point the comparison algorithm will take into account for the first dataset.
|
752
|
+
If None or not provided all nodes will be taken into account.
|
753
|
+
norm : {"max", "sum"}, default="max"
|
754
|
+
The normalization method to use.
|
755
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
756
|
+
Which tree approximation is going to be used for the comparisons.
|
757
|
+
downsample : int, default=2
|
758
|
+
The downsample factor for the downsampled tree approximation.
|
759
|
+
Used only when `style="downsampled"`.
|
760
|
+
|
761
|
+
Returns
|
762
|
+
-------
|
763
|
+
dict mapping str to lists of str
|
764
|
+
- 'matched' The labels of the matched nodes of the alignment.
|
765
|
+
- 'unmatched' The labels of the unmatched nodes of the alginment.
|
766
|
+
"""
|
767
|
+
|
768
|
+
parameters = (
|
769
|
+
(end_time1, end_time2),
|
770
|
+
convert_style_to_number(style, downsample),
|
771
|
+
)
|
772
|
+
n1_embryo, n2_embryo = sorted(
|
773
|
+
((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
|
774
|
+
)
|
775
|
+
self._comparisons.setdefault(parameters, {})
|
776
|
+
if self._comparisons[parameters].get((n1, n2)):
|
777
|
+
tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
|
778
|
+
else:
|
779
|
+
tmp = self.__cross_lineage_edit_backtrace(
|
780
|
+
n1,
|
781
|
+
embryo_1,
|
782
|
+
n2,
|
783
|
+
embryo_2,
|
784
|
+
end_time1,
|
785
|
+
end_time2,
|
786
|
+
style,
|
787
|
+
norm,
|
788
|
+
downsample,
|
789
|
+
)
|
790
|
+
btrc = tmp["alignment"]
|
791
|
+
tree1, tree2 = tmp["trees"]
|
792
|
+
_, times1 = tree1.tree
|
793
|
+
_, times2 = tree2.tree
|
794
|
+
(
|
795
|
+
*_,
|
796
|
+
corres1,
|
797
|
+
) = tree1.edist
|
798
|
+
(
|
799
|
+
*_,
|
800
|
+
corres2,
|
801
|
+
) = tree2.edist
|
802
|
+
if norm not in self.norm_dict:
|
803
|
+
raise Warning(
|
804
|
+
"Select a viable normalization method (max, sum, None)"
|
805
|
+
)
|
806
|
+
matched = []
|
807
|
+
unmatched = []
|
808
|
+
if style not in ("full", "downsampled"):
|
809
|
+
for m in btrc:
|
810
|
+
if m._left != -1 and m._right != -1:
|
811
|
+
cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
|
812
|
+
if len(cyc1) > 1:
|
813
|
+
node_1, *_ = cyc1
|
814
|
+
elif len(cyc1) == 1:
|
815
|
+
node_1 = cyc1.pop()
|
816
|
+
|
817
|
+
cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
|
818
|
+
if len(cyc2) > 1:
|
819
|
+
node_2, *_ = cyc2
|
820
|
+
|
821
|
+
elif len(cyc2) == 1:
|
822
|
+
node_2 = cyc2.pop()
|
823
|
+
|
824
|
+
matched.append(
|
825
|
+
(
|
826
|
+
tree1.lT.labels.get(node_1, node_1),
|
827
|
+
tree2.lT.labels.get(node_2, node_2),
|
828
|
+
)
|
829
|
+
)
|
830
|
+
else:
|
831
|
+
if m._left != -1:
|
832
|
+
tmp_node = tree1.lT.get_chain_of_node(
|
833
|
+
corres1.get(m._left, "-")
|
834
|
+
)[0]
|
835
|
+
node_1 = (
|
836
|
+
tree1.lT.labels.get(tmp_node, tmp_node),
|
837
|
+
tree1.lT.name,
|
838
|
+
)
|
839
|
+
else:
|
840
|
+
tmp_node = tree2.lT.get_chain_of_node(
|
841
|
+
corres2.get(m._right, "-")
|
842
|
+
)[0]
|
843
|
+
node_1 = (
|
844
|
+
tree2.lT.labels.get(tmp_node, tmp_node),
|
845
|
+
tree2.lT.name,
|
846
|
+
)
|
847
|
+
unmatched.append(node_1)
|
848
|
+
else:
|
849
|
+
for m in btrc:
|
850
|
+
if m._left != -1 and m._right != -1:
|
851
|
+
node_1 = corres1[m._left]
|
852
|
+
node_2 = corres2[m._right]
|
853
|
+
matched.append(
|
854
|
+
(
|
855
|
+
tree1.lT.labels.get(node_1, node_1),
|
856
|
+
tree2.lT.labels.get(node_2, node_2),
|
857
|
+
)
|
858
|
+
)
|
859
|
+
else:
|
860
|
+
if m._left != -1:
|
861
|
+
tmp_node = tree1.lT.get_chain_of_node(
|
862
|
+
corres1.get(m._left, "-")
|
863
|
+
)[0]
|
864
|
+
node_1 = (
|
865
|
+
tree1.lT.labels.get(tmp_node, tmp_node),
|
866
|
+
tree1.lT.name,
|
867
|
+
)
|
868
|
+
else:
|
869
|
+
tmp_node = tree2.lT.get_chain_of_node(
|
870
|
+
corres2.get(m._right, "-")
|
871
|
+
)[0]
|
872
|
+
node_1 = (
|
873
|
+
tree2.lT.labels.get(tmp_node, tmp_node),
|
874
|
+
tree2.lT.name,
|
875
|
+
)
|
876
|
+
unmatched.append(node_1)
|
877
|
+
return {"matched": matched, "unmatched": unmatched}
|