LineageTree 1.4.4__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +3 -2
- LineageTree/lineageTree.py +1340 -260
- LineageTree/lineageTreeManager.py +170 -0
- LineageTree/tree_styles.py +305 -0
- LineageTree/utils.py +211 -0
- {LineageTree-1.4.4.dist-info → LineageTree-1.5.0.dist-info}/METADATA +21 -26
- LineageTree-1.5.0.dist-info/RECORD +10 -0
- {LineageTree-1.4.4.dist-info → LineageTree-1.5.0.dist-info}/WHEEL +1 -1
- LineageTree-1.4.4.dist-info/RECORD +0 -7
- {LineageTree-1.4.4.dist-info → LineageTree-1.5.0.dist-info}/LICENSE +0 -0
- {LineageTree-1.4.4.dist-info → LineageTree-1.5.0.dist-info}/top_level.txt +0 -0
LineageTree/lineageTree.py
CHANGED
@@ -2,35 +2,244 @@
|
|
2
2
|
# This file is subject to the terms and conditions defined in
|
3
3
|
# file 'LICENCE', which is part of this source code package.
|
4
4
|
# Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
|
5
|
-
|
6
5
|
import csv
|
7
6
|
import os
|
8
7
|
import pickle as pkl
|
9
8
|
import struct
|
9
|
+
import warnings
|
10
10
|
import xml.etree.ElementTree as ET
|
11
|
+
from collections.abc import Iterable
|
11
12
|
from functools import partial
|
12
13
|
from itertools import combinations
|
13
14
|
from numbers import Number
|
14
|
-
from
|
15
|
-
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import TextIO, Union
|
17
|
+
|
18
|
+
from .tree_styles import tree_style
|
19
|
+
|
20
|
+
try:
|
21
|
+
from edist import uted
|
22
|
+
except ImportError:
|
23
|
+
warnings.warn(
|
24
|
+
"No edist installed therefore you will not be able to compute the tree edit distance."
|
25
|
+
)
|
26
|
+
import matplotlib.pyplot as plt
|
27
|
+
import networkx as nx
|
16
28
|
import numpy as np
|
17
|
-
from scipy.
|
29
|
+
from scipy.interpolate import InterpolatedUnivariateSpline
|
30
|
+
from scipy.spatial import Delaunay, distance
|
18
31
|
from scipy.spatial import cKDTree as KDTree
|
19
32
|
|
33
|
+
from .utils import hierarchy_pos, postions_of_nx
|
34
|
+
|
20
35
|
|
21
36
|
class lineageTree:
|
37
|
+
def __eq__(self, other):
|
38
|
+
if isinstance(other, lineageTree):
|
39
|
+
return other.successor == self.successor
|
40
|
+
return False
|
41
|
+
|
22
42
|
def get_next_id(self):
|
23
43
|
"""Computes the next authorized id.
|
24
44
|
|
25
45
|
Returns:
|
26
46
|
int: next authorized id
|
27
47
|
"""
|
48
|
+
if self.max_id == -1 and self.nodes:
|
49
|
+
self.max_id = max(self.nodes)
|
28
50
|
if self.next_id == []:
|
29
51
|
self.max_id += 1
|
30
52
|
return self.max_id
|
31
53
|
else:
|
32
54
|
return self.next_id.pop()
|
33
55
|
|
56
|
+
def complete_lineage(self, nodes: Union[int, set] = None):
|
57
|
+
"""Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
|
58
|
+
for tree edit distance algorithms.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
|
62
|
+
"""
|
63
|
+
if nodes is None:
|
64
|
+
nodes = set(self.roots)
|
65
|
+
elif isinstance(nodes, int):
|
66
|
+
nodes = {nodes}
|
67
|
+
for node in nodes:
|
68
|
+
sub = set(self.get_sub_tree(node))
|
69
|
+
specific_leaves = sub.intersection(self.leaves)
|
70
|
+
for leaf in specific_leaves:
|
71
|
+
self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
|
72
|
+
|
73
|
+
###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
|
74
|
+
def add_branch(
|
75
|
+
self,
|
76
|
+
pred: int,
|
77
|
+
length: int,
|
78
|
+
move_timepoints: bool = True,
|
79
|
+
pos: Union[callable, None] = None,
|
80
|
+
reverse: bool = False,
|
81
|
+
):
|
82
|
+
"""Adds a branch of specific length to a node either as a successor or as a predecessor.
|
83
|
+
If it is placed on top of a tree all the nodes will move timepoints #length down.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
pred (int): Id of the successor (predecessor if reverse is False)
|
87
|
+
length (int): The length of the new branch.
|
88
|
+
pos (np.ndarray, optional): The new position of the branch. Defaults to None.
|
89
|
+
move_timepoints (bool): Moves the ti Important only if reverse= True
|
90
|
+
reverese (bool): If reverse will add a successor branch instead of a predecessor branch
|
91
|
+
Returns:
|
92
|
+
(int): Id of the first node of the sublineage.
|
93
|
+
"""
|
94
|
+
if length == 0:
|
95
|
+
return pred
|
96
|
+
if self.predecessor.get(pred) and not reverse:
|
97
|
+
raise Warning("Cannot add 2 predecessors to a node")
|
98
|
+
time = self.time[pred]
|
99
|
+
original = pred
|
100
|
+
if not reverse:
|
101
|
+
if move_timepoints:
|
102
|
+
nodes_to_move = set(self.get_sub_tree(pred))
|
103
|
+
new_times = {
|
104
|
+
node: self.time[node] + length for node in nodes_to_move
|
105
|
+
}
|
106
|
+
for node in nodes_to_move:
|
107
|
+
old_time = self.time[node]
|
108
|
+
self.time_nodes[old_time].remove(node)
|
109
|
+
self.time_nodes.setdefault(old_time + length, set()).add(
|
110
|
+
node
|
111
|
+
)
|
112
|
+
self.time.update(new_times)
|
113
|
+
for t in range(length - 1, -1, -1):
|
114
|
+
_next = self.add_node(
|
115
|
+
time + t,
|
116
|
+
succ=pred,
|
117
|
+
pos=self.pos[original],
|
118
|
+
reverse=True,
|
119
|
+
)
|
120
|
+
pred = _next
|
121
|
+
else:
|
122
|
+
for t in range(length):
|
123
|
+
_next = self.add_node(
|
124
|
+
time - t,
|
125
|
+
succ=pred,
|
126
|
+
pos=self.pos[original],
|
127
|
+
reverse=True,
|
128
|
+
)
|
129
|
+
pred = _next
|
130
|
+
else:
|
131
|
+
for t in range(length):
|
132
|
+
_next = self.add_node(
|
133
|
+
time + t, succ=pred, pos=self.pos[original], reverse=False
|
134
|
+
)
|
135
|
+
pred = _next
|
136
|
+
self.labels[pred] = "New branch"
|
137
|
+
if self.time[pred] == self.t_b:
|
138
|
+
self.roots.add(pred)
|
139
|
+
self.labels[pred] = "New branch"
|
140
|
+
if original in self.roots and reverse is True:
|
141
|
+
self.roots.add(pred)
|
142
|
+
self.labels[pred] = "New branch"
|
143
|
+
self.roots.remove(original)
|
144
|
+
self.labels.pop(original, -1)
|
145
|
+
self.t_e = max(self.time_nodes)
|
146
|
+
return pred
|
147
|
+
|
148
|
+
def cut_tree(self, root):
|
149
|
+
"""It transforms a lineage that has at least 2 divisions into 2 independent lineages,
|
150
|
+
that spawn from the time point of the first node. (splits a tree into 2)
|
151
|
+
|
152
|
+
Args:
|
153
|
+
root (int): The id of the node, which will be cut.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
int: The id of the new tree
|
157
|
+
"""
|
158
|
+
cycle = self.get_successors(root)
|
159
|
+
last_cell = cycle[-1]
|
160
|
+
if last_cell in self.successor:
|
161
|
+
new_lT = self.successor[last_cell].pop()
|
162
|
+
self.predecessor.pop(new_lT)
|
163
|
+
label_of_root = self.labels.get(cycle[0], cycle[0])
|
164
|
+
self.labels[cycle[0]] = f"L-Split {label_of_root}"
|
165
|
+
new_tr = self.add_branch(
|
166
|
+
new_lT, len(cycle) + 1, move_timepoints=False
|
167
|
+
)
|
168
|
+
self.roots.add(new_tr)
|
169
|
+
self.labels[new_tr] = f"R-Split {label_of_root}"
|
170
|
+
return new_tr
|
171
|
+
else:
|
172
|
+
raise Warning("No division of the branch")
|
173
|
+
|
174
|
+
def fuse_lineage_tree(
|
175
|
+
self,
|
176
|
+
l1_root: int,
|
177
|
+
l2_root: int,
|
178
|
+
length_l1: int = 0,
|
179
|
+
length_l2: int = 0,
|
180
|
+
length: int = 1,
|
181
|
+
):
|
182
|
+
"""Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
|
183
|
+
first node and the node of the resulting lineage can also be longer.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
l1_root (int): Id of the first root
|
187
|
+
l2_root (int): Id of the second root
|
188
|
+
length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
|
189
|
+
length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
|
190
|
+
length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
int: The id of the root of the new lineage.
|
194
|
+
"""
|
195
|
+
if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
|
196
|
+
raise ValueError("Please select 2 roots.")
|
197
|
+
if self.time[l1_root] != self.time[l2_root]:
|
198
|
+
warnings.warn(
|
199
|
+
"Using lineagetrees that do not exist in the same timepoint. The operation will continue"
|
200
|
+
)
|
201
|
+
new_root1 = self.add_branch(l1_root, length_l1)
|
202
|
+
new_root2 = self.add_branch(l2_root, length_l2)
|
203
|
+
next_root1 = self[new_root1][0]
|
204
|
+
self.remove_nodes(new_root1)
|
205
|
+
self.successor[new_root2].append(next_root1)
|
206
|
+
self.predecessor[next_root1] = [new_root2]
|
207
|
+
new_branch = self.add_branch(new_root2, length)
|
208
|
+
self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
|
209
|
+
return new_branch
|
210
|
+
|
211
|
+
def copy_lineage(self, root):
|
212
|
+
"""
|
213
|
+
Copies the structure of a tree and makes a new with new nodes.
|
214
|
+
Warning does not take into account the predecessor of the root node.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
root (int): The root of the tree to be copied
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
int: The root of the new tree.
|
221
|
+
"""
|
222
|
+
new_nodes = {
|
223
|
+
old_node: self.get_next_id()
|
224
|
+
for old_node in self.get_sub_tree(root)
|
225
|
+
}
|
226
|
+
self.nodes.update(new_nodes.values())
|
227
|
+
for old_node, new_node in new_nodes.items():
|
228
|
+
self.time[new_node] = self.time[old_node]
|
229
|
+
succ = self.successor.get(old_node)
|
230
|
+
if succ:
|
231
|
+
self.successor[new_node] = [new_nodes[n] for n in succ]
|
232
|
+
pred = self.predecessor.get(old_node)
|
233
|
+
if pred:
|
234
|
+
self.predecessor[new_node] = [new_nodes[n] for n in pred]
|
235
|
+
self.pos[new_node] = self.pos[old_node] + 0.5
|
236
|
+
self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
|
237
|
+
new_root = new_nodes[root]
|
238
|
+
self.labels[new_root] = f"Copy of {root}"
|
239
|
+
if self.time[new_root] == 0:
|
240
|
+
self.roots.add(new_root)
|
241
|
+
return new_root
|
242
|
+
|
34
243
|
def add_node(
|
35
244
|
self,
|
36
245
|
t: int = None,
|
@@ -55,109 +264,126 @@ class lineageTree:
|
|
55
264
|
int: id of the new node.
|
56
265
|
"""
|
57
266
|
C_next = self.get_next_id() if nid is None else nid
|
58
|
-
self.time_nodes.setdefault(t,
|
267
|
+
self.time_nodes.setdefault(t, set()).add(C_next)
|
59
268
|
if succ is not None and not reverse:
|
60
269
|
self.successor.setdefault(succ, []).append(C_next)
|
61
270
|
self.predecessor.setdefault(C_next, []).append(succ)
|
62
|
-
self.edges.add((succ, C_next))
|
63
271
|
elif succ is not None:
|
64
272
|
self.predecessor.setdefault(succ, []).append(C_next)
|
65
273
|
self.successor.setdefault(C_next, []).append(succ)
|
66
|
-
self.edges.add((C_next, succ))
|
67
274
|
self.nodes.add(C_next)
|
68
275
|
self.pos[C_next] = pos
|
69
|
-
self.progeny[C_next] = 0
|
70
276
|
self.time[C_next] = t
|
71
277
|
return C_next
|
72
278
|
|
73
|
-
def
|
74
|
-
|
75
|
-
times = {self.time[n] for n in track}
|
76
|
-
for t in times:
|
77
|
-
self.time_nodes[t] = list(
|
78
|
-
set(self.time_nodes[t]).difference(track)
|
79
|
-
)
|
80
|
-
for i, c in enumerate(track):
|
81
|
-
self.pos.pop(c)
|
82
|
-
if i != 0:
|
83
|
-
self.predecessor.pop(c)
|
84
|
-
if i < len(track) - 1:
|
85
|
-
self.successor.pop(c)
|
86
|
-
self.time.pop(c)
|
87
|
-
|
88
|
-
def remove_node(self, c: int) -> tuple:
|
89
|
-
"""Removes a node and update the lineageTree accordingly
|
279
|
+
def remove_nodes(self, group: Union[int, set, list]):
|
280
|
+
"""Removes a group of nodes from the LineageTree
|
90
281
|
|
91
282
|
Args:
|
92
|
-
|
93
|
-
"""
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
def
|
120
|
-
"""
|
121
|
-
|
283
|
+
group (set|list|int): One or more nodes that are to be removed.
|
284
|
+
"""
|
285
|
+
if isinstance(group, int):
|
286
|
+
group = {group}
|
287
|
+
if isinstance(group, list):
|
288
|
+
group = set(group)
|
289
|
+
group = group.intersection(self.nodes)
|
290
|
+
self.nodes.difference_update(group)
|
291
|
+
times = {self.time.pop(n) for n in group}
|
292
|
+
for t in times:
|
293
|
+
self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
|
294
|
+
for node in group:
|
295
|
+
self.pos.pop(node)
|
296
|
+
if self.predecessor.get(node):
|
297
|
+
pred = self.predecessor[node][0]
|
298
|
+
siblings = self.successor.pop(pred, [])
|
299
|
+
if len(siblings) == 2:
|
300
|
+
siblings.remove(node)
|
301
|
+
self.successor[pred] = siblings
|
302
|
+
self.predecessor.pop(node, [])
|
303
|
+
for succ in self.successor.get(node, []):
|
304
|
+
self.predecessor.pop(succ, [])
|
305
|
+
self.successor.pop(node, [])
|
306
|
+
self.labels.pop(node, 0)
|
307
|
+
if node in self.roots:
|
308
|
+
self.roots.remove(node)
|
309
|
+
|
310
|
+
def modify_branch(self, node, new_length):
|
311
|
+
"""Changes the length of a branch, so it adds or removes nodes
|
312
|
+
to make the correct length of the cycle.
|
122
313
|
|
123
314
|
Args:
|
124
|
-
|
125
|
-
|
315
|
+
node (int): Any node of the branch to be modified/
|
316
|
+
new_length (int): The new length of the tree.
|
126
317
|
"""
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
self.
|
147
|
-
|
148
|
-
|
149
|
-
|
318
|
+
if new_length <= 1:
|
319
|
+
warnings.warn("New length should be more than 1")
|
320
|
+
return None
|
321
|
+
cycle = self.get_cycle(node)
|
322
|
+
length = len(cycle)
|
323
|
+
successors = self.successor.get(cycle[-1])
|
324
|
+
if length == 1 and new_length != 1:
|
325
|
+
pred = self.predecessor.pop(node, None)
|
326
|
+
new_node = self.add_branch(
|
327
|
+
node, length=new_length, move_timepoints=True, reverse=False
|
328
|
+
)
|
329
|
+
if pred:
|
330
|
+
self.successor[pred[0]].remove(node)
|
331
|
+
self.successor[pred[0]].append(new_node)
|
332
|
+
elif self.leaves.intersection(cycle) and new_length < length:
|
333
|
+
self.remove_nodes(cycle[new_length:])
|
334
|
+
elif new_length < length:
|
335
|
+
to_remove = length - new_length
|
336
|
+
last_cell = cycle[new_length - 1]
|
337
|
+
subtree = self.get_sub_tree(cycle[-1])[1:]
|
338
|
+
self.remove_nodes(cycle[new_length:])
|
339
|
+
self.successor[last_cell] = successors
|
340
|
+
if successors:
|
341
|
+
for succ in successors:
|
342
|
+
self.predecessor[succ] = [last_cell]
|
343
|
+
for node in subtree:
|
344
|
+
if node not in cycle[new_length - 1 :]:
|
345
|
+
old_time = self.time[node]
|
346
|
+
self.time[node] = old_time - to_remove
|
347
|
+
self.time_nodes[old_time].remove(node)
|
348
|
+
self.time_nodes.setdefault(
|
349
|
+
old_time - to_remove, set()
|
350
|
+
).add(node)
|
351
|
+
elif length < new_length:
|
352
|
+
to_add = new_length - length
|
353
|
+
last_cell = cycle[-1]
|
354
|
+
self.successor.pop(cycle[-2])
|
355
|
+
self.predecessor.pop(last_cell)
|
356
|
+
succ = self.add_branch(
|
357
|
+
last_cell, length=to_add, move_timepoints=True, reverse=False
|
358
|
+
)
|
359
|
+
self.predecessor[succ] = [cycle[-2]]
|
360
|
+
self.successor[cycle[-2]] = [succ]
|
361
|
+
self.time[last_cell] = (
|
362
|
+
self.time[self.predecessor[last_cell][0]] + 1
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
return None
|
150
366
|
|
151
367
|
@property
|
152
368
|
def roots(self):
|
153
369
|
if not hasattr(self, "_roots"):
|
154
|
-
self._roots = set(self.
|
370
|
+
self._roots = set(self.nodes).difference(self.predecessor)
|
155
371
|
return self._roots
|
156
372
|
|
373
|
+
@property
|
374
|
+
def edges(self):
|
375
|
+
return {(k, vi) for k, v in self.successor.items() for vi in v}
|
376
|
+
|
157
377
|
@property
|
158
378
|
def leaves(self):
|
159
379
|
return set(self.predecessor).difference(self.successor)
|
160
380
|
|
381
|
+
@property
|
382
|
+
def labels(self):
|
383
|
+
if not hasattr(self, "_labels"):
|
384
|
+
self._labels = {i: "Unlabeled" for i in self.roots}
|
385
|
+
return self._labels
|
386
|
+
|
161
387
|
def _write_header_am(self, f: TextIO, nb_points: int, length: int):
|
162
388
|
"""Header for Amira .am files"""
|
163
389
|
f.write("# AmiraMesh 3D ASCII 2.0\n")
|
@@ -470,7 +696,7 @@ class lineageTree:
|
|
470
696
|
stroke=svgwrite.rgb(0, 0, 0),
|
471
697
|
)
|
472
698
|
)
|
473
|
-
for si in self
|
699
|
+
for si in self[c_cycle[-1]]:
|
474
700
|
x3, y3 = positions[si]
|
475
701
|
dwg.add(
|
476
702
|
dwg.line(
|
@@ -483,7 +709,7 @@ class lineageTree:
|
|
483
709
|
else:
|
484
710
|
for c in treated_cells:
|
485
711
|
x1, y1 = positions[c]
|
486
|
-
for si in self
|
712
|
+
for si in self[c]:
|
487
713
|
x2, y2 = positions[si]
|
488
714
|
if draw_edges:
|
489
715
|
dwg.add(
|
@@ -535,7 +761,7 @@ class lineageTree:
|
|
535
761
|
start_time = times_to_consider[0]
|
536
762
|
for t in times_to_consider:
|
537
763
|
for id_mother in self.time_nodes[t]:
|
538
|
-
ids_daughters = self
|
764
|
+
ids_daughters = self[id_mother]
|
539
765
|
new_ids_daughters = ids_daughters.copy()
|
540
766
|
for _ in range(sampling - 1):
|
541
767
|
tmp = []
|
@@ -659,12 +885,12 @@ class lineageTree:
|
|
659
885
|
edges_to_use += list(s_edges)
|
660
886
|
else:
|
661
887
|
edges_to_use = []
|
888
|
+
nodes_to_use = set(nodes_to_use)
|
662
889
|
if temporal:
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
]
|
890
|
+
for n in nodes_to_use:
|
891
|
+
for d in self.successor.get(n, []):
|
892
|
+
if d in nodes_to_use:
|
893
|
+
edges_to_use.append((n, d))
|
668
894
|
if spatial:
|
669
895
|
edges_to_use += [
|
670
896
|
e for e in s_edges if t_min < self.time[e[0]] < t_max
|
@@ -787,7 +1013,6 @@ class lineageTree:
|
|
787
1013
|
self.time_edges = {}
|
788
1014
|
unique_id = 0
|
789
1015
|
self.nodes = set()
|
790
|
-
self.edges = set()
|
791
1016
|
self.successor = {}
|
792
1017
|
self.predecessor = {}
|
793
1018
|
self.pos = {}
|
@@ -826,7 +1051,6 @@ class lineageTree:
|
|
826
1051
|
M = corres[pred]
|
827
1052
|
self.predecessor[C] = [M]
|
828
1053
|
self.successor.setdefault(M, []).append(C)
|
829
|
-
self.edges.add((M, C))
|
830
1054
|
self.time_edges.setdefault(t, set()).add((M, C))
|
831
1055
|
self.lin.setdefault(lin_id, []).append(C)
|
832
1056
|
self.C_lin[C] = lin_id
|
@@ -962,8 +1186,9 @@ class lineageTree:
|
|
962
1186
|
if "cell_fate" in tmp_data:
|
963
1187
|
self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
|
964
1188
|
if "cell_barycenter" in tmp_data:
|
965
|
-
self.pos[unique_id] = tmp_data["cell_barycenter"].get(
|
966
|
-
|
1189
|
+
self.pos[unique_id] = tmp_data["cell_barycenter"].get(
|
1190
|
+
n, np.zeros(3)
|
1191
|
+
)
|
967
1192
|
|
968
1193
|
unique_id += 1
|
969
1194
|
if do_surf:
|
@@ -982,9 +1207,7 @@ class lineageTree:
|
|
982
1207
|
self.successor[new_id] = [
|
983
1208
|
self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
|
984
1209
|
]
|
985
|
-
|
986
|
-
[(new_id, ni) for ni in self.successor[new_id]]
|
987
|
-
)
|
1210
|
+
|
988
1211
|
for ni in self.successor[new_id]:
|
989
1212
|
self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
|
990
1213
|
|
@@ -993,31 +1216,43 @@ class lineageTree:
|
|
993
1216
|
self.max_id = unique_id
|
994
1217
|
|
995
1218
|
# do this in the end of the process, skip lineage tree and whatever is stored already
|
996
|
-
|
1219
|
+
discard = {
|
997
1220
|
"cell_volume",
|
998
1221
|
"cell_fate",
|
999
1222
|
"cell_barycenter",
|
1000
1223
|
"cell_contact_surface",
|
1001
1224
|
"cell_lineage",
|
1002
|
-
|
1225
|
+
"all_cells",
|
1226
|
+
"cell_history",
|
1227
|
+
"problematic_cells",
|
1228
|
+
"cell_labels_in_time",
|
1229
|
+
}
|
1230
|
+
self.specific_properties = []
|
1003
1231
|
for prop_name, prop_values in tmp_data.items():
|
1004
|
-
if not (prop_name in
|
1232
|
+
if not (prop_name in discard or hasattr(self, prop_name)):
|
1005
1233
|
if isinstance(prop_values, dict):
|
1006
1234
|
dictionary = {
|
1007
|
-
self.pkl2lT.get(k, -1): v
|
1235
|
+
self.pkl2lT.get(k, -1): v
|
1236
|
+
for k, v in prop_values.items()
|
1008
1237
|
}
|
1009
1238
|
# is it a regular dictionary or a dictionary with dictionaries inside?
|
1010
1239
|
for key, value in dictionary.items():
|
1011
1240
|
if isinstance(value, dict):
|
1012
1241
|
# rename all ids from old to new
|
1013
1242
|
dictionary[key] = {
|
1014
|
-
self.pkl2lT.get(k, -1): v
|
1243
|
+
self.pkl2lT.get(k, -1): v
|
1244
|
+
for k, v in value.items()
|
1015
1245
|
}
|
1016
1246
|
self.__dict__[prop_name] = dictionary
|
1247
|
+
self.specific_properties.append(prop_name)
|
1017
1248
|
# is any of this necessary? Or does it mean it anyways does not contain
|
1018
1249
|
# information about the id and a simple else: is enough?
|
1019
|
-
elif
|
1250
|
+
elif (
|
1251
|
+
isinstance(prop_values, (list, set, np.ndarray))
|
1252
|
+
and prop_name not in []
|
1253
|
+
):
|
1020
1254
|
self.__dict__[prop_name] = prop_values
|
1255
|
+
self.specific_properties.append(prop_name)
|
1021
1256
|
|
1022
1257
|
# what else could it be?
|
1023
1258
|
|
@@ -1129,7 +1364,6 @@ class lineageTree:
|
|
1129
1364
|
p = None
|
1130
1365
|
self.predecessor.setdefault(c, []).append(p)
|
1131
1366
|
self.successor.setdefault(p, []).append(c)
|
1132
|
-
self.edges.add((p, c))
|
1133
1367
|
self.time_edges.setdefault(t - 1, set()).add((p, c))
|
1134
1368
|
self.max_id = unique_id
|
1135
1369
|
|
@@ -1217,7 +1451,6 @@ class lineageTree:
|
|
1217
1451
|
p = None
|
1218
1452
|
self.predecessor.setdefault(c, []).append(p)
|
1219
1453
|
self.successor.setdefault(p, []).append(c)
|
1220
|
-
self.edges.add((p, c))
|
1221
1454
|
self.time_edges.setdefault(t - 1, set()).add((p, c))
|
1222
1455
|
self.max_id = unique_id
|
1223
1456
|
|
@@ -1241,7 +1474,6 @@ class lineageTree:
|
|
1241
1474
|
self.time_edges = {}
|
1242
1475
|
unique_id = 0
|
1243
1476
|
self.nodes = set()
|
1244
|
-
self.edges = set()
|
1245
1477
|
self.successor = {}
|
1246
1478
|
self.predecessor = {}
|
1247
1479
|
self.pos = {}
|
@@ -1301,7 +1533,6 @@ class lineageTree:
|
|
1301
1533
|
M = self.time_id[(t - 1, M_id)]
|
1302
1534
|
self.successor.setdefault(M, []).append(C)
|
1303
1535
|
self.predecessor.setdefault(C, []).append(M)
|
1304
|
-
self.edges.add((M, C))
|
1305
1536
|
self.time_edges[t].add((M, C))
|
1306
1537
|
else:
|
1307
1538
|
if M_id != -1:
|
@@ -1338,7 +1569,6 @@ class lineageTree:
|
|
1338
1569
|
|
1339
1570
|
mr = MastodonReader(path)
|
1340
1571
|
spots, links = mr.read_tables()
|
1341
|
-
mr.read_tags(spots, links)
|
1342
1572
|
|
1343
1573
|
self.node_name = {}
|
1344
1574
|
|
@@ -1358,7 +1588,6 @@ class lineageTree:
|
|
1358
1588
|
target = e.target_idx
|
1359
1589
|
self.predecessor.setdefault(target, []).append(source)
|
1360
1590
|
self.successor.setdefault(source, []).append(target)
|
1361
|
-
self.edges.add((source, target))
|
1362
1591
|
self.time_edges.setdefault(self.time[source], set()).add(
|
1363
1592
|
(source, target)
|
1364
1593
|
)
|
@@ -1393,14 +1622,13 @@ class lineageTree:
|
|
1393
1622
|
self.nodes.add(unique_id)
|
1394
1623
|
self.time[unique_id] = t
|
1395
1624
|
self.node_name[unique_id] = spot[1]
|
1396
|
-
self.pos[unique_id] = np.array([x, y, z])
|
1625
|
+
self.pos[unique_id] = np.array([x, y, z], dtype=float)
|
1397
1626
|
|
1398
1627
|
for link in links:
|
1399
1628
|
source = int(float(link[4]))
|
1400
1629
|
target = int(float(link[5]))
|
1401
1630
|
self.predecessor.setdefault(target, []).append(source)
|
1402
1631
|
self.successor.setdefault(source, []).append(target)
|
1403
|
-
self.edges.add((source, target))
|
1404
1632
|
self.time_edges.setdefault(self.time[source], set()).add(
|
1405
1633
|
(source, target)
|
1406
1634
|
)
|
@@ -1455,23 +1683,24 @@ class lineageTree:
|
|
1455
1683
|
if attr in self.xml_attributes:
|
1456
1684
|
self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
|
1457
1685
|
|
1458
|
-
self.edges = set()
|
1459
1686
|
tracks = {}
|
1460
1687
|
self.successor = {}
|
1461
1688
|
self.predecessor = {}
|
1462
1689
|
self.track_name = {}
|
1463
1690
|
for track in AllTracks:
|
1464
1691
|
if "TRACK_DURATION" in track.attrib:
|
1465
|
-
t_id, _ =
|
1466
|
-
track.attrib["
|
1692
|
+
t_id, _ = (
|
1693
|
+
int(track.attrib["TRACK_ID"]),
|
1694
|
+
float(track.attrib["TRACK_DURATION"]),
|
1467
1695
|
)
|
1468
1696
|
else:
|
1469
1697
|
t_id = int(track.attrib["TRACK_ID"])
|
1470
1698
|
t_name = track.attrib["name"]
|
1471
1699
|
tracks[t_id] = []
|
1472
1700
|
for edge in track:
|
1473
|
-
s, t =
|
1474
|
-
edge.attrib["
|
1701
|
+
s, t = (
|
1702
|
+
int(edge.attrib["SPOT_SOURCE_ID"]),
|
1703
|
+
int(edge.attrib["SPOT_TARGET_ID"]),
|
1475
1704
|
)
|
1476
1705
|
if s in self.nodes and t in self.nodes:
|
1477
1706
|
if self.time[s] > self.time[t]:
|
@@ -1481,7 +1710,6 @@ class lineageTree:
|
|
1481
1710
|
self.track_name[s] = t_name
|
1482
1711
|
self.track_name[t] = t_name
|
1483
1712
|
tracks[t_id].append((s, t))
|
1484
|
-
self.edges.add((s, t))
|
1485
1713
|
self.t_b = min(self.time_nodes.keys())
|
1486
1714
|
self.t_e = max(self.time_nodes.keys())
|
1487
1715
|
|
@@ -1519,7 +1747,7 @@ class lineageTree:
|
|
1519
1747
|
curr_c = to_treat.pop()
|
1520
1748
|
number_sequence.append(curr_c)
|
1521
1749
|
pos_sequence += list(self.pos[curr_c])
|
1522
|
-
if self
|
1750
|
+
if self[curr_c] == []:
|
1523
1751
|
number_sequence.append(-1)
|
1524
1752
|
elif len(self.successor[curr_c]) == 1:
|
1525
1753
|
to_treat += self.successor[curr_c]
|
@@ -1681,7 +1909,6 @@ class lineageTree:
|
|
1681
1909
|
self.time_edges = time_edges
|
1682
1910
|
self.pos = pos
|
1683
1911
|
self.nodes = set(nodes)
|
1684
|
-
self.edges = set(edges)
|
1685
1912
|
self.t_b = min(time_nodes.keys())
|
1686
1913
|
self.t_e = max(time_nodes.keys())
|
1687
1914
|
self.is_root = is_root
|
@@ -1701,7 +1928,7 @@ class lineageTree:
|
|
1701
1928
|
f.close()
|
1702
1929
|
|
1703
1930
|
@classmethod
|
1704
|
-
def load(clf, fname: str):
|
1931
|
+
def load(clf, fname: str, rm_empty_lists=True):
|
1705
1932
|
"""
|
1706
1933
|
Loading a lineage tree from a ".lT" file.
|
1707
1934
|
|
@@ -1714,6 +1941,18 @@ class lineageTree:
|
|
1714
1941
|
with open(fname, "br") as f:
|
1715
1942
|
lT = pkl.load(f)
|
1716
1943
|
f.close()
|
1944
|
+
if rm_empty_lists:
|
1945
|
+
if [] in lT.successor.values():
|
1946
|
+
for node, succ in lT.successor.items():
|
1947
|
+
if succ == []:
|
1948
|
+
lT.successor.pop(node)
|
1949
|
+
if [] in lT.predecessor.values():
|
1950
|
+
for node, succ in lT.predecessor.items():
|
1951
|
+
if succ == []:
|
1952
|
+
lT.predecessor.pop(node)
|
1953
|
+
lT.t_e = max(lT.time_nodes)
|
1954
|
+
lT.t_b = min(lT.time_nodes)
|
1955
|
+
warnings.warn("Empty lists have been removed")
|
1717
1956
|
return lT
|
1718
1957
|
|
1719
1958
|
def get_idx3d(self, t: int) -> tuple:
|
@@ -1795,7 +2034,9 @@ class lineageTree:
|
|
1795
2034
|
|
1796
2035
|
return self.Gabriel_graph[t]
|
1797
2036
|
|
1798
|
-
def get_predecessors(
|
2037
|
+
def get_predecessors(
|
2038
|
+
self, x: int, depth: int = None, start_time: int = None, end_time=None
|
2039
|
+
) -> list:
|
1799
2040
|
"""Computes the predecessors of the node `x` up to
|
1800
2041
|
`depth` predecessors or the begining of the life of `x`.
|
1801
2042
|
The ordered list of ids is returned.
|
@@ -1806,20 +2047,34 @@ class lineageTree:
|
|
1806
2047
|
Returns:
|
1807
2048
|
[int, ]: list of ids, the last id is `x`
|
1808
2049
|
"""
|
1809
|
-
|
2050
|
+
if not start_time:
|
2051
|
+
start_time = self.t_b
|
2052
|
+
if not end_time:
|
2053
|
+
end_time = self.t_e
|
2054
|
+
unconstrained_cycle = [x]
|
2055
|
+
cycle = [x] if start_time <= self.time[x] <= end_time else []
|
1810
2056
|
acc = 0
|
1811
2057
|
while (
|
1812
|
-
len(
|
1813
|
-
self.successor.get(self.predecessor.get(cycle[0], [-1])[0], [])
|
1814
|
-
)
|
2058
|
+
len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
|
1815
2059
|
== 1
|
1816
2060
|
and acc != depth
|
2061
|
+
and start_time
|
2062
|
+
<= self.time.get(
|
2063
|
+
self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
|
2064
|
+
)
|
1817
2065
|
):
|
1818
|
-
|
2066
|
+
unconstrained_cycle.insert(
|
2067
|
+
0, self.predecessor[unconstrained_cycle[0]][0]
|
2068
|
+
)
|
1819
2069
|
acc += 1
|
2070
|
+
if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
|
2071
|
+
cycle.insert(0, unconstrained_cycle[0])
|
2072
|
+
|
1820
2073
|
return cycle
|
1821
2074
|
|
1822
|
-
def get_successors(
|
2075
|
+
def get_successors(
|
2076
|
+
self, x: int, depth: int = None, end_time: int = None
|
2077
|
+
) -> list:
|
1823
2078
|
"""Computes the successors of the node `x` up to
|
1824
2079
|
`depth` successors or the end of the life of `x`.
|
1825
2080
|
The ordered list of ids is returned.
|
@@ -1830,11 +2085,18 @@ class lineageTree:
|
|
1830
2085
|
Returns:
|
1831
2086
|
[int, ]: list of ids, the first id is `x`
|
1832
2087
|
"""
|
2088
|
+
if end_time is None:
|
2089
|
+
end_time = self.t_e
|
1833
2090
|
cycle = [x]
|
1834
2091
|
acc = 0
|
1835
|
-
while
|
2092
|
+
while (
|
2093
|
+
len(self[cycle[-1]]) == 1
|
2094
|
+
and acc != depth
|
2095
|
+
and self.time[cycle[-1]] < end_time
|
2096
|
+
):
|
1836
2097
|
cycle += self.successor[cycle[-1]]
|
1837
2098
|
acc += 1
|
2099
|
+
|
1838
2100
|
return cycle
|
1839
2101
|
|
1840
2102
|
def get_cycle(
|
@@ -1843,12 +2105,14 @@ class lineageTree:
|
|
1843
2105
|
depth: int = None,
|
1844
2106
|
depth_pred: int = None,
|
1845
2107
|
depth_succ: int = None,
|
2108
|
+
end_time: int = None,
|
1846
2109
|
) -> list:
|
1847
2110
|
"""Computes the predecessors and successors of the node `x` up to
|
1848
2111
|
`depth_pred` predecessors plus `depth_succ` successors.
|
1849
2112
|
If the value `depth` is provided and not None,
|
1850
2113
|
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1851
2114
|
The ordered list of ids is returned.
|
2115
|
+
If all `depth` are None, the full cycle is returned.
|
1852
2116
|
|
1853
2117
|
Args:
|
1854
2118
|
x (int): id of the node to compute
|
@@ -1858,11 +2122,13 @@ class lineageTree:
|
|
1858
2122
|
Returns:
|
1859
2123
|
[int, ]: list of ids
|
1860
2124
|
"""
|
2125
|
+
if end_time is None:
|
2126
|
+
end_time = self.t_e
|
1861
2127
|
if depth is not None:
|
1862
2128
|
depth_pred = depth_succ = depth
|
1863
|
-
return self.get_predecessors(x, depth_pred)[
|
1864
|
-
|
1865
|
-
)
|
2129
|
+
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
2130
|
+
:-1
|
2131
|
+
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1866
2132
|
|
1867
2133
|
@property
|
1868
2134
|
def all_tracks(self):
|
@@ -1870,6 +2136,29 @@ class lineageTree:
|
|
1870
2136
|
self._all_tracks = self.get_all_tracks()
|
1871
2137
|
return self._all_tracks
|
1872
2138
|
|
2139
|
+
def get_all_branches_of_node(
|
2140
|
+
self, node: int, end_time: int = None
|
2141
|
+
) -> list:
|
2142
|
+
"""Computes all the tracks of the subtree spawn by a given node.
|
2143
|
+
Similar to get_all_tracks().
|
2144
|
+
|
2145
|
+
Args:
|
2146
|
+
node (int, optional): The node that we want to get its branches.
|
2147
|
+
|
2148
|
+
Returns:
|
2149
|
+
([[int, ...], ...]): list of lists containing track cell ids
|
2150
|
+
"""
|
2151
|
+
if not end_time:
|
2152
|
+
end_time = self.t_e
|
2153
|
+
branches = [self.get_successors(node)]
|
2154
|
+
to_do = self[branches[0][-1]].copy()
|
2155
|
+
while to_do:
|
2156
|
+
current = to_do.pop()
|
2157
|
+
track = self.get_cycle(current, end_time=end_time)
|
2158
|
+
branches += [track]
|
2159
|
+
to_do.extend(self[track[-1]])
|
2160
|
+
return branches
|
2161
|
+
|
1873
2162
|
def get_all_tracks(self, force_recompute: bool = False) -> list:
|
1874
2163
|
"""Computes all the tracks of a given lineage tree,
|
1875
2164
|
stores it in `self.all_tracks` and returns it.
|
@@ -1877,17 +2166,42 @@ class lineageTree:
|
|
1877
2166
|
Returns:
|
1878
2167
|
([[int, ...], ...]): list of lists containing track cell ids
|
1879
2168
|
"""
|
1880
|
-
if not hasattr(self, "_all_tracks"):
|
2169
|
+
if not hasattr(self, "_all_tracks") or force_recompute:
|
1881
2170
|
self._all_tracks = []
|
1882
|
-
to_do =
|
2171
|
+
to_do = list(self.roots)
|
1883
2172
|
while len(to_do) != 0:
|
1884
2173
|
current = to_do.pop()
|
1885
2174
|
track = self.get_cycle(current)
|
1886
2175
|
self._all_tracks += [track]
|
1887
|
-
to_do
|
2176
|
+
to_do.extend(self[track[-1]])
|
1888
2177
|
return self._all_tracks
|
1889
2178
|
|
1890
|
-
def
|
2179
|
+
def get_tracks(self, roots: list = None) -> list:
|
2180
|
+
"""Computes the tracks given by the list of nodes `roots` and returns it.
|
2181
|
+
|
2182
|
+
Args:
|
2183
|
+
roots (list): list of ids of the roots to be computed
|
2184
|
+
Returns:
|
2185
|
+
([[int, ...], ...]): list of lists containing track cell ids
|
2186
|
+
"""
|
2187
|
+
if roots is None:
|
2188
|
+
return self.get_all_tracks(force_recompute=True)
|
2189
|
+
else:
|
2190
|
+
tracks = []
|
2191
|
+
to_do = list(roots)
|
2192
|
+
while len(to_do) != 0:
|
2193
|
+
current = to_do.pop()
|
2194
|
+
track = self.get_cycle(current)
|
2195
|
+
tracks.append(track)
|
2196
|
+
to_do.extend(self[track[-1]])
|
2197
|
+
return tracks
|
2198
|
+
|
2199
|
+
def get_sub_tree(
|
2200
|
+
self,
|
2201
|
+
x: Union[int, Iterable],
|
2202
|
+
end_time: Union[int, None] = None,
|
2203
|
+
preorder: bool = False,
|
2204
|
+
) -> list:
|
1891
2205
|
"""Computes the list of cells from the subtree spawned by *x*
|
1892
2206
|
The default output order is breadth first traversal.
|
1893
2207
|
Unless preorder is `True` in that case the order is
|
@@ -1899,16 +2213,24 @@ class lineageTree:
|
|
1899
2213
|
Returns:
|
1900
2214
|
([int, ...]): the ordered list of node ids
|
1901
2215
|
"""
|
1902
|
-
|
2216
|
+
if not end_time:
|
2217
|
+
end_time = self.t_e
|
2218
|
+
if not isinstance(x, Iterable):
|
2219
|
+
to_do = [x]
|
2220
|
+
elif isinstance(x, Iterable):
|
2221
|
+
to_do = list(x)
|
1903
2222
|
sub_tree = []
|
1904
|
-
while
|
1905
|
-
curr = to_do.pop(
|
2223
|
+
while to_do:
|
2224
|
+
curr = to_do.pop()
|
1906
2225
|
succ = self.successor.get(curr, [])
|
2226
|
+
if succ and end_time < self.time.get(curr, end_time):
|
2227
|
+
succ = []
|
2228
|
+
continue
|
1907
2229
|
if preorder:
|
1908
2230
|
to_do = succ + to_do
|
1909
2231
|
else:
|
1910
2232
|
to_do += succ
|
1911
|
-
|
2233
|
+
sub_tree += [curr]
|
1912
2234
|
return sub_tree
|
1913
2235
|
|
1914
2236
|
def compute_spatial_density(
|
@@ -1979,6 +2301,70 @@ class lineageTree:
|
|
1979
2301
|
)
|
1980
2302
|
return self.th_edges
|
1981
2303
|
|
2304
|
+
def main_axes(self, time: int = None):
|
2305
|
+
"""Finds the main axes for a timepoint.
|
2306
|
+
If none will select the timepoint with the highest amound of cells.
|
2307
|
+
|
2308
|
+
Args:
|
2309
|
+
time (int, optional): The timepoint to find the main axes.
|
2310
|
+
If None will find the timepoint
|
2311
|
+
with the largest number of cells.
|
2312
|
+
|
2313
|
+
Returns:
|
2314
|
+
list: A list that contains the array of eigenvalues and eigenvectors.
|
2315
|
+
"""
|
2316
|
+
if time is None:
|
2317
|
+
time = np.argmax(
|
2318
|
+
[len(self.time_nodes[t]) for t in range(int(self.t_e))]
|
2319
|
+
)
|
2320
|
+
pos = np.array([self.pos[node] for node in self.time_nodes[time]])
|
2321
|
+
pos = pos - np.mean(pos, axis=0)
|
2322
|
+
cov = np.cov(np.array(pos).T)
|
2323
|
+
eig_val, eig_vec = np.linalg.eig(cov)
|
2324
|
+
srt = np.argsort(eig_val)[::-1]
|
2325
|
+
self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
|
2326
|
+
return eig_val[srt], eig_vec[:, srt]
|
2327
|
+
|
2328
|
+
def scale_embryo(self, scale=1000):
|
2329
|
+
"""Scale the embryo using their eigenvalues.
|
2330
|
+
|
2331
|
+
Args:
|
2332
|
+
scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
|
2333
|
+
|
2334
|
+
Returns:
|
2335
|
+
float: The scale factor.
|
2336
|
+
"""
|
2337
|
+
eig = self.main_axes()[0]
|
2338
|
+
return scale / (np.sqrt(eig[0]))
|
2339
|
+
|
2340
|
+
@staticmethod
|
2341
|
+
def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
|
2342
|
+
"""Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
|
2343
|
+
Uses the Rodrigues rotation formula.
|
2344
|
+
|
2345
|
+
Args:
|
2346
|
+
vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
|
2347
|
+
vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
|
2348
|
+
|
2349
|
+
Returns:
|
2350
|
+
np.array: The rotation matrix.
|
2351
|
+
"""
|
2352
|
+
vector1 = vector1 / np.linalg.norm(vector1)
|
2353
|
+
vector2 = vector2 / np.linalg.norm(vector2)
|
2354
|
+
if vector1 @ vector2 == 1:
|
2355
|
+
return np.eye(3)
|
2356
|
+
angle = np.arccos(vector1 @ vector2)
|
2357
|
+
axis = np.cross(vector1, vector2)
|
2358
|
+
axis = axis / np.linalg.norm(axis)
|
2359
|
+
K = np.array(
|
2360
|
+
[
|
2361
|
+
[0, -axis[2], axis[1]],
|
2362
|
+
[axis[2], 0, -axis[0]],
|
2363
|
+
[-axis[1], axis[0], 0],
|
2364
|
+
]
|
2365
|
+
)
|
2366
|
+
return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
|
2367
|
+
|
1982
2368
|
def get_ancestor_at_t(self, n: int, time: int = None):
|
1983
2369
|
"""
|
1984
2370
|
Find the id of the ancestor of a give node `n`
|
@@ -2005,62 +2391,27 @@ class lineageTree:
|
|
2005
2391
|
ancestor = self.predecessor.get(ancestor, [-1])[0]
|
2006
2392
|
return ancestor
|
2007
2393
|
|
2008
|
-
def
|
2009
|
-
"""
|
2010
|
-
Get a "simple" version of the tree spawned by the node `r`
|
2011
|
-
This simple version is just one node per cell (as opposed to
|
2012
|
-
one node per cell per time-point). The life time duration of
|
2013
|
-
a cell `c` is stored in `self.cycle_time` and return by this
|
2014
|
-
function
|
2394
|
+
def get_labelled_ancestor(self, node: int):
|
2395
|
+
"""Finds the first labelled ancestor and returns its ID otherwise returns None
|
2015
2396
|
|
2016
2397
|
Args:
|
2017
|
-
|
2018
|
-
time_resolution (float): the time between two consecutive time points
|
2398
|
+
node (int): The id of the node
|
2019
2399
|
|
2020
2400
|
Returns:
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
|
2025
|
-
|
2026
|
-
|
2027
|
-
|
2028
|
-
self.
|
2029
|
-
|
2030
|
-
|
2031
|
-
|
2032
|
-
|
2033
|
-
|
2034
|
-
|
2035
|
-
if _next:
|
2036
|
-
out_dict[current] = _next
|
2037
|
-
to_do.extend(_next)
|
2038
|
-
self.cycle_time[current] = len(cycle) * time_resolution
|
2039
|
-
return out_dict, self.cycle_time
|
2040
|
-
|
2041
|
-
@staticmethod
|
2042
|
-
def __edist_format(adj_dict: dict):
|
2043
|
-
inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
|
2044
|
-
roots = set(adj_dict).difference(inv_adj)
|
2045
|
-
nid2list = {}
|
2046
|
-
list2nid = {}
|
2047
|
-
nodes = []
|
2048
|
-
adj_list = []
|
2049
|
-
curr_id = 0
|
2050
|
-
for r in roots:
|
2051
|
-
to_do = [r]
|
2052
|
-
while to_do:
|
2053
|
-
curr = to_do.pop(0)
|
2054
|
-
nid2list[curr] = curr_id
|
2055
|
-
list2nid[curr_id] = curr
|
2056
|
-
nodes.append(curr_id)
|
2057
|
-
to_do = adj_dict.get(curr, []) + to_do
|
2058
|
-
curr_id += 1
|
2059
|
-
adj_list = [
|
2060
|
-
[nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
|
2061
|
-
for _id in nodes
|
2062
|
-
]
|
2063
|
-
return nodes, adj_list, list2nid
|
2401
|
+
[None,int]: Returns the first ancestor found that has a label otherwise
|
2402
|
+
None.
|
2403
|
+
"""
|
2404
|
+
if node not in self.nodes:
|
2405
|
+
return None
|
2406
|
+
ancestor = node
|
2407
|
+
while (
|
2408
|
+
self.t_b <= self.time.get(ancestor, self.t_b - 1)
|
2409
|
+
and ancestor != -1
|
2410
|
+
):
|
2411
|
+
if ancestor in self.labels:
|
2412
|
+
return ancestor
|
2413
|
+
ancestor = self.predecessor.get(ancestor, [-1])[0]
|
2414
|
+
return
|
2064
2415
|
|
2065
2416
|
def unordered_tree_edit_distances_at_time_t(
|
2066
2417
|
self,
|
@@ -2068,6 +2419,7 @@ class lineageTree:
|
|
2068
2419
|
delta: callable = None,
|
2069
2420
|
norm: callable = None,
|
2070
2421
|
recompute: bool = False,
|
2422
|
+
end_time: int = None,
|
2071
2423
|
) -> dict:
|
2072
2424
|
"""
|
2073
2425
|
Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
|
@@ -2079,6 +2431,8 @@ class lineageTree:
|
|
2079
2431
|
of the tree spawned by `n1` and the number of nodes
|
2080
2432
|
of the tree spawned by `n2` as arguments.
|
2081
2433
|
recompute (bool): if True, forces to recompute the distances (default: False)
|
2434
|
+
end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
|
2435
|
+
will be taken into account.
|
2082
2436
|
|
2083
2437
|
Returns:
|
2084
2438
|
(dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
|
@@ -2092,14 +2446,20 @@ class lineageTree:
|
|
2092
2446
|
for n1, n2 in combinations(roots, 2):
|
2093
2447
|
key = tuple(sorted((n1, n2)))
|
2094
2448
|
self.uted[t][key] = self.unordered_tree_edit_distance(
|
2095
|
-
n1, n2,
|
2449
|
+
n1, n2, end_time=end_time
|
2096
2450
|
)
|
2097
2451
|
return self.uted[t]
|
2098
2452
|
|
2099
2453
|
def unordered_tree_edit_distance(
|
2100
|
-
self,
|
2454
|
+
self,
|
2455
|
+
n1: int,
|
2456
|
+
n2: int,
|
2457
|
+
end_time: int = None,
|
2458
|
+
style="fragmented",
|
2459
|
+
node_lengths: tuple = (1, 5, 7),
|
2101
2460
|
) -> float:
|
2102
2461
|
"""
|
2462
|
+
TODO: Add option for choosing which tree aproximation should be used (Full, simple, comp)
|
2103
2463
|
Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
2104
2464
|
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
2105
2465
|
cost is given by the function delta (see edist doc for more information).
|
@@ -2109,48 +2469,178 @@ class lineageTree:
|
|
2109
2469
|
Args:
|
2110
2470
|
n1 (int): id of the first node to compare
|
2111
2471
|
n2 (int): id of the second node to compare
|
2112
|
-
|
2113
|
-
|
2114
|
-
of the tree spawned by `n1` and the number of nodes
|
2115
|
-
of the tree spawned by `n2` as arguments.
|
2472
|
+
tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
|
2473
|
+
Defaults to "fragmented".
|
2116
2474
|
|
2117
2475
|
Returns:
|
2118
2476
|
(float) The normed unordered tree edit distance
|
2119
2477
|
"""
|
2120
2478
|
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2479
|
+
tree = tree_style[style].value
|
2480
|
+
tree1 = tree(
|
2481
|
+
lT=self, node_length=node_lengths, end_time=end_time, root=n1
|
2482
|
+
)
|
2483
|
+
tree2 = tree(
|
2484
|
+
lT=self, node_length=node_lengths, end_time=end_time, root=n2
|
2485
|
+
)
|
2486
|
+
delta = tree1.delta
|
2487
|
+
_, times1 = tree1.tree
|
2488
|
+
_, times2 = tree2.tree
|
2489
|
+
(
|
2490
|
+
nodes1,
|
2491
|
+
adj1,
|
2492
|
+
corres1,
|
2493
|
+
) = tree1.edist
|
2494
|
+
(
|
2495
|
+
nodes2,
|
2496
|
+
adj2,
|
2497
|
+
corres2,
|
2498
|
+
) = tree2.edist
|
2499
|
+
if len(nodes1) == len(nodes2) == 0:
|
2500
|
+
return 0
|
2501
|
+
delta_tmp = partial(
|
2502
|
+
delta,
|
2503
|
+
corres1=corres1,
|
2504
|
+
corres2=corres2,
|
2505
|
+
times1=times1,
|
2506
|
+
times2=times2,
|
2507
|
+
)
|
2124
2508
|
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
len_x = times[corres1[x]]
|
2129
|
-
len_y = times[corres2[y]]
|
2130
|
-
return np.abs(len_x - len_y) / (len_x + len_y)
|
2509
|
+
return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
|
2510
|
+
tree1.get_norm(), tree2.get_norm()
|
2511
|
+
)
|
2131
2512
|
|
2132
|
-
|
2513
|
+
def to_simple_networkx(
|
2514
|
+
self, node: Union[int, list, set, tuple] = None, start_time: int = 0
|
2515
|
+
):
|
2516
|
+
"""
|
2517
|
+
Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
|
2518
|
+
they can be used for visualization or other tasks),
|
2519
|
+
so only the start and the end of a branch are calculated, all cells in between are not taken into account.
|
2520
|
+
Args:
|
2521
|
+
start_time (int): From which timepoints are the graphs to be calculated.
|
2522
|
+
For example if start_time is 10, then all trees that begin
|
2523
|
+
on tp 10 or before are calculated.
|
2524
|
+
returns:
|
2525
|
+
G : list(nx.Digraph(),...)
|
2526
|
+
pos : list(dict(id:position))
|
2527
|
+
"""
|
2133
2528
|
|
2134
|
-
|
2135
|
-
|
2529
|
+
if node is None:
|
2530
|
+
mothers = [
|
2531
|
+
root for root in self.roots if self.time[root] <= start_time
|
2532
|
+
]
|
2533
|
+
else:
|
2534
|
+
mothers = node if isinstance(node, (list, set)) else [node]
|
2535
|
+
graph = {}
|
2536
|
+
all_nodes = {}
|
2537
|
+
all_edges = {}
|
2538
|
+
for mom in mothers:
|
2539
|
+
edges = set()
|
2540
|
+
nodes = set()
|
2541
|
+
for branch in self.get_all_branches_of_node(mom):
|
2542
|
+
nodes.update((branch[0], branch[-1]))
|
2543
|
+
if len(branch) > 1:
|
2544
|
+
edges.add((branch[0], branch[-1]))
|
2545
|
+
for suc in self[branch[-1]]:
|
2546
|
+
edges.add((branch[-1], suc))
|
2547
|
+
all_edges[mom] = edges
|
2548
|
+
all_nodes[mom] = nodes
|
2549
|
+
for i, mother in enumerate(mothers):
|
2550
|
+
graph[i] = nx.DiGraph()
|
2551
|
+
graph[i].add_nodes_from(all_nodes[mother])
|
2552
|
+
graph[i].add_edges_from(all_edges[mother])
|
2553
|
+
|
2554
|
+
return graph
|
2555
|
+
|
2556
|
+
def plot_all_lineages(
|
2557
|
+
self,
|
2558
|
+
starting_point: int = 0,
|
2559
|
+
nrows=2,
|
2560
|
+
figsize=(10, 15),
|
2561
|
+
dpi=70,
|
2562
|
+
fontsize=22,
|
2563
|
+
figure=None,
|
2564
|
+
axes=None,
|
2565
|
+
**kwargs,
|
2566
|
+
):
|
2567
|
+
"""Plots all lineages.
|
2136
2568
|
|
2137
|
-
|
2569
|
+
Args:
|
2570
|
+
starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
|
2571
|
+
For example if start_time is 10, then all trees that begin
|
2572
|
+
on tp 10 or before are calculated. Defaults to None.
|
2573
|
+
nrows (int): How many rows of plots should be printed.
|
2574
|
+
kwargs: args accepted by networkx
|
2575
|
+
"""
|
2138
2576
|
|
2139
|
-
|
2140
|
-
|
2577
|
+
nrows = int(nrows)
|
2578
|
+
if nrows < 1 or not nrows:
|
2579
|
+
nrows = 1
|
2580
|
+
raise Warning("Number of rows has to be at least 1")
|
2141
2581
|
|
2142
|
-
|
2143
|
-
|
2144
|
-
|
2145
|
-
|
2146
|
-
|
2147
|
-
return 0
|
2148
|
-
delta_tmp = partial(
|
2149
|
-
delta, corres1=corres1, corres2=corres2, times=self.cycle_time
|
2582
|
+
graphs = self.to_simple_networkx(start_time=starting_point)
|
2583
|
+
ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
|
2584
|
+
pos = postions_of_nx(self, graphs)
|
2585
|
+
figure, axes = plt.subplots(
|
2586
|
+
figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
|
2150
2587
|
)
|
2151
|
-
|
2152
|
-
|
2588
|
+
flat_axes = axes.flatten()
|
2589
|
+
ax2root = {}
|
2590
|
+
for i, graph in enumerate(graphs.values()):
|
2591
|
+
nx.draw_networkx(
|
2592
|
+
graph,
|
2593
|
+
pos[i],
|
2594
|
+
with_labels=False,
|
2595
|
+
arrows=False,
|
2596
|
+
**kwargs,
|
2597
|
+
ax=flat_axes[i],
|
2598
|
+
)
|
2599
|
+
root = [n for n, d in graph.in_degree() if d == 0][0]
|
2600
|
+
label = self.labels.get(root, "Unlabeled")
|
2601
|
+
xlim = flat_axes[i].get_xlim()
|
2602
|
+
ylim = flat_axes[i].get_ylim()
|
2603
|
+
x_pos = (xlim[1]) / 10
|
2604
|
+
y_pos = ylim[0] + 15
|
2605
|
+
ax2root[flat_axes[i]] = root
|
2606
|
+
flat_axes[i].text(
|
2607
|
+
x_pos,
|
2608
|
+
y_pos,
|
2609
|
+
label,
|
2610
|
+
fontsize=fontsize,
|
2611
|
+
color="black",
|
2612
|
+
ha="center",
|
2613
|
+
va="center",
|
2614
|
+
bbox={
|
2615
|
+
"facecolor": "white",
|
2616
|
+
"edgecolor": "green",
|
2617
|
+
"boxstyle": "round",
|
2618
|
+
},
|
2619
|
+
)
|
2620
|
+
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
|
2621
|
+
return figure, axes, ax2root
|
2622
|
+
|
2623
|
+
def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
|
2624
|
+
"""Plots the subtree spawn by a node.
|
2625
|
+
|
2626
|
+
Args:
|
2627
|
+
node (int): The id of the node that is going to be plotted.
|
2628
|
+
kwargs: args accepted by networkx
|
2629
|
+
"""
|
2630
|
+
graph = self.to_simple_networkx(node)
|
2631
|
+
if len(graph) > 1:
|
2632
|
+
raise Warning("Please enter only one node")
|
2633
|
+
graph = graph[list(graph)[0]]
|
2634
|
+
figure, ax = plt.subplots(nrows=1, ncols=1)
|
2635
|
+
nx.draw_networkx(
|
2636
|
+
graph,
|
2637
|
+
hierarchy_pos(graph, self, node),
|
2638
|
+
with_labels=False,
|
2639
|
+
arrows=False,
|
2640
|
+
ax=ax,
|
2641
|
+
**kwargs,
|
2153
2642
|
)
|
2643
|
+
return figure, ax
|
2154
2644
|
|
2155
2645
|
# def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
|
2156
2646
|
# metric='euclidian', **kwargs):
|
@@ -2231,11 +2721,584 @@ class lineageTree:
|
|
2231
2721
|
to_do.append(_next)
|
2232
2722
|
elif self.time[_next] == t:
|
2233
2723
|
final_nodes.append(_next)
|
2234
|
-
if not final_nodes:
|
2724
|
+
if not final_nodes:
|
2725
|
+
return list(r)
|
2235
2726
|
return final_nodes
|
2236
2727
|
|
2728
|
+
@staticmethod
|
2729
|
+
def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
|
2730
|
+
"""
|
2731
|
+
Calculate the line that centers the band w.
|
2732
|
+
|
2733
|
+
Args:
|
2734
|
+
dist_mat (matrix): distance matrix obtained by the function calculate_dtw
|
2735
|
+
|
2736
|
+
Returns:
|
2737
|
+
(float) Slope
|
2738
|
+
(float) intercept of the line
|
2739
|
+
"""
|
2740
|
+
i, j = dist_mat.shape
|
2741
|
+
x1 = max(0, i - j) / 2
|
2742
|
+
x2 = (i + min(i, j)) / 2
|
2743
|
+
y1 = max(0, j - i) / 2
|
2744
|
+
y2 = (j + min(i, j)) / 2
|
2745
|
+
slope = (y1 - y2) / (x1 - x2)
|
2746
|
+
intercept = y1 - slope * x1
|
2747
|
+
return slope, intercept
|
2748
|
+
|
2749
|
+
# Reference: https://github.com/kamperh/lecture_dtw_notebook/blob/main/dtw.ipynb
|
2750
|
+
def __dp(
|
2751
|
+
self,
|
2752
|
+
dist_mat: np.ndarray,
|
2753
|
+
start_d: int = 0,
|
2754
|
+
back_d: int = 0,
|
2755
|
+
fast: bool = False,
|
2756
|
+
w: int = 0,
|
2757
|
+
centered_band: bool = True,
|
2758
|
+
) -> (((int, int), ...), np.ndarray):
|
2759
|
+
"""
|
2760
|
+
Find DTW minimum cost between two series using dynamic programming.
|
2761
|
+
|
2762
|
+
Args:
|
2763
|
+
dist_mat (matrix): distance matrix obtained by the function calculate_dtw
|
2764
|
+
start_d (int): start delay
|
2765
|
+
back_d (int): end delay
|
2766
|
+
w (int): window constrain
|
2767
|
+
slope (float): to calculate window - givem by the function __calculate_diag_line
|
2768
|
+
intercept (flost): to calculate window - givem by the function __calculate_diag_line
|
2769
|
+
use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
|
2770
|
+
|
2771
|
+
Returns:
|
2772
|
+
(tuple of tuples) Aligment path
|
2773
|
+
(matrix) Cost matrix
|
2774
|
+
"""
|
2775
|
+
N, M = dist_mat.shape
|
2776
|
+
w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
|
2777
|
+
|
2778
|
+
if centered_band:
|
2779
|
+
slope, intercept = self.__calculate_diag_line(dist_mat)
|
2780
|
+
square_root = np.sqrt((slope**2) + 1)
|
2781
|
+
|
2782
|
+
# Initialize the cost matrix
|
2783
|
+
cost_mat = np.full((N + 1, M + 1), np.inf)
|
2784
|
+
cost_mat[0, 0] = 0
|
2785
|
+
|
2786
|
+
# Fill the cost matrix while keeping traceback information
|
2787
|
+
traceback_mat = np.zeros((N, M))
|
2788
|
+
|
2789
|
+
cost_mat[: start_d + 1, 0] = 0
|
2790
|
+
cost_mat[0, : start_d + 1] = 0
|
2791
|
+
|
2792
|
+
cost_mat[N - back_d :, M] = 0
|
2793
|
+
cost_mat[N, M - back_d :] = 0
|
2794
|
+
|
2795
|
+
for i in range(N):
|
2796
|
+
for j in range(M):
|
2797
|
+
if fast and not centered_band:
|
2798
|
+
condition = abs(i - j) <= w_limit
|
2799
|
+
elif fast:
|
2800
|
+
condition = (
|
2801
|
+
abs(slope * i - j + intercept) / square_root <= w_limit
|
2802
|
+
)
|
2803
|
+
else:
|
2804
|
+
condition = True
|
2805
|
+
|
2806
|
+
if condition:
|
2807
|
+
penalty = [
|
2808
|
+
cost_mat[i, j], # match (0)
|
2809
|
+
cost_mat[i, j + 1], # insertion (1)
|
2810
|
+
cost_mat[i + 1, j], # deletion (2)
|
2811
|
+
]
|
2812
|
+
i_penalty = np.argmin(penalty)
|
2813
|
+
cost_mat[i + 1, j + 1] = (
|
2814
|
+
dist_mat[i, j] + penalty[i_penalty]
|
2815
|
+
)
|
2816
|
+
traceback_mat[i, j] = i_penalty
|
2817
|
+
|
2818
|
+
min_index1 = np.argmin(cost_mat[N - back_d :, M])
|
2819
|
+
min_index2 = np.argmin(cost_mat[N, M - back_d :])
|
2820
|
+
|
2821
|
+
if (
|
2822
|
+
cost_mat[N, M - back_d + min_index2]
|
2823
|
+
< cost_mat[N - back_d + min_index1, M]
|
2824
|
+
):
|
2825
|
+
i = N - 1
|
2826
|
+
j = M - back_d + min_index2 - 1
|
2827
|
+
final_cost = cost_mat[i + 1, j + 1]
|
2828
|
+
else:
|
2829
|
+
i = N - back_d + min_index1 - 1
|
2830
|
+
j = M - 1
|
2831
|
+
final_cost = cost_mat[i + 1, j + 1]
|
2832
|
+
|
2833
|
+
path = [(i, j)]
|
2834
|
+
|
2835
|
+
while (
|
2836
|
+
start_d != 0
|
2837
|
+
and ((start_d < i and j > 0) or (i > 0 and start_d < j))
|
2838
|
+
) or (start_d == 0 and (i > 0 or j > 0)):
|
2839
|
+
tb_type = traceback_mat[i, j]
|
2840
|
+
if tb_type == 0:
|
2841
|
+
# Match
|
2842
|
+
i -= 1
|
2843
|
+
j -= 1
|
2844
|
+
elif tb_type == 1:
|
2845
|
+
# Insertion
|
2846
|
+
i -= 1
|
2847
|
+
elif tb_type == 2:
|
2848
|
+
# Deletion
|
2849
|
+
j -= 1
|
2850
|
+
|
2851
|
+
path.append((i, j))
|
2852
|
+
|
2853
|
+
# Strip infinity edges from cost_mat before returning
|
2854
|
+
cost_mat = cost_mat[1:, 1:]
|
2855
|
+
return path[::-1], cost_mat, final_cost
|
2856
|
+
|
2857
|
+
# Reference: https://github.com/nghiaho12/rigid_transform_3D
|
2858
|
+
@staticmethod
|
2859
|
+
def __rigid_transform_3D(A, B):
|
2860
|
+
assert A.shape == B.shape
|
2861
|
+
|
2862
|
+
num_rows, num_cols = A.shape
|
2863
|
+
if num_rows != 3:
|
2864
|
+
raise Exception(
|
2865
|
+
f"matrix A is not 3xN, it is {num_rows}x{num_cols}"
|
2866
|
+
)
|
2867
|
+
|
2868
|
+
num_rows, num_cols = B.shape
|
2869
|
+
if num_rows != 3:
|
2870
|
+
raise Exception(
|
2871
|
+
f"matrix B is not 3xN, it is {num_rows}x{num_cols}"
|
2872
|
+
)
|
2873
|
+
|
2874
|
+
# find mean column wise
|
2875
|
+
centroid_A = np.mean(A, axis=1)
|
2876
|
+
centroid_B = np.mean(B, axis=1)
|
2877
|
+
|
2878
|
+
# ensure centroids are 3x1
|
2879
|
+
centroid_A = centroid_A.reshape(-1, 1)
|
2880
|
+
centroid_B = centroid_B.reshape(-1, 1)
|
2881
|
+
|
2882
|
+
# subtract mean
|
2883
|
+
Am = A - centroid_A
|
2884
|
+
Bm = B - centroid_B
|
2885
|
+
|
2886
|
+
H = Am @ np.transpose(Bm)
|
2887
|
+
|
2888
|
+
# find rotation
|
2889
|
+
U, S, Vt = np.linalg.svd(H)
|
2890
|
+
R = Vt.T @ U.T
|
2891
|
+
|
2892
|
+
# special reflection case
|
2893
|
+
if np.linalg.det(R) < 0:
|
2894
|
+
# print("det(R) < R, reflection detected!, correcting for it ...")
|
2895
|
+
Vt[2, :] *= -1
|
2896
|
+
R = Vt.T @ U.T
|
2897
|
+
|
2898
|
+
t = -R @ centroid_A + centroid_B
|
2899
|
+
|
2900
|
+
return R, t
|
2901
|
+
|
2902
|
+
def __interpolate(
|
2903
|
+
self, track1: list, track2: list, threshold: int
|
2904
|
+
) -> (np.ndarray, np.ndarray):
|
2905
|
+
"""
|
2906
|
+
Interpolate two series that have different lengths
|
2907
|
+
|
2908
|
+
Args:
|
2909
|
+
track1 (list): list of nodes of the first cell cycle to compare
|
2910
|
+
track2 (list): list of nodes of the second cell cycle to compare
|
2911
|
+
threshold (int): set a maximum number of points a track can have
|
2912
|
+
|
2913
|
+
Returns:
|
2914
|
+
(list of list) x, y, z postions for track1
|
2915
|
+
(list of list) x, y, z postions for track2
|
2916
|
+
"""
|
2917
|
+
inter1_pos = []
|
2918
|
+
inter2_pos = []
|
2919
|
+
|
2920
|
+
track1_pos = np.array([self.pos[c_id] for c_id in track1])
|
2921
|
+
track2_pos = np.array([self.pos[c_id] for c_id in track2])
|
2922
|
+
|
2923
|
+
# Both tracks have the same length and size below the threshold - nothing is done
|
2924
|
+
if len(track1) == len(track2) and (
|
2925
|
+
len(track1) <= threshold or len(track2) <= threshold
|
2926
|
+
):
|
2927
|
+
return track1_pos, track2_pos
|
2928
|
+
# Both tracks have the same length but one or more sizes are above the threshold
|
2929
|
+
elif len(track1) > threshold or len(track2) > threshold:
|
2930
|
+
sampling = threshold
|
2931
|
+
# Tracks have different lengths and the sizes are below the threshold
|
2932
|
+
else:
|
2933
|
+
sampling = max(len(track1), len(track2))
|
2934
|
+
|
2935
|
+
for pos in range(3):
|
2936
|
+
track1_interp = InterpolatedUnivariateSpline(
|
2937
|
+
np.linspace(0, 1, len(track1_pos[:, pos])),
|
2938
|
+
track1_pos[:, pos],
|
2939
|
+
k=1,
|
2940
|
+
)
|
2941
|
+
inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
|
2942
|
+
|
2943
|
+
track2_interp = InterpolatedUnivariateSpline(
|
2944
|
+
np.linspace(0, 1, len(track2_pos[:, pos])),
|
2945
|
+
track2_pos[:, pos],
|
2946
|
+
k=1,
|
2947
|
+
)
|
2948
|
+
inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
|
2949
|
+
|
2950
|
+
return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
|
2951
|
+
|
2952
|
+
def calculate_dtw(
|
2953
|
+
self,
|
2954
|
+
nodes1: int,
|
2955
|
+
nodes2: int,
|
2956
|
+
threshold: int = 1000,
|
2957
|
+
regist: bool = True,
|
2958
|
+
start_d: int = 0,
|
2959
|
+
back_d: int = 0,
|
2960
|
+
fast: bool = False,
|
2961
|
+
w: int = 0,
|
2962
|
+
centered_band: bool = True,
|
2963
|
+
cost_mat_p: bool = False,
|
2964
|
+
) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
|
2965
|
+
"""
|
2966
|
+
Calculate DTW distance between two cell cycles
|
2967
|
+
|
2968
|
+
Args:
|
2969
|
+
nodes1 (int): node to compare distance
|
2970
|
+
nodes2 (int): node to compare distance
|
2971
|
+
threshold: set a maximum number of points a track can have
|
2972
|
+
regist (boolean): Rotate and translate trajectories
|
2973
|
+
start_d (int): start delay
|
2974
|
+
back_d (int): end delay
|
2975
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
2976
|
+
w (int): window size
|
2977
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
2978
|
+
cost_mat_p (boolean): True if print the not normalized cost matrix
|
2979
|
+
|
2980
|
+
Returns:
|
2981
|
+
(float) DTW distance
|
2982
|
+
(tuple of tuples) Aligment path
|
2983
|
+
(matrix) Cost matrix
|
2984
|
+
(list of lists) pos_cycle1: rotated and translated trajectories positions
|
2985
|
+
(list of lists) pos_cycle2: rotated and translated trajectories positions
|
2986
|
+
"""
|
2987
|
+
nodes1_cycle = self.get_cycle(nodes1)
|
2988
|
+
nodes2_cycle = self.get_cycle(nodes2)
|
2989
|
+
|
2990
|
+
interp_cycle1, interp_cycle2 = self.__interpolate(
|
2991
|
+
nodes1_cycle, nodes2_cycle, threshold
|
2992
|
+
)
|
2993
|
+
|
2994
|
+
pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
|
2995
|
+
pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
|
2996
|
+
|
2997
|
+
if regist:
|
2998
|
+
R, t = self.__rigid_transform_3D(
|
2999
|
+
np.transpose(interp_cycle1), np.transpose(interp_cycle2)
|
3000
|
+
)
|
3001
|
+
pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
|
3002
|
+
|
3003
|
+
dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
|
3004
|
+
|
3005
|
+
path, cost_mat, final_cost = self.__dp(
|
3006
|
+
dist_mat,
|
3007
|
+
start_d,
|
3008
|
+
back_d,
|
3009
|
+
w=w,
|
3010
|
+
fast=fast,
|
3011
|
+
centered_band=centered_band,
|
3012
|
+
)
|
3013
|
+
cost = final_cost / len(path)
|
3014
|
+
|
3015
|
+
if cost_mat_p:
|
3016
|
+
return cost, path, cost_mat, pos_cycle1, pos_cycle2
|
3017
|
+
else:
|
3018
|
+
return cost, path
|
3019
|
+
|
3020
|
+
def plot_dtw_heatmap(
|
3021
|
+
self,
|
3022
|
+
nodes1: int,
|
3023
|
+
nodes2: int,
|
3024
|
+
threshold: int = 1000,
|
3025
|
+
regist: bool = True,
|
3026
|
+
start_d: int = 0,
|
3027
|
+
back_d: int = 0,
|
3028
|
+
fast: bool = False,
|
3029
|
+
w: int = 0,
|
3030
|
+
centered_band: bool = True,
|
3031
|
+
) -> (float, plt.figure):
|
3032
|
+
"""
|
3033
|
+
Plot DTW cost matrix between two cell cycles in heatmap format
|
3034
|
+
|
3035
|
+
Args:
|
3036
|
+
nodes1 (int): node to compare distance
|
3037
|
+
nodes2 (int): node to compare distance
|
3038
|
+
start_d (int): start delay
|
3039
|
+
back_d (int): end delay
|
3040
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
3041
|
+
w (int): window size
|
3042
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
3043
|
+
|
3044
|
+
Returns:
|
3045
|
+
(float) DTW distance
|
3046
|
+
(figure) Heatmap of cost matrix with opitimal path
|
3047
|
+
"""
|
3048
|
+
cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
|
3049
|
+
nodes1,
|
3050
|
+
nodes2,
|
3051
|
+
threshold,
|
3052
|
+
regist,
|
3053
|
+
start_d,
|
3054
|
+
back_d,
|
3055
|
+
fast,
|
3056
|
+
w,
|
3057
|
+
centered_band,
|
3058
|
+
cost_mat_p=True,
|
3059
|
+
)
|
3060
|
+
|
3061
|
+
fig = plt.figure(figsize=(8, 6))
|
3062
|
+
ax = fig.add_subplot(1, 1, 1)
|
3063
|
+
im = ax.imshow(
|
3064
|
+
cost_mat, cmap="viridis", origin="lower", interpolation="nearest"
|
3065
|
+
)
|
3066
|
+
plt.colorbar(im)
|
3067
|
+
ax.set_title("Heatmap of DTW Cost Matrix")
|
3068
|
+
ax.set_xlabel("Tree 1")
|
3069
|
+
ax.set_ylabel("tree 2")
|
3070
|
+
x_path, y_path = zip(*path)
|
3071
|
+
ax.plot(y_path, x_path, color="black")
|
3072
|
+
|
3073
|
+
return cost, fig
|
3074
|
+
|
3075
|
+
@staticmethod
|
3076
|
+
def __plot_2d(
|
3077
|
+
pos_cycle1,
|
3078
|
+
pos_cycle2,
|
3079
|
+
nodes1,
|
3080
|
+
nodes2,
|
3081
|
+
ax,
|
3082
|
+
x_idx,
|
3083
|
+
y_idx,
|
3084
|
+
x_label,
|
3085
|
+
y_label,
|
3086
|
+
):
|
3087
|
+
ax.plot(
|
3088
|
+
pos_cycle1[:, x_idx],
|
3089
|
+
pos_cycle1[:, y_idx],
|
3090
|
+
"-",
|
3091
|
+
label=f"root = {nodes1}",
|
3092
|
+
)
|
3093
|
+
ax.plot(
|
3094
|
+
pos_cycle2[:, x_idx],
|
3095
|
+
pos_cycle2[:, y_idx],
|
3096
|
+
"-",
|
3097
|
+
label=f"root = {nodes2}",
|
3098
|
+
)
|
3099
|
+
ax.set_xlabel(x_label)
|
3100
|
+
ax.set_ylabel(y_label)
|
3101
|
+
|
3102
|
+
def plot_dtw_trajectory(
|
3103
|
+
self,
|
3104
|
+
nodes1: int,
|
3105
|
+
nodes2: int,
|
3106
|
+
threshold: int = 1000,
|
3107
|
+
regist: bool = True,
|
3108
|
+
start_d: int = 0,
|
3109
|
+
back_d: int = 0,
|
3110
|
+
fast: bool = False,
|
3111
|
+
w: int = 0,
|
3112
|
+
centered_band: bool = True,
|
3113
|
+
projection: str = None,
|
3114
|
+
alig: bool = False,
|
3115
|
+
) -> (float, plt.figure):
|
3116
|
+
"""
|
3117
|
+
Plots DTW trajectories aligment between two cell cycles in 2D or 3D
|
3118
|
+
|
3119
|
+
Args:
|
3120
|
+
nodes1 (int): node to compare distance
|
3121
|
+
nodes2 (int): node to compare distance
|
3122
|
+
threshold (int): set a maximum number of points a track can have
|
3123
|
+
regist (boolean): Rotate and translate trajectories
|
3124
|
+
start_d (int): start delay
|
3125
|
+
back_d (int): end delay
|
3126
|
+
w (int): window size
|
3127
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
3128
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
3129
|
+
projection (string): specify which 2D to plot ->
|
3130
|
+
'3d' : for the 3d visualization
|
3131
|
+
'xy' or None (default) : 2D projection of axis x and y
|
3132
|
+
'xz' : 2D projection of axis x and z
|
3133
|
+
'yz' : 2D projection of axis y and z
|
3134
|
+
'pca' : PCA projection
|
3135
|
+
alig (boolean): True to show alignment on plot
|
3136
|
+
|
3137
|
+
Returns:
|
3138
|
+
(float) DTW distance
|
3139
|
+
(figue) Trajectories Plot
|
3140
|
+
"""
|
3141
|
+
(
|
3142
|
+
distance,
|
3143
|
+
alignment,
|
3144
|
+
cost_mat,
|
3145
|
+
pos_cycle1,
|
3146
|
+
pos_cycle2,
|
3147
|
+
) = self.calculate_dtw(
|
3148
|
+
nodes1,
|
3149
|
+
nodes2,
|
3150
|
+
threshold,
|
3151
|
+
regist,
|
3152
|
+
start_d,
|
3153
|
+
back_d,
|
3154
|
+
fast,
|
3155
|
+
w,
|
3156
|
+
centered_band,
|
3157
|
+
cost_mat_p=True,
|
3158
|
+
)
|
3159
|
+
|
3160
|
+
fig = plt.figure(figsize=(10, 6))
|
3161
|
+
|
3162
|
+
if projection == "3d":
|
3163
|
+
ax = fig.add_subplot(1, 1, 1, projection="3d")
|
3164
|
+
else:
|
3165
|
+
ax = fig.add_subplot(1, 1, 1)
|
3166
|
+
|
3167
|
+
if projection == "3d":
|
3168
|
+
ax.plot(
|
3169
|
+
pos_cycle1[:, 0],
|
3170
|
+
pos_cycle1[:, 1],
|
3171
|
+
pos_cycle1[:, 2],
|
3172
|
+
"-",
|
3173
|
+
label=f"root = {nodes1}",
|
3174
|
+
)
|
3175
|
+
ax.plot(
|
3176
|
+
pos_cycle2[:, 0],
|
3177
|
+
pos_cycle2[:, 1],
|
3178
|
+
pos_cycle2[:, 2],
|
3179
|
+
"-",
|
3180
|
+
label=f"root = {nodes2}",
|
3181
|
+
)
|
3182
|
+
ax.set_ylabel("y position")
|
3183
|
+
ax.set_xlabel("x position")
|
3184
|
+
ax.set_zlabel("z position")
|
3185
|
+
else:
|
3186
|
+
if projection == "xy" or projection == "yx" or projection is None:
|
3187
|
+
self.__plot_2d(
|
3188
|
+
pos_cycle1,
|
3189
|
+
pos_cycle2,
|
3190
|
+
nodes1,
|
3191
|
+
nodes2,
|
3192
|
+
ax,
|
3193
|
+
0,
|
3194
|
+
1,
|
3195
|
+
"x position",
|
3196
|
+
"y position",
|
3197
|
+
)
|
3198
|
+
elif projection == "xz" or projection == "zx":
|
3199
|
+
self.__plot_2d(
|
3200
|
+
pos_cycle1,
|
3201
|
+
pos_cycle2,
|
3202
|
+
nodes1,
|
3203
|
+
nodes2,
|
3204
|
+
ax,
|
3205
|
+
0,
|
3206
|
+
2,
|
3207
|
+
"x position",
|
3208
|
+
"z position",
|
3209
|
+
)
|
3210
|
+
elif projection == "yz" or projection == "zy":
|
3211
|
+
self.__plot_2d(
|
3212
|
+
pos_cycle1,
|
3213
|
+
pos_cycle2,
|
3214
|
+
nodes1,
|
3215
|
+
nodes2,
|
3216
|
+
ax,
|
3217
|
+
1,
|
3218
|
+
2,
|
3219
|
+
"y position",
|
3220
|
+
"z position",
|
3221
|
+
)
|
3222
|
+
elif projection == "pca":
|
3223
|
+
try:
|
3224
|
+
from sklearn.decomposition import PCA
|
3225
|
+
except ImportError:
|
3226
|
+
Warning(
|
3227
|
+
"scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
|
3228
|
+
)
|
3229
|
+
|
3230
|
+
# Apply PCA
|
3231
|
+
pca = PCA(n_components=2)
|
3232
|
+
pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
|
3233
|
+
pos_cycle1_2d = pca.transform(pos_cycle1)
|
3234
|
+
pos_cycle2_2d = pca.transform(pos_cycle2)
|
3235
|
+
|
3236
|
+
ax.plot(
|
3237
|
+
pos_cycle1_2d[:, 0],
|
3238
|
+
pos_cycle1_2d[:, 1],
|
3239
|
+
"-",
|
3240
|
+
label=f"root = {nodes1}",
|
3241
|
+
)
|
3242
|
+
ax.plot(
|
3243
|
+
pos_cycle2_2d[:, 0],
|
3244
|
+
pos_cycle2_2d[:, 1],
|
3245
|
+
"-",
|
3246
|
+
label=f"root = {nodes2}",
|
3247
|
+
)
|
3248
|
+
|
3249
|
+
# Set axis labels
|
3250
|
+
axes = ["x", "y", "z"]
|
3251
|
+
x_label = axes[np.argmax(np.abs(pca.components_[0]))]
|
3252
|
+
y_label = axes[np.argmax(np.abs(pca.components_[1]))]
|
3253
|
+
x_percent = 100 * (
|
3254
|
+
np.max(np.abs(pca.components_[0]))
|
3255
|
+
/ np.sum(np.abs(pca.components_[0]))
|
3256
|
+
)
|
3257
|
+
y_percent = 100 * (
|
3258
|
+
np.max(np.abs(pca.components_[1]))
|
3259
|
+
/ np.sum(np.abs(pca.components_[1]))
|
3260
|
+
)
|
3261
|
+
ax.set_xlabel(f"{x_percent:.0f}% of {x_label} position")
|
3262
|
+
ax.set_ylabel(f"{y_percent:.0f}% of {y_label} position")
|
3263
|
+
else:
|
3264
|
+
raise ValueError(
|
3265
|
+
"""Error: available projections are:
|
3266
|
+
'3d' : for the 3d visualization
|
3267
|
+
'xy' or None (default) : 2D projection of axis x and y
|
3268
|
+
'xz' : 2D projection of axis x and z
|
3269
|
+
'yz' : 2D projection of axis y and z
|
3270
|
+
'pca' : PCA projection"""
|
3271
|
+
)
|
3272
|
+
|
3273
|
+
connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
|
3274
|
+
|
3275
|
+
for connection in connections:
|
3276
|
+
xyz1 = connection[0]
|
3277
|
+
xyz2 = connection[1]
|
3278
|
+
x_pos = [xyz1[0], xyz2[0]]
|
3279
|
+
y_pos = [xyz1[1], xyz2[1]]
|
3280
|
+
z_pos = [xyz1[2], xyz2[2]]
|
3281
|
+
|
3282
|
+
if alig and projection != "pca":
|
3283
|
+
if projection == "3d":
|
3284
|
+
ax.plot(x_pos, y_pos, z_pos, "k--", color="grey")
|
3285
|
+
else:
|
3286
|
+
ax.plot(x_pos, y_pos, "k--", color="grey")
|
3287
|
+
|
3288
|
+
ax.set_aspect("equal")
|
3289
|
+
ax.legend()
|
3290
|
+
fig.tight_layout()
|
3291
|
+
|
3292
|
+
if alig and projection == "pca":
|
3293
|
+
warnings.warn(
|
3294
|
+
"Error: not possible to show alignment in PCA projection !",
|
3295
|
+
UserWarning,
|
3296
|
+
)
|
3297
|
+
|
3298
|
+
return distance, fig
|
3299
|
+
|
2237
3300
|
def first_labelling(self):
|
2238
|
-
self.labels =
|
3301
|
+
self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
|
2239
3302
|
|
2240
3303
|
def __init__(
|
2241
3304
|
self,
|
@@ -2268,12 +3331,12 @@ class lineageTree:
|
|
2268
3331
|
'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
|
2269
3332
|
default is 'binary'
|
2270
3333
|
"""
|
3334
|
+
self.name = name
|
2271
3335
|
self.time_nodes = {}
|
2272
3336
|
self.time_edges = {}
|
2273
3337
|
self.max_id = -1
|
2274
3338
|
self.next_id = []
|
2275
3339
|
self.nodes = set()
|
2276
|
-
self.edges = set()
|
2277
3340
|
self.successor = {}
|
2278
3341
|
self.predecessor = {}
|
2279
3342
|
self.pos = {}
|
@@ -2281,40 +3344,57 @@ class lineageTree:
|
|
2281
3344
|
self.time = {}
|
2282
3345
|
self.kdtrees = {}
|
2283
3346
|
self.spatial_density = {}
|
2284
|
-
|
2285
|
-
|
2286
|
-
|
2287
|
-
self.xml_attributes = []
|
2288
|
-
else:
|
2289
|
-
self.xml_attributes = xml_attributes
|
2290
|
-
file_type = file_type.lower()
|
2291
|
-
if file_type == "tgmm":
|
2292
|
-
self.read_tgmm_xml(file_format, tb, te, z_mult)
|
2293
|
-
self.t_b = tb
|
2294
|
-
self.t_e = te
|
2295
|
-
elif file_type == "mamut" or file_type == "trackmate":
|
2296
|
-
self.read_from_mamut_xml(file_format)
|
2297
|
-
elif file_type == "celegans":
|
2298
|
-
self.read_from_txt_for_celegans(file_format)
|
2299
|
-
elif file_type == "celegans_cao":
|
2300
|
-
self.read_from_txt_for_celegans_CAO(
|
2301
|
-
file_format, reorder=reorder, shape=shape, raw_size=raw_size
|
2302
|
-
)
|
2303
|
-
elif file_type == "mastodon":
|
2304
|
-
if isinstance(file_format, list) and len(file_format) == 2:
|
2305
|
-
self.read_from_mastodon_csv(file_format)
|
3347
|
+
if file_type and file_format:
|
3348
|
+
if xml_attributes is None:
|
3349
|
+
self.xml_attributes = []
|
2306
3350
|
else:
|
2307
|
-
|
2308
|
-
|
2309
|
-
|
2310
|
-
|
2311
|
-
|
2312
|
-
|
2313
|
-
|
2314
|
-
|
2315
|
-
|
2316
|
-
|
2317
|
-
|
2318
|
-
|
2319
|
-
|
2320
|
-
|
3351
|
+
self.xml_attributes = xml_attributes
|
3352
|
+
file_type = file_type.lower()
|
3353
|
+
if file_type == "tgmm":
|
3354
|
+
self.read_tgmm_xml(file_format, tb, te, z_mult)
|
3355
|
+
self.t_b = tb
|
3356
|
+
self.t_e = te
|
3357
|
+
elif file_type == "mamut" or file_type == "trackmate":
|
3358
|
+
self.read_from_mamut_xml(file_format)
|
3359
|
+
elif file_type == "celegans":
|
3360
|
+
self.read_from_txt_for_celegans(file_format)
|
3361
|
+
elif file_type == "celegans_cao":
|
3362
|
+
self.read_from_txt_for_celegans_CAO(
|
3363
|
+
file_format,
|
3364
|
+
reorder=reorder,
|
3365
|
+
shape=shape,
|
3366
|
+
raw_size=raw_size,
|
3367
|
+
)
|
3368
|
+
elif file_type == "mastodon":
|
3369
|
+
if isinstance(file_format, list) and len(file_format) == 2:
|
3370
|
+
self.read_from_mastodon_csv(file_format)
|
3371
|
+
else:
|
3372
|
+
if isinstance(file_format, list):
|
3373
|
+
file_format = file_format[0]
|
3374
|
+
self.read_from_mastodon(file_format, name)
|
3375
|
+
elif file_type == "astec":
|
3376
|
+
self.read_from_ASTEC(file_format, eigen)
|
3377
|
+
elif file_type == "csv":
|
3378
|
+
self.read_from_csv(file_format, z_mult, link=1, delim=delim)
|
3379
|
+
elif file_format and file_format.endswith(".lT"):
|
3380
|
+
with open(file_format, "br") as f:
|
3381
|
+
tmp = pkl.load(f)
|
3382
|
+
f.close()
|
3383
|
+
self.__dict__.update(tmp.__dict__)
|
3384
|
+
elif file_format is not None:
|
3385
|
+
self.read_from_binary(file_format)
|
3386
|
+
if self.name is None:
|
3387
|
+
try:
|
3388
|
+
self.name = Path(file_format).stem
|
3389
|
+
except:
|
3390
|
+
self.name = Path(file_format[0]).stem
|
3391
|
+
if [] in self.successor.values():
|
3392
|
+
successors = list(self.successor.keys())
|
3393
|
+
for succ in successors:
|
3394
|
+
if self[succ] == []:
|
3395
|
+
self.successor.pop(succ)
|
3396
|
+
if [] in self.predecessor.values():
|
3397
|
+
predecessors = list(self.predecessor.keys())
|
3398
|
+
for succ in predecessors:
|
3399
|
+
if self[succ] == []:
|
3400
|
+
self.predecessor.pop(succ)
|