LineageTree 1.8.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,233 @@
1
+ import unittest
2
+
3
+ import edist.uted as uted
4
+
5
+
6
+ # the purpose of this test is to reproduce results mentioned in Guignard et al. (2020) with the tools used
7
+ # for the results in the publication
8
+ class TestTreex(unittest.TestCase):
9
+ # simple test case to test the implementation of the Zhang edit distance
10
+ def test_edist_zhang_edit_distance_tree1_tree2(self):
11
+ tree1_nodes = ["a", "b", "c"]
12
+ tree1_adj = [[1, 2], [], []]
13
+ tree1_attributes = {"a": 20, "b": 10, "c": 30}
14
+ tree2_nodes = ["a", "b", "c"]
15
+ tree2_adj = [[1, 2], [], []]
16
+ tree2_attributes = {"a": 30, "b": 10, "c": 20}
17
+
18
+ def local_cost(x, y):
19
+ if x is None and y is None:
20
+ return 0
21
+ elif x is None:
22
+ return tree2_attributes[y]
23
+ elif y is None:
24
+ return tree1_attributes[x]
25
+ return abs(tree1_attributes[x] - tree2_attributes[y])
26
+
27
+ edist_result = uted.uted(
28
+ tree1_nodes, tree1_adj, tree2_nodes, tree2_adj, local_cost
29
+ )
30
+ self.assertEqual(20, edist_result)
31
+
32
+ # Guignard et al. (2020) Fig. S23
33
+ # https://www.science.org/doi/suppl/10.1126/science.aar5663/suppl_file/aar5663_guignard_sm.pdf"
34
+ def test_edist_zhang_edit_distance_tree_guignard_t1_tree_guignard_t2(self):
35
+ t1_nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8]
36
+ t1_adj = [[1, 4], [2, 3], [], [], [5, 6], [], [7, 8], [], []]
37
+ t1_attributes = {
38
+ 0: 1,
39
+ 1: 1,
40
+ 2: 1,
41
+ 3: 1,
42
+ 4: 9,
43
+ 5: 10,
44
+ 6: 10,
45
+ 7: 10,
46
+ 8: 10,
47
+ }
48
+ t2_nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8]
49
+ t2_adj = [[1, 6], [2, 5], [3, 4], [], [], [], [7, 8], [], []]
50
+ t2_attributes = {
51
+ 0: 1,
52
+ 1: 1,
53
+ 2: 2,
54
+ 3: 1,
55
+ 4: 1,
56
+ 5: 1,
57
+ 6: 10,
58
+ 7: 10,
59
+ 8: 10,
60
+ }
61
+
62
+ def local_cost(t1, t2):
63
+ if t1 is None and t2 is None:
64
+ return 0
65
+ elif t1 is None:
66
+ return t2_attributes[t2]
67
+ elif t2 is None:
68
+ return t1_attributes[t1]
69
+ return abs(t1_attributes[t1] - t2_attributes[t2])
70
+
71
+ def local_cost_normalized(t1, t2):
72
+ if t1 is None and t2 is None:
73
+ return 0
74
+ elif t1 is None or t2 is None or t2 is None:
75
+ return 1
76
+ return abs(t1_attributes[t1] - t2_attributes[t2]) / (
77
+ t1_attributes[t1] + t2_attributes[t2]
78
+ )
79
+
80
+ self.assertEqual(
81
+ 22, uted.uted(t1_nodes, t1_adj, t2_nodes, t2_adj, local_cost)
82
+ )
83
+ # NB: the publication does not illustrate the optimal tree edit distance on purpose,
84
+ # because the primary goal of the figure s23 is to explain all possible edit operations in on figure
85
+ self.assertEqual(
86
+ 4 / 90,
87
+ uted.uted(
88
+ t1_nodes, t1_adj, t2_nodes, t2_adj, local_cost_normalized
89
+ )
90
+ / (sum(t1_attributes.values()) + sum(t2_attributes.values())),
91
+ )
92
+
93
+ # a8.0007* of Pm01
94
+ # a8.0008* of Pm01
95
+ # <a href="https://figshare.com/projects/Phallusiamammillata_embryonic_development/64301">Phallusia mammillata
96
+ # embryonic development data</a>
97
+ def test_edist_uted_Pm01a80007_a80008(self):
98
+ t_a80007_nodes = [
99
+ 0,
100
+ 1,
101
+ 2,
102
+ 3,
103
+ 4,
104
+ 5,
105
+ 6,
106
+ 7,
107
+ 8,
108
+ 9,
109
+ 10,
110
+ 11,
111
+ 12,
112
+ 13,
113
+ 14,
114
+ 15,
115
+ 16,
116
+ ]
117
+ t_a80007_adj = [
118
+ [1, 8],
119
+ [2, 5],
120
+ [3, 4],
121
+ [],
122
+ [],
123
+ [6, 7],
124
+ [],
125
+ [],
126
+ [9, 12],
127
+ [10, 11],
128
+ [],
129
+ [],
130
+ [13, 16],
131
+ [14, 15],
132
+ [],
133
+ [],
134
+ [],
135
+ ]
136
+ t_a80007_attributes = {
137
+ 0: 36,
138
+ 1: 56,
139
+ 2: 72,
140
+ 3: 6,
141
+ 4: 6,
142
+ 5: 66,
143
+ 6: 12,
144
+ 7: 12,
145
+ 8: 36,
146
+ 9: 49,
147
+ 10: 49,
148
+ 11: 49,
149
+ 12: 46,
150
+ 13: 50,
151
+ 14: 2,
152
+ 15: 2,
153
+ 16: 52,
154
+ }
155
+ t_a80008_nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
156
+ t_a80008_adj = [
157
+ [1, 8],
158
+ [2, 5],
159
+ [3, 4],
160
+ [],
161
+ [],
162
+ [6, 7],
163
+ [],
164
+ [],
165
+ [9, 12],
166
+ [10, 11],
167
+ [],
168
+ [],
169
+ [13, 14],
170
+ [],
171
+ [],
172
+ ]
173
+ t_a80008_attributes = {
174
+ 0: 38,
175
+ 1: 39,
176
+ 2: 45,
177
+ 3: 48,
178
+ 4: 48,
179
+ 5: 50,
180
+ 6: 43,
181
+ 7: 43,
182
+ 8: 46,
183
+ 9: 66,
184
+ 10: 20,
185
+ 11: 20,
186
+ 12: 66,
187
+ 13: 20,
188
+ 14: 20,
189
+ }
190
+
191
+ def local_cost(t1, t2):
192
+ if t1 is None and t2 is None:
193
+ return 0
194
+ elif t1 is None:
195
+ return t_a80008_attributes[t2]
196
+ elif t2 is None:
197
+ return t_a80007_attributes[t1]
198
+ return abs(t_a80007_attributes[t1] - t_a80008_attributes[t2])
199
+
200
+ def local_cost_normalized(t1, t2):
201
+ if t1 is None and t2 is None:
202
+ return 0
203
+ elif t1 is None or t2 is None:
204
+ return 1
205
+ return abs(t_a80007_attributes[t1] - t_a80008_attributes[t2]) / (
206
+ t_a80007_attributes[t1] + t_a80008_attributes[t2]
207
+ )
208
+
209
+ self.assertEqual(
210
+ 89,
211
+ uted.uted(
212
+ t_a80007_nodes,
213
+ t_a80007_adj,
214
+ t_a80008_nodes,
215
+ t_a80008_adj,
216
+ local_cost,
217
+ ),
218
+ )
219
+ # ~0.0033d, // NB: the publication says this should be 0.04d (cf. Fig 3B)
220
+ self.assertEqual(
221
+ 3.9974005474699665 / 1213,
222
+ uted.uted(
223
+ t_a80007_nodes,
224
+ t_a80007_adj,
225
+ t_a80008_nodes,
226
+ t_a80008_adj,
227
+ local_cost_normalized,
228
+ )
229
+ / (
230
+ sum(t_a80007_attributes.values())
231
+ + sum(t_a80008_attributes.values())
232
+ ),
233
+ )
@@ -0,0 +1,488 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from enum import Enum
4
+
5
+ import numpy as np
6
+
7
+ from LineageTree import lineageTree
8
+
9
+
10
+ class TreeApproximationTemplate(ABC):
11
+ """Template class to produce different tree styles to comapare lineageTrees.
12
+ To add a new style you need to inherit this class or one of its children
13
+ and add them to the tree_style enum, or use it immediately on the function called.
14
+ The main products of this class are:
15
+ - tree constructor (get_tree) that produces one dictionary that contains
16
+ arbitary unique labels and one dictionary that contains the duration of each node.
17
+ - delta function: A function that handles the cost of comparing nodes to each other.
18
+ - normalization function, a function that returns the length of the tree or any interger.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ lT: lineageTree,
24
+ root: int,
25
+ downsample: int | None = None,
26
+ end_time: int | None = None,
27
+ time_scale: int = 1,
28
+ ):
29
+ self.lT: lineageTree = lT
30
+ self.internal_ids = max(self.lT.nodes)
31
+ self.root: int = root
32
+ self.downsample: int = downsample
33
+ self.end_time: int = end_time if end_time else self.lT.t_e
34
+ self.time_scale: int = int(time_scale) if time_scale else 1
35
+ if time_scale <= 0:
36
+ raise Exception("Please use a valid time_scale (Larger than 0)")
37
+ self.tree: tuple = self.get_tree()
38
+ self.edist = self._edist_format(self.tree[0])
39
+
40
+ def get_next_id(self) -> int:
41
+ self.internal_ids += 1
42
+ return self.internal_ids
43
+
44
+ @staticmethod
45
+ @abstractmethod
46
+ def handle_resolutions(
47
+ time_resolution1: float | int,
48
+ time_resolution2: float | int,
49
+ gcd: int,
50
+ downsample: int,
51
+ ) -> tuple[int | float, int | float]:
52
+ """Handle different time resolutions.
53
+
54
+ Parameters
55
+ ----------
56
+ time_resolution1 : int or float
57
+ Time resolution of the first dataset. (Extracted from lT._time_resolution)
58
+ time_resolution2 : int or float
59
+ Time resolution of the second dataset. (Extracted from lT._time_resolution)
60
+
61
+ Returns
62
+ -------
63
+ int or float
64
+ The time resolution fix for the first dataset
65
+ int or float
66
+ The time resolution fix for the second dataset
67
+ """
68
+
69
+ @abstractmethod
70
+ def get_tree(self) -> tuple[dict, dict]:
71
+ """
72
+ Get a tree version of the tree spawned by the node `r`
73
+
74
+ Returns
75
+ -------
76
+ dict mapping an int to a list of int
77
+ an adjacency dictionnary where the ids are the ids of the
78
+ cells in the original tree at their first time point
79
+ (except for the cell `r` if it was not the first time point).
80
+ dict mapping an int to a float
81
+ life time duration of the key cell `m`
82
+ """
83
+
84
+ @abstractmethod
85
+ def delta(
86
+ self,
87
+ x: int,
88
+ y: int,
89
+ corres1: dict[int, int],
90
+ corres2: dict[int, int],
91
+ times1: dict[int, float],
92
+ times2: dict[int, float],
93
+ ) -> int | float:
94
+ """The distance of two nodes inside a tree. Behaves like a staticmethod.
95
+ The corres1/2 and time1/2 should always be provided and will be handled accordingly by the specific
96
+ delta of each tree style.
97
+
98
+ Parameters
99
+ ----------
100
+ x : int
101
+ The first node to compare, takes the names provided by the edist.
102
+ y : int
103
+ The second node to compare, takes the names provided by the edist
104
+ corres1 : dict
105
+ Dictionary mapping node1 ids to the corresponding id in the original tree.
106
+ corres2 : dict
107
+ Dictionary mapping node2 ids to the corresponding id in the original tree.
108
+ times1 : dict
109
+ The dictionary of the chain lengths of the tree that n1 is spawned from.
110
+ times2 : dict
111
+ The dictionary of the chain lengths of the tree that n2 is spawned from.
112
+
113
+ Returns
114
+ -------
115
+ int or float
116
+ The distance between 'x' and 'y'.
117
+ """
118
+ if x is None and y is None:
119
+ return 0
120
+ if x is None:
121
+ return times2[corres2[y]]
122
+ if y is None:
123
+ return times1[corres1[x]]
124
+ len_x = times1[corres1[x]]
125
+ len_y = times2[corres2[y]]
126
+ return np.abs(len_x - len_y)
127
+
128
+ @abstractmethod
129
+ def get_norm(self, root: int) -> int | float:
130
+ """
131
+ Returns the valid value for normalizing the edit distance.
132
+
133
+ Parameters
134
+ ----------
135
+ root : int
136
+ The starting node of the subtree.
137
+
138
+ Returns
139
+ -------
140
+ int or float
141
+ The number of nodes of each tree according to each style, or the sum of the length of all the nodes in a tree.
142
+ """
143
+
144
+ def _edist_format(
145
+ self, adj_dict: dict
146
+ ) -> tuple[list, list[list], dict[int, int]]:
147
+ inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
148
+ roots = set(adj_dict).difference(inv_adj)
149
+ nid2list = {}
150
+ list2nid = {}
151
+ nodes = []
152
+ adj_list = []
153
+ curr_id = 0
154
+ for r in roots:
155
+ to_do = [r]
156
+ while to_do:
157
+ curr = to_do.pop(0)
158
+ nid2list[curr] = curr_id
159
+ list2nid[curr_id] = curr
160
+ nodes.append(curr_id)
161
+ to_do = adj_dict.get(curr, []) + to_do
162
+ curr_id += 1
163
+ adj_list = [
164
+ [nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
165
+ for _id in nodes
166
+ ]
167
+ return nodes, adj_list, list2nid
168
+
169
+
170
+ class mini_tree(TreeApproximationTemplate):
171
+ """Each branch is converted to a node of length 1, it is useful for comparing synchronous developing cells, extremely fast.
172
+ Mainly used for testing.
173
+ """
174
+
175
+ def __init__(self, **kwargs):
176
+ super().__init__(**kwargs)
177
+
178
+ @staticmethod
179
+ def handle_resolutions(
180
+ time_resolution1: float | int,
181
+ time_resolution2: float | int,
182
+ gcd,
183
+ downsample: int,
184
+ ) -> tuple[int | float, int | float]:
185
+ return (1, 1)
186
+
187
+ def get_tree(self):
188
+ if self.end_time is None:
189
+ self.end_time = self.lT.t_e
190
+ out_dict = {}
191
+ self.times = {}
192
+ to_do = [self.root]
193
+ while to_do:
194
+ current = to_do.pop()
195
+ cycle = np.array(self.lT.get_successors(current))
196
+ cycle_times = np.array([self.lT.time[c] for c in cycle])
197
+ cycle = cycle[cycle_times <= self.end_time]
198
+ if cycle.size:
199
+ _next = list(self.lT.successor[cycle[-1]])
200
+ if 1 < len(_next):
201
+ out_dict[current] = _next
202
+ to_do.extend(_next)
203
+ else:
204
+ out_dict[current] = []
205
+ self.length = len(out_dict)
206
+ return out_dict, None
207
+
208
+ def get_norm(self, root) -> int:
209
+ return len(
210
+ self.lT.get_all_chains_of_subtree(root, end_time=self.end_time)
211
+ )
212
+
213
+ def _edist_format(self, adj_dict: dict):
214
+ return super()._edist_format(adj_dict)
215
+
216
+ def delta(self, x, y, corres1, corres2, times1, times2):
217
+ if x is None and y is None:
218
+ return 0
219
+ if x is None:
220
+ return 1
221
+ if y is None:
222
+ return 1
223
+ return 0
224
+
225
+
226
+ class simple_tree(TreeApproximationTemplate):
227
+ """Each branch is converted to one node with length the same as the life cycle of the cell.
228
+ This method is fast, but imprecise, especialy for small trees (recommended height of the trees should be 100 at least).
229
+ Use with CAUTION.
230
+ """
231
+
232
+ def __init__(self, **kwargs):
233
+ super().__init__(**kwargs)
234
+
235
+ @staticmethod
236
+ def handle_resolutions(
237
+ time_resolution1: float | int,
238
+ time_resolution2: float | int,
239
+ gcd: int,
240
+ downsample: int,
241
+ ) -> tuple[int | float, int | float]:
242
+ return (time_resolution1, time_resolution2)
243
+
244
+ def get_tree(self) -> tuple[dict, dict]:
245
+ if self.end_time is None:
246
+ self.end_time = self.lT.t_e
247
+ out_dict = {}
248
+ self.times = {}
249
+ to_do = [self.root]
250
+ while to_do:
251
+ current = to_do.pop()
252
+ cycle = np.array(self.lT.get_successors(current))
253
+ cycle_times = np.array([self.lT.time[c] for c in cycle])
254
+ cycle = cycle[cycle_times <= self.end_time]
255
+ if cycle.size:
256
+ _next = list(self.lT.successor[cycle[-1]])
257
+ if len(_next) > 1 and self.lT.time[cycle[-1]] < self.end_time:
258
+ out_dict[current] = _next
259
+ to_do.extend(_next)
260
+ else:
261
+ out_dict[current] = []
262
+ self.times[current] = len(cycle) * self.time_scale
263
+ return out_dict, self.times
264
+
265
+ def delta(self, x, y, corres1, corres2, times1, times2):
266
+ return super().delta(x, y, corres1, corres2, times1, times2)
267
+
268
+ def get_norm(self, root) -> int:
269
+ return (
270
+ len(self.lT.get_subtree_nodes(root, end_time=self.end_time))
271
+ * self.time_scale
272
+ )
273
+
274
+
275
+ class downsample_tree(TreeApproximationTemplate):
276
+ """Downsamples a tree so every n nodes are being used as one."""
277
+
278
+ def __init__(self, **kwargs):
279
+ super().__init__(**kwargs)
280
+ if self.downsample == 0:
281
+ raise Exception("Please use a valid downsampling rate")
282
+ if self.downsample == 1:
283
+ warnings.warn(
284
+ "Downsampling rate of 1 is identical to the full tree.",
285
+ stacklevel=1,
286
+ )
287
+
288
+ @staticmethod
289
+ def handle_resolutions(
290
+ time_resolution1: float | int,
291
+ time_resolution2: float | int,
292
+ gcd: int,
293
+ downsample: int,
294
+ ) -> tuple[int | float, int | float]:
295
+ lcm = time_resolution1 * time_resolution2 / gcd
296
+ if downsample % (lcm / 10) != 0:
297
+ raise Exception(
298
+ f"Use a valid downsampling rate (multiple of {lcm/10})"
299
+ )
300
+ return (
301
+ downsample / (time_resolution2 / 10),
302
+ downsample / (time_resolution1 / 10),
303
+ )
304
+
305
+ def get_tree(self) -> tuple[dict, dict]:
306
+ self.out_dict = {}
307
+ self.times = {}
308
+ to_do = [self.root]
309
+ while to_do:
310
+ current = to_do.pop()
311
+ _next = self.lT.nodes_at_t(
312
+ r=current,
313
+ t=self.lT.time[current] + (self.downsample / self.time_scale),
314
+ )
315
+ if _next == [current]:
316
+ _next = None
317
+ if _next and self.lT.time[_next[0]] <= self.end_time:
318
+ self.out_dict[current] = _next
319
+ to_do.extend(_next)
320
+ else:
321
+ self.out_dict[current] = []
322
+ self.times[current] = 1 # self.downsample
323
+ return self.out_dict, self.times
324
+
325
+ def get_norm(self, root) -> float: ###Temporary###
326
+ return len(
327
+ downsample_tree(
328
+ lT=self.lT,
329
+ root=root,
330
+ downsample=self.downsample,
331
+ end_time=self.end_time,
332
+ time_scale=self.time_scale,
333
+ ).out_dict
334
+ )
335
+
336
+ def delta(self, x, y, corres1, corres2, times1, times2):
337
+ if x is None and y is None:
338
+ return 0
339
+ if x is None:
340
+ return 1
341
+ if y is None:
342
+ return 1
343
+ return 0
344
+
345
+
346
+ class normalized_simple_tree(simple_tree):
347
+ def __init__(self, **kwargs):
348
+ super().__init__(**kwargs)
349
+
350
+ def delta(self, x, y, corres1, corres2, times1, times2):
351
+ if x is None and y is None:
352
+ return 0
353
+ if x is None:
354
+ return 1
355
+ if y is None:
356
+ return 1
357
+ return abs(times1[corres1[x]] - times2[corres2[y]]) / (
358
+ times1[corres1[x]] + times2[corres2[y]]
359
+ )
360
+
361
+ def get_norm(self, root) -> int:
362
+ return len(
363
+ self.lT.get_all_chains_of_subtree(root, end_time=self.end_time)
364
+ )
365
+
366
+
367
+ class full_tree(TreeApproximationTemplate):
368
+ """No approximations the whole tree is used here. Perfect accuracy, but heavy on ram and speed.
369
+ Not recommended to use on napari.
370
+
371
+ """
372
+
373
+ def _edist_format(
374
+ self, adj_dict: dict
375
+ ) -> tuple[list, list[list], dict[int, int]]:
376
+ """Formating the custom tree style to the format needed by edist.
377
+ .. warning:: Modifying this function might break your code.
378
+
379
+ Parameters
380
+ ----------
381
+ adj_dict : dict
382
+ The adjacency dictionary produced by 'get_tree'
383
+
384
+ Returns
385
+ -------
386
+ list[int]
387
+ The list of the new nodes to be used for edist
388
+ list[list]
389
+ The adjacency list of these nodes
390
+ dict[int,int]
391
+ The correspondance between the nodes used in edist and lineageTree
392
+ """
393
+ inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
394
+ roots = set(adj_dict).difference(inv_adj)
395
+ nid2list = {}
396
+ list2nid = {}
397
+ nodes = []
398
+ adj_list = []
399
+ curr_id = 0
400
+ to_update = {}
401
+ for r in roots:
402
+ to_do = [r]
403
+ while to_do:
404
+ curr = to_do.pop(0)
405
+ nid2list[curr] = curr_id
406
+ list2nid[curr_id] = curr
407
+ if curr in self.corres_added_nodes:
408
+ to_update[curr_id] = self.corres_added_nodes[curr]
409
+ nodes.append(curr_id)
410
+ to_do = adj_dict.get(curr, []) + to_do
411
+ curr_id += 1
412
+ adj_list = [
413
+ [nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
414
+ for _id in nodes
415
+ ]
416
+ list2nid.update(to_update)
417
+ return nodes, adj_list, list2nid
418
+
419
+ @staticmethod
420
+ def handle_resolutions(
421
+ time_resolution1: float | int,
422
+ time_resolution2: float | int,
423
+ gcd: int,
424
+ downsample: int,
425
+ ) -> tuple[int | float, int | float]:
426
+ if time_resolution1 == time_resolution2:
427
+ return (1, 1)
428
+ lcm = time_resolution1 * time_resolution2 / gcd
429
+ return (
430
+ lcm / time_resolution2,
431
+ lcm / time_resolution1,
432
+ )
433
+
434
+ def get_tree(self) -> tuple[dict, dict]:
435
+ self.out_dict = {}
436
+ self.times = {}
437
+ self.corres_added_nodes = {}
438
+ to_do = [self.root]
439
+ while to_do:
440
+ current = to_do.pop()
441
+ _next = list(self.lT.successor[current])
442
+ if _next and self.lT.time[_next[0]] <= self.end_time:
443
+ if self.time_scale > 1:
444
+ tmp_cur = current
445
+ for _ in range(self.time_scale - 1):
446
+ next_id = self.get_next_id()
447
+ self.out_dict[current] = [next_id]
448
+ current = int(next_id)
449
+ self.corres_added_nodes[current] = tmp_cur
450
+ self.out_dict[current] = _next
451
+ to_do.extend(_next)
452
+ else:
453
+ if self.time_scale > 1:
454
+ tmp_cur = current
455
+ for _ in range(self.time_scale - 1):
456
+ next_id = self.get_next_id()
457
+ self.out_dict[current] = [next_id]
458
+ current = int(next_id)
459
+ self.corres_added_nodes[current] = tmp_cur
460
+ self.out_dict[current] = []
461
+ return self.out_dict, self.times
462
+
463
+ def get_norm(self, root) -> int:
464
+ return (
465
+ len(self.lT.get_subtree_nodes(root, end_time=self.end_time))
466
+ * self.time_scale
467
+ )
468
+
469
+ def delta(self, x, y, corres1, corres2, times1, times2):
470
+ if x is None and y is None:
471
+ return 0
472
+ if x is None:
473
+ return 1
474
+ if y is None:
475
+ return 1
476
+ return 0
477
+
478
+
479
+ class tree_style(Enum):
480
+ mini = mini_tree
481
+ simple = simple_tree
482
+ normalized_simple = normalized_simple_tree
483
+ downsampled = downsample_tree
484
+ full = full_tree
485
+
486
+ @classmethod
487
+ def list_names(self):
488
+ return [style.name for style in self]