LineageTree 1.4.3__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +3 -2
- LineageTree/lineageTree.py +1351 -262
- LineageTree/lineageTreeManager.py +170 -0
- LineageTree/tree_styles.py +305 -0
- LineageTree/utils.py +211 -0
- {LineageTree-1.4.3.dist-info → LineageTree-1.5.0.dist-info}/METADATA +21 -26
- LineageTree-1.5.0.dist-info/RECORD +10 -0
- {LineageTree-1.4.3.dist-info → LineageTree-1.5.0.dist-info}/WHEEL +1 -1
- LineageTree-1.4.3.dist-info/RECORD +0 -7
- {LineageTree-1.4.3.dist-info → LineageTree-1.5.0.dist-info}/LICENSE +0 -0
- {LineageTree-1.4.3.dist-info → LineageTree-1.5.0.dist-info}/top_level.txt +0 -0
LineageTree/lineageTree.py
CHANGED
@@ -2,35 +2,244 @@
|
|
2
2
|
# This file is subject to the terms and conditions defined in
|
3
3
|
# file 'LICENCE', which is part of this source code package.
|
4
4
|
# Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
|
5
|
-
|
6
5
|
import csv
|
7
6
|
import os
|
8
7
|
import pickle as pkl
|
9
8
|
import struct
|
9
|
+
import warnings
|
10
10
|
import xml.etree.ElementTree as ET
|
11
|
+
from collections.abc import Iterable
|
11
12
|
from functools import partial
|
12
13
|
from itertools import combinations
|
13
14
|
from numbers import Number
|
14
|
-
from
|
15
|
-
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import TextIO, Union
|
17
|
+
|
18
|
+
from .tree_styles import tree_style
|
19
|
+
|
20
|
+
try:
|
21
|
+
from edist import uted
|
22
|
+
except ImportError:
|
23
|
+
warnings.warn(
|
24
|
+
"No edist installed therefore you will not be able to compute the tree edit distance."
|
25
|
+
)
|
26
|
+
import matplotlib.pyplot as plt
|
27
|
+
import networkx as nx
|
16
28
|
import numpy as np
|
17
|
-
from scipy.
|
29
|
+
from scipy.interpolate import InterpolatedUnivariateSpline
|
30
|
+
from scipy.spatial import Delaunay, distance
|
18
31
|
from scipy.spatial import cKDTree as KDTree
|
19
32
|
|
33
|
+
from .utils import hierarchy_pos, postions_of_nx
|
34
|
+
|
20
35
|
|
21
36
|
class lineageTree:
|
37
|
+
def __eq__(self, other):
|
38
|
+
if isinstance(other, lineageTree):
|
39
|
+
return other.successor == self.successor
|
40
|
+
return False
|
41
|
+
|
22
42
|
def get_next_id(self):
|
23
43
|
"""Computes the next authorized id.
|
24
44
|
|
25
45
|
Returns:
|
26
46
|
int: next authorized id
|
27
47
|
"""
|
48
|
+
if self.max_id == -1 and self.nodes:
|
49
|
+
self.max_id = max(self.nodes)
|
28
50
|
if self.next_id == []:
|
29
51
|
self.max_id += 1
|
30
52
|
return self.max_id
|
31
53
|
else:
|
32
54
|
return self.next_id.pop()
|
33
55
|
|
56
|
+
def complete_lineage(self, nodes: Union[int, set] = None):
|
57
|
+
"""Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
|
58
|
+
for tree edit distance algorithms.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
|
62
|
+
"""
|
63
|
+
if nodes is None:
|
64
|
+
nodes = set(self.roots)
|
65
|
+
elif isinstance(nodes, int):
|
66
|
+
nodes = {nodes}
|
67
|
+
for node in nodes:
|
68
|
+
sub = set(self.get_sub_tree(node))
|
69
|
+
specific_leaves = sub.intersection(self.leaves)
|
70
|
+
for leaf in specific_leaves:
|
71
|
+
self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
|
72
|
+
|
73
|
+
###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
|
74
|
+
def add_branch(
|
75
|
+
self,
|
76
|
+
pred: int,
|
77
|
+
length: int,
|
78
|
+
move_timepoints: bool = True,
|
79
|
+
pos: Union[callable, None] = None,
|
80
|
+
reverse: bool = False,
|
81
|
+
):
|
82
|
+
"""Adds a branch of specific length to a node either as a successor or as a predecessor.
|
83
|
+
If it is placed on top of a tree all the nodes will move timepoints #length down.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
pred (int): Id of the successor (predecessor if reverse is False)
|
87
|
+
length (int): The length of the new branch.
|
88
|
+
pos (np.ndarray, optional): The new position of the branch. Defaults to None.
|
89
|
+
move_timepoints (bool): Moves the ti Important only if reverse= True
|
90
|
+
reverese (bool): If reverse will add a successor branch instead of a predecessor branch
|
91
|
+
Returns:
|
92
|
+
(int): Id of the first node of the sublineage.
|
93
|
+
"""
|
94
|
+
if length == 0:
|
95
|
+
return pred
|
96
|
+
if self.predecessor.get(pred) and not reverse:
|
97
|
+
raise Warning("Cannot add 2 predecessors to a node")
|
98
|
+
time = self.time[pred]
|
99
|
+
original = pred
|
100
|
+
if not reverse:
|
101
|
+
if move_timepoints:
|
102
|
+
nodes_to_move = set(self.get_sub_tree(pred))
|
103
|
+
new_times = {
|
104
|
+
node: self.time[node] + length for node in nodes_to_move
|
105
|
+
}
|
106
|
+
for node in nodes_to_move:
|
107
|
+
old_time = self.time[node]
|
108
|
+
self.time_nodes[old_time].remove(node)
|
109
|
+
self.time_nodes.setdefault(old_time + length, set()).add(
|
110
|
+
node
|
111
|
+
)
|
112
|
+
self.time.update(new_times)
|
113
|
+
for t in range(length - 1, -1, -1):
|
114
|
+
_next = self.add_node(
|
115
|
+
time + t,
|
116
|
+
succ=pred,
|
117
|
+
pos=self.pos[original],
|
118
|
+
reverse=True,
|
119
|
+
)
|
120
|
+
pred = _next
|
121
|
+
else:
|
122
|
+
for t in range(length):
|
123
|
+
_next = self.add_node(
|
124
|
+
time - t,
|
125
|
+
succ=pred,
|
126
|
+
pos=self.pos[original],
|
127
|
+
reverse=True,
|
128
|
+
)
|
129
|
+
pred = _next
|
130
|
+
else:
|
131
|
+
for t in range(length):
|
132
|
+
_next = self.add_node(
|
133
|
+
time + t, succ=pred, pos=self.pos[original], reverse=False
|
134
|
+
)
|
135
|
+
pred = _next
|
136
|
+
self.labels[pred] = "New branch"
|
137
|
+
if self.time[pred] == self.t_b:
|
138
|
+
self.roots.add(pred)
|
139
|
+
self.labels[pred] = "New branch"
|
140
|
+
if original in self.roots and reverse is True:
|
141
|
+
self.roots.add(pred)
|
142
|
+
self.labels[pred] = "New branch"
|
143
|
+
self.roots.remove(original)
|
144
|
+
self.labels.pop(original, -1)
|
145
|
+
self.t_e = max(self.time_nodes)
|
146
|
+
return pred
|
147
|
+
|
148
|
+
def cut_tree(self, root):
|
149
|
+
"""It transforms a lineage that has at least 2 divisions into 2 independent lineages,
|
150
|
+
that spawn from the time point of the first node. (splits a tree into 2)
|
151
|
+
|
152
|
+
Args:
|
153
|
+
root (int): The id of the node, which will be cut.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
int: The id of the new tree
|
157
|
+
"""
|
158
|
+
cycle = self.get_successors(root)
|
159
|
+
last_cell = cycle[-1]
|
160
|
+
if last_cell in self.successor:
|
161
|
+
new_lT = self.successor[last_cell].pop()
|
162
|
+
self.predecessor.pop(new_lT)
|
163
|
+
label_of_root = self.labels.get(cycle[0], cycle[0])
|
164
|
+
self.labels[cycle[0]] = f"L-Split {label_of_root}"
|
165
|
+
new_tr = self.add_branch(
|
166
|
+
new_lT, len(cycle) + 1, move_timepoints=False
|
167
|
+
)
|
168
|
+
self.roots.add(new_tr)
|
169
|
+
self.labels[new_tr] = f"R-Split {label_of_root}"
|
170
|
+
return new_tr
|
171
|
+
else:
|
172
|
+
raise Warning("No division of the branch")
|
173
|
+
|
174
|
+
def fuse_lineage_tree(
|
175
|
+
self,
|
176
|
+
l1_root: int,
|
177
|
+
l2_root: int,
|
178
|
+
length_l1: int = 0,
|
179
|
+
length_l2: int = 0,
|
180
|
+
length: int = 1,
|
181
|
+
):
|
182
|
+
"""Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
|
183
|
+
first node and the node of the resulting lineage can also be longer.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
l1_root (int): Id of the first root
|
187
|
+
l2_root (int): Id of the second root
|
188
|
+
length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
|
189
|
+
length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
|
190
|
+
length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
int: The id of the root of the new lineage.
|
194
|
+
"""
|
195
|
+
if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
|
196
|
+
raise ValueError("Please select 2 roots.")
|
197
|
+
if self.time[l1_root] != self.time[l2_root]:
|
198
|
+
warnings.warn(
|
199
|
+
"Using lineagetrees that do not exist in the same timepoint. The operation will continue"
|
200
|
+
)
|
201
|
+
new_root1 = self.add_branch(l1_root, length_l1)
|
202
|
+
new_root2 = self.add_branch(l2_root, length_l2)
|
203
|
+
next_root1 = self[new_root1][0]
|
204
|
+
self.remove_nodes(new_root1)
|
205
|
+
self.successor[new_root2].append(next_root1)
|
206
|
+
self.predecessor[next_root1] = [new_root2]
|
207
|
+
new_branch = self.add_branch(new_root2, length)
|
208
|
+
self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
|
209
|
+
return new_branch
|
210
|
+
|
211
|
+
def copy_lineage(self, root):
|
212
|
+
"""
|
213
|
+
Copies the structure of a tree and makes a new with new nodes.
|
214
|
+
Warning does not take into account the predecessor of the root node.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
root (int): The root of the tree to be copied
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
int: The root of the new tree.
|
221
|
+
"""
|
222
|
+
new_nodes = {
|
223
|
+
old_node: self.get_next_id()
|
224
|
+
for old_node in self.get_sub_tree(root)
|
225
|
+
}
|
226
|
+
self.nodes.update(new_nodes.values())
|
227
|
+
for old_node, new_node in new_nodes.items():
|
228
|
+
self.time[new_node] = self.time[old_node]
|
229
|
+
succ = self.successor.get(old_node)
|
230
|
+
if succ:
|
231
|
+
self.successor[new_node] = [new_nodes[n] for n in succ]
|
232
|
+
pred = self.predecessor.get(old_node)
|
233
|
+
if pred:
|
234
|
+
self.predecessor[new_node] = [new_nodes[n] for n in pred]
|
235
|
+
self.pos[new_node] = self.pos[old_node] + 0.5
|
236
|
+
self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
|
237
|
+
new_root = new_nodes[root]
|
238
|
+
self.labels[new_root] = f"Copy of {root}"
|
239
|
+
if self.time[new_root] == 0:
|
240
|
+
self.roots.add(new_root)
|
241
|
+
return new_root
|
242
|
+
|
34
243
|
def add_node(
|
35
244
|
self,
|
36
245
|
t: int = None,
|
@@ -55,109 +264,126 @@ class lineageTree:
|
|
55
264
|
int: id of the new node.
|
56
265
|
"""
|
57
266
|
C_next = self.get_next_id() if nid is None else nid
|
58
|
-
self.time_nodes.setdefault(t,
|
267
|
+
self.time_nodes.setdefault(t, set()).add(C_next)
|
59
268
|
if succ is not None and not reverse:
|
60
269
|
self.successor.setdefault(succ, []).append(C_next)
|
61
270
|
self.predecessor.setdefault(C_next, []).append(succ)
|
62
|
-
self.edges.add((succ, C_next))
|
63
271
|
elif succ is not None:
|
64
272
|
self.predecessor.setdefault(succ, []).append(C_next)
|
65
273
|
self.successor.setdefault(C_next, []).append(succ)
|
66
|
-
self.edges.add((C_next, succ))
|
67
274
|
self.nodes.add(C_next)
|
68
275
|
self.pos[C_next] = pos
|
69
|
-
self.progeny[C_next] = 0
|
70
276
|
self.time[C_next] = t
|
71
277
|
return C_next
|
72
278
|
|
73
|
-
def
|
74
|
-
|
75
|
-
times = {self.time[n] for n in track}
|
76
|
-
for t in times:
|
77
|
-
self.time_nodes[t] = list(
|
78
|
-
set(self.time_nodes[t]).difference(track)
|
79
|
-
)
|
80
|
-
for i, c in enumerate(track):
|
81
|
-
self.pos.pop(c)
|
82
|
-
if i != 0:
|
83
|
-
self.predecessor.pop(c)
|
84
|
-
if i < len(track) - 1:
|
85
|
-
self.successor.pop(c)
|
86
|
-
self.time.pop(c)
|
87
|
-
|
88
|
-
def remove_node(self, c: int) -> tuple:
|
89
|
-
"""Removes a node and update the lineageTree accordingly
|
279
|
+
def remove_nodes(self, group: Union[int, set, list]):
|
280
|
+
"""Removes a group of nodes from the LineageTree
|
90
281
|
|
91
282
|
Args:
|
92
|
-
|
93
|
-
"""
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
def
|
120
|
-
"""
|
121
|
-
|
283
|
+
group (set|list|int): One or more nodes that are to be removed.
|
284
|
+
"""
|
285
|
+
if isinstance(group, int):
|
286
|
+
group = {group}
|
287
|
+
if isinstance(group, list):
|
288
|
+
group = set(group)
|
289
|
+
group = group.intersection(self.nodes)
|
290
|
+
self.nodes.difference_update(group)
|
291
|
+
times = {self.time.pop(n) for n in group}
|
292
|
+
for t in times:
|
293
|
+
self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
|
294
|
+
for node in group:
|
295
|
+
self.pos.pop(node)
|
296
|
+
if self.predecessor.get(node):
|
297
|
+
pred = self.predecessor[node][0]
|
298
|
+
siblings = self.successor.pop(pred, [])
|
299
|
+
if len(siblings) == 2:
|
300
|
+
siblings.remove(node)
|
301
|
+
self.successor[pred] = siblings
|
302
|
+
self.predecessor.pop(node, [])
|
303
|
+
for succ in self.successor.get(node, []):
|
304
|
+
self.predecessor.pop(succ, [])
|
305
|
+
self.successor.pop(node, [])
|
306
|
+
self.labels.pop(node, 0)
|
307
|
+
if node in self.roots:
|
308
|
+
self.roots.remove(node)
|
309
|
+
|
310
|
+
def modify_branch(self, node, new_length):
|
311
|
+
"""Changes the length of a branch, so it adds or removes nodes
|
312
|
+
to make the correct length of the cycle.
|
122
313
|
|
123
314
|
Args:
|
124
|
-
|
125
|
-
|
315
|
+
node (int): Any node of the branch to be modified/
|
316
|
+
new_length (int): The new length of the tree.
|
126
317
|
"""
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
self.
|
147
|
-
|
148
|
-
|
149
|
-
|
318
|
+
if new_length <= 1:
|
319
|
+
warnings.warn("New length should be more than 1")
|
320
|
+
return None
|
321
|
+
cycle = self.get_cycle(node)
|
322
|
+
length = len(cycle)
|
323
|
+
successors = self.successor.get(cycle[-1])
|
324
|
+
if length == 1 and new_length != 1:
|
325
|
+
pred = self.predecessor.pop(node, None)
|
326
|
+
new_node = self.add_branch(
|
327
|
+
node, length=new_length, move_timepoints=True, reverse=False
|
328
|
+
)
|
329
|
+
if pred:
|
330
|
+
self.successor[pred[0]].remove(node)
|
331
|
+
self.successor[pred[0]].append(new_node)
|
332
|
+
elif self.leaves.intersection(cycle) and new_length < length:
|
333
|
+
self.remove_nodes(cycle[new_length:])
|
334
|
+
elif new_length < length:
|
335
|
+
to_remove = length - new_length
|
336
|
+
last_cell = cycle[new_length - 1]
|
337
|
+
subtree = self.get_sub_tree(cycle[-1])[1:]
|
338
|
+
self.remove_nodes(cycle[new_length:])
|
339
|
+
self.successor[last_cell] = successors
|
340
|
+
if successors:
|
341
|
+
for succ in successors:
|
342
|
+
self.predecessor[succ] = [last_cell]
|
343
|
+
for node in subtree:
|
344
|
+
if node not in cycle[new_length - 1 :]:
|
345
|
+
old_time = self.time[node]
|
346
|
+
self.time[node] = old_time - to_remove
|
347
|
+
self.time_nodes[old_time].remove(node)
|
348
|
+
self.time_nodes.setdefault(
|
349
|
+
old_time - to_remove, set()
|
350
|
+
).add(node)
|
351
|
+
elif length < new_length:
|
352
|
+
to_add = new_length - length
|
353
|
+
last_cell = cycle[-1]
|
354
|
+
self.successor.pop(cycle[-2])
|
355
|
+
self.predecessor.pop(last_cell)
|
356
|
+
succ = self.add_branch(
|
357
|
+
last_cell, length=to_add, move_timepoints=True, reverse=False
|
358
|
+
)
|
359
|
+
self.predecessor[succ] = [cycle[-2]]
|
360
|
+
self.successor[cycle[-2]] = [succ]
|
361
|
+
self.time[last_cell] = (
|
362
|
+
self.time[self.predecessor[last_cell][0]] + 1
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
return None
|
150
366
|
|
151
367
|
@property
|
152
368
|
def roots(self):
|
153
369
|
if not hasattr(self, "_roots"):
|
154
|
-
self._roots = set(self.
|
370
|
+
self._roots = set(self.nodes).difference(self.predecessor)
|
155
371
|
return self._roots
|
156
372
|
|
373
|
+
@property
|
374
|
+
def edges(self):
|
375
|
+
return {(k, vi) for k, v in self.successor.items() for vi in v}
|
376
|
+
|
157
377
|
@property
|
158
378
|
def leaves(self):
|
159
379
|
return set(self.predecessor).difference(self.successor)
|
160
380
|
|
381
|
+
@property
|
382
|
+
def labels(self):
|
383
|
+
if not hasattr(self, "_labels"):
|
384
|
+
self._labels = {i: "Unlabeled" for i in self.roots}
|
385
|
+
return self._labels
|
386
|
+
|
161
387
|
def _write_header_am(self, f: TextIO, nb_points: int, length: int):
|
162
388
|
"""Header for Amira .am files"""
|
163
389
|
f.write("# AmiraMesh 3D ASCII 2.0\n")
|
@@ -470,7 +696,7 @@ class lineageTree:
|
|
470
696
|
stroke=svgwrite.rgb(0, 0, 0),
|
471
697
|
)
|
472
698
|
)
|
473
|
-
for si in self
|
699
|
+
for si in self[c_cycle[-1]]:
|
474
700
|
x3, y3 = positions[si]
|
475
701
|
dwg.add(
|
476
702
|
dwg.line(
|
@@ -483,7 +709,7 @@ class lineageTree:
|
|
483
709
|
else:
|
484
710
|
for c in treated_cells:
|
485
711
|
x1, y1 = positions[c]
|
486
|
-
for si in self
|
712
|
+
for si in self[c]:
|
487
713
|
x2, y2 = positions[si]
|
488
714
|
if draw_edges:
|
489
715
|
dwg.add(
|
@@ -535,7 +761,7 @@ class lineageTree:
|
|
535
761
|
start_time = times_to_consider[0]
|
536
762
|
for t in times_to_consider:
|
537
763
|
for id_mother in self.time_nodes[t]:
|
538
|
-
ids_daughters = self
|
764
|
+
ids_daughters = self[id_mother]
|
539
765
|
new_ids_daughters = ids_daughters.copy()
|
540
766
|
for _ in range(sampling - 1):
|
541
767
|
tmp = []
|
@@ -659,12 +885,12 @@ class lineageTree:
|
|
659
885
|
edges_to_use += list(s_edges)
|
660
886
|
else:
|
661
887
|
edges_to_use = []
|
888
|
+
nodes_to_use = set(nodes_to_use)
|
662
889
|
if temporal:
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
]
|
890
|
+
for n in nodes_to_use:
|
891
|
+
for d in self.successor.get(n, []):
|
892
|
+
if d in nodes_to_use:
|
893
|
+
edges_to_use.append((n, d))
|
668
894
|
if spatial:
|
669
895
|
edges_to_use += [
|
670
896
|
e for e in s_edges if t_min < self.time[e[0]] < t_max
|
@@ -787,7 +1013,6 @@ class lineageTree:
|
|
787
1013
|
self.time_edges = {}
|
788
1014
|
unique_id = 0
|
789
1015
|
self.nodes = set()
|
790
|
-
self.edges = set()
|
791
1016
|
self.successor = {}
|
792
1017
|
self.predecessor = {}
|
793
1018
|
self.pos = {}
|
@@ -826,7 +1051,6 @@ class lineageTree:
|
|
826
1051
|
M = corres[pred]
|
827
1052
|
self.predecessor[C] = [M]
|
828
1053
|
self.successor.setdefault(M, []).append(C)
|
829
|
-
self.edges.add((M, C))
|
830
1054
|
self.time_edges.setdefault(t, set()).add((M, C))
|
831
1055
|
self.lin.setdefault(lin_id, []).append(C)
|
832
1056
|
self.C_lin[C] = lin_id
|
@@ -924,7 +1148,12 @@ class lineageTree:
|
|
924
1148
|
|
925
1149
|
# make sure these are all named liked they are in tmp_data (or change dictionary above)
|
926
1150
|
self.name = {}
|
927
|
-
|
1151
|
+
if "cell_volume" in tmp_data:
|
1152
|
+
self.volume = {}
|
1153
|
+
if "cell_fate" in tmp_data:
|
1154
|
+
self.fates = {}
|
1155
|
+
if "cell_barycenter" in tmp_data:
|
1156
|
+
self.pos = {}
|
928
1157
|
self.lT2pkl = {}
|
929
1158
|
self.pkl2lT = {}
|
930
1159
|
self.contact = {}
|
@@ -955,8 +1184,11 @@ class lineageTree:
|
|
955
1184
|
if "cell_volume" in tmp_data:
|
956
1185
|
self.volume[unique_id] = tmp_data["cell_volume"].get(n, 0.0)
|
957
1186
|
if "cell_fate" in tmp_data:
|
958
|
-
self.fates = {}
|
959
1187
|
self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
|
1188
|
+
if "cell_barycenter" in tmp_data:
|
1189
|
+
self.pos[unique_id] = tmp_data["cell_barycenter"].get(
|
1190
|
+
n, np.zeros(3)
|
1191
|
+
)
|
960
1192
|
|
961
1193
|
unique_id += 1
|
962
1194
|
if do_surf:
|
@@ -975,9 +1207,7 @@ class lineageTree:
|
|
975
1207
|
self.successor[new_id] = [
|
976
1208
|
self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
|
977
1209
|
]
|
978
|
-
|
979
|
-
[(new_id, ni) for ni in self.successor[new_id]]
|
980
|
-
)
|
1210
|
+
|
981
1211
|
for ni in self.successor[new_id]:
|
982
1212
|
self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
|
983
1213
|
|
@@ -986,30 +1216,43 @@ class lineageTree:
|
|
986
1216
|
self.max_id = unique_id
|
987
1217
|
|
988
1218
|
# do this in the end of the process, skip lineage tree and whatever is stored already
|
1219
|
+
discard = {
|
1220
|
+
"cell_volume",
|
1221
|
+
"cell_fate",
|
1222
|
+
"cell_barycenter",
|
1223
|
+
"cell_contact_surface",
|
1224
|
+
"cell_lineage",
|
1225
|
+
"all_cells",
|
1226
|
+
"cell_history",
|
1227
|
+
"problematic_cells",
|
1228
|
+
"cell_labels_in_time",
|
1229
|
+
}
|
1230
|
+
self.specific_properties = []
|
989
1231
|
for prop_name, prop_values in tmp_data.items():
|
990
|
-
if hasattr(self, prop_name):
|
991
|
-
continue
|
992
|
-
else:
|
1232
|
+
if not (prop_name in discard or hasattr(self, prop_name)):
|
993
1233
|
if isinstance(prop_values, dict):
|
994
1234
|
dictionary = {
|
995
|
-
self.pkl2lT
|
1235
|
+
self.pkl2lT.get(k, -1): v
|
1236
|
+
for k, v in prop_values.items()
|
996
1237
|
}
|
997
1238
|
# is it a regular dictionary or a dictionary with dictionaries inside?
|
998
1239
|
for key, value in dictionary.items():
|
999
1240
|
if isinstance(value, dict):
|
1000
1241
|
# rename all ids from old to new
|
1001
1242
|
dictionary[key] = {
|
1002
|
-
self.pkl2lT
|
1243
|
+
self.pkl2lT.get(k, -1): v
|
1244
|
+
for k, v in value.items()
|
1003
1245
|
}
|
1004
1246
|
self.__dict__[prop_name] = dictionary
|
1247
|
+
self.specific_properties.append(prop_name)
|
1005
1248
|
# is any of this necessary? Or does it mean it anyways does not contain
|
1006
1249
|
# information about the id and a simple else: is enough?
|
1007
1250
|
elif (
|
1008
|
-
|
1009
|
-
|
1010
|
-
or prop_values.isinstance(np.array)
|
1251
|
+
isinstance(prop_values, (list, set, np.ndarray))
|
1252
|
+
and prop_name not in []
|
1011
1253
|
):
|
1012
1254
|
self.__dict__[prop_name] = prop_values
|
1255
|
+
self.specific_properties.append(prop_name)
|
1013
1256
|
|
1014
1257
|
# what else could it be?
|
1015
1258
|
|
@@ -1121,7 +1364,6 @@ class lineageTree:
|
|
1121
1364
|
p = None
|
1122
1365
|
self.predecessor.setdefault(c, []).append(p)
|
1123
1366
|
self.successor.setdefault(p, []).append(c)
|
1124
|
-
self.edges.add((p, c))
|
1125
1367
|
self.time_edges.setdefault(t - 1, set()).add((p, c))
|
1126
1368
|
self.max_id = unique_id
|
1127
1369
|
|
@@ -1209,7 +1451,6 @@ class lineageTree:
|
|
1209
1451
|
p = None
|
1210
1452
|
self.predecessor.setdefault(c, []).append(p)
|
1211
1453
|
self.successor.setdefault(p, []).append(c)
|
1212
|
-
self.edges.add((p, c))
|
1213
1454
|
self.time_edges.setdefault(t - 1, set()).add((p, c))
|
1214
1455
|
self.max_id = unique_id
|
1215
1456
|
|
@@ -1233,7 +1474,6 @@ class lineageTree:
|
|
1233
1474
|
self.time_edges = {}
|
1234
1475
|
unique_id = 0
|
1235
1476
|
self.nodes = set()
|
1236
|
-
self.edges = set()
|
1237
1477
|
self.successor = {}
|
1238
1478
|
self.predecessor = {}
|
1239
1479
|
self.pos = {}
|
@@ -1293,7 +1533,6 @@ class lineageTree:
|
|
1293
1533
|
M = self.time_id[(t - 1, M_id)]
|
1294
1534
|
self.successor.setdefault(M, []).append(C)
|
1295
1535
|
self.predecessor.setdefault(C, []).append(M)
|
1296
|
-
self.edges.add((M, C))
|
1297
1536
|
self.time_edges[t].add((M, C))
|
1298
1537
|
else:
|
1299
1538
|
if M_id != -1:
|
@@ -1330,7 +1569,6 @@ class lineageTree:
|
|
1330
1569
|
|
1331
1570
|
mr = MastodonReader(path)
|
1332
1571
|
spots, links = mr.read_tables()
|
1333
|
-
mr.read_tags(spots, links)
|
1334
1572
|
|
1335
1573
|
self.node_name = {}
|
1336
1574
|
|
@@ -1350,7 +1588,6 @@ class lineageTree:
|
|
1350
1588
|
target = e.target_idx
|
1351
1589
|
self.predecessor.setdefault(target, []).append(source)
|
1352
1590
|
self.successor.setdefault(source, []).append(target)
|
1353
|
-
self.edges.add((source, target))
|
1354
1591
|
self.time_edges.setdefault(self.time[source], set()).add(
|
1355
1592
|
(source, target)
|
1356
1593
|
)
|
@@ -1385,14 +1622,13 @@ class lineageTree:
|
|
1385
1622
|
self.nodes.add(unique_id)
|
1386
1623
|
self.time[unique_id] = t
|
1387
1624
|
self.node_name[unique_id] = spot[1]
|
1388
|
-
self.pos[unique_id] = np.array([x, y, z])
|
1625
|
+
self.pos[unique_id] = np.array([x, y, z], dtype=float)
|
1389
1626
|
|
1390
1627
|
for link in links:
|
1391
1628
|
source = int(float(link[4]))
|
1392
1629
|
target = int(float(link[5]))
|
1393
1630
|
self.predecessor.setdefault(target, []).append(source)
|
1394
1631
|
self.successor.setdefault(source, []).append(target)
|
1395
|
-
self.edges.add((source, target))
|
1396
1632
|
self.time_edges.setdefault(self.time[source], set()).add(
|
1397
1633
|
(source, target)
|
1398
1634
|
)
|
@@ -1447,23 +1683,24 @@ class lineageTree:
|
|
1447
1683
|
if attr in self.xml_attributes:
|
1448
1684
|
self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
|
1449
1685
|
|
1450
|
-
self.edges = set()
|
1451
1686
|
tracks = {}
|
1452
1687
|
self.successor = {}
|
1453
1688
|
self.predecessor = {}
|
1454
1689
|
self.track_name = {}
|
1455
1690
|
for track in AllTracks:
|
1456
1691
|
if "TRACK_DURATION" in track.attrib:
|
1457
|
-
t_id, _ =
|
1458
|
-
track.attrib["
|
1692
|
+
t_id, _ = (
|
1693
|
+
int(track.attrib["TRACK_ID"]),
|
1694
|
+
float(track.attrib["TRACK_DURATION"]),
|
1459
1695
|
)
|
1460
1696
|
else:
|
1461
1697
|
t_id = int(track.attrib["TRACK_ID"])
|
1462
1698
|
t_name = track.attrib["name"]
|
1463
1699
|
tracks[t_id] = []
|
1464
1700
|
for edge in track:
|
1465
|
-
s, t =
|
1466
|
-
edge.attrib["
|
1701
|
+
s, t = (
|
1702
|
+
int(edge.attrib["SPOT_SOURCE_ID"]),
|
1703
|
+
int(edge.attrib["SPOT_TARGET_ID"]),
|
1467
1704
|
)
|
1468
1705
|
if s in self.nodes and t in self.nodes:
|
1469
1706
|
if self.time[s] > self.time[t]:
|
@@ -1473,7 +1710,6 @@ class lineageTree:
|
|
1473
1710
|
self.track_name[s] = t_name
|
1474
1711
|
self.track_name[t] = t_name
|
1475
1712
|
tracks[t_id].append((s, t))
|
1476
|
-
self.edges.add((s, t))
|
1477
1713
|
self.t_b = min(self.time_nodes.keys())
|
1478
1714
|
self.t_e = max(self.time_nodes.keys())
|
1479
1715
|
|
@@ -1511,7 +1747,7 @@ class lineageTree:
|
|
1511
1747
|
curr_c = to_treat.pop()
|
1512
1748
|
number_sequence.append(curr_c)
|
1513
1749
|
pos_sequence += list(self.pos[curr_c])
|
1514
|
-
if self
|
1750
|
+
if self[curr_c] == []:
|
1515
1751
|
number_sequence.append(-1)
|
1516
1752
|
elif len(self.successor[curr_c]) == 1:
|
1517
1753
|
to_treat += self.successor[curr_c]
|
@@ -1673,7 +1909,6 @@ class lineageTree:
|
|
1673
1909
|
self.time_edges = time_edges
|
1674
1910
|
self.pos = pos
|
1675
1911
|
self.nodes = set(nodes)
|
1676
|
-
self.edges = set(edges)
|
1677
1912
|
self.t_b = min(time_nodes.keys())
|
1678
1913
|
self.t_e = max(time_nodes.keys())
|
1679
1914
|
self.is_root = is_root
|
@@ -1693,7 +1928,7 @@ class lineageTree:
|
|
1693
1928
|
f.close()
|
1694
1929
|
|
1695
1930
|
@classmethod
|
1696
|
-
def load(clf, fname: str):
|
1931
|
+
def load(clf, fname: str, rm_empty_lists=True):
|
1697
1932
|
"""
|
1698
1933
|
Loading a lineage tree from a ".lT" file.
|
1699
1934
|
|
@@ -1706,6 +1941,18 @@ class lineageTree:
|
|
1706
1941
|
with open(fname, "br") as f:
|
1707
1942
|
lT = pkl.load(f)
|
1708
1943
|
f.close()
|
1944
|
+
if rm_empty_lists:
|
1945
|
+
if [] in lT.successor.values():
|
1946
|
+
for node, succ in lT.successor.items():
|
1947
|
+
if succ == []:
|
1948
|
+
lT.successor.pop(node)
|
1949
|
+
if [] in lT.predecessor.values():
|
1950
|
+
for node, succ in lT.predecessor.items():
|
1951
|
+
if succ == []:
|
1952
|
+
lT.predecessor.pop(node)
|
1953
|
+
lT.t_e = max(lT.time_nodes)
|
1954
|
+
lT.t_b = min(lT.time_nodes)
|
1955
|
+
warnings.warn("Empty lists have been removed")
|
1709
1956
|
return lT
|
1710
1957
|
|
1711
1958
|
def get_idx3d(self, t: int) -> tuple:
|
@@ -1787,7 +2034,9 @@ class lineageTree:
|
|
1787
2034
|
|
1788
2035
|
return self.Gabriel_graph[t]
|
1789
2036
|
|
1790
|
-
def get_predecessors(
|
2037
|
+
def get_predecessors(
|
2038
|
+
self, x: int, depth: int = None, start_time: int = None, end_time=None
|
2039
|
+
) -> list:
|
1791
2040
|
"""Computes the predecessors of the node `x` up to
|
1792
2041
|
`depth` predecessors or the begining of the life of `x`.
|
1793
2042
|
The ordered list of ids is returned.
|
@@ -1798,20 +2047,34 @@ class lineageTree:
|
|
1798
2047
|
Returns:
|
1799
2048
|
[int, ]: list of ids, the last id is `x`
|
1800
2049
|
"""
|
1801
|
-
|
2050
|
+
if not start_time:
|
2051
|
+
start_time = self.t_b
|
2052
|
+
if not end_time:
|
2053
|
+
end_time = self.t_e
|
2054
|
+
unconstrained_cycle = [x]
|
2055
|
+
cycle = [x] if start_time <= self.time[x] <= end_time else []
|
1802
2056
|
acc = 0
|
1803
2057
|
while (
|
1804
|
-
len(
|
1805
|
-
self.successor.get(self.predecessor.get(cycle[0], [-1])[0], [])
|
1806
|
-
)
|
2058
|
+
len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
|
1807
2059
|
== 1
|
1808
2060
|
and acc != depth
|
2061
|
+
and start_time
|
2062
|
+
<= self.time.get(
|
2063
|
+
self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
|
2064
|
+
)
|
1809
2065
|
):
|
1810
|
-
|
2066
|
+
unconstrained_cycle.insert(
|
2067
|
+
0, self.predecessor[unconstrained_cycle[0]][0]
|
2068
|
+
)
|
1811
2069
|
acc += 1
|
2070
|
+
if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
|
2071
|
+
cycle.insert(0, unconstrained_cycle[0])
|
2072
|
+
|
1812
2073
|
return cycle
|
1813
2074
|
|
1814
|
-
def get_successors(
|
2075
|
+
def get_successors(
|
2076
|
+
self, x: int, depth: int = None, end_time: int = None
|
2077
|
+
) -> list:
|
1815
2078
|
"""Computes the successors of the node `x` up to
|
1816
2079
|
`depth` successors or the end of the life of `x`.
|
1817
2080
|
The ordered list of ids is returned.
|
@@ -1822,11 +2085,18 @@ class lineageTree:
|
|
1822
2085
|
Returns:
|
1823
2086
|
[int, ]: list of ids, the first id is `x`
|
1824
2087
|
"""
|
2088
|
+
if end_time is None:
|
2089
|
+
end_time = self.t_e
|
1825
2090
|
cycle = [x]
|
1826
2091
|
acc = 0
|
1827
|
-
while
|
2092
|
+
while (
|
2093
|
+
len(self[cycle[-1]]) == 1
|
2094
|
+
and acc != depth
|
2095
|
+
and self.time[cycle[-1]] < end_time
|
2096
|
+
):
|
1828
2097
|
cycle += self.successor[cycle[-1]]
|
1829
2098
|
acc += 1
|
2099
|
+
|
1830
2100
|
return cycle
|
1831
2101
|
|
1832
2102
|
def get_cycle(
|
@@ -1835,12 +2105,14 @@ class lineageTree:
|
|
1835
2105
|
depth: int = None,
|
1836
2106
|
depth_pred: int = None,
|
1837
2107
|
depth_succ: int = None,
|
2108
|
+
end_time: int = None,
|
1838
2109
|
) -> list:
|
1839
2110
|
"""Computes the predecessors and successors of the node `x` up to
|
1840
2111
|
`depth_pred` predecessors plus `depth_succ` successors.
|
1841
2112
|
If the value `depth` is provided and not None,
|
1842
2113
|
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1843
2114
|
The ordered list of ids is returned.
|
2115
|
+
If all `depth` are None, the full cycle is returned.
|
1844
2116
|
|
1845
2117
|
Args:
|
1846
2118
|
x (int): id of the node to compute
|
@@ -1850,11 +2122,13 @@ class lineageTree:
|
|
1850
2122
|
Returns:
|
1851
2123
|
[int, ]: list of ids
|
1852
2124
|
"""
|
2125
|
+
if end_time is None:
|
2126
|
+
end_time = self.t_e
|
1853
2127
|
if depth is not None:
|
1854
2128
|
depth_pred = depth_succ = depth
|
1855
|
-
return self.get_predecessors(x, depth_pred)[
|
1856
|
-
|
1857
|
-
)
|
2129
|
+
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
2130
|
+
:-1
|
2131
|
+
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1858
2132
|
|
1859
2133
|
@property
|
1860
2134
|
def all_tracks(self):
|
@@ -1862,6 +2136,29 @@ class lineageTree:
|
|
1862
2136
|
self._all_tracks = self.get_all_tracks()
|
1863
2137
|
return self._all_tracks
|
1864
2138
|
|
2139
|
+
def get_all_branches_of_node(
|
2140
|
+
self, node: int, end_time: int = None
|
2141
|
+
) -> list:
|
2142
|
+
"""Computes all the tracks of the subtree spawn by a given node.
|
2143
|
+
Similar to get_all_tracks().
|
2144
|
+
|
2145
|
+
Args:
|
2146
|
+
node (int, optional): The node that we want to get its branches.
|
2147
|
+
|
2148
|
+
Returns:
|
2149
|
+
([[int, ...], ...]): list of lists containing track cell ids
|
2150
|
+
"""
|
2151
|
+
if not end_time:
|
2152
|
+
end_time = self.t_e
|
2153
|
+
branches = [self.get_successors(node)]
|
2154
|
+
to_do = self[branches[0][-1]].copy()
|
2155
|
+
while to_do:
|
2156
|
+
current = to_do.pop()
|
2157
|
+
track = self.get_cycle(current, end_time=end_time)
|
2158
|
+
branches += [track]
|
2159
|
+
to_do.extend(self[track[-1]])
|
2160
|
+
return branches
|
2161
|
+
|
1865
2162
|
def get_all_tracks(self, force_recompute: bool = False) -> list:
|
1866
2163
|
"""Computes all the tracks of a given lineage tree,
|
1867
2164
|
stores it in `self.all_tracks` and returns it.
|
@@ -1869,17 +2166,42 @@ class lineageTree:
|
|
1869
2166
|
Returns:
|
1870
2167
|
([[int, ...], ...]): list of lists containing track cell ids
|
1871
2168
|
"""
|
1872
|
-
if not hasattr(self, "_all_tracks"):
|
2169
|
+
if not hasattr(self, "_all_tracks") or force_recompute:
|
1873
2170
|
self._all_tracks = []
|
1874
|
-
to_do =
|
2171
|
+
to_do = list(self.roots)
|
1875
2172
|
while len(to_do) != 0:
|
1876
2173
|
current = to_do.pop()
|
1877
2174
|
track = self.get_cycle(current)
|
1878
2175
|
self._all_tracks += [track]
|
1879
|
-
to_do
|
2176
|
+
to_do.extend(self[track[-1]])
|
1880
2177
|
return self._all_tracks
|
1881
2178
|
|
1882
|
-
def
|
2179
|
+
def get_tracks(self, roots: list = None) -> list:
|
2180
|
+
"""Computes the tracks given by the list of nodes `roots` and returns it.
|
2181
|
+
|
2182
|
+
Args:
|
2183
|
+
roots (list): list of ids of the roots to be computed
|
2184
|
+
Returns:
|
2185
|
+
([[int, ...], ...]): list of lists containing track cell ids
|
2186
|
+
"""
|
2187
|
+
if roots is None:
|
2188
|
+
return self.get_all_tracks(force_recompute=True)
|
2189
|
+
else:
|
2190
|
+
tracks = []
|
2191
|
+
to_do = list(roots)
|
2192
|
+
while len(to_do) != 0:
|
2193
|
+
current = to_do.pop()
|
2194
|
+
track = self.get_cycle(current)
|
2195
|
+
tracks.append(track)
|
2196
|
+
to_do.extend(self[track[-1]])
|
2197
|
+
return tracks
|
2198
|
+
|
2199
|
+
def get_sub_tree(
|
2200
|
+
self,
|
2201
|
+
x: Union[int, Iterable],
|
2202
|
+
end_time: Union[int, None] = None,
|
2203
|
+
preorder: bool = False,
|
2204
|
+
) -> list:
|
1883
2205
|
"""Computes the list of cells from the subtree spawned by *x*
|
1884
2206
|
The default output order is breadth first traversal.
|
1885
2207
|
Unless preorder is `True` in that case the order is
|
@@ -1891,16 +2213,24 @@ class lineageTree:
|
|
1891
2213
|
Returns:
|
1892
2214
|
([int, ...]): the ordered list of node ids
|
1893
2215
|
"""
|
1894
|
-
|
2216
|
+
if not end_time:
|
2217
|
+
end_time = self.t_e
|
2218
|
+
if not isinstance(x, Iterable):
|
2219
|
+
to_do = [x]
|
2220
|
+
elif isinstance(x, Iterable):
|
2221
|
+
to_do = list(x)
|
1895
2222
|
sub_tree = []
|
1896
|
-
while
|
1897
|
-
curr = to_do.pop(
|
2223
|
+
while to_do:
|
2224
|
+
curr = to_do.pop()
|
1898
2225
|
succ = self.successor.get(curr, [])
|
2226
|
+
if succ and end_time < self.time.get(curr, end_time):
|
2227
|
+
succ = []
|
2228
|
+
continue
|
1899
2229
|
if preorder:
|
1900
2230
|
to_do = succ + to_do
|
1901
2231
|
else:
|
1902
2232
|
to_do += succ
|
1903
|
-
|
2233
|
+
sub_tree += [curr]
|
1904
2234
|
return sub_tree
|
1905
2235
|
|
1906
2236
|
def compute_spatial_density(
|
@@ -1971,6 +2301,70 @@ class lineageTree:
|
|
1971
2301
|
)
|
1972
2302
|
return self.th_edges
|
1973
2303
|
|
2304
|
+
def main_axes(self, time: int = None):
|
2305
|
+
"""Finds the main axes for a timepoint.
|
2306
|
+
If none will select the timepoint with the highest amound of cells.
|
2307
|
+
|
2308
|
+
Args:
|
2309
|
+
time (int, optional): The timepoint to find the main axes.
|
2310
|
+
If None will find the timepoint
|
2311
|
+
with the largest number of cells.
|
2312
|
+
|
2313
|
+
Returns:
|
2314
|
+
list: A list that contains the array of eigenvalues and eigenvectors.
|
2315
|
+
"""
|
2316
|
+
if time is None:
|
2317
|
+
time = np.argmax(
|
2318
|
+
[len(self.time_nodes[t]) for t in range(int(self.t_e))]
|
2319
|
+
)
|
2320
|
+
pos = np.array([self.pos[node] for node in self.time_nodes[time]])
|
2321
|
+
pos = pos - np.mean(pos, axis=0)
|
2322
|
+
cov = np.cov(np.array(pos).T)
|
2323
|
+
eig_val, eig_vec = np.linalg.eig(cov)
|
2324
|
+
srt = np.argsort(eig_val)[::-1]
|
2325
|
+
self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
|
2326
|
+
return eig_val[srt], eig_vec[:, srt]
|
2327
|
+
|
2328
|
+
def scale_embryo(self, scale=1000):
|
2329
|
+
"""Scale the embryo using their eigenvalues.
|
2330
|
+
|
2331
|
+
Args:
|
2332
|
+
scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
|
2333
|
+
|
2334
|
+
Returns:
|
2335
|
+
float: The scale factor.
|
2336
|
+
"""
|
2337
|
+
eig = self.main_axes()[0]
|
2338
|
+
return scale / (np.sqrt(eig[0]))
|
2339
|
+
|
2340
|
+
@staticmethod
|
2341
|
+
def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
|
2342
|
+
"""Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
|
2343
|
+
Uses the Rodrigues rotation formula.
|
2344
|
+
|
2345
|
+
Args:
|
2346
|
+
vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
|
2347
|
+
vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
|
2348
|
+
|
2349
|
+
Returns:
|
2350
|
+
np.array: The rotation matrix.
|
2351
|
+
"""
|
2352
|
+
vector1 = vector1 / np.linalg.norm(vector1)
|
2353
|
+
vector2 = vector2 / np.linalg.norm(vector2)
|
2354
|
+
if vector1 @ vector2 == 1:
|
2355
|
+
return np.eye(3)
|
2356
|
+
angle = np.arccos(vector1 @ vector2)
|
2357
|
+
axis = np.cross(vector1, vector2)
|
2358
|
+
axis = axis / np.linalg.norm(axis)
|
2359
|
+
K = np.array(
|
2360
|
+
[
|
2361
|
+
[0, -axis[2], axis[1]],
|
2362
|
+
[axis[2], 0, -axis[0]],
|
2363
|
+
[-axis[1], axis[0], 0],
|
2364
|
+
]
|
2365
|
+
)
|
2366
|
+
return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
|
2367
|
+
|
1974
2368
|
def get_ancestor_at_t(self, n: int, time: int = None):
|
1975
2369
|
"""
|
1976
2370
|
Find the id of the ancestor of a give node `n`
|
@@ -1997,62 +2391,27 @@ class lineageTree:
|
|
1997
2391
|
ancestor = self.predecessor.get(ancestor, [-1])[0]
|
1998
2392
|
return ancestor
|
1999
2393
|
|
2000
|
-
def
|
2001
|
-
"""
|
2002
|
-
Get a "simple" version of the tree spawned by the node `r`
|
2003
|
-
This simple version is just one node per cell (as opposed to
|
2004
|
-
one node per cell per time-point). The life time duration of
|
2005
|
-
a cell `c` is stored in `self.cycle_time` and return by this
|
2006
|
-
function
|
2394
|
+
def get_labelled_ancestor(self, node: int):
|
2395
|
+
"""Finds the first labelled ancestor and returns its ID otherwise returns None
|
2007
2396
|
|
2008
2397
|
Args:
|
2009
|
-
|
2010
|
-
time_resolution (float): the time between two consecutive time points
|
2398
|
+
node (int): The id of the node
|
2011
2399
|
|
2012
2400
|
Returns:
|
2013
|
-
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2019
|
-
|
2020
|
-
self.
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
|
2025
|
-
|
2026
|
-
|
2027
|
-
if _next:
|
2028
|
-
out_dict[current] = _next
|
2029
|
-
to_do.extend(_next)
|
2030
|
-
self.cycle_time[current] = len(cycle) * time_resolution
|
2031
|
-
return out_dict, self.cycle_time
|
2032
|
-
|
2033
|
-
@staticmethod
|
2034
|
-
def __edist_format(adj_dict: dict):
|
2035
|
-
inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
|
2036
|
-
roots = set(adj_dict).difference(inv_adj)
|
2037
|
-
nid2list = {}
|
2038
|
-
list2nid = {}
|
2039
|
-
nodes = []
|
2040
|
-
adj_list = []
|
2041
|
-
curr_id = 0
|
2042
|
-
for r in roots:
|
2043
|
-
to_do = [r]
|
2044
|
-
while to_do:
|
2045
|
-
curr = to_do.pop(0)
|
2046
|
-
nid2list[curr] = curr_id
|
2047
|
-
list2nid[curr_id] = curr
|
2048
|
-
nodes.append(curr_id)
|
2049
|
-
to_do = adj_dict.get(curr, []) + to_do
|
2050
|
-
curr_id += 1
|
2051
|
-
adj_list = [
|
2052
|
-
[nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
|
2053
|
-
for _id in nodes
|
2054
|
-
]
|
2055
|
-
return nodes, adj_list, list2nid
|
2401
|
+
[None,int]: Returns the first ancestor found that has a label otherwise
|
2402
|
+
None.
|
2403
|
+
"""
|
2404
|
+
if node not in self.nodes:
|
2405
|
+
return None
|
2406
|
+
ancestor = node
|
2407
|
+
while (
|
2408
|
+
self.t_b <= self.time.get(ancestor, self.t_b - 1)
|
2409
|
+
and ancestor != -1
|
2410
|
+
):
|
2411
|
+
if ancestor in self.labels:
|
2412
|
+
return ancestor
|
2413
|
+
ancestor = self.predecessor.get(ancestor, [-1])[0]
|
2414
|
+
return
|
2056
2415
|
|
2057
2416
|
def unordered_tree_edit_distances_at_time_t(
|
2058
2417
|
self,
|
@@ -2060,9 +2419,10 @@ class lineageTree:
|
|
2060
2419
|
delta: callable = None,
|
2061
2420
|
norm: callable = None,
|
2062
2421
|
recompute: bool = False,
|
2422
|
+
end_time: int = None,
|
2063
2423
|
) -> dict:
|
2064
2424
|
"""
|
2065
|
-
Compute all the pairwise unordered tree edit distances from Zhang
|
2425
|
+
Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
|
2066
2426
|
|
2067
2427
|
Args:
|
2068
2428
|
t (int): time to look at
|
@@ -2071,6 +2431,8 @@ class lineageTree:
|
|
2071
2431
|
of the tree spawned by `n1` and the number of nodes
|
2072
2432
|
of the tree spawned by `n2` as arguments.
|
2073
2433
|
recompute (bool): if True, forces to recompute the distances (default: False)
|
2434
|
+
end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
|
2435
|
+
will be taken into account.
|
2074
2436
|
|
2075
2437
|
Returns:
|
2076
2438
|
(dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
|
@@ -2084,14 +2446,20 @@ class lineageTree:
|
|
2084
2446
|
for n1, n2 in combinations(roots, 2):
|
2085
2447
|
key = tuple(sorted((n1, n2)))
|
2086
2448
|
self.uted[t][key] = self.unordered_tree_edit_distance(
|
2087
|
-
n1, n2,
|
2449
|
+
n1, n2, end_time=end_time
|
2088
2450
|
)
|
2089
2451
|
return self.uted[t]
|
2090
2452
|
|
2091
2453
|
def unordered_tree_edit_distance(
|
2092
|
-
self,
|
2454
|
+
self,
|
2455
|
+
n1: int,
|
2456
|
+
n2: int,
|
2457
|
+
end_time: int = None,
|
2458
|
+
style="fragmented",
|
2459
|
+
node_lengths: tuple = (1, 5, 7),
|
2093
2460
|
) -> float:
|
2094
2461
|
"""
|
2462
|
+
TODO: Add option for choosing which tree aproximation should be used (Full, simple, comp)
|
2095
2463
|
Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
2096
2464
|
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
2097
2465
|
cost is given by the function delta (see edist doc for more information).
|
@@ -2101,48 +2469,178 @@ class lineageTree:
|
|
2101
2469
|
Args:
|
2102
2470
|
n1 (int): id of the first node to compare
|
2103
2471
|
n2 (int): id of the second node to compare
|
2104
|
-
|
2105
|
-
|
2106
|
-
of the tree spawned by `n1` and the number of nodes
|
2107
|
-
of the tree spawned by `n2` as arguments.
|
2472
|
+
tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
|
2473
|
+
Defaults to "fragmented".
|
2108
2474
|
|
2109
2475
|
Returns:
|
2110
2476
|
(float) The normed unordered tree edit distance
|
2111
2477
|
"""
|
2112
2478
|
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2479
|
+
tree = tree_style[style].value
|
2480
|
+
tree1 = tree(
|
2481
|
+
lT=self, node_length=node_lengths, end_time=end_time, root=n1
|
2482
|
+
)
|
2483
|
+
tree2 = tree(
|
2484
|
+
lT=self, node_length=node_lengths, end_time=end_time, root=n2
|
2485
|
+
)
|
2486
|
+
delta = tree1.delta
|
2487
|
+
_, times1 = tree1.tree
|
2488
|
+
_, times2 = tree2.tree
|
2489
|
+
(
|
2490
|
+
nodes1,
|
2491
|
+
adj1,
|
2492
|
+
corres1,
|
2493
|
+
) = tree1.edist
|
2494
|
+
(
|
2495
|
+
nodes2,
|
2496
|
+
adj2,
|
2497
|
+
corres2,
|
2498
|
+
) = tree2.edist
|
2499
|
+
if len(nodes1) == len(nodes2) == 0:
|
2500
|
+
return 0
|
2501
|
+
delta_tmp = partial(
|
2502
|
+
delta,
|
2503
|
+
corres1=corres1,
|
2504
|
+
corres2=corres2,
|
2505
|
+
times1=times1,
|
2506
|
+
times2=times2,
|
2507
|
+
)
|
2116
2508
|
|
2117
|
-
|
2118
|
-
|
2119
|
-
|
2120
|
-
len_x = times[corres1[x]]
|
2121
|
-
len_y = times[corres2[y]]
|
2122
|
-
return np.abs(len_x - len_y) / (len_x + len_y)
|
2509
|
+
return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
|
2510
|
+
tree1.get_norm(), tree2.get_norm()
|
2511
|
+
)
|
2123
2512
|
|
2124
|
-
|
2513
|
+
def to_simple_networkx(
|
2514
|
+
self, node: Union[int, list, set, tuple] = None, start_time: int = 0
|
2515
|
+
):
|
2516
|
+
"""
|
2517
|
+
Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
|
2518
|
+
they can be used for visualization or other tasks),
|
2519
|
+
so only the start and the end of a branch are calculated, all cells in between are not taken into account.
|
2520
|
+
Args:
|
2521
|
+
start_time (int): From which timepoints are the graphs to be calculated.
|
2522
|
+
For example if start_time is 10, then all trees that begin
|
2523
|
+
on tp 10 or before are calculated.
|
2524
|
+
returns:
|
2525
|
+
G : list(nx.Digraph(),...)
|
2526
|
+
pos : list(dict(id:position))
|
2527
|
+
"""
|
2125
2528
|
|
2126
|
-
|
2127
|
-
|
2529
|
+
if node is None:
|
2530
|
+
mothers = [
|
2531
|
+
root for root in self.roots if self.time[root] <= start_time
|
2532
|
+
]
|
2533
|
+
else:
|
2534
|
+
mothers = node if isinstance(node, (list, set)) else [node]
|
2535
|
+
graph = {}
|
2536
|
+
all_nodes = {}
|
2537
|
+
all_edges = {}
|
2538
|
+
for mom in mothers:
|
2539
|
+
edges = set()
|
2540
|
+
nodes = set()
|
2541
|
+
for branch in self.get_all_branches_of_node(mom):
|
2542
|
+
nodes.update((branch[0], branch[-1]))
|
2543
|
+
if len(branch) > 1:
|
2544
|
+
edges.add((branch[0], branch[-1]))
|
2545
|
+
for suc in self[branch[-1]]:
|
2546
|
+
edges.add((branch[-1], suc))
|
2547
|
+
all_edges[mom] = edges
|
2548
|
+
all_nodes[mom] = nodes
|
2549
|
+
for i, mother in enumerate(mothers):
|
2550
|
+
graph[i] = nx.DiGraph()
|
2551
|
+
graph[i].add_nodes_from(all_nodes[mother])
|
2552
|
+
graph[i].add_edges_from(all_edges[mother])
|
2553
|
+
|
2554
|
+
return graph
|
2555
|
+
|
2556
|
+
def plot_all_lineages(
|
2557
|
+
self,
|
2558
|
+
starting_point: int = 0,
|
2559
|
+
nrows=2,
|
2560
|
+
figsize=(10, 15),
|
2561
|
+
dpi=70,
|
2562
|
+
fontsize=22,
|
2563
|
+
figure=None,
|
2564
|
+
axes=None,
|
2565
|
+
**kwargs,
|
2566
|
+
):
|
2567
|
+
"""Plots all lineages.
|
2128
2568
|
|
2129
|
-
|
2569
|
+
Args:
|
2570
|
+
starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
|
2571
|
+
For example if start_time is 10, then all trees that begin
|
2572
|
+
on tp 10 or before are calculated. Defaults to None.
|
2573
|
+
nrows (int): How many rows of plots should be printed.
|
2574
|
+
kwargs: args accepted by networkx
|
2575
|
+
"""
|
2130
2576
|
|
2131
|
-
|
2132
|
-
|
2577
|
+
nrows = int(nrows)
|
2578
|
+
if nrows < 1 or not nrows:
|
2579
|
+
nrows = 1
|
2580
|
+
raise Warning("Number of rows has to be at least 1")
|
2133
2581
|
|
2134
|
-
|
2135
|
-
|
2136
|
-
|
2137
|
-
|
2138
|
-
|
2139
|
-
return 0
|
2140
|
-
delta_tmp = partial(
|
2141
|
-
delta, corres1=corres1, corres2=corres2, times=self.cycle_time
|
2582
|
+
graphs = self.to_simple_networkx(start_time=starting_point)
|
2583
|
+
ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
|
2584
|
+
pos = postions_of_nx(self, graphs)
|
2585
|
+
figure, axes = plt.subplots(
|
2586
|
+
figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
|
2142
2587
|
)
|
2143
|
-
|
2144
|
-
|
2588
|
+
flat_axes = axes.flatten()
|
2589
|
+
ax2root = {}
|
2590
|
+
for i, graph in enumerate(graphs.values()):
|
2591
|
+
nx.draw_networkx(
|
2592
|
+
graph,
|
2593
|
+
pos[i],
|
2594
|
+
with_labels=False,
|
2595
|
+
arrows=False,
|
2596
|
+
**kwargs,
|
2597
|
+
ax=flat_axes[i],
|
2598
|
+
)
|
2599
|
+
root = [n for n, d in graph.in_degree() if d == 0][0]
|
2600
|
+
label = self.labels.get(root, "Unlabeled")
|
2601
|
+
xlim = flat_axes[i].get_xlim()
|
2602
|
+
ylim = flat_axes[i].get_ylim()
|
2603
|
+
x_pos = (xlim[1]) / 10
|
2604
|
+
y_pos = ylim[0] + 15
|
2605
|
+
ax2root[flat_axes[i]] = root
|
2606
|
+
flat_axes[i].text(
|
2607
|
+
x_pos,
|
2608
|
+
y_pos,
|
2609
|
+
label,
|
2610
|
+
fontsize=fontsize,
|
2611
|
+
color="black",
|
2612
|
+
ha="center",
|
2613
|
+
va="center",
|
2614
|
+
bbox={
|
2615
|
+
"facecolor": "white",
|
2616
|
+
"edgecolor": "green",
|
2617
|
+
"boxstyle": "round",
|
2618
|
+
},
|
2619
|
+
)
|
2620
|
+
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
|
2621
|
+
return figure, axes, ax2root
|
2622
|
+
|
2623
|
+
def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
|
2624
|
+
"""Plots the subtree spawn by a node.
|
2625
|
+
|
2626
|
+
Args:
|
2627
|
+
node (int): The id of the node that is going to be plotted.
|
2628
|
+
kwargs: args accepted by networkx
|
2629
|
+
"""
|
2630
|
+
graph = self.to_simple_networkx(node)
|
2631
|
+
if len(graph) > 1:
|
2632
|
+
raise Warning("Please enter only one node")
|
2633
|
+
graph = graph[list(graph)[0]]
|
2634
|
+
figure, ax = plt.subplots(nrows=1, ncols=1)
|
2635
|
+
nx.draw_networkx(
|
2636
|
+
graph,
|
2637
|
+
hierarchy_pos(graph, self, node),
|
2638
|
+
with_labels=False,
|
2639
|
+
arrows=False,
|
2640
|
+
ax=ax,
|
2641
|
+
**kwargs,
|
2145
2642
|
)
|
2643
|
+
return figure, ax
|
2146
2644
|
|
2147
2645
|
# def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
|
2148
2646
|
# metric='euclidian', **kwargs):
|
@@ -2223,10 +2721,584 @@ class lineageTree:
|
|
2223
2721
|
to_do.append(_next)
|
2224
2722
|
elif self.time[_next] == t:
|
2225
2723
|
final_nodes.append(_next)
|
2724
|
+
if not final_nodes:
|
2725
|
+
return list(r)
|
2226
2726
|
return final_nodes
|
2227
2727
|
|
2728
|
+
@staticmethod
|
2729
|
+
def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
|
2730
|
+
"""
|
2731
|
+
Calculate the line that centers the band w.
|
2732
|
+
|
2733
|
+
Args:
|
2734
|
+
dist_mat (matrix): distance matrix obtained by the function calculate_dtw
|
2735
|
+
|
2736
|
+
Returns:
|
2737
|
+
(float) Slope
|
2738
|
+
(float) intercept of the line
|
2739
|
+
"""
|
2740
|
+
i, j = dist_mat.shape
|
2741
|
+
x1 = max(0, i - j) / 2
|
2742
|
+
x2 = (i + min(i, j)) / 2
|
2743
|
+
y1 = max(0, j - i) / 2
|
2744
|
+
y2 = (j + min(i, j)) / 2
|
2745
|
+
slope = (y1 - y2) / (x1 - x2)
|
2746
|
+
intercept = y1 - slope * x1
|
2747
|
+
return slope, intercept
|
2748
|
+
|
2749
|
+
# Reference: https://github.com/kamperh/lecture_dtw_notebook/blob/main/dtw.ipynb
|
2750
|
+
def __dp(
|
2751
|
+
self,
|
2752
|
+
dist_mat: np.ndarray,
|
2753
|
+
start_d: int = 0,
|
2754
|
+
back_d: int = 0,
|
2755
|
+
fast: bool = False,
|
2756
|
+
w: int = 0,
|
2757
|
+
centered_band: bool = True,
|
2758
|
+
) -> (((int, int), ...), np.ndarray):
|
2759
|
+
"""
|
2760
|
+
Find DTW minimum cost between two series using dynamic programming.
|
2761
|
+
|
2762
|
+
Args:
|
2763
|
+
dist_mat (matrix): distance matrix obtained by the function calculate_dtw
|
2764
|
+
start_d (int): start delay
|
2765
|
+
back_d (int): end delay
|
2766
|
+
w (int): window constrain
|
2767
|
+
slope (float): to calculate window - givem by the function __calculate_diag_line
|
2768
|
+
intercept (flost): to calculate window - givem by the function __calculate_diag_line
|
2769
|
+
use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
|
2770
|
+
|
2771
|
+
Returns:
|
2772
|
+
(tuple of tuples) Aligment path
|
2773
|
+
(matrix) Cost matrix
|
2774
|
+
"""
|
2775
|
+
N, M = dist_mat.shape
|
2776
|
+
w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
|
2777
|
+
|
2778
|
+
if centered_band:
|
2779
|
+
slope, intercept = self.__calculate_diag_line(dist_mat)
|
2780
|
+
square_root = np.sqrt((slope**2) + 1)
|
2781
|
+
|
2782
|
+
# Initialize the cost matrix
|
2783
|
+
cost_mat = np.full((N + 1, M + 1), np.inf)
|
2784
|
+
cost_mat[0, 0] = 0
|
2785
|
+
|
2786
|
+
# Fill the cost matrix while keeping traceback information
|
2787
|
+
traceback_mat = np.zeros((N, M))
|
2788
|
+
|
2789
|
+
cost_mat[: start_d + 1, 0] = 0
|
2790
|
+
cost_mat[0, : start_d + 1] = 0
|
2791
|
+
|
2792
|
+
cost_mat[N - back_d :, M] = 0
|
2793
|
+
cost_mat[N, M - back_d :] = 0
|
2794
|
+
|
2795
|
+
for i in range(N):
|
2796
|
+
for j in range(M):
|
2797
|
+
if fast and not centered_band:
|
2798
|
+
condition = abs(i - j) <= w_limit
|
2799
|
+
elif fast:
|
2800
|
+
condition = (
|
2801
|
+
abs(slope * i - j + intercept) / square_root <= w_limit
|
2802
|
+
)
|
2803
|
+
else:
|
2804
|
+
condition = True
|
2805
|
+
|
2806
|
+
if condition:
|
2807
|
+
penalty = [
|
2808
|
+
cost_mat[i, j], # match (0)
|
2809
|
+
cost_mat[i, j + 1], # insertion (1)
|
2810
|
+
cost_mat[i + 1, j], # deletion (2)
|
2811
|
+
]
|
2812
|
+
i_penalty = np.argmin(penalty)
|
2813
|
+
cost_mat[i + 1, j + 1] = (
|
2814
|
+
dist_mat[i, j] + penalty[i_penalty]
|
2815
|
+
)
|
2816
|
+
traceback_mat[i, j] = i_penalty
|
2817
|
+
|
2818
|
+
min_index1 = np.argmin(cost_mat[N - back_d :, M])
|
2819
|
+
min_index2 = np.argmin(cost_mat[N, M - back_d :])
|
2820
|
+
|
2821
|
+
if (
|
2822
|
+
cost_mat[N, M - back_d + min_index2]
|
2823
|
+
< cost_mat[N - back_d + min_index1, M]
|
2824
|
+
):
|
2825
|
+
i = N - 1
|
2826
|
+
j = M - back_d + min_index2 - 1
|
2827
|
+
final_cost = cost_mat[i + 1, j + 1]
|
2828
|
+
else:
|
2829
|
+
i = N - back_d + min_index1 - 1
|
2830
|
+
j = M - 1
|
2831
|
+
final_cost = cost_mat[i + 1, j + 1]
|
2832
|
+
|
2833
|
+
path = [(i, j)]
|
2834
|
+
|
2835
|
+
while (
|
2836
|
+
start_d != 0
|
2837
|
+
and ((start_d < i and j > 0) or (i > 0 and start_d < j))
|
2838
|
+
) or (start_d == 0 and (i > 0 or j > 0)):
|
2839
|
+
tb_type = traceback_mat[i, j]
|
2840
|
+
if tb_type == 0:
|
2841
|
+
# Match
|
2842
|
+
i -= 1
|
2843
|
+
j -= 1
|
2844
|
+
elif tb_type == 1:
|
2845
|
+
# Insertion
|
2846
|
+
i -= 1
|
2847
|
+
elif tb_type == 2:
|
2848
|
+
# Deletion
|
2849
|
+
j -= 1
|
2850
|
+
|
2851
|
+
path.append((i, j))
|
2852
|
+
|
2853
|
+
# Strip infinity edges from cost_mat before returning
|
2854
|
+
cost_mat = cost_mat[1:, 1:]
|
2855
|
+
return path[::-1], cost_mat, final_cost
|
2856
|
+
|
2857
|
+
# Reference: https://github.com/nghiaho12/rigid_transform_3D
|
2858
|
+
@staticmethod
|
2859
|
+
def __rigid_transform_3D(A, B):
|
2860
|
+
assert A.shape == B.shape
|
2861
|
+
|
2862
|
+
num_rows, num_cols = A.shape
|
2863
|
+
if num_rows != 3:
|
2864
|
+
raise Exception(
|
2865
|
+
f"matrix A is not 3xN, it is {num_rows}x{num_cols}"
|
2866
|
+
)
|
2867
|
+
|
2868
|
+
num_rows, num_cols = B.shape
|
2869
|
+
if num_rows != 3:
|
2870
|
+
raise Exception(
|
2871
|
+
f"matrix B is not 3xN, it is {num_rows}x{num_cols}"
|
2872
|
+
)
|
2873
|
+
|
2874
|
+
# find mean column wise
|
2875
|
+
centroid_A = np.mean(A, axis=1)
|
2876
|
+
centroid_B = np.mean(B, axis=1)
|
2877
|
+
|
2878
|
+
# ensure centroids are 3x1
|
2879
|
+
centroid_A = centroid_A.reshape(-1, 1)
|
2880
|
+
centroid_B = centroid_B.reshape(-1, 1)
|
2881
|
+
|
2882
|
+
# subtract mean
|
2883
|
+
Am = A - centroid_A
|
2884
|
+
Bm = B - centroid_B
|
2885
|
+
|
2886
|
+
H = Am @ np.transpose(Bm)
|
2887
|
+
|
2888
|
+
# find rotation
|
2889
|
+
U, S, Vt = np.linalg.svd(H)
|
2890
|
+
R = Vt.T @ U.T
|
2891
|
+
|
2892
|
+
# special reflection case
|
2893
|
+
if np.linalg.det(R) < 0:
|
2894
|
+
# print("det(R) < R, reflection detected!, correcting for it ...")
|
2895
|
+
Vt[2, :] *= -1
|
2896
|
+
R = Vt.T @ U.T
|
2897
|
+
|
2898
|
+
t = -R @ centroid_A + centroid_B
|
2899
|
+
|
2900
|
+
return R, t
|
2901
|
+
|
2902
|
+
def __interpolate(
|
2903
|
+
self, track1: list, track2: list, threshold: int
|
2904
|
+
) -> (np.ndarray, np.ndarray):
|
2905
|
+
"""
|
2906
|
+
Interpolate two series that have different lengths
|
2907
|
+
|
2908
|
+
Args:
|
2909
|
+
track1 (list): list of nodes of the first cell cycle to compare
|
2910
|
+
track2 (list): list of nodes of the second cell cycle to compare
|
2911
|
+
threshold (int): set a maximum number of points a track can have
|
2912
|
+
|
2913
|
+
Returns:
|
2914
|
+
(list of list) x, y, z postions for track1
|
2915
|
+
(list of list) x, y, z postions for track2
|
2916
|
+
"""
|
2917
|
+
inter1_pos = []
|
2918
|
+
inter2_pos = []
|
2919
|
+
|
2920
|
+
track1_pos = np.array([self.pos[c_id] for c_id in track1])
|
2921
|
+
track2_pos = np.array([self.pos[c_id] for c_id in track2])
|
2922
|
+
|
2923
|
+
# Both tracks have the same length and size below the threshold - nothing is done
|
2924
|
+
if len(track1) == len(track2) and (
|
2925
|
+
len(track1) <= threshold or len(track2) <= threshold
|
2926
|
+
):
|
2927
|
+
return track1_pos, track2_pos
|
2928
|
+
# Both tracks have the same length but one or more sizes are above the threshold
|
2929
|
+
elif len(track1) > threshold or len(track2) > threshold:
|
2930
|
+
sampling = threshold
|
2931
|
+
# Tracks have different lengths and the sizes are below the threshold
|
2932
|
+
else:
|
2933
|
+
sampling = max(len(track1), len(track2))
|
2934
|
+
|
2935
|
+
for pos in range(3):
|
2936
|
+
track1_interp = InterpolatedUnivariateSpline(
|
2937
|
+
np.linspace(0, 1, len(track1_pos[:, pos])),
|
2938
|
+
track1_pos[:, pos],
|
2939
|
+
k=1,
|
2940
|
+
)
|
2941
|
+
inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
|
2942
|
+
|
2943
|
+
track2_interp = InterpolatedUnivariateSpline(
|
2944
|
+
np.linspace(0, 1, len(track2_pos[:, pos])),
|
2945
|
+
track2_pos[:, pos],
|
2946
|
+
k=1,
|
2947
|
+
)
|
2948
|
+
inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
|
2949
|
+
|
2950
|
+
return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
|
2951
|
+
|
2952
|
+
def calculate_dtw(
|
2953
|
+
self,
|
2954
|
+
nodes1: int,
|
2955
|
+
nodes2: int,
|
2956
|
+
threshold: int = 1000,
|
2957
|
+
regist: bool = True,
|
2958
|
+
start_d: int = 0,
|
2959
|
+
back_d: int = 0,
|
2960
|
+
fast: bool = False,
|
2961
|
+
w: int = 0,
|
2962
|
+
centered_band: bool = True,
|
2963
|
+
cost_mat_p: bool = False,
|
2964
|
+
) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
|
2965
|
+
"""
|
2966
|
+
Calculate DTW distance between two cell cycles
|
2967
|
+
|
2968
|
+
Args:
|
2969
|
+
nodes1 (int): node to compare distance
|
2970
|
+
nodes2 (int): node to compare distance
|
2971
|
+
threshold: set a maximum number of points a track can have
|
2972
|
+
regist (boolean): Rotate and translate trajectories
|
2973
|
+
start_d (int): start delay
|
2974
|
+
back_d (int): end delay
|
2975
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
2976
|
+
w (int): window size
|
2977
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
2978
|
+
cost_mat_p (boolean): True if print the not normalized cost matrix
|
2979
|
+
|
2980
|
+
Returns:
|
2981
|
+
(float) DTW distance
|
2982
|
+
(tuple of tuples) Aligment path
|
2983
|
+
(matrix) Cost matrix
|
2984
|
+
(list of lists) pos_cycle1: rotated and translated trajectories positions
|
2985
|
+
(list of lists) pos_cycle2: rotated and translated trajectories positions
|
2986
|
+
"""
|
2987
|
+
nodes1_cycle = self.get_cycle(nodes1)
|
2988
|
+
nodes2_cycle = self.get_cycle(nodes2)
|
2989
|
+
|
2990
|
+
interp_cycle1, interp_cycle2 = self.__interpolate(
|
2991
|
+
nodes1_cycle, nodes2_cycle, threshold
|
2992
|
+
)
|
2993
|
+
|
2994
|
+
pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
|
2995
|
+
pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
|
2996
|
+
|
2997
|
+
if regist:
|
2998
|
+
R, t = self.__rigid_transform_3D(
|
2999
|
+
np.transpose(interp_cycle1), np.transpose(interp_cycle2)
|
3000
|
+
)
|
3001
|
+
pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
|
3002
|
+
|
3003
|
+
dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
|
3004
|
+
|
3005
|
+
path, cost_mat, final_cost = self.__dp(
|
3006
|
+
dist_mat,
|
3007
|
+
start_d,
|
3008
|
+
back_d,
|
3009
|
+
w=w,
|
3010
|
+
fast=fast,
|
3011
|
+
centered_band=centered_band,
|
3012
|
+
)
|
3013
|
+
cost = final_cost / len(path)
|
3014
|
+
|
3015
|
+
if cost_mat_p:
|
3016
|
+
return cost, path, cost_mat, pos_cycle1, pos_cycle2
|
3017
|
+
else:
|
3018
|
+
return cost, path
|
3019
|
+
|
3020
|
+
def plot_dtw_heatmap(
|
3021
|
+
self,
|
3022
|
+
nodes1: int,
|
3023
|
+
nodes2: int,
|
3024
|
+
threshold: int = 1000,
|
3025
|
+
regist: bool = True,
|
3026
|
+
start_d: int = 0,
|
3027
|
+
back_d: int = 0,
|
3028
|
+
fast: bool = False,
|
3029
|
+
w: int = 0,
|
3030
|
+
centered_band: bool = True,
|
3031
|
+
) -> (float, plt.figure):
|
3032
|
+
"""
|
3033
|
+
Plot DTW cost matrix between two cell cycles in heatmap format
|
3034
|
+
|
3035
|
+
Args:
|
3036
|
+
nodes1 (int): node to compare distance
|
3037
|
+
nodes2 (int): node to compare distance
|
3038
|
+
start_d (int): start delay
|
3039
|
+
back_d (int): end delay
|
3040
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
3041
|
+
w (int): window size
|
3042
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
3043
|
+
|
3044
|
+
Returns:
|
3045
|
+
(float) DTW distance
|
3046
|
+
(figure) Heatmap of cost matrix with opitimal path
|
3047
|
+
"""
|
3048
|
+
cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
|
3049
|
+
nodes1,
|
3050
|
+
nodes2,
|
3051
|
+
threshold,
|
3052
|
+
regist,
|
3053
|
+
start_d,
|
3054
|
+
back_d,
|
3055
|
+
fast,
|
3056
|
+
w,
|
3057
|
+
centered_band,
|
3058
|
+
cost_mat_p=True,
|
3059
|
+
)
|
3060
|
+
|
3061
|
+
fig = plt.figure(figsize=(8, 6))
|
3062
|
+
ax = fig.add_subplot(1, 1, 1)
|
3063
|
+
im = ax.imshow(
|
3064
|
+
cost_mat, cmap="viridis", origin="lower", interpolation="nearest"
|
3065
|
+
)
|
3066
|
+
plt.colorbar(im)
|
3067
|
+
ax.set_title("Heatmap of DTW Cost Matrix")
|
3068
|
+
ax.set_xlabel("Tree 1")
|
3069
|
+
ax.set_ylabel("tree 2")
|
3070
|
+
x_path, y_path = zip(*path)
|
3071
|
+
ax.plot(y_path, x_path, color="black")
|
3072
|
+
|
3073
|
+
return cost, fig
|
3074
|
+
|
3075
|
+
@staticmethod
|
3076
|
+
def __plot_2d(
|
3077
|
+
pos_cycle1,
|
3078
|
+
pos_cycle2,
|
3079
|
+
nodes1,
|
3080
|
+
nodes2,
|
3081
|
+
ax,
|
3082
|
+
x_idx,
|
3083
|
+
y_idx,
|
3084
|
+
x_label,
|
3085
|
+
y_label,
|
3086
|
+
):
|
3087
|
+
ax.plot(
|
3088
|
+
pos_cycle1[:, x_idx],
|
3089
|
+
pos_cycle1[:, y_idx],
|
3090
|
+
"-",
|
3091
|
+
label=f"root = {nodes1}",
|
3092
|
+
)
|
3093
|
+
ax.plot(
|
3094
|
+
pos_cycle2[:, x_idx],
|
3095
|
+
pos_cycle2[:, y_idx],
|
3096
|
+
"-",
|
3097
|
+
label=f"root = {nodes2}",
|
3098
|
+
)
|
3099
|
+
ax.set_xlabel(x_label)
|
3100
|
+
ax.set_ylabel(y_label)
|
3101
|
+
|
3102
|
+
def plot_dtw_trajectory(
|
3103
|
+
self,
|
3104
|
+
nodes1: int,
|
3105
|
+
nodes2: int,
|
3106
|
+
threshold: int = 1000,
|
3107
|
+
regist: bool = True,
|
3108
|
+
start_d: int = 0,
|
3109
|
+
back_d: int = 0,
|
3110
|
+
fast: bool = False,
|
3111
|
+
w: int = 0,
|
3112
|
+
centered_band: bool = True,
|
3113
|
+
projection: str = None,
|
3114
|
+
alig: bool = False,
|
3115
|
+
) -> (float, plt.figure):
|
3116
|
+
"""
|
3117
|
+
Plots DTW trajectories aligment between two cell cycles in 2D or 3D
|
3118
|
+
|
3119
|
+
Args:
|
3120
|
+
nodes1 (int): node to compare distance
|
3121
|
+
nodes2 (int): node to compare distance
|
3122
|
+
threshold (int): set a maximum number of points a track can have
|
3123
|
+
regist (boolean): Rotate and translate trajectories
|
3124
|
+
start_d (int): start delay
|
3125
|
+
back_d (int): end delay
|
3126
|
+
w (int): window size
|
3127
|
+
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
3128
|
+
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
3129
|
+
projection (string): specify which 2D to plot ->
|
3130
|
+
'3d' : for the 3d visualization
|
3131
|
+
'xy' or None (default) : 2D projection of axis x and y
|
3132
|
+
'xz' : 2D projection of axis x and z
|
3133
|
+
'yz' : 2D projection of axis y and z
|
3134
|
+
'pca' : PCA projection
|
3135
|
+
alig (boolean): True to show alignment on plot
|
3136
|
+
|
3137
|
+
Returns:
|
3138
|
+
(float) DTW distance
|
3139
|
+
(figue) Trajectories Plot
|
3140
|
+
"""
|
3141
|
+
(
|
3142
|
+
distance,
|
3143
|
+
alignment,
|
3144
|
+
cost_mat,
|
3145
|
+
pos_cycle1,
|
3146
|
+
pos_cycle2,
|
3147
|
+
) = self.calculate_dtw(
|
3148
|
+
nodes1,
|
3149
|
+
nodes2,
|
3150
|
+
threshold,
|
3151
|
+
regist,
|
3152
|
+
start_d,
|
3153
|
+
back_d,
|
3154
|
+
fast,
|
3155
|
+
w,
|
3156
|
+
centered_band,
|
3157
|
+
cost_mat_p=True,
|
3158
|
+
)
|
3159
|
+
|
3160
|
+
fig = plt.figure(figsize=(10, 6))
|
3161
|
+
|
3162
|
+
if projection == "3d":
|
3163
|
+
ax = fig.add_subplot(1, 1, 1, projection="3d")
|
3164
|
+
else:
|
3165
|
+
ax = fig.add_subplot(1, 1, 1)
|
3166
|
+
|
3167
|
+
if projection == "3d":
|
3168
|
+
ax.plot(
|
3169
|
+
pos_cycle1[:, 0],
|
3170
|
+
pos_cycle1[:, 1],
|
3171
|
+
pos_cycle1[:, 2],
|
3172
|
+
"-",
|
3173
|
+
label=f"root = {nodes1}",
|
3174
|
+
)
|
3175
|
+
ax.plot(
|
3176
|
+
pos_cycle2[:, 0],
|
3177
|
+
pos_cycle2[:, 1],
|
3178
|
+
pos_cycle2[:, 2],
|
3179
|
+
"-",
|
3180
|
+
label=f"root = {nodes2}",
|
3181
|
+
)
|
3182
|
+
ax.set_ylabel("y position")
|
3183
|
+
ax.set_xlabel("x position")
|
3184
|
+
ax.set_zlabel("z position")
|
3185
|
+
else:
|
3186
|
+
if projection == "xy" or projection == "yx" or projection is None:
|
3187
|
+
self.__plot_2d(
|
3188
|
+
pos_cycle1,
|
3189
|
+
pos_cycle2,
|
3190
|
+
nodes1,
|
3191
|
+
nodes2,
|
3192
|
+
ax,
|
3193
|
+
0,
|
3194
|
+
1,
|
3195
|
+
"x position",
|
3196
|
+
"y position",
|
3197
|
+
)
|
3198
|
+
elif projection == "xz" or projection == "zx":
|
3199
|
+
self.__plot_2d(
|
3200
|
+
pos_cycle1,
|
3201
|
+
pos_cycle2,
|
3202
|
+
nodes1,
|
3203
|
+
nodes2,
|
3204
|
+
ax,
|
3205
|
+
0,
|
3206
|
+
2,
|
3207
|
+
"x position",
|
3208
|
+
"z position",
|
3209
|
+
)
|
3210
|
+
elif projection == "yz" or projection == "zy":
|
3211
|
+
self.__plot_2d(
|
3212
|
+
pos_cycle1,
|
3213
|
+
pos_cycle2,
|
3214
|
+
nodes1,
|
3215
|
+
nodes2,
|
3216
|
+
ax,
|
3217
|
+
1,
|
3218
|
+
2,
|
3219
|
+
"y position",
|
3220
|
+
"z position",
|
3221
|
+
)
|
3222
|
+
elif projection == "pca":
|
3223
|
+
try:
|
3224
|
+
from sklearn.decomposition import PCA
|
3225
|
+
except ImportError:
|
3226
|
+
Warning(
|
3227
|
+
"scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
|
3228
|
+
)
|
3229
|
+
|
3230
|
+
# Apply PCA
|
3231
|
+
pca = PCA(n_components=2)
|
3232
|
+
pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
|
3233
|
+
pos_cycle1_2d = pca.transform(pos_cycle1)
|
3234
|
+
pos_cycle2_2d = pca.transform(pos_cycle2)
|
3235
|
+
|
3236
|
+
ax.plot(
|
3237
|
+
pos_cycle1_2d[:, 0],
|
3238
|
+
pos_cycle1_2d[:, 1],
|
3239
|
+
"-",
|
3240
|
+
label=f"root = {nodes1}",
|
3241
|
+
)
|
3242
|
+
ax.plot(
|
3243
|
+
pos_cycle2_2d[:, 0],
|
3244
|
+
pos_cycle2_2d[:, 1],
|
3245
|
+
"-",
|
3246
|
+
label=f"root = {nodes2}",
|
3247
|
+
)
|
3248
|
+
|
3249
|
+
# Set axis labels
|
3250
|
+
axes = ["x", "y", "z"]
|
3251
|
+
x_label = axes[np.argmax(np.abs(pca.components_[0]))]
|
3252
|
+
y_label = axes[np.argmax(np.abs(pca.components_[1]))]
|
3253
|
+
x_percent = 100 * (
|
3254
|
+
np.max(np.abs(pca.components_[0]))
|
3255
|
+
/ np.sum(np.abs(pca.components_[0]))
|
3256
|
+
)
|
3257
|
+
y_percent = 100 * (
|
3258
|
+
np.max(np.abs(pca.components_[1]))
|
3259
|
+
/ np.sum(np.abs(pca.components_[1]))
|
3260
|
+
)
|
3261
|
+
ax.set_xlabel(f"{x_percent:.0f}% of {x_label} position")
|
3262
|
+
ax.set_ylabel(f"{y_percent:.0f}% of {y_label} position")
|
3263
|
+
else:
|
3264
|
+
raise ValueError(
|
3265
|
+
"""Error: available projections are:
|
3266
|
+
'3d' : for the 3d visualization
|
3267
|
+
'xy' or None (default) : 2D projection of axis x and y
|
3268
|
+
'xz' : 2D projection of axis x and z
|
3269
|
+
'yz' : 2D projection of axis y and z
|
3270
|
+
'pca' : PCA projection"""
|
3271
|
+
)
|
3272
|
+
|
3273
|
+
connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
|
3274
|
+
|
3275
|
+
for connection in connections:
|
3276
|
+
xyz1 = connection[0]
|
3277
|
+
xyz2 = connection[1]
|
3278
|
+
x_pos = [xyz1[0], xyz2[0]]
|
3279
|
+
y_pos = [xyz1[1], xyz2[1]]
|
3280
|
+
z_pos = [xyz1[2], xyz2[2]]
|
3281
|
+
|
3282
|
+
if alig and projection != "pca":
|
3283
|
+
if projection == "3d":
|
3284
|
+
ax.plot(x_pos, y_pos, z_pos, "k--", color="grey")
|
3285
|
+
else:
|
3286
|
+
ax.plot(x_pos, y_pos, "k--", color="grey")
|
3287
|
+
|
3288
|
+
ax.set_aspect("equal")
|
3289
|
+
ax.legend()
|
3290
|
+
fig.tight_layout()
|
3291
|
+
|
3292
|
+
if alig and projection == "pca":
|
3293
|
+
warnings.warn(
|
3294
|
+
"Error: not possible to show alignment in PCA projection !",
|
3295
|
+
UserWarning,
|
3296
|
+
)
|
3297
|
+
|
3298
|
+
return distance, fig
|
3299
|
+
|
2228
3300
|
def first_labelling(self):
|
2229
|
-
self.labels =
|
3301
|
+
self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
|
2230
3302
|
|
2231
3303
|
def __init__(
|
2232
3304
|
self,
|
@@ -2259,12 +3331,12 @@ class lineageTree:
|
|
2259
3331
|
'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
|
2260
3332
|
default is 'binary'
|
2261
3333
|
"""
|
3334
|
+
self.name = name
|
2262
3335
|
self.time_nodes = {}
|
2263
3336
|
self.time_edges = {}
|
2264
3337
|
self.max_id = -1
|
2265
3338
|
self.next_id = []
|
2266
3339
|
self.nodes = set()
|
2267
|
-
self.edges = set()
|
2268
3340
|
self.successor = {}
|
2269
3341
|
self.predecessor = {}
|
2270
3342
|
self.pos = {}
|
@@ -2272,40 +3344,57 @@ class lineageTree:
|
|
2272
3344
|
self.time = {}
|
2273
3345
|
self.kdtrees = {}
|
2274
3346
|
self.spatial_density = {}
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2278
|
-
self.xml_attributes = []
|
2279
|
-
else:
|
2280
|
-
self.xml_attributes = xml_attributes
|
2281
|
-
file_type = file_type.lower()
|
2282
|
-
if file_type == "tgmm":
|
2283
|
-
self.read_tgmm_xml(file_format, tb, te, z_mult)
|
2284
|
-
self.t_b = tb
|
2285
|
-
self.t_e = te
|
2286
|
-
elif file_type == "mamut" or file_type == "trackmate":
|
2287
|
-
self.read_from_mamut_xml(file_format)
|
2288
|
-
elif file_type == "celegans":
|
2289
|
-
self.read_from_txt_for_celegans(file_format)
|
2290
|
-
elif file_type == "celegans_cao":
|
2291
|
-
self.read_from_txt_for_celegans_CAO(
|
2292
|
-
file_format, reorder=reorder, shape=shape, raw_size=raw_size
|
2293
|
-
)
|
2294
|
-
elif file_type == "mastodon":
|
2295
|
-
if isinstance(file_format, list) and len(file_format) == 2:
|
2296
|
-
self.read_from_mastodon_csv(file_format)
|
3347
|
+
if file_type and file_format:
|
3348
|
+
if xml_attributes is None:
|
3349
|
+
self.xml_attributes = []
|
2297
3350
|
else:
|
2298
|
-
|
2299
|
-
|
2300
|
-
|
2301
|
-
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
|
2306
|
-
|
2307
|
-
|
2308
|
-
|
2309
|
-
|
2310
|
-
|
2311
|
-
|
3351
|
+
self.xml_attributes = xml_attributes
|
3352
|
+
file_type = file_type.lower()
|
3353
|
+
if file_type == "tgmm":
|
3354
|
+
self.read_tgmm_xml(file_format, tb, te, z_mult)
|
3355
|
+
self.t_b = tb
|
3356
|
+
self.t_e = te
|
3357
|
+
elif file_type == "mamut" or file_type == "trackmate":
|
3358
|
+
self.read_from_mamut_xml(file_format)
|
3359
|
+
elif file_type == "celegans":
|
3360
|
+
self.read_from_txt_for_celegans(file_format)
|
3361
|
+
elif file_type == "celegans_cao":
|
3362
|
+
self.read_from_txt_for_celegans_CAO(
|
3363
|
+
file_format,
|
3364
|
+
reorder=reorder,
|
3365
|
+
shape=shape,
|
3366
|
+
raw_size=raw_size,
|
3367
|
+
)
|
3368
|
+
elif file_type == "mastodon":
|
3369
|
+
if isinstance(file_format, list) and len(file_format) == 2:
|
3370
|
+
self.read_from_mastodon_csv(file_format)
|
3371
|
+
else:
|
3372
|
+
if isinstance(file_format, list):
|
3373
|
+
file_format = file_format[0]
|
3374
|
+
self.read_from_mastodon(file_format, name)
|
3375
|
+
elif file_type == "astec":
|
3376
|
+
self.read_from_ASTEC(file_format, eigen)
|
3377
|
+
elif file_type == "csv":
|
3378
|
+
self.read_from_csv(file_format, z_mult, link=1, delim=delim)
|
3379
|
+
elif file_format and file_format.endswith(".lT"):
|
3380
|
+
with open(file_format, "br") as f:
|
3381
|
+
tmp = pkl.load(f)
|
3382
|
+
f.close()
|
3383
|
+
self.__dict__.update(tmp.__dict__)
|
3384
|
+
elif file_format is not None:
|
3385
|
+
self.read_from_binary(file_format)
|
3386
|
+
if self.name is None:
|
3387
|
+
try:
|
3388
|
+
self.name = Path(file_format).stem
|
3389
|
+
except:
|
3390
|
+
self.name = Path(file_format[0]).stem
|
3391
|
+
if [] in self.successor.values():
|
3392
|
+
successors = list(self.successor.keys())
|
3393
|
+
for succ in successors:
|
3394
|
+
if self[succ] == []:
|
3395
|
+
self.successor.pop(succ)
|
3396
|
+
if [] in self.predecessor.values():
|
3397
|
+
predecessors = list(self.predecessor.keys())
|
3398
|
+
for succ in predecessors:
|
3399
|
+
if self[succ] == []:
|
3400
|
+
self.predecessor.pop(succ)
|