LineageTree 1.8.0__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +27 -2
- LineageTree/legacy/export_csv.py +70 -0
- LineageTree/legacy/to_lineajea.py +30 -0
- LineageTree/legacy/to_motile.py +36 -0
- LineageTree/lineageTree.py +2298 -1471
- LineageTree/lineageTreeManager.py +767 -79
- LineageTree/loaders.py +942 -864
- LineageTree/test/test_lineageTree.py +634 -0
- LineageTree/test/test_uted.py +233 -0
- LineageTree/tree_approximation.py +488 -0
- LineageTree/utils.py +103 -103
- lineagetree-2.0.3.dist-info/METADATA +130 -0
- lineagetree-2.0.3.dist-info/RECORD +16 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.3.dist-info}/WHEEL +1 -1
- LineageTree/tree_styles.py +0 -334
- LineageTree-1.8.0.dist-info/METADATA +0 -156
- LineageTree-1.8.0.dist-info/RECORD +0 -11
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.3.dist-info/licenses}/LICENSE +0 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.3.dist-info}/top_level.txt +0 -0
LineageTree/lineageTree.py
CHANGED
@@ -2,643 +2,542 @@
|
|
2
2
|
# This file is subject to the terms and conditions defined in
|
3
3
|
# file 'LICENCE', which is part of this source code package.
|
4
4
|
# Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import importlib.metadata
|
5
9
|
import os
|
6
10
|
import pickle as pkl
|
7
11
|
import struct
|
8
12
|
import warnings
|
9
|
-
from collections.abc import Iterable
|
10
|
-
from functools import partial
|
13
|
+
from collections.abc import Callable, Iterable, Sequence
|
14
|
+
from functools import partial, wraps
|
11
15
|
from itertools import combinations
|
12
16
|
from numbers import Number
|
13
|
-
from
|
14
|
-
from typing import
|
15
|
-
|
16
|
-
|
17
|
-
from .tree_styles import tree_style
|
18
|
-
|
19
|
-
try:
|
20
|
-
from edist import uted
|
21
|
-
except ImportError:
|
22
|
-
warnings.warn(
|
23
|
-
"No edist installed therefore you will not be able to compute the tree edit distance.",
|
24
|
-
stacklevel=2,
|
25
|
-
)
|
17
|
+
from types import MappingProxyType
|
18
|
+
from typing import TYPE_CHECKING, Literal
|
19
|
+
|
20
|
+
import matplotlib.colors as mcolors
|
26
21
|
import matplotlib.pyplot as plt
|
27
22
|
import numpy as np
|
23
|
+
import svgwrite
|
24
|
+
from edist import uted
|
25
|
+
from matplotlib import colormaps
|
26
|
+
from matplotlib.collections import LineCollection
|
27
|
+
from packaging.version import Version
|
28
28
|
from scipy.interpolate import InterpolatedUnivariateSpline
|
29
|
-
from scipy.
|
30
|
-
from scipy.spatial import
|
29
|
+
from scipy.sparse import dok_array
|
30
|
+
from scipy.spatial import Delaunay, KDTree, distance
|
31
31
|
|
32
|
+
from . import __version__
|
33
|
+
from .tree_approximation import TreeApproximationTemplate, tree_style
|
32
34
|
from .utils import (
|
33
|
-
|
35
|
+
convert_style_to_number,
|
36
|
+
create_links_and_chains,
|
34
37
|
hierarchical_pos,
|
35
38
|
)
|
36
39
|
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from edist.alignment import Alignment
|
42
|
+
|
43
|
+
|
44
|
+
class dynamic_property(property):
|
45
|
+
def __init__(
|
46
|
+
self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
|
47
|
+
):
|
48
|
+
super().__init__(fget, fset, fdel, doc)
|
49
|
+
self.protected_name = protected_name
|
50
|
+
|
51
|
+
def __set_name__(self, owner, name):
|
52
|
+
self.name = name
|
53
|
+
if self.protected_name is None:
|
54
|
+
self.protected_name = f"_{name}"
|
55
|
+
if not hasattr(owner, "_protected_dynamic_properties"):
|
56
|
+
owner._protected_dynamic_properties = []
|
57
|
+
owner._protected_dynamic_properties.append(self.protected_name)
|
58
|
+
if not hasattr(owner, "_dynamic_properties"):
|
59
|
+
owner._dynamic_properties = []
|
60
|
+
owner._dynamic_properties += [name, self.protected_name]
|
61
|
+
setattr(owner, self.protected_name, None)
|
62
|
+
|
63
|
+
def __get__(self, instance, owner):
|
64
|
+
if instance is None:
|
65
|
+
return self
|
66
|
+
instance._has_been_reset = False
|
67
|
+
if getattr(instance, self.protected_name) is None:
|
68
|
+
value = super().__get__(instance, owner)
|
69
|
+
setattr(instance, self.protected_name, value)
|
70
|
+
return value
|
71
|
+
else:
|
72
|
+
return getattr(instance, self.protected_name)
|
73
|
+
|
74
|
+
|
75
|
+
class lineageTree:
|
76
|
+
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
|
77
|
+
|
78
|
+
def modifier(wrapped_func):
|
79
|
+
@wraps(wrapped_func)
|
80
|
+
def raising_flag(self, *args, **kwargs):
|
81
|
+
should_reset = (
|
82
|
+
not hasattr(self, "_has_been_reset")
|
83
|
+
or not self._has_been_reset
|
84
|
+
)
|
85
|
+
out_func = wrapped_func(self, *args, **kwargs)
|
86
|
+
if should_reset:
|
87
|
+
for prop in self._protected_dynamic_properties:
|
88
|
+
self.__dict__[prop] = None
|
89
|
+
self._has_been_reset = True
|
90
|
+
return out_func
|
91
|
+
|
92
|
+
return raising_flag
|
93
|
+
|
94
|
+
def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
|
95
|
+
"""Check if the connected component of a given node `n` has a cycle.
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
bool
|
100
|
+
True if the tree has cycles, False otherwise.
|
101
|
+
set of int
|
102
|
+
The set of nodes that have been checked.
|
103
|
+
"""
|
104
|
+
to_do = [n]
|
105
|
+
no_cycle = True
|
106
|
+
already_done = set()
|
107
|
+
while to_do and no_cycle:
|
108
|
+
current = to_do.pop(-1)
|
109
|
+
if current not in already_done:
|
110
|
+
already_done.add(current)
|
111
|
+
else:
|
112
|
+
no_cycle = False
|
113
|
+
to_do.extend(self._successor[current])
|
114
|
+
to_do = list(self._predecessor[n])
|
115
|
+
while to_do and no_cycle:
|
116
|
+
current = to_do.pop(-1)
|
117
|
+
if current not in already_done:
|
118
|
+
already_done.add(current)
|
119
|
+
else:
|
120
|
+
no_cycle = False
|
121
|
+
to_do.extend(self._predecessor[current])
|
122
|
+
return not no_cycle, already_done
|
123
|
+
|
124
|
+
def __check_for_cycles(self) -> bool:
|
125
|
+
"""Check if the tree has cycles.
|
126
|
+
|
127
|
+
Returns
|
128
|
+
-------
|
129
|
+
bool
|
130
|
+
True if the tree has cycles, False otherwise.
|
131
|
+
"""
|
132
|
+
to_do = set(self.nodes)
|
133
|
+
found_cycle = False
|
134
|
+
while to_do and not found_cycle:
|
135
|
+
current = to_do.pop()
|
136
|
+
found_cycle, done = self.__check_cc_cycles(current)
|
137
|
+
to_do.difference_update(done)
|
138
|
+
return found_cycle
|
37
139
|
|
38
|
-
|
39
|
-
def __eq__(self, other):
|
140
|
+
def __eq__(self, other) -> bool:
|
40
141
|
if isinstance(other, lineageTree):
|
41
|
-
return
|
42
|
-
|
142
|
+
return (
|
143
|
+
other._successor == self._successor
|
144
|
+
and other._predecessor == self._predecessor
|
145
|
+
and other._time == self._time
|
146
|
+
)
|
147
|
+
else:
|
148
|
+
return False
|
43
149
|
|
44
|
-
def get_next_id(self):
|
45
|
-
"""Computes the next authorized id.
|
150
|
+
def get_next_id(self) -> int:
|
151
|
+
"""Computes the next authorized id and assign it.
|
46
152
|
|
47
|
-
Returns
|
48
|
-
|
153
|
+
Returns
|
154
|
+
-------
|
155
|
+
int
|
156
|
+
next authorized id
|
49
157
|
"""
|
50
|
-
if self.max_id == -1 and self.nodes:
|
51
|
-
self.max_id = max(self.nodes)
|
52
|
-
if self.next_id == []:
|
158
|
+
if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
|
159
|
+
self.max_id = max(self.nodes) if len(self.nodes) else 0
|
160
|
+
if not hasattr(self, "next_id") or self.next_id == []:
|
53
161
|
self.max_id += 1
|
54
162
|
return self.max_id
|
55
163
|
else:
|
56
164
|
return self.next_id.pop()
|
57
165
|
|
58
|
-
def complete_lineage(self, nodes: Union[int, set] = None):
|
59
|
-
"""Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
|
60
|
-
for tree edit distance algorithms.
|
61
|
-
|
62
|
-
Args:
|
63
|
-
nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
|
64
|
-
"""
|
65
|
-
if nodes is None:
|
66
|
-
nodes = set(self.roots)
|
67
|
-
elif isinstance(nodes, int):
|
68
|
-
nodes = {nodes}
|
69
|
-
for node in nodes:
|
70
|
-
sub = set(self.get_sub_tree(node))
|
71
|
-
specific_leaves = sub.intersection(self.leaves)
|
72
|
-
for leaf in specific_leaves:
|
73
|
-
self.add_branch(
|
74
|
-
leaf,
|
75
|
-
(self.t_e - self.time[leaf]),
|
76
|
-
reverse=True,
|
77
|
-
move_timepoints=True,
|
78
|
-
)
|
79
|
-
|
80
166
|
###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
|
81
|
-
|
167
|
+
@modifier
|
168
|
+
def add_chain(
|
82
169
|
self,
|
83
|
-
|
170
|
+
node: int,
|
84
171
|
length: int,
|
85
|
-
|
86
|
-
pos:
|
87
|
-
|
88
|
-
|
89
|
-
"""Adds a branch of specific length to a node either as a successor or as a predecessor.
|
172
|
+
downstream: bool,
|
173
|
+
pos: Callable | None = None,
|
174
|
+
) -> int:
|
175
|
+
"""Adds a chain of specific length to a node either as a successor or as a predecessor.
|
90
176
|
If it is placed on top of a tree all the nodes will move timepoints #length down.
|
91
177
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
178
|
+
Parameters
|
179
|
+
----------
|
180
|
+
node : int
|
181
|
+
Id of the successor (predecessor if `downstream==False`)
|
182
|
+
length : int
|
183
|
+
The length of the new chain.
|
184
|
+
downstream : bool, default=True
|
185
|
+
If `True` will create a chain that goes forwards in time otherwise backwards.
|
186
|
+
pos : np.ndarray, optional
|
187
|
+
The new position of the chain. Defaults to None.
|
188
|
+
|
189
|
+
Returns
|
190
|
+
-------
|
191
|
+
int
|
192
|
+
Id of the first node of the sublineage.
|
100
193
|
"""
|
101
194
|
if length == 0:
|
102
|
-
return
|
103
|
-
if
|
104
|
-
raise
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
new_times = {
|
111
|
-
node: self.time[node] + length for node in nodes_to_move
|
112
|
-
}
|
113
|
-
for node in nodes_to_move:
|
114
|
-
old_time = self.time[node]
|
115
|
-
self.time_nodes[old_time].remove(node)
|
116
|
-
self.time_nodes.setdefault(old_time + length, set()).add(
|
117
|
-
node
|
118
|
-
)
|
119
|
-
self.time.update(new_times)
|
120
|
-
for t in range(length - 1, -1, -1):
|
121
|
-
_next = self.add_node(
|
122
|
-
time + t,
|
123
|
-
succ=pred,
|
124
|
-
pos=self.pos[original],
|
125
|
-
reverse=True,
|
126
|
-
)
|
127
|
-
pred = _next
|
128
|
-
else:
|
129
|
-
for t in range(length):
|
130
|
-
_next = self.add_node(
|
131
|
-
time - t,
|
132
|
-
succ=pred,
|
133
|
-
pos=self.pos[original],
|
134
|
-
reverse=True,
|
135
|
-
)
|
136
|
-
pred = _next
|
195
|
+
return node
|
196
|
+
if length < 1:
|
197
|
+
raise ValueError("Length cannot be <1")
|
198
|
+
if downstream:
|
199
|
+
for _ in range(int(length)):
|
200
|
+
old_node = node
|
201
|
+
node = self._add_node(pred=[old_node])
|
202
|
+
self._time[node] = self._time[old_node] + 1
|
137
203
|
else:
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
reverse=False,
|
204
|
+
if self._predecessor[node]:
|
205
|
+
raise Warning("The node already has a predecessor.")
|
206
|
+
if self._time[node] - length < self.t_b:
|
207
|
+
raise Warning(
|
208
|
+
"A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
|
144
209
|
)
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
"""
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
new_tr = self.add_branch(new_lT, len(cycle), move_timepoints=False)
|
174
|
-
self.roots.add(new_tr)
|
175
|
-
self.labels[new_tr] = f"R-Split {label_of_root}"
|
176
|
-
return new_tr
|
177
|
-
else:
|
178
|
-
raise Warning("No division of the branch")
|
179
|
-
|
180
|
-
def fuse_lineage_tree(
|
181
|
-
self,
|
182
|
-
l1_root: int,
|
183
|
-
l2_root: int,
|
184
|
-
length_l1: int = 0,
|
185
|
-
length_l2: int = 0,
|
186
|
-
length: int = 1,
|
187
|
-
):
|
188
|
-
"""Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
|
189
|
-
first node and the node of the resulting lineage can also be longer.
|
190
|
-
|
191
|
-
Args:
|
192
|
-
l1_root (int): Id of the first root
|
193
|
-
l2_root (int): Id of the second root
|
194
|
-
length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
|
195
|
-
length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
|
196
|
-
length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
|
197
|
-
|
198
|
-
Returns:
|
199
|
-
int: The id of the root of the new lineage.
|
200
|
-
"""
|
201
|
-
if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
|
202
|
-
raise ValueError("Please select 2 roots.")
|
203
|
-
if self.time[l1_root] != self.time[l2_root]:
|
204
|
-
warnings.warn(
|
205
|
-
"Using lineagetrees that do not exist in the same timepoint. The operation will continue"
|
206
|
-
)
|
207
|
-
new_root1 = self.add_branch(l1_root, length_l1)
|
208
|
-
new_root2 = self.add_branch(l2_root, length_l2)
|
209
|
-
next_root1 = self[new_root1][0]
|
210
|
-
self.remove_nodes(new_root1)
|
211
|
-
self.successor[new_root2].append(next_root1)
|
212
|
-
self.predecessor[next_root1] = [new_root2]
|
213
|
-
new_branch = self.add_branch(
|
214
|
-
new_root2,
|
215
|
-
length - 1,
|
216
|
-
)
|
217
|
-
self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
|
218
|
-
return new_branch
|
219
|
-
|
220
|
-
def copy_lineage(self, root):
|
221
|
-
"""
|
222
|
-
Copies the structure of a tree and makes a new with new nodes.
|
223
|
-
Warning does not take into account the predecessor of the root node.
|
224
|
-
|
225
|
-
Args:
|
226
|
-
root (int): The root of the tree to be copied
|
210
|
+
for _ in range(int(length)):
|
211
|
+
old_node = node
|
212
|
+
node = self._add_node(succ=[old_node])
|
213
|
+
self._time[node] = self._time[old_node] - 1
|
214
|
+
return node
|
215
|
+
|
216
|
+
@modifier
|
217
|
+
def add_root(self, t: int, pos: list | None = None) -> int:
|
218
|
+
"""Adds a root to a specific timepoint.
|
219
|
+
|
220
|
+
Parameters
|
221
|
+
----------
|
222
|
+
t :int
|
223
|
+
The timepoint the node is going to be added.
|
224
|
+
pos : list
|
225
|
+
The position of the new node.
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
int
|
229
|
+
The id of the new root.
|
230
|
+
"""
|
231
|
+
C_next = self.get_next_id()
|
232
|
+
self._successor[C_next] = ()
|
233
|
+
self._predecessor[C_next] = ()
|
234
|
+
self._time[C_next] = t
|
235
|
+
self.pos[C_next] = pos if isinstance(pos, list) else []
|
236
|
+
self._changed_roots = True
|
237
|
+
return C_next
|
227
238
|
|
228
|
-
|
229
|
-
int: The root of the new tree.
|
230
|
-
"""
|
231
|
-
new_nodes = {
|
232
|
-
old_node: self.get_next_id()
|
233
|
-
for old_node in self.get_sub_tree(root)
|
234
|
-
}
|
235
|
-
self.nodes.update(new_nodes.values())
|
236
|
-
for old_node, new_node in new_nodes.items():
|
237
|
-
self.time[new_node] = self.time[old_node]
|
238
|
-
succ = self.successor.get(old_node)
|
239
|
-
if succ:
|
240
|
-
self.successor[new_node] = [new_nodes[n] for n in succ]
|
241
|
-
pred = self.predecessor.get(old_node)
|
242
|
-
if pred:
|
243
|
-
self.predecessor[new_node] = [new_nodes[n] for n in pred]
|
244
|
-
self.pos[new_node] = self.pos[old_node] + 0.5
|
245
|
-
self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
|
246
|
-
new_root = new_nodes[root]
|
247
|
-
self.labels[new_root] = f"Copy of {root}"
|
248
|
-
if self.time[new_root] == 0:
|
249
|
-
self.roots.add(new_root)
|
250
|
-
return new_root
|
251
|
-
|
252
|
-
def add_node(
|
239
|
+
def _add_node(
|
253
240
|
self,
|
254
|
-
|
255
|
-
|
256
|
-
pos: np.ndarray = None,
|
257
|
-
nid: int = None,
|
258
|
-
reverse: bool = False,
|
241
|
+
succ: list | None = None,
|
242
|
+
pred: list | None = None,
|
243
|
+
pos: np.ndarray | None = None,
|
244
|
+
nid: int | None = None,
|
259
245
|
) -> int:
|
260
|
-
"""Adds a node to the
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
246
|
+
"""Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
|
247
|
+
Does not handle time! You cannot enter both a successor and a predecessor.
|
248
|
+
|
249
|
+
Parameters
|
250
|
+
----------
|
251
|
+
succ : list
|
252
|
+
list of ids of the nodes the new node is a successor to
|
253
|
+
pred : list
|
254
|
+
list of ids of the nodes the new node is a predecessor to
|
255
|
+
pos : np.ndarray, optional
|
256
|
+
position of the new node
|
257
|
+
nid : int, optional
|
258
|
+
id value of the new node, to be used carefully,
|
259
|
+
if None is provided the new id is automatically computed.
|
260
|
+
|
261
|
+
Returns
|
262
|
+
-------
|
263
|
+
int
|
264
|
+
id of the new node.
|
265
|
+
"""
|
266
|
+
if not succ and not pred:
|
267
|
+
raise Warning(
|
268
|
+
"Please enter a successor or a predecessor, otherwise use the add_roots() function."
|
269
|
+
)
|
275
270
|
C_next = self.get_next_id() if nid is None else nid
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
self.
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
271
|
+
if succ:
|
272
|
+
self._successor[C_next] = succ
|
273
|
+
for suc in succ:
|
274
|
+
self._predecessor[suc] = (C_next,)
|
275
|
+
else:
|
276
|
+
self._successor[C_next] = ()
|
277
|
+
if pred:
|
278
|
+
self._predecessor[C_next] = pred
|
279
|
+
self._successor[pred[0]] = self._successor.setdefault(
|
280
|
+
pred[0], ()
|
281
|
+
) + (C_next,)
|
282
|
+
else:
|
283
|
+
self._predecessor[C_next] = ()
|
284
|
+
if isinstance(pos, list):
|
285
|
+
self.pos[C_next] = pos
|
286
286
|
return C_next
|
287
287
|
|
288
|
-
|
288
|
+
@modifier
|
289
|
+
def remove_nodes(self, group: int | set | list) -> None:
|
289
290
|
"""Removes a group of nodes from the LineageTree
|
290
291
|
|
291
|
-
|
292
|
-
|
292
|
+
Parameters
|
293
|
+
----------
|
294
|
+
group : set of int or list of int or int
|
295
|
+
One or more nodes that are to be removed.
|
293
296
|
"""
|
294
|
-
if isinstance(group, int):
|
297
|
+
if isinstance(group, int | float):
|
295
298
|
group = {group}
|
296
299
|
if isinstance(group, list):
|
297
300
|
group = set(group)
|
298
|
-
group =
|
299
|
-
self.nodes.difference_update(group)
|
300
|
-
times = {self.time.pop(n) for n in group}
|
301
|
-
for t in times:
|
302
|
-
self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
|
301
|
+
group = self.nodes.intersection(group)
|
303
302
|
for node in group:
|
304
|
-
self.
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
Args:
|
324
|
-
node (int): Any node of the branch to be modified/
|
325
|
-
new_length (int): The new length of the tree.
|
326
|
-
"""
|
327
|
-
if new_length <= 1:
|
328
|
-
warnings.warn("New length should be more than 1", stacklevel=2)
|
329
|
-
return None
|
330
|
-
cycle = self.get_cycle(node)
|
331
|
-
length = len(cycle)
|
332
|
-
successors = self.successor.get(cycle[-1])
|
333
|
-
if length == 1 and new_length != 1:
|
334
|
-
pred = self.predecessor.pop(node, None)
|
335
|
-
new_node = self.add_branch(
|
336
|
-
node,
|
337
|
-
length=new_length - 1,
|
338
|
-
move_timepoints=True,
|
339
|
-
reverse=False,
|
340
|
-
)
|
341
|
-
if pred:
|
342
|
-
self.successor[pred[0]].remove(node)
|
343
|
-
self.successor[pred[0]].append(new_node)
|
344
|
-
elif self.leaves.intersection(cycle) and new_length < length:
|
345
|
-
self.remove_nodes(cycle[new_length:])
|
346
|
-
elif new_length < length:
|
347
|
-
to_remove = length - new_length
|
348
|
-
last_cell = cycle[new_length - 1]
|
349
|
-
subtree = self.get_sub_tree(cycle[-1])[1:]
|
350
|
-
self.remove_nodes(cycle[new_length:])
|
351
|
-
self.successor[last_cell] = successors
|
352
|
-
if successors:
|
353
|
-
for succ in successors:
|
354
|
-
self.predecessor[succ] = [last_cell]
|
355
|
-
for node in subtree:
|
356
|
-
if node not in cycle[new_length - 1 :]:
|
357
|
-
old_time = self.time[node]
|
358
|
-
self.time[node] = old_time - to_remove
|
359
|
-
self.time_nodes[old_time].remove(node)
|
360
|
-
self.time_nodes.setdefault(
|
361
|
-
old_time - to_remove, set()
|
362
|
-
).add(node)
|
363
|
-
elif length < new_length:
|
364
|
-
to_add = new_length - length
|
365
|
-
last_cell = cycle[-1]
|
366
|
-
self.successor.pop(cycle[-2])
|
367
|
-
self.predecessor.pop(last_cell)
|
368
|
-
succ = self.add_branch(
|
369
|
-
last_cell, length=to_add, move_timepoints=True, reverse=False
|
370
|
-
)
|
371
|
-
self.predecessor[succ] = [cycle[-2]]
|
372
|
-
self.successor[cycle[-2]] = [succ]
|
373
|
-
self.time[last_cell] = (
|
374
|
-
self.time[self.predecessor[last_cell][0]] + 1
|
375
|
-
)
|
376
|
-
else:
|
377
|
-
return None
|
378
|
-
|
379
|
-
@property
|
380
|
-
def time_resolution(self):
|
381
|
-
if not hasattr(self, "_time_resolution"):
|
382
|
-
self.time_resolution = 1
|
383
|
-
return self._time_resolution / 10
|
384
|
-
|
385
|
-
@time_resolution.setter
|
386
|
-
def time_resolution(self, time_resolution):
|
387
|
-
if time_resolution is not None:
|
388
|
-
self._time_resolution = int(time_resolution * 10)
|
389
|
-
else:
|
390
|
-
warnings.warn("Time resolution set to default 0", stacklevel=2)
|
391
|
-
self._time_resolution = 10
|
392
|
-
|
393
|
-
@property
|
394
|
-
def depth(self):
|
395
|
-
if not hasattr(self, "_depth"):
|
396
|
-
self._depth = {}
|
397
|
-
for leaf in self.leaves:
|
398
|
-
self._depth[leaf] = 1
|
399
|
-
while leaf in self.predecessor:
|
400
|
-
parent = self.predecessor[leaf][0]
|
401
|
-
current_depth = self._depth.get(parent, 0)
|
402
|
-
self._depth[parent] = max(
|
403
|
-
self._depth[leaf] + 1, current_depth
|
404
|
-
)
|
405
|
-
leaf = parent
|
406
|
-
for root in self.roots - set(self._depth):
|
407
|
-
self._depth[root] = 1
|
408
|
-
return self._depth
|
303
|
+
for attr in self.__dict__:
|
304
|
+
attr_value = self.__getattribute__(attr)
|
305
|
+
if isinstance(attr_value, dict) and attr not in [
|
306
|
+
"successor",
|
307
|
+
"predecessor",
|
308
|
+
"_successor",
|
309
|
+
"_predecessor",
|
310
|
+
]:
|
311
|
+
attr_value.pop(node, ())
|
312
|
+
if self._predecessor.get(node):
|
313
|
+
self._successor[self._predecessor[node][0]] = tuple(
|
314
|
+
set(
|
315
|
+
self._successor[self._predecessor[node][0]]
|
316
|
+
).difference(group)
|
317
|
+
)
|
318
|
+
for p_node in self._successor.get(node, []):
|
319
|
+
self._predecessor[p_node] = ()
|
320
|
+
self._predecessor.pop(node, ())
|
321
|
+
self._successor.pop(node, ())
|
409
322
|
|
410
323
|
@property
|
411
|
-
def
|
412
|
-
|
324
|
+
def successor(self) -> MappingProxyType[int, tuple[int]]:
|
325
|
+
"""The successor of the tree."""
|
326
|
+
if not hasattr(self, "_protected_successor"):
|
327
|
+
self._protected_successor = MappingProxyType(self._successor)
|
328
|
+
return self._protected_successor
|
413
329
|
|
414
330
|
@property
|
415
|
-
def
|
416
|
-
|
331
|
+
def predecessor(self) -> MappingProxyType[int, tuple[int]]:
|
332
|
+
"""The predecessor of the tree."""
|
333
|
+
if not hasattr(self, "_protected_predecessor"):
|
334
|
+
self._protected_predecessor = MappingProxyType(self._predecessor)
|
335
|
+
return self._protected_predecessor
|
417
336
|
|
418
337
|
@property
|
419
|
-
def
|
420
|
-
|
338
|
+
def time(self) -> MappingProxyType[int, int]:
|
339
|
+
"""The time of the tree."""
|
340
|
+
if not hasattr(self, "_protected_time"):
|
341
|
+
self._protected_time = MappingProxyType(self._time)
|
342
|
+
return self._protected_time
|
343
|
+
|
344
|
+
@dynamic_property
|
345
|
+
def t_b(self) -> int:
|
346
|
+
"""The first timepoint of the tree."""
|
347
|
+
return min(self._time.values())
|
348
|
+
|
349
|
+
@dynamic_property
|
350
|
+
def t_e(self) -> int:
|
351
|
+
"""The last timepoint of the tree."""
|
352
|
+
return max(self._time.values())
|
353
|
+
|
354
|
+
@dynamic_property
|
355
|
+
def nodes(self) -> frozenset[int]:
|
356
|
+
"""Nodes of the tree"""
|
357
|
+
return frozenset(self._successor.keys())
|
358
|
+
|
359
|
+
@dynamic_property
|
360
|
+
def number_of_nodes(self) -> int:
|
361
|
+
return len(self.nodes)
|
362
|
+
|
363
|
+
@dynamic_property
|
364
|
+
def depth(self) -> dict[int, int]:
|
365
|
+
"""The depth of each node in the tree."""
|
366
|
+
_depth = {}
|
367
|
+
for leaf in self.leaves:
|
368
|
+
_depth[leaf] = 1
|
369
|
+
while leaf in self._predecessor and self._predecessor[leaf]:
|
370
|
+
parent = self._predecessor[leaf][0]
|
371
|
+
current_depth = _depth.get(parent, 0)
|
372
|
+
_depth[parent] = max(_depth[leaf] + 1, current_depth)
|
373
|
+
leaf = parent
|
374
|
+
for root in self.roots - set(_depth):
|
375
|
+
_depth[root] = 1
|
376
|
+
return _depth
|
377
|
+
|
378
|
+
@dynamic_property
|
379
|
+
def roots(self) -> frozenset[int]:
|
380
|
+
"""Set of roots of the tree"""
|
381
|
+
return frozenset({s for s, p in self._predecessor.items() if p == ()})
|
382
|
+
|
383
|
+
@dynamic_property
|
384
|
+
def leaves(self) -> frozenset[int]:
|
385
|
+
"""Set of leaves"""
|
386
|
+
return frozenset({p for p, s in self._successor.items() if s == ()})
|
387
|
+
|
388
|
+
@dynamic_property
|
389
|
+
def edges(self) -> tuple[tuple[int, int]]:
|
390
|
+
"""Set of edges"""
|
391
|
+
return tuple((p, si) for p, s in self._successor.items() for si in s)
|
421
392
|
|
422
393
|
@property
|
423
|
-
def labels(self):
|
394
|
+
def labels(self) -> dict[int, str]:
|
395
|
+
"""The labels of the nodes."""
|
424
396
|
if not hasattr(self, "_labels"):
|
425
|
-
if hasattr(self, "
|
397
|
+
if hasattr(self, "node_name"):
|
426
398
|
self._labels = {
|
427
|
-
i: self.
|
399
|
+
i: self.node_name.get(i, "Unlabeled") for i in self.roots
|
428
400
|
}
|
429
401
|
else:
|
430
402
|
self._labels = {
|
431
403
|
root: "Unlabeled"
|
432
404
|
for root in self.roots
|
433
405
|
for leaf in self.find_leaves(root)
|
434
|
-
if abs(self.
|
406
|
+
if abs(self._time[leaf] - self._time[root])
|
435
407
|
>= abs(self.t_e - self.t_b) / 4
|
436
408
|
}
|
437
409
|
return self._labels
|
438
410
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
f.write("define POINT %d\n" % ((length) * nb_points))
|
445
|
-
f.write("Parameters {\n")
|
446
|
-
f.write('\tContentType "HxSpatialGraph"\n')
|
447
|
-
f.write("}\n")
|
448
|
-
|
449
|
-
f.write("VERTEX { float[3] VertexCoordinates } @1\n")
|
450
|
-
f.write("EDGE { int[2] EdgeConnectivity } @2\n")
|
451
|
-
f.write("EDGE { int NumEdgePoints } @3\n")
|
452
|
-
f.write("POINT { float[3] EdgePointCoordinates } @4\n")
|
453
|
-
f.write("VERTEX { float Vcolor } @5\n")
|
454
|
-
f.write("VERTEX { int Vbool } @6\n")
|
455
|
-
f.write("EDGE { float Ecolor } @7\n")
|
456
|
-
f.write("VERTEX { int Vbool2 } @8\n")
|
457
|
-
|
458
|
-
def write_to_am(
|
459
|
-
self,
|
460
|
-
path_format: str,
|
461
|
-
t_b: int = None,
|
462
|
-
t_e: int = None,
|
463
|
-
length: int = 5,
|
464
|
-
manual_labels: dict = None,
|
465
|
-
default_label: int = 5,
|
466
|
-
new_pos: np.ndarray = None,
|
467
|
-
):
|
468
|
-
"""Writes a lineageTree into an Amira readable data (.am format).
|
469
|
-
|
470
|
-
Args:
|
471
|
-
path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
|
472
|
-
t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
|
473
|
-
t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
|
474
|
-
note, if there is no 'to_take_time' attribute, self.time_nodes
|
475
|
-
is considered instead (historical)
|
476
|
-
length (int): length of the track to print (how many time before).
|
477
|
-
manual_labels ({id: label, }): dictionary that maps cell ids to
|
478
|
-
default_label (int): default value for the manual label
|
479
|
-
new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
|
480
|
-
if new_pos == None (default) then self.pos is considered.
|
481
|
-
"""
|
482
|
-
if not hasattr(self, "to_take_time"):
|
483
|
-
self.to_take_time = self.time_nodes
|
484
|
-
if t_b is None:
|
485
|
-
t_b = min(self.to_take_time.keys())
|
486
|
-
if t_e is None:
|
487
|
-
t_e = max(self.to_take_time.keys())
|
488
|
-
if new_pos is None:
|
489
|
-
new_pos = self.pos
|
490
|
-
|
491
|
-
if manual_labels is None:
|
492
|
-
manual_labels = {}
|
493
|
-
for t in range(t_b, t_e + 1):
|
494
|
-
with open(path_format % t, "w") as f:
|
495
|
-
nb_points = len(self.to_take_time[t])
|
496
|
-
self._write_header_am(f, nb_points, length)
|
497
|
-
points_v = {}
|
498
|
-
for C in self.to_take_time[t]:
|
499
|
-
C_tmp = C
|
500
|
-
positions = []
|
501
|
-
for _ in range(length):
|
502
|
-
C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
|
503
|
-
positions.append(new_pos[C_tmp])
|
504
|
-
points_v[C] = positions
|
505
|
-
|
506
|
-
f.write("@1\n")
|
507
|
-
for C in self.to_take_time[t]:
|
508
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
|
509
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
|
510
|
-
|
511
|
-
f.write("@2\n")
|
512
|
-
for i, _ in enumerate(self.to_take_time[t]):
|
513
|
-
f.write("%d %d\n" % (2 * i, 2 * i + 1))
|
514
|
-
|
515
|
-
f.write("@3\n")
|
516
|
-
for _ in self.to_take_time[t]:
|
517
|
-
f.write("%d\n" % (length))
|
518
|
-
|
519
|
-
f.write("@4\n")
|
520
|
-
for C in self.to_take_time[t]:
|
521
|
-
for p in points_v[C]:
|
522
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
|
523
|
-
|
524
|
-
f.write("@5\n")
|
525
|
-
for C in self.to_take_time[t]:
|
526
|
-
f.write(f"{manual_labels.get(C, default_label):f}\n")
|
527
|
-
f.write(f"{0:f}\n")
|
528
|
-
|
529
|
-
f.write("@6\n")
|
530
|
-
for C in self.to_take_time[t]:
|
531
|
-
f.write(
|
532
|
-
"%d\n"
|
533
|
-
% (
|
534
|
-
int(
|
535
|
-
manual_labels.get(C, default_label)
|
536
|
-
!= default_label
|
537
|
-
)
|
538
|
-
)
|
539
|
-
)
|
540
|
-
f.write("%d\n" % (0))
|
541
|
-
|
542
|
-
f.write("@7\n")
|
543
|
-
for C in self.to_take_time[t]:
|
544
|
-
f.write(
|
545
|
-
f"{np.linalg.norm(points_v[C][0] - points_v[C][-1]):f}\n"
|
546
|
-
)
|
547
|
-
|
548
|
-
f.write("@8\n")
|
549
|
-
for _ in self.to_take_time[t]:
|
550
|
-
f.write("%d\n" % (1))
|
551
|
-
f.write("%d\n" % (0))
|
552
|
-
f.close()
|
411
|
+
@property
|
412
|
+
def time_resolution(self) -> float:
|
413
|
+
if not hasattr(self, "_time_resolution"):
|
414
|
+
self._time_resolution = 0
|
415
|
+
return self._time_resolution / 10
|
553
416
|
|
554
|
-
|
555
|
-
|
417
|
+
@time_resolution.setter
|
418
|
+
def time_resolution(self, time_resolution) -> None:
|
419
|
+
if time_resolution is not None and time_resolution > 0:
|
420
|
+
self._time_resolution = int(time_resolution * 10)
|
421
|
+
else:
|
422
|
+
warnings.warn("Time resolution set to default 0", stacklevel=2)
|
423
|
+
self._time_resolution = 0
|
424
|
+
|
425
|
+
def __setstate__(self, state):
|
426
|
+
if "_successor" not in state:
|
427
|
+
state["_successor"] = state["successor"]
|
428
|
+
if "_predecessor" not in state:
|
429
|
+
state["_predecessor"] = state["predecessor"]
|
430
|
+
if "_time" not in state:
|
431
|
+
state["_time"] = state["time"]
|
432
|
+
self.__dict__.update(state)
|
433
|
+
|
434
|
+
def _get_height(self, c: int, done: dict) -> float:
|
435
|
+
"""Recursively computes the height of a node within a tree times a space factor.
|
556
436
|
This function is specific to the function write_to_svg.
|
557
437
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
c : int
|
441
|
+
id of a node in a lineage tree from which the height will be computed from
|
442
|
+
done : dict mapping int to list of two int
|
443
|
+
a dictionary that maps a node id to its vertical and horizontal position
|
444
|
+
|
445
|
+
Returns
|
446
|
+
-------
|
447
|
+
float
|
448
|
+
the height of the node `c`
|
563
449
|
"""
|
564
450
|
if c in done:
|
565
451
|
return done[c][0]
|
566
452
|
else:
|
567
453
|
P = np.mean(
|
568
|
-
[self._get_height(di, done) for di in self.
|
454
|
+
[self._get_height(di, done) for di in self._successor[c]]
|
569
455
|
)
|
570
|
-
done[c] = [P, self.vert_space_factor * self.
|
456
|
+
done[c] = [P, self.vert_space_factor * self._time[c]]
|
571
457
|
return P
|
572
458
|
|
573
459
|
def write_to_svg(
|
574
460
|
self,
|
575
461
|
file_name: str,
|
576
|
-
roots: list = None,
|
462
|
+
roots: list | None = None,
|
577
463
|
draw_nodes: bool = True,
|
578
464
|
draw_edges: bool = True,
|
579
|
-
order_key:
|
465
|
+
order_key: Callable | None = None,
|
580
466
|
vert_space_factor: float = 0.5,
|
581
467
|
horizontal_space: float = 1,
|
582
|
-
node_size:
|
583
|
-
stroke_width:
|
468
|
+
node_size: Callable | str | None = None,
|
469
|
+
stroke_width: Callable | None = None,
|
584
470
|
factor: float = 1.0,
|
585
|
-
node_color:
|
586
|
-
stroke_color:
|
587
|
-
positions: dict = None,
|
588
|
-
node_color_map:
|
589
|
-
|
590
|
-
):
|
591
|
-
##### remove background? default True background value? default 1
|
592
|
-
|
471
|
+
node_color: Callable | str | None = None,
|
472
|
+
stroke_color: Callable | None = None,
|
473
|
+
positions: dict | None = None,
|
474
|
+
node_color_map: Callable | str | None = None,
|
475
|
+
) -> None:
|
593
476
|
"""Writes the lineage tree to an SVG file.
|
594
477
|
Node and edges coloring and size can be provided.
|
595
478
|
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
479
|
+
Parameters
|
480
|
+
----------
|
481
|
+
file_name : str
|
482
|
+
filesystem filename valid for `open()`
|
483
|
+
roots : list of int, defaults to `self.roots`
|
484
|
+
list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
|
485
|
+
draw_nodes : bool, default True
|
486
|
+
wether to print the nodes or not
|
487
|
+
draw_edges : bool, default True
|
488
|
+
wether to print the edges or not
|
489
|
+
order_key : Callable, optional
|
490
|
+
function that would work for the attribute `key=` for the `sort`/`sorted` function
|
491
|
+
vert_space_factor : float, default=0.5
|
492
|
+
the vertical position of a node is its time. `vert_space_factor` is a
|
493
|
+
multiplier to space more or less nodes in time
|
494
|
+
horizontal_space : float, default=1
|
495
|
+
space between two consecutive nodes
|
496
|
+
node_size : Callable or str, optional
|
497
|
+
a function that maps a node id to a `float` value that will determine the
|
498
|
+
radius of the node. The default function return the constant value `vertical_space_factor/2.1`
|
499
|
+
If a string is given instead and it is a property of the tree,
|
500
|
+
the the size will be mapped according to the property
|
501
|
+
stroke_width : Callable, optional
|
502
|
+
a function that maps a node id to a `float` value that will determine the
|
503
|
+
width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
|
504
|
+
factor : float, default=1.0
|
505
|
+
scaling factor for nodes positions, default 1
|
506
|
+
node_color : Callable or str, optional
|
507
|
+
a function that maps a node id to a triplet between 0 and 255.
|
508
|
+
The triplet will determine the color of the node. If a string is given instead and it is a property
|
509
|
+
of the tree, the the color will be mapped according to the property
|
510
|
+
node_color_map : Callable or str, optional
|
511
|
+
the name of the colormap to use to color the nodes, or a colormap function
|
512
|
+
stroke_color : Callable, optional
|
513
|
+
a function that maps a node id to a triplet between 0 and 255.
|
514
|
+
The triplet will determine the color of the stroke of the inward edge.
|
515
|
+
positions : dict mapping int to list of two float, optional
|
516
|
+
dictionary that maps a node id to a 2D position.
|
517
|
+
Default `None`. If provided it will be used to position the nodes.
|
620
518
|
"""
|
621
|
-
import svgwrite
|
622
519
|
|
623
520
|
def normalize_values(v, nodes, _range, shift, mult):
|
624
521
|
min_ = np.percentile(v, 1)
|
625
522
|
max_ = np.percentile(v, 99)
|
626
523
|
values = _range * ((v - min_) / (max_ - min_)) + shift
|
627
|
-
values_dict_nodes = dict(zip(nodes, values))
|
524
|
+
values_dict_nodes = dict(zip(nodes, values, strict=True))
|
628
525
|
return lambda x: values_dict_nodes[x] * mult
|
629
526
|
|
630
527
|
if roots is None:
|
631
528
|
roots = self.roots
|
632
529
|
if hasattr(self, "image_label"):
|
633
|
-
roots = [
|
530
|
+
roots = [node for node in roots if self.image_label[node] != 1]
|
634
531
|
|
635
532
|
if node_size is None:
|
636
533
|
|
637
534
|
def node_size(x):
|
638
535
|
return vert_space_factor / 2.1
|
639
536
|
|
640
|
-
|
641
|
-
values = np.array(
|
537
|
+
else:
|
538
|
+
values = np.array(
|
539
|
+
[self._successor[node_size][c] for c in self.nodes]
|
540
|
+
)
|
642
541
|
node_size = normalize_values(
|
643
542
|
values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
|
644
543
|
)
|
@@ -653,18 +552,19 @@ class lineageTree(lineageTreeLoaders):
|
|
653
552
|
return 0, 0, 0
|
654
553
|
|
655
554
|
elif isinstance(node_color, str) and node_color in self.__dict__:
|
656
|
-
|
657
|
-
from matplotlib import colormaps
|
555
|
+
from matplotlib import colormaps
|
658
556
|
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
values = np.array(
|
557
|
+
if node_color_map in colormaps:
|
558
|
+
cmap = colormaps[node_color_map]
|
559
|
+
else:
|
560
|
+
cmap = colormaps["viridis"]
|
561
|
+
values = np.array(
|
562
|
+
[self._successor[node_color][c] for c in self.nodes]
|
563
|
+
)
|
664
564
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
665
565
|
|
666
566
|
def node_color(x):
|
667
|
-
return [k * 255 for k in
|
567
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
668
568
|
|
669
569
|
coloring_edges = stroke_color is not None
|
670
570
|
if not coloring_edges:
|
@@ -673,24 +573,25 @@ class lineageTree(lineageTreeLoaders):
|
|
673
573
|
return 0, 0, 0
|
674
574
|
|
675
575
|
elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
|
676
|
-
|
677
|
-
from matplotlib import colormaps
|
576
|
+
from matplotlib import colormaps
|
678
577
|
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
values = np.array(
|
578
|
+
if node_color_map in colormaps:
|
579
|
+
cmap = colormaps[node_color_map]
|
580
|
+
else:
|
581
|
+
cmap = colormaps["viridis"]
|
582
|
+
values = np.array(
|
583
|
+
[self._successor[stroke_color][c] for c in self.nodes]
|
584
|
+
)
|
684
585
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
685
586
|
|
686
587
|
def stroke_color(x):
|
687
|
-
return [k * 255 for k in
|
588
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
688
589
|
|
689
590
|
prev_x = 0
|
690
591
|
self.vert_space_factor = vert_space_factor
|
691
592
|
if order_key is not None:
|
692
593
|
roots.sort(key=order_key)
|
693
|
-
|
594
|
+
treated_nodes = []
|
694
595
|
|
695
596
|
pos_given = positions is not None
|
696
597
|
if not pos_given:
|
@@ -701,25 +602,26 @@ class lineageTree(lineageTreeLoaders):
|
|
701
602
|
[0.0, 0.0],
|
702
603
|
]
|
703
604
|
* len(self.nodes),
|
704
|
-
|
605
|
+
strict=True,
|
606
|
+
),
|
705
607
|
)
|
706
608
|
for _i, r in enumerate(roots):
|
707
609
|
r_leaves = []
|
708
610
|
to_do = [r]
|
709
611
|
while len(to_do) != 0:
|
710
612
|
curr = to_do.pop(0)
|
711
|
-
|
712
|
-
if
|
613
|
+
treated_nodes += [curr]
|
614
|
+
if not self._successor[curr]:
|
713
615
|
if order_key is not None:
|
714
|
-
to_do += sorted(self.
|
616
|
+
to_do += sorted(self._successor[curr], key=order_key)
|
715
617
|
else:
|
716
|
-
to_do += self.
|
618
|
+
to_do += self._successor[curr]
|
717
619
|
else:
|
718
620
|
r_leaves += [curr]
|
719
621
|
r_pos = {
|
720
622
|
leave: [
|
721
623
|
prev_x + horizontal_space * (1 + j),
|
722
|
-
self.vert_space_factor * self.
|
624
|
+
self.vert_space_factor * self._time[leave],
|
723
625
|
]
|
724
626
|
for j, leave in enumerate(r_leaves)
|
725
627
|
}
|
@@ -734,12 +636,12 @@ class lineageTree(lineageTreeLoaders):
|
|
734
636
|
size=factor * np.max(list(positions.values()), axis=0),
|
735
637
|
)
|
736
638
|
if draw_edges and not draw_nodes and not coloring_edges:
|
737
|
-
to_do = set(
|
639
|
+
to_do = set(treated_nodes)
|
738
640
|
while len(to_do) > 0:
|
739
641
|
curr = to_do.pop()
|
740
|
-
|
741
|
-
x1, y1 = positions[
|
742
|
-
x2, y2 = positions[
|
642
|
+
c_chain = self.get_chain_of_node(curr)
|
643
|
+
x1, y1 = positions[c_chain[0]]
|
644
|
+
x2, y2 = positions[c_chain[-1]]
|
743
645
|
dwg.add(
|
744
646
|
dwg.line(
|
745
647
|
(factor * x1, factor * y1),
|
@@ -747,7 +649,7 @@ class lineageTree(lineageTreeLoaders):
|
|
747
649
|
stroke=svgwrite.rgb(0, 0, 0),
|
748
650
|
)
|
749
651
|
)
|
750
|
-
for si in self[
|
652
|
+
for si in self._successor[c_chain[-1]]:
|
751
653
|
x3, y3 = positions[si]
|
752
654
|
dwg.add(
|
753
655
|
dwg.line(
|
@@ -756,11 +658,11 @@ class lineageTree(lineageTreeLoaders):
|
|
756
658
|
stroke=svgwrite.rgb(0, 0, 0),
|
757
659
|
)
|
758
660
|
)
|
759
|
-
to_do.difference_update(
|
661
|
+
to_do.difference_update(c_chain)
|
760
662
|
else:
|
761
|
-
for c in
|
663
|
+
for c in treated_nodes:
|
762
664
|
x1, y1 = positions[c]
|
763
|
-
for si in self[c]:
|
665
|
+
for si in self._successor[c]:
|
764
666
|
x2, y2 = positions[si]
|
765
667
|
if draw_edges:
|
766
668
|
dwg.add(
|
@@ -771,7 +673,7 @@ class lineageTree(lineageTreeLoaders):
|
|
771
673
|
stroke_width=svgwrite.pt(stroke_width(si)),
|
772
674
|
)
|
773
675
|
)
|
774
|
-
for c in
|
676
|
+
for c in treated_nodes:
|
775
677
|
x1, y1 = positions[c]
|
776
678
|
if draw_nodes:
|
777
679
|
dwg.add(
|
@@ -788,49 +690,58 @@ class lineageTree(lineageTreeLoaders):
|
|
788
690
|
fname: str,
|
789
691
|
t_min: int = -1,
|
790
692
|
t_max: int = np.inf,
|
791
|
-
nodes_to_use: list = None,
|
693
|
+
nodes_to_use: list[int] | None = None,
|
792
694
|
temporal: bool = True,
|
793
|
-
spatial: str = None,
|
695
|
+
spatial: str | None = None,
|
794
696
|
write_layout: bool = True,
|
795
|
-
node_properties: dict = None,
|
697
|
+
node_properties: dict | None = None,
|
796
698
|
Names: bool = False,
|
797
|
-
):
|
699
|
+
) -> None:
|
798
700
|
"""Write a lineage tree into an understable tulip file.
|
799
701
|
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
702
|
+
Parameters
|
703
|
+
----------
|
704
|
+
fname : str
|
705
|
+
path to the tulip file to create
|
706
|
+
t_min : int, default=-1
|
707
|
+
minimum time to consider
|
708
|
+
t_max : int, default=np.inf
|
709
|
+
maximum time to consider
|
710
|
+
nodes_to_use : list of int, optional
|
711
|
+
list of nodes to show in the graph,
|
712
|
+
if `None` then self.nodes is used
|
713
|
+
(taking into account `t_min` and `t_max`)
|
714
|
+
temporal : bool, default=True
|
715
|
+
True if the temporal links should be printed
|
716
|
+
spatial : str, optional
|
717
|
+
Build spatial edges from a spatial neighbourhood graph.
|
718
|
+
The graph has to be computed before running this function
|
719
|
+
'ball': neighbours at a given distance,
|
720
|
+
'kn': k-nearest neighbours,
|
721
|
+
'GG': gabriel graph,
|
722
|
+
None: no spatial edges are writen.
|
723
|
+
Default None
|
724
|
+
write_layout : bool, default=True
|
725
|
+
write the spatial position as layout if True
|
726
|
+
do not write spatial position otherwise
|
727
|
+
node_properties : dict mapping str to list of dict of properties and its default value, optional
|
728
|
+
a dictionary of properties to write
|
729
|
+
To a key representing the name of the property is
|
730
|
+
paired a dictionary that maps a node id to a property
|
731
|
+
and a default value for this property
|
732
|
+
Names : bool, default=True
|
733
|
+
Only works with ASTEC outputs, True to sort the nodes by their names
|
823
734
|
"""
|
824
735
|
|
825
736
|
def format_names(names_which_matter):
|
826
|
-
"""Return an ensured formated
|
737
|
+
"""Return an ensured formated node names"""
|
827
738
|
tmp = {}
|
828
739
|
for k, v in names_which_matter.items():
|
829
740
|
tmp[k] = (
|
830
741
|
v.split(".")[0][0]
|
831
|
-
+ "
|
742
|
+
+ "{:02d}".format(int(v.split(".")[0][1:]))
|
832
743
|
+ "."
|
833
|
-
+ "
|
744
|
+
+ "{:04d}".format(int(v.split(".")[1][:-1]))
|
834
745
|
+ v.split(".")[1][-1]
|
835
746
|
)
|
836
747
|
return tmp
|
@@ -857,20 +768,20 @@ class lineageTree(lineageTreeLoaders):
|
|
857
768
|
if not nodes_to_use:
|
858
769
|
if t_max != np.inf or t_min > -1:
|
859
770
|
nodes_to_use = [
|
860
|
-
n for n in self.nodes if t_min < self.
|
771
|
+
n for n in self.nodes if t_min < self._time[n] <= t_max
|
861
772
|
]
|
862
773
|
edges_to_use = []
|
863
774
|
if temporal:
|
864
775
|
edges_to_use += [
|
865
776
|
e
|
866
777
|
for e in self.edges
|
867
|
-
if t_min < self.
|
778
|
+
if t_min < self._time[e[0]] < t_max
|
868
779
|
]
|
869
780
|
if spatial:
|
870
781
|
edges_to_use += [
|
871
782
|
e
|
872
783
|
for e in s_edges
|
873
|
-
if t_min < self.
|
784
|
+
if t_min < self._time[e[0]] < t_max
|
874
785
|
]
|
875
786
|
else:
|
876
787
|
nodes_to_use = list(self.nodes)
|
@@ -884,12 +795,12 @@ class lineageTree(lineageTreeLoaders):
|
|
884
795
|
nodes_to_use = set(nodes_to_use)
|
885
796
|
if temporal:
|
886
797
|
for n in nodes_to_use:
|
887
|
-
for d in self.
|
798
|
+
for d in self._successor[n]:
|
888
799
|
if d in nodes_to_use:
|
889
800
|
edges_to_use.append((n, d))
|
890
801
|
if spatial:
|
891
802
|
edges_to_use += [
|
892
|
-
e for e in s_edges if t_min < self.
|
803
|
+
e for e in s_edges if t_min < self._time[e[0]] < t_max
|
893
804
|
]
|
894
805
|
nodes_to_use = set(nodes_to_use)
|
895
806
|
if Names:
|
@@ -907,12 +818,12 @@ class lineageTree(lineageTreeLoaders):
|
|
907
818
|
for k, v in node_properties[Names][0].items():
|
908
819
|
if (
|
909
820
|
len(
|
910
|
-
self.
|
911
|
-
self.
|
821
|
+
self._successor.get(
|
822
|
+
self._predecessor.get(k, [-1])[0], ()
|
912
823
|
)
|
913
824
|
)
|
914
825
|
!= 1
|
915
|
-
or self.
|
826
|
+
or self._time[k] == t_min + 1
|
916
827
|
):
|
917
828
|
tmp_names[k] = v
|
918
829
|
node_properties[Names][0] = tmp_names
|
@@ -942,7 +853,7 @@ class lineageTree(lineageTreeLoaders):
|
|
942
853
|
f.write('\t(default "0" "0")\n')
|
943
854
|
for n in nodes_to_use:
|
944
855
|
f.write(
|
945
|
-
"\t(node " + str(n) + ' "' + str(self.
|
856
|
+
"\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
|
946
857
|
)
|
947
858
|
f.write(")\n")
|
948
859
|
|
@@ -989,7 +900,9 @@ class lineageTree(lineageTreeLoaders):
|
|
989
900
|
f.write(")")
|
990
901
|
f.close()
|
991
902
|
|
992
|
-
def to_binary(
|
903
|
+
def to_binary(
|
904
|
+
self, fname: str, starting_points: list[int] | None = None
|
905
|
+
) -> None:
|
993
906
|
"""Writes the lineage tree (a forest) as a binary structure
|
994
907
|
(assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
|
995
908
|
The binary file is composed of 3 sequences of numbers and
|
@@ -1004,36 +917,37 @@ class lineageTree(lineageTreeLoaders):
|
|
1004
917
|
The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
|
1005
918
|
The *pos_sequence* is stored as a list of double.
|
1006
919
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
920
|
+
Parameters
|
921
|
+
----------
|
922
|
+
fname : str
|
923
|
+
name of the binary file
|
924
|
+
starting_points : list of int, optional
|
925
|
+
list of the roots to be written.
|
926
|
+
If `None`, all roots are written, default value, None
|
1011
927
|
"""
|
1012
928
|
if starting_points is None:
|
1013
|
-
starting_points =
|
1014
|
-
c for c in self.successor if self.predecessor.get(c, []) == []
|
1015
|
-
]
|
929
|
+
starting_points = list(self.roots)
|
1016
930
|
number_sequence = [-1]
|
1017
931
|
pos_sequence = []
|
1018
932
|
time_sequence = []
|
1019
933
|
for c in starting_points:
|
1020
|
-
time_sequence.append(self.
|
934
|
+
time_sequence.append(self._time.get(c, 0))
|
1021
935
|
to_treat = [c]
|
1022
|
-
while to_treat
|
936
|
+
while to_treat:
|
1023
937
|
curr_c = to_treat.pop()
|
1024
938
|
number_sequence.append(curr_c)
|
1025
939
|
pos_sequence += list(self.pos[curr_c])
|
1026
|
-
if self[curr_c] ==
|
940
|
+
if self._successor[curr_c] == ():
|
1027
941
|
number_sequence.append(-1)
|
1028
|
-
elif len(self.
|
1029
|
-
to_treat += self.
|
942
|
+
elif len(self._successor[curr_c]) == 1:
|
943
|
+
to_treat += self._successor[curr_c]
|
1030
944
|
else:
|
1031
945
|
number_sequence.append(-2)
|
1032
|
-
to_treat += self.
|
946
|
+
to_treat += self._successor[curr_c]
|
1033
947
|
remaining_nodes = set(self.nodes) - set(number_sequence)
|
1034
948
|
|
1035
949
|
for c in remaining_nodes:
|
1036
|
-
time_sequence.append(self.
|
950
|
+
time_sequence.append(self._time.get(c, 0))
|
1037
951
|
number_sequence.append(c)
|
1038
952
|
pos_sequence += list(self.pos[c])
|
1039
953
|
number_sequence.append(-1)
|
@@ -1048,50 +962,269 @@ class lineageTree(lineageTreeLoaders):
|
|
1048
962
|
|
1049
963
|
f.close()
|
1050
964
|
|
1051
|
-
def write(self, fname: str):
|
1052
|
-
"""
|
1053
|
-
Write a lineage tree on disk as an .lT file.
|
965
|
+
def write(self, fname: str) -> None:
|
966
|
+
"""Write a lineage tree on disk as an .lT file.
|
1054
967
|
|
1055
|
-
|
1056
|
-
|
968
|
+
Parameters
|
969
|
+
----------
|
970
|
+
fname : str
|
971
|
+
path to and name of the file to save
|
1057
972
|
"""
|
1058
|
-
if os.path.splitext(fname)[-1] != ".
|
973
|
+
if os.path.splitext(fname)[-1].upper() != ".LT":
|
1059
974
|
fname = os.path.extsep.join((fname, "lT"))
|
975
|
+
if hasattr(self, "_protected_predecessor"):
|
976
|
+
del self._protected_predecessor
|
977
|
+
if hasattr(self, "_protected_successor"):
|
978
|
+
del self._protected_successor
|
979
|
+
if hasattr(self, "_protected_time"):
|
980
|
+
del self._protected_time
|
1060
981
|
with open(fname, "bw") as f:
|
1061
982
|
pkl.dump(self, f)
|
1062
983
|
f.close()
|
1063
984
|
|
1064
985
|
@classmethod
|
1065
|
-
def load(clf, fname: str
|
1066
|
-
"""
|
1067
|
-
Loading a lineage tree from a ".lT" file.
|
986
|
+
def load(clf, fname: str):
|
987
|
+
"""Loading a lineage tree from a '.lT' file.
|
1068
988
|
|
1069
|
-
|
1070
|
-
|
989
|
+
Parameters
|
990
|
+
----------
|
991
|
+
fname : str
|
992
|
+
path to and name of the file to read
|
1071
993
|
|
1072
|
-
Returns
|
1073
|
-
|
994
|
+
Returns
|
995
|
+
-------
|
996
|
+
LineageTree
|
997
|
+
loaded file
|
1074
998
|
"""
|
1075
999
|
with open(fname, "br") as f:
|
1076
1000
|
lT = pkl.load(f)
|
1077
1001
|
f.close()
|
1002
|
+
if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
|
1003
|
+
"2.0.0"
|
1004
|
+
):
|
1005
|
+
properties = {
|
1006
|
+
prop_name: prop
|
1007
|
+
for prop_name, prop in lT.__dict__.items()
|
1008
|
+
if (isinstance(prop, dict) or prop_name == "_time_resolution")
|
1009
|
+
and prop_name
|
1010
|
+
not in [
|
1011
|
+
"successor",
|
1012
|
+
"predecessor",
|
1013
|
+
"time",
|
1014
|
+
"_successor",
|
1015
|
+
"_predecessor",
|
1016
|
+
"_time",
|
1017
|
+
"pos",
|
1018
|
+
"labels",
|
1019
|
+
]
|
1020
|
+
+ lineageTree._dynamic_properties
|
1021
|
+
+ lineageTree._protected_dynamic_properties
|
1022
|
+
}
|
1023
|
+
lT = lineageTree(
|
1024
|
+
successor=lT._successor,
|
1025
|
+
time=lT._time,
|
1026
|
+
pos=lT.pos,
|
1027
|
+
name=lT.name if hasattr(lT, "name") else None,
|
1028
|
+
**properties,
|
1029
|
+
)
|
1078
1030
|
if not hasattr(lT, "time_resolution"):
|
1079
|
-
lT.time_resolution =
|
1031
|
+
lT.time_resolution = 1
|
1032
|
+
|
1080
1033
|
return lT
|
1081
1034
|
|
1082
|
-
def
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1035
|
+
def get_predecessors(
|
1036
|
+
self,
|
1037
|
+
x: int,
|
1038
|
+
depth: int | None = None,
|
1039
|
+
start_time: int | None = None,
|
1040
|
+
end_time: int | None = None,
|
1041
|
+
) -> list[int]:
|
1042
|
+
"""Computes the predecessors of the node `x` up to
|
1043
|
+
`depth` predecessors or the begining of the life of `x`.
|
1044
|
+
The ordered list of ids is returned.
|
1045
|
+
|
1046
|
+
Parameters
|
1047
|
+
----------
|
1048
|
+
x : int
|
1049
|
+
id of the node to compute
|
1050
|
+
depth : int
|
1051
|
+
maximum number of predecessors to return
|
1052
|
+
|
1053
|
+
Returns
|
1054
|
+
-------
|
1055
|
+
list of int
|
1056
|
+
list of ids, the last id is `x`
|
1057
|
+
"""
|
1058
|
+
if not start_time:
|
1059
|
+
start_time = self.t_b
|
1060
|
+
if not end_time:
|
1061
|
+
end_time = self.t_e
|
1062
|
+
unconstrained_chain = [x]
|
1063
|
+
chain = [x] if start_time <= self._time[x] <= end_time else []
|
1064
|
+
acc = 0
|
1065
|
+
while (
|
1066
|
+
acc != depth
|
1067
|
+
and start_time < self._time[unconstrained_chain[0]]
|
1068
|
+
and (
|
1069
|
+
self._predecessor[unconstrained_chain[0]] != ()
|
1070
|
+
and ( # Please dont change very important even if it looks weird.
|
1071
|
+
len(
|
1072
|
+
self._successor[
|
1073
|
+
self._predecessor[unconstrained_chain[0]][0]
|
1074
|
+
]
|
1075
|
+
)
|
1076
|
+
== 1
|
1077
|
+
)
|
1078
|
+
)
|
1079
|
+
):
|
1080
|
+
unconstrained_chain.insert(
|
1081
|
+
0, self._predecessor[unconstrained_chain[0]][0]
|
1082
|
+
)
|
1083
|
+
acc += 1
|
1084
|
+
if start_time <= self._time[unconstrained_chain[0]] <= end_time:
|
1085
|
+
chain.insert(0, unconstrained_chain[0])
|
1086
|
+
|
1087
|
+
return chain
|
1088
|
+
|
1089
|
+
def get_successors(
|
1090
|
+
self, x: int, depth: int | None = None, end_time: int | None = None
|
1091
|
+
) -> list[int]:
|
1092
|
+
"""Computes the successors of the node `x` up to
|
1093
|
+
`depth` successors or the end of the life of `x`.
|
1094
|
+
The ordered list of ids is returned.
|
1095
|
+
|
1096
|
+
Parameters
|
1097
|
+
----------
|
1098
|
+
x : int
|
1099
|
+
id of the node to compute
|
1100
|
+
depth : int, optional
|
1101
|
+
maximum number of predecessors to return
|
1102
|
+
end_time : int, optional
|
1103
|
+
maximum time to consider
|
1104
|
+
|
1105
|
+
Returns
|
1106
|
+
-------
|
1107
|
+
list of int
|
1108
|
+
list of ids, the first id is `x`
|
1109
|
+
"""
|
1110
|
+
if end_time is None:
|
1111
|
+
end_time = self.t_e
|
1112
|
+
chain = [x]
|
1113
|
+
acc = 0
|
1114
|
+
while (
|
1115
|
+
len(self._successor[chain[-1]]) == 1
|
1116
|
+
and acc != depth
|
1117
|
+
and self._time[chain[-1]] < end_time
|
1118
|
+
):
|
1119
|
+
chain += self._successor[chain[-1]]
|
1120
|
+
acc += 1
|
1121
|
+
|
1122
|
+
return chain
|
1123
|
+
|
1124
|
+
def get_chain_of_node(
|
1125
|
+
self,
|
1126
|
+
x: int,
|
1127
|
+
depth: int | None = None,
|
1128
|
+
depth_pred: int | None = None,
|
1129
|
+
depth_succ: int | None = None,
|
1130
|
+
end_time: int | None = None,
|
1131
|
+
) -> list[int]:
|
1132
|
+
"""Computes the predecessors and successors of the node `x` up to
|
1133
|
+
`depth_pred` predecessors plus `depth_succ` successors.
|
1134
|
+
If the value `depth` is provided and not None,
|
1135
|
+
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1136
|
+
The ordered list of ids is returned.
|
1137
|
+
If all `depth` are None, the full chain is returned.
|
1138
|
+
|
1139
|
+
Parameters
|
1140
|
+
----------
|
1141
|
+
x : int
|
1142
|
+
id of the node to compute
|
1143
|
+
depth : int, optional
|
1144
|
+
maximum number of predecessors and successor to return
|
1145
|
+
depth_pred : int, optional
|
1146
|
+
maximum number of predecessors to return
|
1147
|
+
depth_succ : int, optional
|
1148
|
+
maximum number of successors to return
|
1149
|
+
|
1150
|
+
Returns
|
1151
|
+
-------
|
1152
|
+
list of int
|
1153
|
+
list of node ids
|
1093
1154
|
"""
|
1094
|
-
|
1155
|
+
if end_time is None:
|
1156
|
+
end_time = self.t_e
|
1157
|
+
if depth is not None:
|
1158
|
+
depth_pred = depth_succ = depth
|
1159
|
+
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1160
|
+
:-1
|
1161
|
+
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1162
|
+
|
1163
|
+
@dynamic_property
|
1164
|
+
def all_chains(self) -> list[list[int]]:
|
1165
|
+
"""List of all chains in the tree, ordered in depth-first search."""
|
1166
|
+
return self._compute_all_chains()
|
1167
|
+
|
1168
|
+
@dynamic_property
|
1169
|
+
def time_nodes(self):
|
1170
|
+
_time_nodes = {}
|
1171
|
+
for c, t in self._time.items():
|
1172
|
+
_time_nodes.setdefault(t, set()).add(c)
|
1173
|
+
return _time_nodes
|
1174
|
+
|
1175
|
+
def m(self, i, j):
|
1176
|
+
if (i, j) not in self._tmp_parenting:
|
1177
|
+
if i == j: # the distance to the node itself is 0
|
1178
|
+
self._tmp_parenting[(i, j)] = 0
|
1179
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1180
|
+
elif not self._predecessor[
|
1181
|
+
j
|
1182
|
+
]: # j and i are note connected so the distance if inf
|
1183
|
+
self._tmp_parenting[(i, j)] = np.inf
|
1184
|
+
else: # the distance between i and j is the distance between i and pred(j) + 1
|
1185
|
+
self._tmp_parenting[(i, j)] = (
|
1186
|
+
self.m(i, self._predecessor[j][0]) + 1
|
1187
|
+
)
|
1188
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1189
|
+
self._parenting[j, i] = -self._tmp_parenting[(i, j)]
|
1190
|
+
return self._tmp_parenting[(i, j)]
|
1191
|
+
|
1192
|
+
@property
|
1193
|
+
def parenting(self):
|
1194
|
+
if not hasattr(self, "_parenting"):
|
1195
|
+
self._parenting = dok_array((max(self.nodes) + 1,) * 2)
|
1196
|
+
self._tmp_parenting = {}
|
1197
|
+
for i, j in combinations(self.nodes, 2):
|
1198
|
+
if self._time[j] < self.time[i]:
|
1199
|
+
i, j = j, i
|
1200
|
+
self._tmp_parenting[(i, j)] = self.m(i, j)
|
1201
|
+
del self._tmp_parenting
|
1202
|
+
return self._parenting
|
1203
|
+
|
1204
|
+
def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
|
1205
|
+
"""Get a 3d kdtree for the dataset at time `t`.
|
1206
|
+
The kdtree is stored in `self.kdtrees[t]` and returned.
|
1207
|
+
The correspondancy list is also returned.
|
1208
|
+
|
1209
|
+
Parameters
|
1210
|
+
----------
|
1211
|
+
t : int
|
1212
|
+
time
|
1213
|
+
|
1214
|
+
Returns
|
1215
|
+
-------
|
1216
|
+
KDTree
|
1217
|
+
The KDTree corresponding to the lineage tree at time `t`
|
1218
|
+
np.ndarray
|
1219
|
+
The correspondancy list in the KDTree.
|
1220
|
+
If the query in the kdtree gives you the value `i`,
|
1221
|
+
then it corresponds to the id in the tree `to_check_self[i]`
|
1222
|
+
"""
|
1223
|
+
to_check_self = list(self.nodes_at_t(t=t))
|
1224
|
+
|
1225
|
+
if not hasattr(self, "kdtrees"):
|
1226
|
+
self.kdtrees = {}
|
1227
|
+
|
1095
1228
|
if t not in self.kdtrees:
|
1096
1229
|
data_corres = {}
|
1097
1230
|
data = []
|
@@ -1104,16 +1237,21 @@ class lineageTree(lineageTreeLoaders):
|
|
1104
1237
|
idx3d = self.kdtrees[t]
|
1105
1238
|
return idx3d, np.array(to_check_self)
|
1106
1239
|
|
1107
|
-
def get_gabriel_graph(self, t: int) -> dict:
|
1108
|
-
"""Build the Gabriel graph of the given graph for time point `t
|
1109
|
-
The Garbiel graph is then stored in self.Gabriel_graph and returned
|
1110
|
-
|
1240
|
+
def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
|
1241
|
+
"""Build the Gabriel graph of the given graph for time point `t`.
|
1242
|
+
The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
|
1243
|
+
|
1244
|
+
.. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
|
1111
1245
|
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1246
|
+
Parameters
|
1247
|
+
----------
|
1248
|
+
t : int
|
1249
|
+
time
|
1250
|
+
|
1251
|
+
Returns
|
1252
|
+
-------
|
1253
|
+
dict of int to set of int
|
1254
|
+
A dictionary that maps a node to the set of its neighbors
|
1117
1255
|
"""
|
1118
1256
|
if not hasattr(self, "Gabriel_graph"):
|
1119
1257
|
self.Gabriel_graph = {}
|
@@ -1158,178 +1296,110 @@ class lineageTree(lineageTreeLoaders):
|
|
1158
1296
|
|
1159
1297
|
return self.Gabriel_graph[t]
|
1160
1298
|
|
1161
|
-
def
|
1162
|
-
self,
|
1163
|
-
) -> list:
|
1164
|
-
"""Computes the
|
1165
|
-
|
1166
|
-
The ordered list of ids is returned.
|
1299
|
+
def get_all_chains_of_subtree(
|
1300
|
+
self, node: int, end_time: int | None = None
|
1301
|
+
) -> list[list[int]]:
|
1302
|
+
"""Computes all the chains of the subtree spawn by a given node.
|
1303
|
+
Similar to get_all_chains().
|
1167
1304
|
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
if not start_time:
|
1175
|
-
start_time = self.t_b
|
1176
|
-
if not end_time:
|
1177
|
-
end_time = self.t_e
|
1178
|
-
unconstrained_cycle = [x]
|
1179
|
-
cycle = [x] if start_time <= self.time[x] <= end_time else []
|
1180
|
-
acc = 0
|
1181
|
-
while (
|
1182
|
-
len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
|
1183
|
-
== 1
|
1184
|
-
and acc != depth
|
1185
|
-
and start_time
|
1186
|
-
<= self.time.get(
|
1187
|
-
self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
|
1188
|
-
)
|
1189
|
-
):
|
1190
|
-
unconstrained_cycle.insert(
|
1191
|
-
0, self.predecessor[unconstrained_cycle[0]][0]
|
1192
|
-
)
|
1193
|
-
acc += 1
|
1194
|
-
if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
|
1195
|
-
cycle.insert(0, unconstrained_cycle[0])
|
1196
|
-
|
1197
|
-
return cycle
|
1198
|
-
|
1199
|
-
def get_successors(
|
1200
|
-
self, x: int, depth: int = None, end_time: int = None
|
1201
|
-
) -> list:
|
1202
|
-
"""Computes the successors of the node `x` up to
|
1203
|
-
`depth` successors or the end of the life of `x`.
|
1204
|
-
The ordered list of ids is returned.
|
1305
|
+
Parameters
|
1306
|
+
----------
|
1307
|
+
node : int
|
1308
|
+
The node from which we want to get its chains.
|
1309
|
+
end_time : int, optional
|
1310
|
+
The time at which we want to stop the chains.
|
1205
1311
|
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
[int, ]: list of ids, the first id is `x`
|
1211
|
-
"""
|
1212
|
-
if end_time is None:
|
1213
|
-
end_time = self.t_e
|
1214
|
-
cycle = [x]
|
1215
|
-
acc = 0
|
1216
|
-
while (
|
1217
|
-
len(self[cycle[-1]]) == 1
|
1218
|
-
and acc != depth
|
1219
|
-
and self.time[cycle[-1]] < end_time
|
1220
|
-
):
|
1221
|
-
cycle += self.successor[cycle[-1]]
|
1222
|
-
acc += 1
|
1223
|
-
|
1224
|
-
return cycle
|
1225
|
-
|
1226
|
-
def get_cycle(
|
1227
|
-
self,
|
1228
|
-
x: int,
|
1229
|
-
depth: int = None,
|
1230
|
-
depth_pred: int = None,
|
1231
|
-
depth_succ: int = None,
|
1232
|
-
end_time: int = None,
|
1233
|
-
) -> list:
|
1234
|
-
"""Computes the predecessors and successors of the node `x` up to
|
1235
|
-
`depth_pred` predecessors plus `depth_succ` successors.
|
1236
|
-
If the value `depth` is provided and not None,
|
1237
|
-
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1238
|
-
The ordered list of ids is returned.
|
1239
|
-
If all `depth` are None, the full cycle is returned.
|
1240
|
-
|
1241
|
-
Args:
|
1242
|
-
x (int): id of the node to compute
|
1243
|
-
depth (int): maximum number of predecessors and successor to return
|
1244
|
-
depth_pred (int): maximum number of predecessors to return
|
1245
|
-
depth_succ (int): maximum number of successors to return
|
1246
|
-
Returns:
|
1247
|
-
[int, ]: list of ids
|
1248
|
-
"""
|
1249
|
-
if end_time is None:
|
1250
|
-
end_time = self.t_e
|
1251
|
-
if depth is not None:
|
1252
|
-
depth_pred = depth_succ = depth
|
1253
|
-
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1254
|
-
:-1
|
1255
|
-
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1256
|
-
|
1257
|
-
@property
|
1258
|
-
def all_tracks(self):
|
1259
|
-
if not hasattr(self, "_all_tracks"):
|
1260
|
-
self._all_tracks = self.get_all_tracks()
|
1261
|
-
return self._all_tracks
|
1262
|
-
|
1263
|
-
def get_all_branches_of_node(
|
1264
|
-
self, node: int, end_time: int = None
|
1265
|
-
) -> list:
|
1266
|
-
"""Computes all the tracks of the subtree spawn by a given node.
|
1267
|
-
Similar to get_all_tracks().
|
1268
|
-
|
1269
|
-
Args:
|
1270
|
-
node (int, optional): The node that we want to get its branches.
|
1271
|
-
|
1272
|
-
Returns:
|
1273
|
-
([[int, ...], ...]): list of lists containing track cell ids
|
1312
|
+
Returns
|
1313
|
+
-------
|
1314
|
+
list of list of int
|
1315
|
+
list of chains
|
1274
1316
|
"""
|
1275
1317
|
if not end_time:
|
1276
1318
|
end_time = self.t_e
|
1277
|
-
|
1278
|
-
to_do = list(self[
|
1319
|
+
chains = [self.get_successors(node)]
|
1320
|
+
to_do = list(self._successor[chains[0][-1]])
|
1279
1321
|
while to_do:
|
1280
1322
|
current = to_do.pop()
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
return
|
1304
|
-
|
1305
|
-
def
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1323
|
+
chain = self.get_successors(current, end_time=end_time)
|
1324
|
+
if self._time[chain[-1]] <= end_time:
|
1325
|
+
chains += [chain]
|
1326
|
+
to_do += self._successor[chain[-1]]
|
1327
|
+
return chains
|
1328
|
+
|
1329
|
+
def _compute_all_chains(self) -> list[list[int]]:
|
1330
|
+
"""Computes all the chains of a given lineage tree,
|
1331
|
+
stores it in `self.all_chains` and returns it.
|
1332
|
+
|
1333
|
+
Returns
|
1334
|
+
-------
|
1335
|
+
list of list of int
|
1336
|
+
list of chains
|
1337
|
+
"""
|
1338
|
+
all_chains = []
|
1339
|
+
to_do = sorted(self.roots, key=self.time.get, reverse=True)
|
1340
|
+
while len(to_do) != 0:
|
1341
|
+
current = to_do.pop()
|
1342
|
+
chain = self.get_chain_of_node(current)
|
1343
|
+
all_chains += [chain]
|
1344
|
+
to_do.extend(self._successor[chain[-1]])
|
1345
|
+
return all_chains
|
1346
|
+
|
1347
|
+
def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
|
1348
|
+
self, nodes: Iterable | int | None = None
|
1349
|
+
) -> dict[int, list[list[int]]]:
|
1350
|
+
"""Returns all the chains in the subtrees spawned by each of the given nodes.
|
1351
|
+
|
1352
|
+
Parameters
|
1353
|
+
----------
|
1354
|
+
nodes : Iterable or int, optional
|
1355
|
+
id or Iterable of ids of the nodes to be computed, if `None` all roots are used
|
1356
|
+
|
1357
|
+
Returns
|
1358
|
+
-------
|
1359
|
+
dict mapping int to list of Chain
|
1360
|
+
dictionary mapping the node ids to a list of chains
|
1361
|
+
"""
|
1362
|
+
all_chains = self.all_chains
|
1363
|
+
if nodes is None:
|
1364
|
+
nodes = self.roots
|
1365
|
+
if not isinstance(nodes, Iterable):
|
1366
|
+
nodes = [nodes]
|
1367
|
+
output_chains = {}
|
1368
|
+
for n in nodes:
|
1369
|
+
starting_node = self.get_predecessors(n)[0]
|
1370
|
+
found = False
|
1371
|
+
done = False
|
1372
|
+
starting_time = self.time[n]
|
1373
|
+
i = 0
|
1374
|
+
current_chain = []
|
1375
|
+
while not done and i < len(all_chains):
|
1376
|
+
curr_found = all_chains[i][0] == starting_node
|
1377
|
+
found = found or curr_found
|
1378
|
+
if found:
|
1379
|
+
done = (
|
1380
|
+
self.time[all_chains[i][0]] <= starting_time
|
1381
|
+
) and not curr_found
|
1382
|
+
if not done:
|
1383
|
+
if curr_found:
|
1384
|
+
current_chain.append(self.get_successors(n))
|
1385
|
+
else:
|
1386
|
+
current_chain.append(all_chains[i])
|
1387
|
+
i += 1
|
1388
|
+
output_chains[n] = current_chain
|
1389
|
+
return output_chains
|
1390
|
+
|
1391
|
+
def find_leaves(self, roots: int | Iterable) -> set[int]:
|
1326
1392
|
"""Finds the leaves of a tree spawned by one or more nodes.
|
1327
1393
|
|
1328
|
-
|
1329
|
-
|
1394
|
+
Parameters
|
1395
|
+
----------
|
1396
|
+
roots : int or Iterable
|
1397
|
+
The roots of the trees spawning the leaves
|
1330
1398
|
|
1331
|
-
Returns
|
1332
|
-
|
1399
|
+
Returns
|
1400
|
+
-------
|
1401
|
+
set
|
1402
|
+
The leaves of one or more trees.
|
1333
1403
|
"""
|
1334
1404
|
if not isinstance(roots, Iterable):
|
1335
1405
|
to_do = [roots]
|
@@ -1338,28 +1408,34 @@ class lineageTree(lineageTreeLoaders):
|
|
1338
1408
|
leaves = set()
|
1339
1409
|
while to_do:
|
1340
1410
|
curr = to_do.pop()
|
1341
|
-
succ = self.
|
1411
|
+
succ = self._successor[curr]
|
1342
1412
|
if not succ:
|
1343
1413
|
leaves.add(curr)
|
1344
1414
|
to_do += succ
|
1345
1415
|
return leaves
|
1346
1416
|
|
1347
|
-
def
|
1417
|
+
def get_subtree_nodes(
|
1348
1418
|
self,
|
1349
|
-
x:
|
1350
|
-
end_time:
|
1419
|
+
x: int | Iterable,
|
1420
|
+
end_time: int | None = None,
|
1351
1421
|
preorder: bool = False,
|
1352
|
-
) -> list:
|
1353
|
-
"""Computes the list of
|
1354
|
-
The default output order is
|
1422
|
+
) -> list[int]:
|
1423
|
+
"""Computes the list of nodes from the subtree spawned by *x*
|
1424
|
+
The default output order is Breadth First Traversal.
|
1355
1425
|
Unless preorder is `True` in that case the order is
|
1356
|
-
Depth
|
1426
|
+
Depth First Traversal (DFT) preordered.
|
1427
|
+
|
1428
|
+
Parameters
|
1429
|
+
----------
|
1430
|
+
x : int
|
1431
|
+
id of root node
|
1432
|
+
preorder : bool, default=False
|
1433
|
+
if True the output preorder is DFT
|
1357
1434
|
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
([int, ...]): the ordered list of node ids
|
1435
|
+
Returns
|
1436
|
+
-------
|
1437
|
+
list of int
|
1438
|
+
the ordered list of node ids
|
1363
1439
|
"""
|
1364
1440
|
if not end_time:
|
1365
1441
|
end_time = self.t_e
|
@@ -1367,233 +1443,258 @@ class lineageTree(lineageTreeLoaders):
|
|
1367
1443
|
to_do = [x]
|
1368
1444
|
elif isinstance(x, Iterable):
|
1369
1445
|
to_do = list(x)
|
1370
|
-
|
1446
|
+
subtree = []
|
1371
1447
|
while to_do:
|
1372
1448
|
curr = to_do.pop()
|
1373
|
-
succ = self.
|
1374
|
-
if succ and end_time < self.
|
1449
|
+
succ = self._successor[curr]
|
1450
|
+
if succ and end_time < self._time.get(curr, end_time):
|
1375
1451
|
succ = []
|
1376
1452
|
continue
|
1377
1453
|
if preorder:
|
1378
1454
|
to_do = succ + to_do
|
1379
1455
|
else:
|
1380
1456
|
to_do += succ
|
1381
|
-
|
1382
|
-
return
|
1457
|
+
subtree += [curr]
|
1458
|
+
return subtree
|
1383
1459
|
|
1384
1460
|
def compute_spatial_density(
|
1385
|
-
self, t_b: int = None, t_e: int = None, th: float = 50
|
1386
|
-
) -> dict:
|
1387
|
-
"""Computes the spatial density of
|
1388
|
-
The
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1461
|
+
self, t_b: int | None = None, t_e: int | None = None, th: float = 50
|
1462
|
+
) -> dict[int, float]:
|
1463
|
+
"""Computes the spatial density of nodes between `t_b` and `t_e`.
|
1464
|
+
The results is stored in `self.spatial_density` and returned.
|
1465
|
+
|
1466
|
+
Parameters
|
1467
|
+
----------
|
1468
|
+
t_b : int, optional
|
1469
|
+
starting time to look at, default first time point
|
1470
|
+
t_e : int, optional
|
1471
|
+
ending time to look at, default last time point
|
1472
|
+
th : float, default=50
|
1473
|
+
size of the neighbourhood
|
1474
|
+
|
1475
|
+
Returns
|
1476
|
+
-------
|
1477
|
+
dict mapping int to float
|
1478
|
+
dictionary that maps a node id to its spatial density
|
1479
|
+
"""
|
1480
|
+
if not hasattr(self, "spatial_density"):
|
1481
|
+
self.spatial_density = {}
|
1399
1482
|
s_vol = 4 / 3.0 * np.pi * th**3
|
1400
|
-
|
1483
|
+
if t_b is None:
|
1484
|
+
t_b = self.t_b
|
1485
|
+
if t_e is None:
|
1486
|
+
t_e = self.t_e
|
1487
|
+
time_range = set(range(t_b, t_e)).intersection(self._time.values())
|
1401
1488
|
for t in time_range:
|
1402
1489
|
idx3d, nodes = self.get_idx3d(t)
|
1403
1490
|
nb_ni = [
|
1404
1491
|
(len(ni) - 1) / s_vol
|
1405
1492
|
for ni in idx3d.query_ball_tree(idx3d, th)
|
1406
1493
|
]
|
1407
|
-
self.spatial_density.update(dict(zip(nodes, nb_ni)))
|
1494
|
+
self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
|
1408
1495
|
return self.spatial_density
|
1409
1496
|
|
1410
|
-
def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
|
1497
|
+
def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
|
1411
1498
|
"""Computes the k-nearest neighbors
|
1412
1499
|
Writes the output in the attribute `kn_graph`
|
1413
1500
|
and returns it.
|
1414
1501
|
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1502
|
+
Parameters
|
1503
|
+
----------
|
1504
|
+
k : float
|
1505
|
+
number of nearest neighours
|
1506
|
+
|
1507
|
+
Returns
|
1508
|
+
-------
|
1509
|
+
dict mapping int to set of int
|
1510
|
+
dictionary that maps
|
1511
|
+
a node id to its `k` nearest neighbors
|
1420
1512
|
"""
|
1421
1513
|
self.kn_graph = {}
|
1422
|
-
for t
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1514
|
+
for t in set(self._time.values()):
|
1515
|
+
nodes = self.nodes_at_t(t)
|
1516
|
+
if 1 < len(nodes):
|
1517
|
+
use_k = k if k < len(nodes) else len(nodes)
|
1518
|
+
idx3d, nodes = self.get_idx3d(t)
|
1519
|
+
pos = [self.pos[c] for c in nodes]
|
1520
|
+
_, neighbs = idx3d.query(pos, use_k)
|
1521
|
+
out = dict(
|
1522
|
+
zip(
|
1523
|
+
nodes,
|
1524
|
+
map(set, nodes[neighbs]),
|
1525
|
+
strict=True,
|
1526
|
+
)
|
1527
|
+
)
|
1528
|
+
self.kn_graph.update(out)
|
1529
|
+
else:
|
1530
|
+
n = nodes.pop
|
1531
|
+
self.kn_graph.update({n: {n}})
|
1429
1532
|
return self.kn_graph
|
1430
1533
|
|
1431
|
-
def compute_spatial_edges(self, th: int = 50) -> dict:
|
1534
|
+
def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
|
1432
1535
|
"""Computes the neighbors at a distance `th`
|
1433
1536
|
Writes the output in the attribute `th_edge`
|
1434
1537
|
and returns it.
|
1435
1538
|
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1539
|
+
Parameters
|
1540
|
+
----------
|
1541
|
+
th : float, default=50
|
1542
|
+
distance to consider neighbors
|
1543
|
+
|
1544
|
+
Returns
|
1545
|
+
-------
|
1546
|
+
dict mapping int to set of int
|
1547
|
+
dictionary that maps a node id to its neighbors at a distance `th`
|
1441
1548
|
"""
|
1442
1549
|
self.th_edges = {}
|
1443
|
-
for t
|
1550
|
+
for t in set(self._time.values()):
|
1551
|
+
nodes = self.nodes_at_t(t)
|
1444
1552
|
idx3d, nodes = self.get_idx3d(t)
|
1445
1553
|
neighbs = idx3d.query_ball_tree(idx3d, th)
|
1446
|
-
out = dict(
|
1554
|
+
out = dict(
|
1555
|
+
zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
|
1556
|
+
)
|
1447
1557
|
self.th_edges.update(
|
1448
1558
|
{k: v.difference([k]) for k, v in out.items()}
|
1449
1559
|
)
|
1450
1560
|
return self.th_edges
|
1451
1561
|
|
1452
|
-
def
|
1453
|
-
"""
|
1454
|
-
If none will select the timepoint with the highest amound of cells.
|
1455
|
-
|
1456
|
-
Args:
|
1457
|
-
time (int, optional): The timepoint to find the main axes.
|
1458
|
-
If None will find the timepoint
|
1459
|
-
with the largest number of cells.
|
1460
|
-
|
1461
|
-
Returns:
|
1462
|
-
list: A list that contains the array of eigenvalues and eigenvectors.
|
1463
|
-
"""
|
1464
|
-
if time is None:
|
1465
|
-
time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
|
1466
|
-
pos = np.array([self.pos[node] for node in self.time_nodes[time]])
|
1467
|
-
pos = pos - np.mean(pos, axis=0)
|
1468
|
-
cov = np.cov(np.array(pos).T)
|
1469
|
-
eig_val, eig_vec = np.linalg.eig(cov)
|
1470
|
-
srt = np.argsort(eig_val)[::-1]
|
1471
|
-
self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
|
1472
|
-
return eig_val[srt], eig_vec[:, srt]
|
1473
|
-
|
1474
|
-
def scale_embryo(self, scale=1000):
|
1475
|
-
"""Scale the embryo using their eigenvalues.
|
1476
|
-
|
1477
|
-
Args:
|
1478
|
-
scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
|
1479
|
-
|
1480
|
-
Returns:
|
1481
|
-
float: The scale factor.
|
1482
|
-
"""
|
1483
|
-
eig = self.main_axes()[0]
|
1484
|
-
return scale / (np.sqrt(eig[0]))
|
1485
|
-
|
1486
|
-
@staticmethod
|
1487
|
-
def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
|
1488
|
-
"""Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
|
1489
|
-
Uses the Rodrigues rotation formula.
|
1490
|
-
|
1491
|
-
Args:
|
1492
|
-
vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
|
1493
|
-
vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
|
1494
|
-
|
1495
|
-
Returns:
|
1496
|
-
np.array: The rotation matrix.
|
1497
|
-
"""
|
1498
|
-
vector1 = vector1 / np.linalg.norm(vector1)
|
1499
|
-
vector2 = vector2 / np.linalg.norm(vector2)
|
1500
|
-
if vector1 @ vector2 == 1:
|
1501
|
-
return np.eye(3)
|
1502
|
-
angle = np.arccos(vector1 @ vector2)
|
1503
|
-
axis = np.cross(vector1, vector2)
|
1504
|
-
axis = axis / np.linalg.norm(axis)
|
1505
|
-
K = np.array(
|
1506
|
-
[
|
1507
|
-
[0, -axis[2], axis[1]],
|
1508
|
-
[axis[2], 0, -axis[0]],
|
1509
|
-
[-axis[1], axis[0], 0],
|
1510
|
-
]
|
1511
|
-
)
|
1512
|
-
return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
|
1513
|
-
|
1514
|
-
def get_ancestor_at_t(self, n: int, time: int = None):
|
1515
|
-
"""
|
1516
|
-
Find the id of the ancestor of a give node `n`
|
1562
|
+
def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
|
1563
|
+
"""Find the id of the ancestor of a give node `n`
|
1517
1564
|
at a given time `time`.
|
1518
1565
|
|
1519
|
-
If there is no ancestor, returns
|
1520
|
-
If time is None return the root of the
|
1566
|
+
If there is no ancestor, returns `None`
|
1567
|
+
If time is None return the root of the subtree that spawns
|
1521
1568
|
the node n.
|
1522
1569
|
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1570
|
+
Parameters
|
1571
|
+
----------
|
1572
|
+
n : int
|
1573
|
+
node for which to look the ancestor
|
1574
|
+
time : int, optional
|
1575
|
+
time at which the ancestor has to be found.
|
1576
|
+
If `None` the ancestor at the first time point
|
1577
|
+
will be found.
|
1578
|
+
|
1579
|
+
Returns
|
1580
|
+
-------
|
1581
|
+
int
|
1582
|
+
the id of the ancestor at time `time`,
|
1583
|
+
`-1` if there is no ancestor.
|
1532
1584
|
"""
|
1533
1585
|
if n not in self.nodes:
|
1534
|
-
return
|
1586
|
+
return -1
|
1535
1587
|
if time is None:
|
1536
1588
|
time = self.t_b
|
1537
1589
|
ancestor = n
|
1538
1590
|
while (
|
1539
|
-
time < self.
|
1591
|
+
time < self._time.get(ancestor, self.t_b - 1)
|
1592
|
+
and self._predecessor[ancestor]
|
1540
1593
|
):
|
1541
|
-
ancestor = self.
|
1542
|
-
|
1594
|
+
ancestor = self._predecessor[ancestor][0]
|
1595
|
+
if self._time.get(ancestor, self.t_b - 1) == time:
|
1596
|
+
return ancestor
|
1597
|
+
else:
|
1598
|
+
return -1
|
1543
1599
|
|
1544
|
-
def get_labelled_ancestor(self, node: int):
|
1545
|
-
"""Finds the first labelled ancestor and returns its ID otherwise returns
|
1600
|
+
def get_labelled_ancestor(self, node: int) -> int:
|
1601
|
+
"""Finds the first labelled ancestor and returns its ID otherwise returns -1
|
1546
1602
|
|
1547
|
-
|
1548
|
-
|
1603
|
+
Parameters
|
1604
|
+
----------
|
1605
|
+
node : int
|
1606
|
+
The id of the node
|
1549
1607
|
|
1550
|
-
Returns
|
1551
|
-
|
1552
|
-
|
1608
|
+
Returns
|
1609
|
+
-------
|
1610
|
+
int
|
1611
|
+
Returns the first ancestor found that has a label otherwise `-1`.
|
1553
1612
|
"""
|
1554
1613
|
if node not in self.nodes:
|
1555
|
-
return
|
1614
|
+
return -1
|
1556
1615
|
ancestor = node
|
1557
1616
|
while (
|
1558
|
-
self.t_b <= self.
|
1617
|
+
self.t_b <= self._time.get(ancestor, self.t_b - 1)
|
1559
1618
|
and ancestor != -1
|
1560
1619
|
):
|
1561
1620
|
if ancestor in self.labels:
|
1562
1621
|
return ancestor
|
1563
|
-
ancestor = self.
|
1564
|
-
return
|
1622
|
+
ancestor = self._predecessor.get(ancestor, [-1])[0]
|
1623
|
+
return -1
|
1624
|
+
|
1625
|
+
def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
|
1626
|
+
"""General purpose function to help with searching the first ancestor that has an attribute.
|
1627
|
+
Similar to get_labeled_ancestor and may make it redundant.
|
1628
|
+
|
1629
|
+
Parameters
|
1630
|
+
----------
|
1631
|
+
node : int
|
1632
|
+
The id of the node
|
1633
|
+
|
1634
|
+
Returns
|
1635
|
+
-------
|
1636
|
+
int
|
1637
|
+
Returns the first ancestor found that has an attribute otherwise `-1`.
|
1638
|
+
"""
|
1639
|
+
attr_dict = self.__getattribute__(attribute)
|
1640
|
+
if not isinstance(attr_dict, dict):
|
1641
|
+
raise ValueError("Please select a dict attribute")
|
1642
|
+
if node not in self.nodes:
|
1643
|
+
return -1
|
1644
|
+
if node in attr_dict:
|
1645
|
+
return node
|
1646
|
+
if node in self.roots:
|
1647
|
+
return -1
|
1648
|
+
ancestor = (node,)
|
1649
|
+
while ancestor and ancestor != [-1]:
|
1650
|
+
ancestor = ancestor[0]
|
1651
|
+
if ancestor in attr_dict:
|
1652
|
+
return ancestor
|
1653
|
+
ancestor = self._predecessor.get(ancestor, [-1])
|
1654
|
+
return -1
|
1565
1655
|
|
1566
1656
|
def unordered_tree_edit_distances_at_time_t(
|
1567
1657
|
self,
|
1568
1658
|
t: int,
|
1569
|
-
end_time: int = None,
|
1570
|
-
style
|
1659
|
+
end_time: int | None = None,
|
1660
|
+
style: (
|
1661
|
+
Literal["simple", "full", "downsampled", "normalized_simple"]
|
1662
|
+
| type[TreeApproximationTemplate]
|
1663
|
+
) = "simple",
|
1571
1664
|
downsample: int = 2,
|
1572
|
-
|
1665
|
+
norm: Literal["max", "sum", None] = "max",
|
1573
1666
|
recompute: bool = False,
|
1574
|
-
) -> dict:
|
1575
|
-
"""
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1667
|
+
) -> dict[tuple[int, int], float]:
|
1668
|
+
"""Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
|
1669
|
+
|
1670
|
+
Parameters
|
1671
|
+
----------
|
1672
|
+
t : int
|
1673
|
+
time to look at
|
1674
|
+
end_time : int
|
1675
|
+
The final time point the comparison algorithm will take into account.
|
1676
|
+
If None all nodes will be taken into account.
|
1677
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1678
|
+
Which tree approximation is going to be used for the comparisons.
|
1679
|
+
downsample : int, default=2
|
1680
|
+
The downsample factor for the downsampled tree approximation.
|
1681
|
+
Used only when `style="downsampled"`.
|
1682
|
+
norm : {"max", "sum"}, default="max"
|
1683
|
+
The normalization method to use.
|
1684
|
+
recompute : bool, default=False
|
1685
|
+
If True, forces to recompute the distances
|
1686
|
+
|
1687
|
+
Returns
|
1688
|
+
-------
|
1689
|
+
dict mapping a tuple of tuple that contains 2 ints to float
|
1690
|
+
a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
|
1590
1691
|
"""
|
1591
1692
|
if not hasattr(self, "uted"):
|
1592
1693
|
self.uted = {}
|
1593
1694
|
elif t in self.uted and not recompute:
|
1594
1695
|
return self.uted[t]
|
1595
1696
|
self.uted[t] = {}
|
1596
|
-
roots = self.
|
1697
|
+
roots = self.nodes_at_t(t=t)
|
1597
1698
|
for n1, n2 in combinations(roots, 2):
|
1598
1699
|
key = tuple(sorted((n1, n2)))
|
1599
1700
|
self.uted[t][key] = self.unordered_tree_edit_distance(
|
@@ -1602,37 +1703,132 @@ class lineageTree(lineageTreeLoaders):
|
|
1602
1703
|
end_time=end_time,
|
1603
1704
|
style=style,
|
1604
1705
|
downsample=downsample,
|
1605
|
-
|
1706
|
+
norm=norm,
|
1606
1707
|
)
|
1607
1708
|
return self.uted[t]
|
1608
1709
|
|
1609
|
-
def
|
1710
|
+
def __calculate_distance_of_sub_tree(
|
1711
|
+
self,
|
1712
|
+
node1: int,
|
1713
|
+
node2: int,
|
1714
|
+
alignment: Alignment,
|
1715
|
+
corres1: dict[int, int],
|
1716
|
+
corres2: dict[int, int],
|
1717
|
+
delta_tmp: Callable,
|
1718
|
+
norm: Callable,
|
1719
|
+
norm1: int | float,
|
1720
|
+
norm2: int | float,
|
1721
|
+
) -> float:
|
1722
|
+
"""Calculates the distance of the subtree of each node matched in a comparison.
|
1723
|
+
DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
|
1724
|
+
TODO ITS BOUND TO CHANGE
|
1725
|
+
Parameters
|
1726
|
+
----------
|
1727
|
+
node1 : int
|
1728
|
+
The root of the first subtree
|
1729
|
+
node2 : int
|
1730
|
+
The root of the second subtree
|
1731
|
+
alignment : Alignment
|
1732
|
+
The alignment of the subtree
|
1733
|
+
corres1 : dict
|
1734
|
+
The correspndance dictionary of the first lineage
|
1735
|
+
corres2 : dict
|
1736
|
+
The correspondance dictionary of the second lineage
|
1737
|
+
delta_tmp : Callable
|
1738
|
+
The delta function for the comparisons
|
1739
|
+
norm : Callable
|
1740
|
+
How should the lineages be normalized
|
1741
|
+
norm1 : int or float
|
1742
|
+
The result of the normalization of the first tree
|
1743
|
+
norm2 : int or float
|
1744
|
+
The result of the normalization of the second tree
|
1745
|
+
|
1746
|
+
Returns
|
1747
|
+
-------
|
1748
|
+
float
|
1749
|
+
The result of the comparison of the subtree
|
1750
|
+
"""
|
1751
|
+
sub_tree_1 = set(self.get_subtree_nodes(node1))
|
1752
|
+
sub_tree_2 = set(self.get_subtree_nodes(node2))
|
1753
|
+
res = 0
|
1754
|
+
for m in alignment:
|
1755
|
+
if (
|
1756
|
+
corres1.get(m._left, -1) in sub_tree_1
|
1757
|
+
or corres2.get(m._right, -1) in sub_tree_2
|
1758
|
+
):
|
1759
|
+
res += delta_tmp(
|
1760
|
+
m._left if m._left != -1 else None,
|
1761
|
+
m._right if m._right != -1 else None,
|
1762
|
+
)
|
1763
|
+
return res / norm([norm1, norm2])
|
1764
|
+
|
1765
|
+
def clear_comparisons(self):
|
1766
|
+
self._comparisons.clear()
|
1767
|
+
|
1768
|
+
def __unordereded_backtrace(
|
1610
1769
|
self,
|
1611
1770
|
n1: int,
|
1612
1771
|
n2: int,
|
1613
|
-
end_time: int = None,
|
1614
|
-
norm:
|
1615
|
-
style
|
1772
|
+
end_time: int | None = None,
|
1773
|
+
norm: Literal["max", "sum", None] = "max",
|
1774
|
+
style: (
|
1775
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1776
|
+
| type[TreeApproximationTemplate]
|
1777
|
+
) = "simple",
|
1616
1778
|
downsample: int = 2,
|
1617
|
-
) ->
|
1779
|
+
) -> dict[
|
1780
|
+
str,
|
1781
|
+
Alignment
|
1782
|
+
| tuple[TreeApproximationTemplate, TreeApproximationTemplate],
|
1783
|
+
]:
|
1618
1784
|
"""
|
1619
|
-
Compute the unordered tree edit
|
1785
|
+
Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
|
1620
1786
|
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
1621
1787
|
cost is given by the function delta (see edist doc for more information).
|
1622
|
-
The distance is normed by the function norm that takes the two list of nodes
|
1623
|
-
spawned by the trees `n1` and `n2`.
|
1624
1788
|
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
1789
|
+
Parameters
|
1790
|
+
----------
|
1791
|
+
n1 : int
|
1792
|
+
id of the first node to compare
|
1793
|
+
n2 : int
|
1794
|
+
id of the second node to compare
|
1795
|
+
end_time : int
|
1796
|
+
The final time point the comparison algorithm will take into account.
|
1797
|
+
If None all nodes will be taken into account.
|
1798
|
+
norm : {"max", "sum"}, default="max"
|
1799
|
+
The normalization method to use.
|
1800
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1801
|
+
Which tree approximation is going to be used for the comparisons.
|
1802
|
+
downsample : int, default=2
|
1803
|
+
The downsample factor for the downsampled tree approximation.
|
1804
|
+
Used only when `style="downsampled"`.
|
1805
|
+
|
1806
|
+
Returns
|
1807
|
+
-------
|
1808
|
+
dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
|
1809
|
+
- 'alignment'
|
1810
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
|
1811
|
+
- 'trees'
|
1812
|
+
A list of the two trees that have been mapped to each other.
|
1813
|
+
"""
|
1814
|
+
|
1815
|
+
parameters = (
|
1816
|
+
end_time,
|
1817
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1818
|
+
)
|
1819
|
+
n1, n2 = sorted([n1, n2])
|
1820
|
+
self._comparisons.setdefault(parameters, {})
|
1821
|
+
if len(self._comparisons) > 100:
|
1822
|
+
warnings.warn(
|
1823
|
+
"More than 100 comparisons are saved, use clear_comparisons() to delete them.",
|
1824
|
+
stacklevel=2,
|
1825
|
+
)
|
1826
|
+
if isinstance(style, str):
|
1827
|
+
tree = tree_style[style].value
|
1828
|
+
elif issubclass(style, TreeApproximationTemplate):
|
1829
|
+
tree = style
|
1830
|
+
else:
|
1831
|
+
raise ValueError("Please use a valid approximation.")
|
1636
1832
|
tree1 = tree(
|
1637
1833
|
lT=self,
|
1638
1834
|
downsample=downsample,
|
@@ -1661,7 +1857,11 @@ class lineageTree(lineageTreeLoaders):
|
|
1661
1857
|
corres2,
|
1662
1858
|
) = tree2.edist
|
1663
1859
|
if len(nodes1) == len(nodes2) == 0:
|
1664
|
-
|
1860
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1861
|
+
"alignment": (),
|
1862
|
+
"trees": (),
|
1863
|
+
}
|
1864
|
+
return self._comparisons[parameters][(n1, n2)]
|
1665
1865
|
delta_tmp = partial(
|
1666
1866
|
delta,
|
1667
1867
|
corres1=corres1,
|
@@ -1669,126 +1869,538 @@ class lineageTree(lineageTreeLoaders):
|
|
1669
1869
|
times1=times1,
|
1670
1870
|
times2=times2,
|
1671
1871
|
)
|
1672
|
-
|
1673
|
-
|
1674
|
-
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1872
|
+
btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
|
1873
|
+
|
1874
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1875
|
+
"alignment": btrc,
|
1876
|
+
"trees": (tree1, tree2),
|
1877
|
+
}
|
1878
|
+
return self._comparisons[parameters][(n1, n2)]
|
1879
|
+
|
1880
|
+
def plot_tree_distance_graphs(
|
1881
|
+
self,
|
1882
|
+
n1: int,
|
1883
|
+
n2: int,
|
1884
|
+
end_time: int | None = None,
|
1885
|
+
norm: Literal["max", "sum", None] = "max",
|
1886
|
+
style: (
|
1887
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1888
|
+
| type[TreeApproximationTemplate]
|
1889
|
+
) = "simple",
|
1890
|
+
downsample: int = 2,
|
1891
|
+
colormap: str = "cool",
|
1892
|
+
default_color: str = "black",
|
1893
|
+
size: float = 10,
|
1894
|
+
lw: float = 0.3,
|
1895
|
+
ax: list[plt.Axes] | None = None,
|
1896
|
+
) -> tuple[plt.figure, plt.Axes]:
|
1897
|
+
"""
|
1898
|
+
Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
|
1899
|
+
|
1900
|
+
Parameters
|
1901
|
+
----------
|
1902
|
+
n1 : int
|
1903
|
+
id of the first node to compare
|
1904
|
+
n2 : int
|
1905
|
+
id of the second node to compare
|
1906
|
+
end_time : int
|
1907
|
+
The final time point the comparison algorithm will take into account.
|
1908
|
+
If None all nodes will be taken into account.
|
1909
|
+
norm : {"max", "sum"}, default="max"
|
1910
|
+
The normalization method to use.
|
1911
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
1912
|
+
Which tree approximation is going to be used for the comparisons.
|
1913
|
+
downsample : int, default=2
|
1914
|
+
The downsample factor for the downsampled tree approximation.
|
1915
|
+
Used only when `style="downsampled"`.
|
1916
|
+
colormap : str, default="cool"
|
1917
|
+
The colormap used for matched nodes, defaults to "cool"
|
1918
|
+
default_color : str
|
1919
|
+
The color of the unmatched nodes, defaults to "black"
|
1920
|
+
size : float
|
1921
|
+
The size of the nodes, defaults to 10
|
1922
|
+
lw : float
|
1923
|
+
The width of the edges, defaults to 0.3
|
1924
|
+
ax : np.ndarray, optional
|
1925
|
+
The axes used, if not provided another set of axes is produced, defaults to None
|
1926
|
+
|
1927
|
+
Returns
|
1928
|
+
-------
|
1929
|
+
plt.Figure
|
1930
|
+
The figure of the plot
|
1931
|
+
plt.Axes
|
1932
|
+
The axes of the plot
|
1933
|
+
"""
|
1934
|
+
parameters = (
|
1935
|
+
end_time,
|
1936
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1937
|
+
)
|
1938
|
+
n1, n2 = sorted([n1, n2])
|
1939
|
+
self._comparisons.setdefault(parameters, {})
|
1940
|
+
if self._comparisons[parameters].get((n1, n2)):
|
1941
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
1942
|
+
else:
|
1943
|
+
tmp = self.__unordereded_backtrace(
|
1944
|
+
n1, n2, end_time, norm, style, downsample
|
1945
|
+
)
|
1946
|
+
btrc: Alignment = tmp["alignment"]
|
1947
|
+
tree1, tree2 = tmp["trees"]
|
1948
|
+
_, times1 = tree1.tree
|
1949
|
+
_, times2 = tree2.tree
|
1950
|
+
(
|
1951
|
+
*_,
|
1952
|
+
corres1,
|
1953
|
+
) = tree1.edist
|
1954
|
+
(
|
1955
|
+
*_,
|
1956
|
+
corres2,
|
1957
|
+
) = tree2.edist
|
1958
|
+
delta_tmp = partial(
|
1959
|
+
tree1.delta,
|
1960
|
+
corres1=corres1,
|
1961
|
+
corres2=corres2,
|
1962
|
+
times1=times1,
|
1963
|
+
times2=times2,
|
1964
|
+
)
|
1965
|
+
|
1966
|
+
if norm not in self.norm_dict:
|
1678
1967
|
raise Warning(
|
1679
1968
|
"Select a viable normalization method (max, sum, None)"
|
1680
1969
|
)
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1970
|
+
matched_right = []
|
1971
|
+
matched_left = []
|
1972
|
+
colors = {}
|
1973
|
+
if style not in ("full", "downsampled"):
|
1974
|
+
for m in btrc:
|
1975
|
+
if m._left != -1 and m._right != -1:
|
1976
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
1977
|
+
if len(cyc1) > 1:
|
1978
|
+
node_1, *_, l_node_1 = cyc1
|
1979
|
+
matched_left.append(node_1)
|
1980
|
+
matched_left.append(l_node_1)
|
1981
|
+
elif len(cyc1) == 1:
|
1982
|
+
node_1 = l_node_1 = cyc1.pop()
|
1983
|
+
matched_left.append(node_1)
|
1984
|
+
|
1985
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
1986
|
+
if len(cyc2) > 1:
|
1987
|
+
node_2, *_, l_node_2 = cyc2
|
1988
|
+
matched_right.append(node_2)
|
1989
|
+
matched_right.append(l_node_2)
|
1990
|
+
|
1991
|
+
elif len(cyc2) == 1:
|
1992
|
+
node_2 = l_node_2 = cyc2.pop()
|
1993
|
+
matched_right.append(node_2)
|
1994
|
+
|
1995
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
1996
|
+
node_1,
|
1997
|
+
node_2,
|
1998
|
+
btrc,
|
1999
|
+
corres1,
|
2000
|
+
corres2,
|
2001
|
+
delta_tmp,
|
2002
|
+
self.norm_dict[norm],
|
2003
|
+
tree1.get_norm(node_1),
|
2004
|
+
tree2.get_norm(node_2),
|
2005
|
+
)
|
2006
|
+
colors[node_2] = colors[node_1]
|
2007
|
+
colors[l_node_1] = colors[node_1]
|
2008
|
+
colors[l_node_2] = colors[node_2]
|
2009
|
+
else:
|
2010
|
+
for m in btrc:
|
2011
|
+
if m._left != -1 and m._right != -1:
|
2012
|
+
node_1 = corres1[m._left]
|
2013
|
+
node_2 = corres2[m._right]
|
2014
|
+
|
2015
|
+
if (
|
2016
|
+
self.get_chain_of_node(node_1)[0] == node_1
|
2017
|
+
or self.get_chain_of_node(node_2)[0] == node_2
|
2018
|
+
and (node_1 not in colors or node_2 not in colors)
|
2019
|
+
):
|
2020
|
+
matched_left.append(node_1)
|
2021
|
+
l_node_1 = self.get_chain_of_node(node_1)[-1]
|
2022
|
+
matched_left.append(l_node_1)
|
2023
|
+
matched_right.append(node_2)
|
2024
|
+
l_node_2 = self.get_chain_of_node(node_2)[-1]
|
2025
|
+
matched_right.append(l_node_2)
|
2026
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
2027
|
+
node_1,
|
2028
|
+
node_2,
|
2029
|
+
btrc,
|
2030
|
+
corres1,
|
2031
|
+
corres2,
|
2032
|
+
delta_tmp,
|
2033
|
+
self.norm_dict[norm],
|
2034
|
+
tree1.get_norm(node_1),
|
2035
|
+
tree2.get_norm(node_2),
|
2036
|
+
)
|
2037
|
+
colors[l_node_1] = colors[node_1]
|
2038
|
+
colors[node_2] = colors[node_1]
|
2039
|
+
colors[l_node_2] = colors[node_1]
|
2040
|
+
if ax is None:
|
2041
|
+
fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
|
2042
|
+
cmap = colormaps[colormap]
|
2043
|
+
c_norm = mcolors.Normalize(0, 1)
|
2044
|
+
colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
|
2045
|
+
self.plot_subtree(
|
2046
|
+
self.get_ancestor_at_t(n1),
|
2047
|
+
end_time=end_time,
|
2048
|
+
size=size,
|
2049
|
+
selected_nodes=matched_left,
|
2050
|
+
color_of_nodes=colors,
|
2051
|
+
selected_edges=matched_left,
|
2052
|
+
color_of_edges=colors,
|
2053
|
+
default_color=default_color,
|
2054
|
+
lw=lw,
|
2055
|
+
ax=ax[0],
|
2056
|
+
)
|
2057
|
+
self.plot_subtree(
|
2058
|
+
self.get_ancestor_at_t(n2),
|
2059
|
+
end_time=end_time,
|
2060
|
+
size=size,
|
2061
|
+
selected_nodes=matched_right,
|
2062
|
+
color_of_nodes=colors,
|
2063
|
+
selected_edges=matched_right,
|
2064
|
+
color_of_edges=colors,
|
2065
|
+
default_color=default_color,
|
2066
|
+
lw=lw,
|
2067
|
+
ax=ax[1],
|
2068
|
+
)
|
2069
|
+
return ax[0].get_figure(), ax
|
2070
|
+
|
2071
|
+
def labelled_mappings(
|
2072
|
+
self,
|
2073
|
+
n1: int,
|
2074
|
+
n2: int,
|
2075
|
+
end_time: int | None = None,
|
2076
|
+
norm: Literal["max", "sum", None] = "max",
|
2077
|
+
style: (
|
2078
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2079
|
+
| type[TreeApproximationTemplate]
|
2080
|
+
) = "simple",
|
2081
|
+
downsample: int = 2,
|
2082
|
+
) -> dict[str, list[str]]:
|
2083
|
+
"""
|
2084
|
+
Returns the labels or IDs of all the nodes in the subtrees compared.
|
2085
|
+
|
2086
|
+
|
2087
|
+
Parameters
|
2088
|
+
----------
|
2089
|
+
n1 : int
|
2090
|
+
id of the first node to compare
|
2091
|
+
n2 : int
|
2092
|
+
id of the second node to compare
|
2093
|
+
end_time : int, optional
|
2094
|
+
The final time point the comparison algorithm will take into account.
|
2095
|
+
If None or not provided all nodes will be taken into account.
|
2096
|
+
norm : {"max", "sum"}, default="max"
|
2097
|
+
The normalization method to use, defaults to 'max'.
|
2098
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
2099
|
+
Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
|
2100
|
+
downsample : int, default=2
|
2101
|
+
The downsample factor for the downsampled tree approximation.
|
2102
|
+
Used only when `style="downsampled"`.
|
2103
|
+
|
2104
|
+
Returns
|
2105
|
+
-------
|
2106
|
+
dict mapping str to list of str
|
2107
|
+
- 'matched' The labels of the matched nodes of the alignment.
|
2108
|
+
- 'unmatched' The labels of the unmatched nodes of the alginment.
|
2109
|
+
"""
|
2110
|
+
parameters = (
|
2111
|
+
end_time,
|
2112
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2113
|
+
)
|
2114
|
+
n1, n2 = sorted([n1, n2])
|
2115
|
+
self._comparisons.setdefault(parameters, {})
|
2116
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2117
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2118
|
+
else:
|
2119
|
+
tmp = self.__unordereded_backtrace(
|
2120
|
+
n1, n2, end_time, norm, style, downsample
|
2121
|
+
)
|
2122
|
+
btrc = tmp["alignment"]
|
2123
|
+
tree1, tree2 = tmp["trees"]
|
2124
|
+
|
2125
|
+
(
|
2126
|
+
*_,
|
2127
|
+
corres1,
|
2128
|
+
) = tree1.edist
|
2129
|
+
(
|
2130
|
+
*_,
|
2131
|
+
corres2,
|
2132
|
+
) = tree2.edist
|
2133
|
+
|
2134
|
+
if norm not in self.norm_dict:
|
2135
|
+
raise Warning(
|
2136
|
+
"Select a viable normalization method (max, sum, None)"
|
2137
|
+
)
|
2138
|
+
matched = []
|
2139
|
+
unmatched = []
|
2140
|
+
if style not in ("full", "downsampled"):
|
2141
|
+
for m in btrc:
|
2142
|
+
if m._left != -1 and m._right != -1:
|
2143
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
2144
|
+
if len(cyc1) > 1:
|
2145
|
+
node_1, *_ = cyc1
|
2146
|
+
elif len(cyc1) == 1:
|
2147
|
+
node_1 = cyc1.pop()
|
2148
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
2149
|
+
if len(cyc2) > 1:
|
2150
|
+
node_2, *_ = cyc2
|
2151
|
+
elif len(cyc2) == 1:
|
2152
|
+
node_2 = cyc2.pop()
|
2153
|
+
matched.append(
|
2154
|
+
(
|
2155
|
+
self.labels.get(node_1, node_1),
|
2156
|
+
self.labels.get(node_2, node_2),
|
2157
|
+
)
|
2158
|
+
)
|
2159
|
+
|
2160
|
+
else:
|
2161
|
+
if m._left != -1:
|
2162
|
+
node_1 = self.get_chain_of_node(
|
2163
|
+
corres1.get(m._left, "-")
|
2164
|
+
)[0]
|
2165
|
+
else:
|
2166
|
+
node_1 = self.get_chain_of_node(
|
2167
|
+
corres2.get(m._right, "-")
|
2168
|
+
)[0]
|
2169
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2170
|
+
else:
|
2171
|
+
for m in btrc:
|
2172
|
+
if m._left != -1 and m._right != -1:
|
2173
|
+
node_1 = corres1[m._left]
|
2174
|
+
node_2 = corres2[m._right]
|
2175
|
+
matched.append(
|
2176
|
+
(
|
2177
|
+
self.labels.get(node_1, node_1),
|
2178
|
+
self.labels.get(node_2, node_2),
|
2179
|
+
)
|
2180
|
+
)
|
2181
|
+
else:
|
2182
|
+
if m._left != -1:
|
2183
|
+
node_1 = corres1[m._left]
|
2184
|
+
else:
|
2185
|
+
node_1 = corres2[m._right]
|
2186
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2187
|
+
return {"matched": matched, "unmatched": unmatched}
|
2188
|
+
|
2189
|
+
def unordered_tree_edit_distance(
|
2190
|
+
self,
|
2191
|
+
n1: int,
|
2192
|
+
n2: int,
|
2193
|
+
end_time: int | None = None,
|
2194
|
+
norm: Literal["max", "sum", None] = "max",
|
2195
|
+
style: (
|
2196
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2197
|
+
| type[TreeApproximationTemplate]
|
2198
|
+
) = "simple",
|
2199
|
+
downsample: int = 2,
|
2200
|
+
return_norms: bool = False,
|
2201
|
+
) -> float | tuple[float, tuple[float, float]]:
|
2202
|
+
"""
|
2203
|
+
Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
2204
|
+
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
2205
|
+
cost is given by the function delta (see edist doc for more information).
|
2206
|
+
The distance is normed by the function norm that takes the two list of nodes
|
2207
|
+
spawned by the trees `n1` and `n2`.
|
2208
|
+
|
2209
|
+
Parameters
|
2210
|
+
----------
|
2211
|
+
n1 : int
|
2212
|
+
id of the first node to compare
|
2213
|
+
n2 : int
|
2214
|
+
id of the second node to compare
|
2215
|
+
end_time : int, optional
|
2216
|
+
The final time point the comparison algorithm will take into account.
|
2217
|
+
If None or not provided all nodes will be taken into account.
|
2218
|
+
norm : {"max", "sum"}, default="max"
|
2219
|
+
The normalization method to use, defaults to 'max'.
|
2220
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
2221
|
+
Which tree approximation is going to be used for the comparisons.
|
2222
|
+
downsample : int, default=2
|
2223
|
+
The downsample factor for the downsampled tree approximation.
|
2224
|
+
Used only when `style="downsampled"`.
|
2225
|
+
|
2226
|
+
Returns
|
2227
|
+
-------
|
2228
|
+
float
|
2229
|
+
The normalized unordered tree edit distance between `n1` and `n2`
|
2230
|
+
"""
|
2231
|
+
parameters = (
|
2232
|
+
end_time,
|
2233
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2234
|
+
)
|
2235
|
+
n1, n2 = sorted([n1, n2])
|
2236
|
+
self._comparisons.setdefault(parameters, {})
|
2237
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2238
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2239
|
+
else:
|
2240
|
+
tmp = self.__unordereded_backtrace(
|
2241
|
+
n1, n2, end_time, norm, style, downsample
|
2242
|
+
)
|
2243
|
+
btrc = tmp["alignment"]
|
2244
|
+
tree1, tree2 = tmp["trees"]
|
2245
|
+
_, times1 = tree1.tree
|
2246
|
+
_, times2 = tree2.tree
|
2247
|
+
(
|
2248
|
+
nodes1,
|
2249
|
+
adj1,
|
2250
|
+
corres1,
|
2251
|
+
) = tree1.edist
|
2252
|
+
(
|
2253
|
+
nodes2,
|
2254
|
+
adj2,
|
2255
|
+
corres2,
|
2256
|
+
) = tree2.edist
|
2257
|
+
delta_tmp = partial(
|
2258
|
+
tree1.delta,
|
2259
|
+
corres1=corres1,
|
2260
|
+
corres2=corres2,
|
2261
|
+
times1=times1,
|
2262
|
+
times2=times2,
|
2263
|
+
)
|
2264
|
+
|
2265
|
+
if norm not in self.norm_dict:
|
2266
|
+
raise ValueError(
|
2267
|
+
"Select a viable normalization method (max, sum, None)"
|
2268
|
+
)
|
2269
|
+
cost = btrc.cost(nodes1, nodes2, delta_tmp)
|
2270
|
+
norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
|
2271
|
+
if return_norms:
|
2272
|
+
return cost, norm_values
|
2273
|
+
return cost / self.norm_dict[norm](norm_values)
|
1684
2274
|
|
1685
2275
|
@staticmethod
|
1686
2276
|
def __plot_nodes(
|
1687
|
-
hier
|
1688
|
-
|
2277
|
+
hier: dict,
|
2278
|
+
selected_nodes: set,
|
2279
|
+
color: str | dict | list,
|
2280
|
+
size: int | float,
|
2281
|
+
ax: plt.Axes,
|
2282
|
+
default_color: str = "black",
|
2283
|
+
**kwargs,
|
2284
|
+
) -> None:
|
1689
2285
|
"""
|
1690
2286
|
Private method that plots the nodes of the tree.
|
1691
2287
|
"""
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
)
|
1703
|
-
if selected_nodes.intersection(hier.keys()):
|
1704
|
-
hier_selected = np.array(
|
1705
|
-
[v for k, v in hier.items() if k in selected_nodes]
|
1706
|
-
)
|
1707
|
-
ax.scatter(
|
1708
|
-
*hier_selected.T, s=size, zorder=10, color=color, **kwargs
|
1709
|
-
)
|
2288
|
+
|
2289
|
+
if isinstance(color, dict):
|
2290
|
+
color = [color.get(k, default_color) for k in hier]
|
2291
|
+
elif isinstance(color, str | list):
|
2292
|
+
color = [
|
2293
|
+
color if node in selected_nodes else default_color
|
2294
|
+
for node in hier
|
2295
|
+
]
|
2296
|
+
hier_pos = np.array(list(hier.values()))
|
2297
|
+
ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
|
1710
2298
|
|
1711
2299
|
@staticmethod
|
1712
2300
|
def __plot_edges(
|
1713
|
-
hier,
|
1714
|
-
lnks_tms,
|
1715
|
-
selected_edges,
|
1716
|
-
color,
|
1717
|
-
|
1718
|
-
|
2301
|
+
hier: dict,
|
2302
|
+
lnks_tms: dict,
|
2303
|
+
selected_edges: Iterable,
|
2304
|
+
color: str | dict | list,
|
2305
|
+
lw: float,
|
2306
|
+
ax: plt.Axes,
|
2307
|
+
default_color: str = "black",
|
1719
2308
|
**kwargs,
|
1720
|
-
):
|
2309
|
+
) -> None:
|
1721
2310
|
"""
|
1722
2311
|
Private method that plots the edges of the tree.
|
1723
2312
|
"""
|
1724
|
-
|
2313
|
+
if isinstance(color, dict):
|
2314
|
+
selected_edges = color.keys()
|
2315
|
+
lines = []
|
2316
|
+
c = []
|
1725
2317
|
for pred, succs in lnks_tms["links"].items():
|
1726
|
-
for
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
2318
|
+
for suc in succs:
|
2319
|
+
lines.append(
|
2320
|
+
[
|
2321
|
+
[hier[suc][0], hier[suc][1]],
|
2322
|
+
[hier[pred][0], hier[pred][1]],
|
2323
|
+
]
|
2324
|
+
)
|
2325
|
+
if pred in selected_edges:
|
2326
|
+
if isinstance(color, str | list):
|
2327
|
+
c.append(color)
|
2328
|
+
elif isinstance(color, dict):
|
2329
|
+
c.append(color[pred])
|
2330
|
+
else:
|
2331
|
+
c.append(default_color)
|
2332
|
+
lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
|
2333
|
+
ax.add_collection(lc)
|
1738
2334
|
|
1739
2335
|
def draw_tree_graph(
|
1740
2336
|
self,
|
1741
|
-
hier,
|
1742
|
-
lnks_tms,
|
1743
|
-
selected_nodes=None,
|
1744
|
-
selected_edges=None,
|
1745
|
-
color_of_nodes="magenta",
|
1746
|
-
color_of_edges=
|
1747
|
-
size=10,
|
1748
|
-
|
1749
|
-
|
2337
|
+
hier: dict[int, tuple[int, int]],
|
2338
|
+
lnks_tms: dict[str, dict[int, list | int]],
|
2339
|
+
selected_nodes: list | set | None = None,
|
2340
|
+
selected_edges: list | set | None = None,
|
2341
|
+
color_of_nodes: str | dict = "magenta",
|
2342
|
+
color_of_edges: str | dict = "magenta",
|
2343
|
+
size: int | float = 10,
|
2344
|
+
lw: float = 0.3,
|
2345
|
+
ax: plt.Axes | None = None,
|
2346
|
+
default_color: str = "black",
|
1750
2347
|
**kwargs,
|
1751
|
-
):
|
2348
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
1752
2349
|
"""Function to plot the tree graph.
|
1753
2350
|
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
2351
|
+
Parameters
|
2352
|
+
----------
|
2353
|
+
hier : dict mapping int to tuple of int
|
2354
|
+
Dictionary that contains the positions of all nodes.
|
2355
|
+
lnks_tms : dict mapping string to dictionaries mapping int to list or int
|
2356
|
+
- 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
|
2357
|
+
- 'times' : contains the distance between the start and the end of each chain.
|
2358
|
+
selected_nodes : list or set, optional
|
2359
|
+
Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
|
2360
|
+
selected_edges : list or set, optional
|
2361
|
+
Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
|
2362
|
+
color_of_nodes : str, default="magenta"
|
2363
|
+
Color of selected nodes
|
2364
|
+
color_of_edges : str, default="magenta"
|
2365
|
+
Color of selected edges
|
2366
|
+
size : int, default=10
|
2367
|
+
Size of the nodes, defaults to 10
|
2368
|
+
lw : float, default=0.3
|
2369
|
+
The width of the edges of the tree graph, defaults to 0.3
|
2370
|
+
ax : plt.Axes, optional
|
2371
|
+
Plot the graph on existing ax. If not provided or None a new ax is going to be created.
|
2372
|
+
default_color : str, default="black"
|
2373
|
+
Default color of nodes
|
2374
|
+
|
2375
|
+
Returns
|
2376
|
+
-------
|
2377
|
+
plt.Figure
|
2378
|
+
The matplotlib figure
|
2379
|
+
plt.Axes
|
2380
|
+
The matplotlib ax
|
1770
2381
|
"""
|
1771
2382
|
if selected_nodes is None:
|
1772
2383
|
selected_nodes = []
|
1773
2384
|
if selected_edges is None:
|
1774
2385
|
selected_edges = []
|
1775
2386
|
if ax is None:
|
1776
|
-
|
2387
|
+
_, ax = plt.subplots()
|
1777
2388
|
else:
|
1778
2389
|
ax.clear()
|
1779
2390
|
if not isinstance(selected_nodes, set):
|
1780
2391
|
selected_nodes = set(selected_nodes)
|
1781
2392
|
if not isinstance(selected_edges, set):
|
1782
2393
|
selected_edges = set(selected_edges)
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
2394
|
+
if 0 < size:
|
2395
|
+
self.__plot_nodes(
|
2396
|
+
hier,
|
2397
|
+
selected_nodes,
|
2398
|
+
color_of_nodes,
|
2399
|
+
size=size,
|
2400
|
+
ax=ax,
|
2401
|
+
default_color=default_color,
|
2402
|
+
**kwargs,
|
2403
|
+
)
|
1792
2404
|
if not color_of_edges:
|
1793
2405
|
color_of_edges = color_of_nodes
|
1794
2406
|
self.__plot_edges(
|
@@ -1796,62 +2408,107 @@ class lineageTree(lineageTreeLoaders):
|
|
1796
2408
|
lnks_tms,
|
1797
2409
|
selected_edges,
|
1798
2410
|
color_of_edges,
|
2411
|
+
lw,
|
1799
2412
|
ax,
|
1800
2413
|
default_color=default_color,
|
1801
2414
|
**kwargs,
|
1802
2415
|
)
|
2416
|
+
ax.autoscale()
|
2417
|
+
plt.draw()
|
1803
2418
|
ax.get_yaxis().set_visible(False)
|
1804
2419
|
ax.get_xaxis().set_visible(False)
|
1805
2420
|
return ax.get_figure(), ax
|
1806
2421
|
|
1807
|
-
def
|
2422
|
+
def _create_dict_of_plots(
|
2423
|
+
self,
|
2424
|
+
node: int | Iterable[int] | None = None,
|
2425
|
+
start_time: int | None = None,
|
2426
|
+
end_time: int | None = None,
|
2427
|
+
) -> dict[int, dict]:
|
1808
2428
|
"""Generates a dictionary of graphs where the keys are the index of the graph and
|
1809
|
-
the values are the graphs themselves which are produced by
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
2429
|
+
the values are the graphs themselves which are produced by `create_links_and_chains`
|
2430
|
+
|
2431
|
+
Parameters
|
2432
|
+
----------
|
2433
|
+
node : int or Iterable of int, optional
|
2434
|
+
The id of the node/nodes to produce the simple graphs, if not provided or None will
|
2435
|
+
calculate the dicts for every root that starts before 'start_time'
|
2436
|
+
start_time : int, optional
|
2437
|
+
Important only if there are no nodes it will produce the graph of every
|
2438
|
+
root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
|
2439
|
+
end_time : int, optional
|
2440
|
+
The last timepoint to be considered, if not provided or None the last timepoint of the
|
2441
|
+
dataset (t_e) is considered.
|
2442
|
+
|
2443
|
+
Returns
|
2444
|
+
-------
|
2445
|
+
dict mapping int to dict
|
2446
|
+
The keys are just index values 0-n and the values are the graphs produced.
|
1818
2447
|
"""
|
1819
2448
|
if start_time is None:
|
1820
2449
|
start_time = self.t_b
|
2450
|
+
if end_time is None:
|
2451
|
+
end_time = self.t_e
|
1821
2452
|
if node is None:
|
1822
2453
|
mothers = [
|
1823
|
-
root for root in self.roots if self.
|
2454
|
+
root for root in self.roots if self._time[root] <= start_time
|
1824
2455
|
]
|
2456
|
+
elif isinstance(node, Iterable):
|
2457
|
+
mothers = node
|
1825
2458
|
else:
|
1826
|
-
mothers =
|
2459
|
+
mothers = [node]
|
1827
2460
|
return {
|
1828
|
-
i:
|
2461
|
+
i: create_links_and_chains(self, mother, end_time=end_time)
|
1829
2462
|
for i, mother in enumerate(mothers)
|
1830
2463
|
}
|
1831
2464
|
|
1832
2465
|
def plot_all_lineages(
|
1833
2466
|
self,
|
1834
|
-
nodes: list = None,
|
1835
|
-
last_time_point_to_consider: int = None,
|
1836
|
-
nrows=2,
|
1837
|
-
figsize=(10, 15),
|
1838
|
-
dpi=100,
|
1839
|
-
fontsize=15,
|
1840
|
-
axes=None,
|
1841
|
-
vert_gap=1,
|
2467
|
+
nodes: list | None = None,
|
2468
|
+
last_time_point_to_consider: int | None = None,
|
2469
|
+
nrows: int = 2,
|
2470
|
+
figsize: tuple[int, int] = (10, 15),
|
2471
|
+
dpi: int = 100,
|
2472
|
+
fontsize: int = 15,
|
2473
|
+
axes: plt.Axes | None = None,
|
2474
|
+
vert_gap: int = 1,
|
1842
2475
|
**kwargs,
|
1843
|
-
):
|
2476
|
+
) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
|
1844
2477
|
"""Plots all lineages.
|
1845
2478
|
|
1846
|
-
|
1847
|
-
|
1848
|
-
|
1849
|
-
|
1850
|
-
|
1851
|
-
|
1852
|
-
|
2479
|
+
Parameters
|
2480
|
+
----------
|
2481
|
+
nodes : list, optional
|
2482
|
+
The nodes spawning the graphs to be plotted.
|
2483
|
+
last_time_point_to_consider : int, optional
|
2484
|
+
Which timepoints and upwards are the graphs to be plotted.
|
2485
|
+
For example if start_time is 10, then all trees that begin
|
2486
|
+
on tp 10 or before are calculated. Defaults to None, where
|
2487
|
+
it will plot all the roots that exist on `self.t_b`.
|
2488
|
+
nrows : int, default=2
|
2489
|
+
How many rows of plots should be printed.
|
2490
|
+
figsize : tuple, default=(10, 15)
|
2491
|
+
The size of the figure.
|
2492
|
+
dpi : int, default=100
|
2493
|
+
The dpi of the figure.
|
2494
|
+
fontsize : int, default=15
|
2495
|
+
The fontsize of the labels.
|
2496
|
+
axes : plt.Axes, optional
|
2497
|
+
The axes to plot the graphs on. If None or not provided new axes are going to be created.
|
2498
|
+
vert_gap : int, default=1
|
2499
|
+
space between the nodes, defaults to 1
|
2500
|
+
**kwargs:
|
2501
|
+
kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
|
2502
|
+
|
2503
|
+
Returns
|
2504
|
+
-------
|
2505
|
+
plt.Figure
|
2506
|
+
The figure
|
2507
|
+
plt.Axes
|
2508
|
+
The axes
|
2509
|
+
dict of plt.Axes to int
|
2510
|
+
A dictionary that maps the axes to the root of the tree.
|
1853
2511
|
"""
|
1854
|
-
|
1855
2512
|
nrows = int(nrows)
|
1856
2513
|
if last_time_point_to_consider is None:
|
1857
2514
|
last_time_point_to_consider = self.t_b
|
@@ -1860,17 +2517,18 @@ class lineageTree(lineageTreeLoaders):
|
|
1860
2517
|
raise Warning("Number of rows has to be at least 1")
|
1861
2518
|
if nodes:
|
1862
2519
|
graphs = {
|
1863
|
-
i: self.
|
2520
|
+
i: self._create_dict_of_plots(node)
|
2521
|
+
for i, node in enumerate(nodes)
|
1864
2522
|
}
|
1865
2523
|
else:
|
1866
|
-
graphs = self.
|
2524
|
+
graphs = self._create_dict_of_plots(
|
1867
2525
|
start_time=last_time_point_to_consider
|
1868
2526
|
)
|
1869
2527
|
pos = {
|
1870
2528
|
i: hierarchical_pos(
|
1871
2529
|
g,
|
1872
2530
|
g["root"],
|
1873
|
-
ycenter=-int(self.
|
2531
|
+
ycenter=-int(self._time[g["root"]]),
|
1874
2532
|
vert_gap=vert_gap,
|
1875
2533
|
)
|
1876
2534
|
for i, g in graphs.items()
|
@@ -1925,135 +2583,147 @@ class lineageTree(lineageTreeLoaders):
|
|
1925
2583
|
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
|
1926
2584
|
return axes.flatten()[0].get_figure(), axes, ax2root
|
1927
2585
|
|
1928
|
-
def
|
2586
|
+
def plot_subtree(
|
1929
2587
|
self,
|
1930
|
-
node,
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
2588
|
+
node: int,
|
2589
|
+
end_time: int | None = None,
|
2590
|
+
figsize: tuple[int, int] = (4, 7),
|
2591
|
+
dpi: int = 150,
|
2592
|
+
vert_gap: int = 2,
|
2593
|
+
selected_nodes: list | None = None,
|
2594
|
+
selected_edges: list | None = None,
|
2595
|
+
color_of_nodes: str | dict = "magenta",
|
2596
|
+
color_of_edges: str | dict = "magenta",
|
2597
|
+
size: int | float = 10,
|
2598
|
+
lw: float = 0.1,
|
2599
|
+
default_color: str = "black",
|
2600
|
+
ax: plt.Axes | None = None,
|
2601
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
1937
2602
|
"""Plots the subtree spawn by a node.
|
1938
2603
|
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1942
|
-
|
1943
|
-
|
2604
|
+
Parameters
|
2605
|
+
----------
|
2606
|
+
node : int
|
2607
|
+
The id of the node that is going to be plotted.
|
2608
|
+
end_time : int, optional
|
2609
|
+
The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
|
2610
|
+
figsize : tuple of 2 ints, default=(4,7)
|
2611
|
+
The size of the figure, deafults to (4,7)
|
2612
|
+
vert_gap : int, default=2
|
2613
|
+
The verical gap of a node when it divides, defaults to 2.
|
2614
|
+
dpi : int, default=150
|
2615
|
+
The dpi of the figure, defaults to 150
|
2616
|
+
selected_nodes : list, optional
|
2617
|
+
The nodes that are selected by the user to be colored in a different color, defaults to None
|
2618
|
+
selected_edges : list, optional
|
2619
|
+
The edges that are selected by the user to be colored in a different color, defaults to None
|
2620
|
+
color_of_nodes : str, default="magenta"
|
2621
|
+
The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
|
2622
|
+
color_of_edges : str, default="magenta"
|
2623
|
+
The color of the edges to be colored, except the default colored ones, defaults to "magenta"
|
2624
|
+
size : int, default=10
|
2625
|
+
The size of the nodes, defaults to 10
|
2626
|
+
lw : float, default=0.1
|
2627
|
+
The widthe of the edges of the tree graph, defaults to 0.1
|
2628
|
+
default_color : str, default="black"
|
2629
|
+
The default color of nodes and edges, defaults to "black"
|
2630
|
+
ax : plt.Axes, optional
|
2631
|
+
The ax where the plot is going to be applied, if not provided or None new axes will be created.
|
2632
|
+
|
2633
|
+
Returns
|
2634
|
+
-------
|
2635
|
+
plt.Figure
|
2636
|
+
The matplotlib figure
|
2637
|
+
plt.Axes
|
2638
|
+
The matplotlib axes
|
2639
|
+
|
2640
|
+
Raises
|
2641
|
+
------
|
2642
|
+
Warning
|
2643
|
+
If more than one nodes are received
|
2644
|
+
"""
|
2645
|
+
graph = self._create_dict_of_plots(node, end_time=end_time)
|
1944
2646
|
if len(graph) > 1:
|
1945
|
-
raise Warning(
|
2647
|
+
raise Warning(
|
2648
|
+
"Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
|
2649
|
+
)
|
1946
2650
|
graph = graph[0]
|
1947
2651
|
if not ax:
|
1948
|
-
|
1949
|
-
nrows=1, ncols=1, figsize=figsize, dpi=dpi
|
1950
|
-
)
|
2652
|
+
_, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
|
1951
2653
|
self.draw_tree_graph(
|
1952
2654
|
hier=hierarchical_pos(
|
1953
2655
|
graph,
|
1954
2656
|
graph["root"],
|
1955
2657
|
vert_gap=vert_gap,
|
1956
|
-
ycenter=-int(self.
|
2658
|
+
ycenter=-int(self._time[node]),
|
1957
2659
|
),
|
2660
|
+
selected_edges=selected_edges,
|
2661
|
+
selected_nodes=selected_nodes,
|
2662
|
+
color_of_edges=color_of_edges,
|
2663
|
+
color_of_nodes=color_of_nodes,
|
2664
|
+
default_color=default_color,
|
2665
|
+
size=size,
|
2666
|
+
lw=lw,
|
1958
2667
|
lnks_tms=graph,
|
1959
2668
|
ax=ax,
|
1960
2669
|
)
|
1961
2670
|
return ax.get_figure(), ax
|
1962
2671
|
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1966
|
-
|
1967
|
-
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1972
|
-
|
1973
|
-
|
1974
|
-
|
1975
|
-
|
1976
|
-
|
1977
|
-
|
1978
|
-
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
# Returns:
|
1988
|
-
# float: the dynamic time warping distance between the two tracks
|
1989
|
-
# """
|
1990
|
-
# from scipy.sparse import
|
1991
|
-
# pos_t1 = [self.pos[ti] for ti in t1]
|
1992
|
-
# pos_t2 = [self.pos[ti] for ti in t2]
|
1993
|
-
# distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
|
1994
|
-
|
1995
|
-
# c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
|
1996
|
-
|
1997
|
-
# d = np.zeros(c.shape)
|
1998
|
-
# d[0, 0] = c[0, 0]
|
1999
|
-
# n, m = c.shape
|
2000
|
-
# for i in range(1, n):
|
2001
|
-
# d[i, 0] = d[i-1, 0] + c[i, 0]
|
2002
|
-
# for j in range(1, m):
|
2003
|
-
# d[0, j] = d[0, j-1] + c[0, j]
|
2004
|
-
# for i in range(1, n):
|
2005
|
-
# for j in range(1, m):
|
2006
|
-
# d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
|
2007
|
-
# return d[-1, -1], d
|
2008
|
-
|
2009
|
-
def __getitem__(self, item):
|
2010
|
-
if isinstance(item, str):
|
2011
|
-
return self.__dict__[item]
|
2012
|
-
elif np.issubdtype(type(item), np.integer):
|
2013
|
-
return self.successor.get(item, [])
|
2014
|
-
else:
|
2015
|
-
raise KeyError(
|
2016
|
-
"Only integer or string are valid key for lineageTree"
|
2017
|
-
)
|
2018
|
-
|
2019
|
-
def get_cells_at_t_from_root(self, r: [int, list], t: int = None) -> list:
|
2020
|
-
"""
|
2021
|
-
Returns the list of cells at time `t` that are spawn by the node(s) `r`.
|
2022
|
-
|
2023
|
-
Args:
|
2024
|
-
r (int | list): id or list of ids of the spawning node
|
2025
|
-
t (int): target time, if None goes as far as possible
|
2026
|
-
(default None)
|
2027
|
-
|
2028
|
-
Returns:
|
2029
|
-
(list) list of nodes at time `t` spawned by `r`
|
2030
|
-
"""
|
2031
|
-
if not isinstance(r, list):
|
2672
|
+
def nodes_at_t(
|
2673
|
+
self,
|
2674
|
+
t: int,
|
2675
|
+
r: int | Iterable[int] | None = None,
|
2676
|
+
) -> list[int]:
|
2677
|
+
"""
|
2678
|
+
Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
|
2679
|
+
|
2680
|
+
Parameters
|
2681
|
+
----------
|
2682
|
+
t : int
|
2683
|
+
target time, if `None` goes as far as possible
|
2684
|
+
r : int or Iterable of int, optional
|
2685
|
+
id or list of ids of the spawning node
|
2686
|
+
|
2687
|
+
Returns
|
2688
|
+
-------
|
2689
|
+
list of int
|
2690
|
+
list of ids of the nodes at time `t` spawned by `r`
|
2691
|
+
"""
|
2692
|
+
if not r and r != 0:
|
2693
|
+
r = {root for root in self.roots if self.time[root] <= t}
|
2694
|
+
if isinstance(r, int):
|
2032
2695
|
r = [r]
|
2696
|
+
if t is None:
|
2697
|
+
t = self.t_e
|
2033
2698
|
to_do = list(r)
|
2034
2699
|
final_nodes = []
|
2035
2700
|
while len(to_do) > 0:
|
2036
2701
|
curr = to_do.pop()
|
2037
|
-
for _next in self[curr]:
|
2038
|
-
if self.
|
2702
|
+
for _next in self._successor[curr]:
|
2703
|
+
if self._time[_next] < t:
|
2039
2704
|
to_do.append(_next)
|
2040
|
-
elif self.
|
2705
|
+
elif self._time[_next] == t:
|
2041
2706
|
final_nodes.append(_next)
|
2042
2707
|
if not final_nodes:
|
2043
2708
|
return list(r)
|
2044
2709
|
return final_nodes
|
2045
2710
|
|
2046
2711
|
@staticmethod
|
2047
|
-
def __calculate_diag_line(dist_mat: np.ndarray) ->
|
2712
|
+
def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
|
2048
2713
|
"""
|
2049
2714
|
Calculate the line that centers the band w.
|
2050
2715
|
|
2051
|
-
|
2052
|
-
|
2716
|
+
Parameters
|
2717
|
+
----------
|
2718
|
+
dist_mat : np.ndarray
|
2719
|
+
distance matrix obtained by the function calculate_dtw
|
2053
2720
|
|
2054
|
-
|
2055
|
-
|
2056
|
-
|
2721
|
+
Returns
|
2722
|
+
-------
|
2723
|
+
float
|
2724
|
+
The slope of the curve
|
2725
|
+
float
|
2726
|
+
The intercept of the curve
|
2057
2727
|
"""
|
2058
2728
|
i, j = dist_mat.shape
|
2059
2729
|
x1 = max(0, i - j) / 2
|
@@ -2073,22 +2743,33 @@ class lineageTree(lineageTreeLoaders):
|
|
2073
2743
|
fast: bool = False,
|
2074
2744
|
w: int = 0,
|
2075
2745
|
centered_band: bool = True,
|
2076
|
-
) ->
|
2746
|
+
) -> tuple[list[int], np.ndarray, float]:
|
2077
2747
|
"""
|
2078
2748
|
Find DTW minimum cost between two series using dynamic programming.
|
2079
2749
|
|
2080
|
-
|
2081
|
-
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
|
2087
|
-
|
2088
|
-
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2750
|
+
Parameters
|
2751
|
+
----------
|
2752
|
+
dist_mat : np.ndarray
|
2753
|
+
distance matrix obtained by the function calculate_dtw
|
2754
|
+
start_d : int, default=0
|
2755
|
+
start delay
|
2756
|
+
back_d : int, default=0
|
2757
|
+
end delay
|
2758
|
+
fast : bool, default=False
|
2759
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2760
|
+
w : int, default=0
|
2761
|
+
window constrain
|
2762
|
+
centered_band : bool, default=True
|
2763
|
+
if `True`, the band will be centered around the diagonal
|
2764
|
+
|
2765
|
+
Returns
|
2766
|
+
-------
|
2767
|
+
tuple of tuples of int
|
2768
|
+
Aligment path
|
2769
|
+
np.ndarray
|
2770
|
+
cost matrix
|
2771
|
+
float
|
2772
|
+
optimal cost
|
2092
2773
|
"""
|
2093
2774
|
N, M = dist_mat.shape
|
2094
2775
|
w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
|
@@ -2209,7 +2890,6 @@ class lineageTree(lineageTreeLoaders):
|
|
2209
2890
|
|
2210
2891
|
# special reflection case
|
2211
2892
|
if np.linalg.det(R) < 0:
|
2212
|
-
# print("det(R) < R, reflection detected!, correcting for it ...")
|
2213
2893
|
Vt[2, :] *= -1
|
2214
2894
|
R = Vt.T @ U.T
|
2215
2895
|
|
@@ -2218,52 +2898,59 @@ class lineageTree(lineageTreeLoaders):
|
|
2218
2898
|
return R, t
|
2219
2899
|
|
2220
2900
|
def __interpolate(
|
2221
|
-
self,
|
2222
|
-
) ->
|
2901
|
+
self, chain1: list, chain2: list, threshold: int
|
2902
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
2223
2903
|
"""
|
2224
2904
|
Interpolate two series that have different lengths
|
2225
2905
|
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2229
|
-
|
2230
|
-
|
2231
|
-
|
2232
|
-
|
2233
|
-
|
2906
|
+
Parameters
|
2907
|
+
----------
|
2908
|
+
chain1 : list of int
|
2909
|
+
list of nodes of the first chain to compare
|
2910
|
+
chain2 : list of int
|
2911
|
+
list of nodes of the second chain to compare
|
2912
|
+
threshold : int
|
2913
|
+
set a maximum number of points a chain can have
|
2914
|
+
|
2915
|
+
Returns
|
2916
|
+
-------
|
2917
|
+
list of np.ndarray
|
2918
|
+
`x`, `y`, `z` postions for `chain1`
|
2919
|
+
list of np.ndarray
|
2920
|
+
`x`, `y`, `z` postions for `chain2`
|
2234
2921
|
"""
|
2235
2922
|
inter1_pos = []
|
2236
2923
|
inter2_pos = []
|
2237
2924
|
|
2238
|
-
|
2239
|
-
|
2925
|
+
chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
|
2926
|
+
chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
|
2240
2927
|
|
2241
|
-
# Both
|
2242
|
-
if len(
|
2243
|
-
len(
|
2928
|
+
# Both chains have the same length and size below the threshold - nothing is done
|
2929
|
+
if len(chain1) == len(chain2) and (
|
2930
|
+
len(chain1) <= threshold or len(chain2) <= threshold
|
2244
2931
|
):
|
2245
|
-
return
|
2246
|
-
# Both
|
2247
|
-
elif len(
|
2932
|
+
return chain1_pos, chain2_pos
|
2933
|
+
# Both chains have the same length but one or more sizes are above the threshold
|
2934
|
+
elif len(chain1) > threshold or len(chain2) > threshold:
|
2248
2935
|
sampling = threshold
|
2249
|
-
#
|
2936
|
+
# chains have different lengths and the sizes are below the threshold
|
2250
2937
|
else:
|
2251
|
-
sampling = max(len(
|
2938
|
+
sampling = max(len(chain1), len(chain2))
|
2252
2939
|
|
2253
2940
|
for pos in range(3):
|
2254
|
-
|
2255
|
-
np.linspace(0, 1, len(
|
2256
|
-
|
2941
|
+
chain1_interp = InterpolatedUnivariateSpline(
|
2942
|
+
np.linspace(0, 1, len(chain1_pos[:, pos])),
|
2943
|
+
chain1_pos[:, pos],
|
2257
2944
|
k=1,
|
2258
2945
|
)
|
2259
|
-
inter1_pos.append(
|
2946
|
+
inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
|
2260
2947
|
|
2261
|
-
|
2262
|
-
np.linspace(0, 1, len(
|
2263
|
-
|
2948
|
+
chain2_interp = InterpolatedUnivariateSpline(
|
2949
|
+
np.linspace(0, 1, len(chain2_pos[:, pos])),
|
2950
|
+
chain2_pos[:, pos],
|
2264
2951
|
k=1,
|
2265
2952
|
)
|
2266
|
-
inter2_pos.append(
|
2953
|
+
inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
|
2267
2954
|
|
2268
2955
|
return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
|
2269
2956
|
|
@@ -2279,46 +2966,66 @@ class lineageTree(lineageTreeLoaders):
|
|
2279
2966
|
w: int = 0,
|
2280
2967
|
centered_band: bool = True,
|
2281
2968
|
cost_mat_p: bool = False,
|
2282
|
-
) -> (
|
2283
|
-
|
2284
|
-
|
2285
|
-
|
2286
|
-
Args:
|
2287
|
-
nodes1 (int): node to compare distance
|
2288
|
-
nodes2 (int): node to compare distance
|
2289
|
-
threshold: set a maximum number of points a track can have
|
2290
|
-
regist (boolean): Rotate and translate trajectories
|
2291
|
-
start_d (int): start delay
|
2292
|
-
back_d (int): end delay
|
2293
|
-
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
2294
|
-
w (int): window size
|
2295
|
-
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
2296
|
-
cost_mat_p (boolean): True if print the not normalized cost matrix
|
2297
|
-
|
2298
|
-
Returns:
|
2299
|
-
(float) DTW distance
|
2300
|
-
(tuple of tuples) Aligment path
|
2301
|
-
(matrix) Cost matrix
|
2302
|
-
(list of lists) pos_cycle1: rotated and translated trajectories positions
|
2303
|
-
(list of lists) pos_cycle2: rotated and translated trajectories positions
|
2969
|
+
) -> (
|
2970
|
+
tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
|
2971
|
+
| tuple[float, tuple]
|
2972
|
+
):
|
2304
2973
|
"""
|
2305
|
-
|
2306
|
-
|
2307
|
-
|
2308
|
-
|
2309
|
-
|
2974
|
+
Calculate DTW distance between two chains
|
2975
|
+
|
2976
|
+
Parameters
|
2977
|
+
----------
|
2978
|
+
nodes1 : int
|
2979
|
+
node to compare distance
|
2980
|
+
nodes2 : int
|
2981
|
+
node to compare distance
|
2982
|
+
threshold : int, default=1000
|
2983
|
+
set a maximum number of points a chain can have
|
2984
|
+
regist : bool, default=True
|
2985
|
+
Rotate and translate trajectories
|
2986
|
+
start_d : int, default=0
|
2987
|
+
start delay
|
2988
|
+
back_d : int, default=0
|
2989
|
+
end delay
|
2990
|
+
fast : bool, default=False
|
2991
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2992
|
+
w : int, default=0
|
2993
|
+
window size
|
2994
|
+
centered_band : bool, default=True
|
2995
|
+
when running the fast algorithm, `True` if the windown is centered
|
2996
|
+
cost_mat_p : bool, default=False
|
2997
|
+
True if print the not normalized cost matrix
|
2998
|
+
|
2999
|
+
Returns
|
3000
|
+
-------
|
3001
|
+
float
|
3002
|
+
DTW distance
|
3003
|
+
tuple of tuples
|
3004
|
+
Aligment path
|
3005
|
+
matrix
|
3006
|
+
Cost matrix
|
3007
|
+
list of lists
|
3008
|
+
rotated and translated trajectories positions
|
3009
|
+
list of lists
|
3010
|
+
rotated and translated trajectories positions
|
3011
|
+
"""
|
3012
|
+
nodes1_chain = self.get_chain_of_node(nodes1)
|
3013
|
+
nodes2_chain = self.get_chain_of_node(nodes2)
|
3014
|
+
|
3015
|
+
interp_chain1, interp_chain2 = self.__interpolate(
|
3016
|
+
nodes1_chain, nodes2_chain, threshold
|
2310
3017
|
)
|
2311
3018
|
|
2312
|
-
|
2313
|
-
|
3019
|
+
pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
|
3020
|
+
pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
|
2314
3021
|
|
2315
3022
|
if regist:
|
2316
3023
|
R, t = self.__rigid_transform_3D(
|
2317
|
-
np.transpose(
|
3024
|
+
np.transpose(interp_chain1), np.transpose(interp_chain2)
|
2318
3025
|
)
|
2319
|
-
|
3026
|
+
pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
|
2320
3027
|
|
2321
|
-
dist_mat = distance.cdist(
|
3028
|
+
dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
|
2322
3029
|
|
2323
3030
|
path, cost_mat, final_cost = self.__dp(
|
2324
3031
|
dist_mat,
|
@@ -2331,7 +3038,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2331
3038
|
cost = final_cost / len(path)
|
2332
3039
|
|
2333
3040
|
if cost_mat_p:
|
2334
|
-
return cost, path, cost_mat,
|
3041
|
+
return cost, path, cost_mat, pos_chain1, pos_chain2
|
2335
3042
|
else:
|
2336
3043
|
return cost, path
|
2337
3044
|
|
@@ -2346,24 +3053,39 @@ class lineageTree(lineageTreeLoaders):
|
|
2346
3053
|
fast: bool = False,
|
2347
3054
|
w: int = 0,
|
2348
3055
|
centered_band: bool = True,
|
2349
|
-
) ->
|
2350
|
-
"""
|
2351
|
-
Plot DTW cost matrix between two
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2355
|
-
|
2356
|
-
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2360
|
-
|
2361
|
-
|
2362
|
-
|
2363
|
-
|
2364
|
-
|
2365
|
-
|
2366
|
-
|
3056
|
+
) -> tuple[float, plt.Figure]:
|
3057
|
+
"""
|
3058
|
+
Plot DTW cost matrix between two chains in heatmap format
|
3059
|
+
|
3060
|
+
Parameters
|
3061
|
+
----------
|
3062
|
+
nodes1 : int
|
3063
|
+
node to compare distance
|
3064
|
+
nodes2 : int
|
3065
|
+
node to compare distance
|
3066
|
+
threshold : int, default=1000
|
3067
|
+
set a maximum number of points a chain can have
|
3068
|
+
regist : bool, default=True
|
3069
|
+
Rotate and translate trajectories
|
3070
|
+
start_d : int, default=0
|
3071
|
+
start delay
|
3072
|
+
back_d : int, default=0
|
3073
|
+
end delay
|
3074
|
+
fast : bool, default=False
|
3075
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
3076
|
+
w : int, default=0
|
3077
|
+
window size
|
3078
|
+
centered_band : bool, default=True
|
3079
|
+
when running the fast algorithm, `True` if the windown is centered
|
3080
|
+
|
3081
|
+
Returns
|
3082
|
+
-------
|
3083
|
+
float
|
3084
|
+
DTW distance
|
3085
|
+
plt.Figure
|
3086
|
+
Heatmap of cost matrix with opitimal path
|
3087
|
+
"""
|
3088
|
+
cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
|
2367
3089
|
nodes1,
|
2368
3090
|
nodes2,
|
2369
3091
|
threshold,
|
@@ -2385,32 +3107,32 @@ class lineageTree(lineageTreeLoaders):
|
|
2385
3107
|
ax.set_title("Heatmap of DTW Cost Matrix")
|
2386
3108
|
ax.set_xlabel("Tree 1")
|
2387
3109
|
ax.set_ylabel("tree 2")
|
2388
|
-
x_path, y_path = zip(*path)
|
3110
|
+
x_path, y_path = zip(*path, strict=True)
|
2389
3111
|
ax.plot(y_path, x_path, color="black")
|
2390
3112
|
|
2391
3113
|
return cost, fig
|
2392
3114
|
|
2393
3115
|
@staticmethod
|
2394
3116
|
def __plot_2d(
|
2395
|
-
|
2396
|
-
|
2397
|
-
nodes1,
|
2398
|
-
nodes2,
|
2399
|
-
ax,
|
2400
|
-
x_idx,
|
2401
|
-
y_idx,
|
2402
|
-
x_label,
|
2403
|
-
y_label,
|
2404
|
-
):
|
3117
|
+
pos_chain1: np.ndarray,
|
3118
|
+
pos_chain2: np.ndarray,
|
3119
|
+
nodes1: list[int],
|
3120
|
+
nodes2: list[int],
|
3121
|
+
ax: plt.Axes,
|
3122
|
+
x_idx: list[int],
|
3123
|
+
y_idx: list[int],
|
3124
|
+
x_label: str,
|
3125
|
+
y_label: str,
|
3126
|
+
) -> None:
|
2405
3127
|
ax.plot(
|
2406
|
-
|
2407
|
-
|
3128
|
+
pos_chain1[:, x_idx],
|
3129
|
+
pos_chain1[:, y_idx],
|
2408
3130
|
"-",
|
2409
3131
|
label=f"root = {nodes1}",
|
2410
3132
|
)
|
2411
3133
|
ax.plot(
|
2412
|
-
|
2413
|
-
|
3134
|
+
pos_chain2[:, x_idx],
|
3135
|
+
pos_chain2[:, y_idx],
|
2414
3136
|
"-",
|
2415
3137
|
label=f"root = {nodes2}",
|
2416
3138
|
)
|
@@ -2428,40 +3150,55 @@ class lineageTree(lineageTreeLoaders):
|
|
2428
3150
|
fast: bool = False,
|
2429
3151
|
w: int = 0,
|
2430
3152
|
centered_band: bool = True,
|
2431
|
-
projection:
|
3153
|
+
projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
|
2432
3154
|
alig: bool = False,
|
2433
|
-
) ->
|
2434
|
-
"""
|
2435
|
-
Plots DTW trajectories aligment between two
|
2436
|
-
|
2437
|
-
|
2438
|
-
|
2439
|
-
|
2440
|
-
|
2441
|
-
|
2442
|
-
|
2443
|
-
|
2444
|
-
|
2445
|
-
|
2446
|
-
|
2447
|
-
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
3155
|
+
) -> tuple[float, plt.Figure]:
|
3156
|
+
"""
|
3157
|
+
Plots DTW trajectories aligment between two chains in 2D or 3D
|
3158
|
+
|
3159
|
+
Parameters
|
3160
|
+
----------
|
3161
|
+
nodes1 : int
|
3162
|
+
node to compare distance
|
3163
|
+
nodes2 : int
|
3164
|
+
node to compare distance
|
3165
|
+
threshold : int, default=1000
|
3166
|
+
set a maximum number of points a chain can have
|
3167
|
+
regist : bool, default=True
|
3168
|
+
Rotate and translate trajectories
|
3169
|
+
start_d : int, default=0
|
3170
|
+
start delay
|
3171
|
+
back_d : int, default=0
|
3172
|
+
end delay
|
3173
|
+
w : int, default=0
|
3174
|
+
window size
|
3175
|
+
fast : bool, default=False
|
3176
|
+
True if the user wants to run the fast algorithm with window restrains
|
3177
|
+
centered_band : bool, default=True
|
3178
|
+
if running the fast algorithm, True if the windown is centered
|
3179
|
+
projection : {"3d", "xy", "xz", "yz", "pca"}, optional
|
3180
|
+
specify which 2D to plot ->
|
3181
|
+
"3d" : for the 3d visualization
|
3182
|
+
"xy" or None (default) : 2D projection of axis x and y
|
3183
|
+
"xz" : 2D projection of axis x and z
|
3184
|
+
"yz" : 2D projection of axis y and z
|
3185
|
+
"pca" : PCA projection
|
3186
|
+
alig : bool
|
3187
|
+
True to show alignment on plot
|
3188
|
+
|
3189
|
+
Returns
|
3190
|
+
-------
|
3191
|
+
float
|
3192
|
+
DTW distance
|
3193
|
+
figure
|
3194
|
+
Trajectories Plot
|
2458
3195
|
"""
|
2459
3196
|
(
|
2460
3197
|
distance,
|
2461
3198
|
alignment,
|
2462
3199
|
cost_mat,
|
2463
|
-
|
2464
|
-
|
3200
|
+
pos_chain1,
|
3201
|
+
pos_chain2,
|
2465
3202
|
) = self.calculate_dtw(
|
2466
3203
|
nodes1,
|
2467
3204
|
nodes2,
|
@@ -2484,16 +3221,16 @@ class lineageTree(lineageTreeLoaders):
|
|
2484
3221
|
|
2485
3222
|
if projection == "3d":
|
2486
3223
|
ax.plot(
|
2487
|
-
|
2488
|
-
|
2489
|
-
|
3224
|
+
pos_chain1[:, 0],
|
3225
|
+
pos_chain1[:, 1],
|
3226
|
+
pos_chain1[:, 2],
|
2490
3227
|
"-",
|
2491
3228
|
label=f"root = {nodes1}",
|
2492
3229
|
)
|
2493
3230
|
ax.plot(
|
2494
|
-
|
2495
|
-
|
2496
|
-
|
3231
|
+
pos_chain2[:, 0],
|
3232
|
+
pos_chain2[:, 1],
|
3233
|
+
pos_chain2[:, 2],
|
2497
3234
|
"-",
|
2498
3235
|
label=f"root = {nodes2}",
|
2499
3236
|
)
|
@@ -2503,8 +3240,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2503
3240
|
else:
|
2504
3241
|
if projection == "xy" or projection == "yx" or projection is None:
|
2505
3242
|
self.__plot_2d(
|
2506
|
-
|
2507
|
-
|
3243
|
+
pos_chain1,
|
3244
|
+
pos_chain2,
|
2508
3245
|
nodes1,
|
2509
3246
|
nodes2,
|
2510
3247
|
ax,
|
@@ -2515,8 +3252,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2515
3252
|
)
|
2516
3253
|
elif projection == "xz" or projection == "zx":
|
2517
3254
|
self.__plot_2d(
|
2518
|
-
|
2519
|
-
|
3255
|
+
pos_chain1,
|
3256
|
+
pos_chain2,
|
2520
3257
|
nodes1,
|
2521
3258
|
nodes2,
|
2522
3259
|
ax,
|
@@ -2527,8 +3264,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2527
3264
|
)
|
2528
3265
|
elif projection == "yz" or projection == "zy":
|
2529
3266
|
self.__plot_2d(
|
2530
|
-
|
2531
|
-
|
3267
|
+
pos_chain1,
|
3268
|
+
pos_chain2,
|
2532
3269
|
nodes1,
|
2533
3270
|
nodes2,
|
2534
3271
|
ax,
|
@@ -2542,24 +3279,25 @@ class lineageTree(lineageTreeLoaders):
|
|
2542
3279
|
from sklearn.decomposition import PCA
|
2543
3280
|
except ImportError:
|
2544
3281
|
Warning(
|
2545
|
-
"scikit-learn is not installed, the PCA orientation cannot be used.
|
3282
|
+
"scikit-learn is not installed, the PCA orientation cannot be used."
|
3283
|
+
"You can install scikit-learn with pip install"
|
2546
3284
|
)
|
2547
3285
|
|
2548
3286
|
# Apply PCA
|
2549
3287
|
pca = PCA(n_components=2)
|
2550
|
-
pca.fit(np.vstack([
|
2551
|
-
|
2552
|
-
|
3288
|
+
pca.fit(np.vstack([pos_chain1, pos_chain2]))
|
3289
|
+
pos_chain1_2d = pca.transform(pos_chain1)
|
3290
|
+
pos_chain2_2d = pca.transform(pos_chain2)
|
2553
3291
|
|
2554
3292
|
ax.plot(
|
2555
|
-
|
2556
|
-
|
3293
|
+
pos_chain1_2d[:, 0],
|
3294
|
+
pos_chain1_2d[:, 1],
|
2557
3295
|
"-",
|
2558
3296
|
label=f"root = {nodes1}",
|
2559
3297
|
)
|
2560
3298
|
ax.plot(
|
2561
|
-
|
2562
|
-
|
3299
|
+
pos_chain2_2d[:, 0],
|
3300
|
+
pos_chain2_2d[:, 1],
|
2563
3301
|
"-",
|
2564
3302
|
label=f"root = {nodes2}",
|
2565
3303
|
)
|
@@ -2588,7 +3326,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2588
3326
|
'pca' : PCA projection"""
|
2589
3327
|
)
|
2590
3328
|
|
2591
|
-
connections = [[
|
3329
|
+
connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
|
2592
3330
|
|
2593
3331
|
for connection in connections:
|
2594
3332
|
xyz1 = connection[0]
|
@@ -2616,108 +3354,197 @@ class lineageTree(lineageTreeLoaders):
|
|
2616
3354
|
|
2617
3355
|
return distance, fig
|
2618
3356
|
|
2619
|
-
def
|
2620
|
-
|
3357
|
+
def get_subtree(self, node_list: set[int]) -> lineageTree:
|
3358
|
+
new_successors = {
|
3359
|
+
n: tuple(vi for vi in self.successor[n] if vi in node_list)
|
3360
|
+
for n in node_list
|
3361
|
+
}
|
3362
|
+
return lineageTree(
|
3363
|
+
successor=new_successors,
|
3364
|
+
time=self._time,
|
3365
|
+
pos=self.pos,
|
3366
|
+
name=self.name,
|
3367
|
+
root_leaf_value=[
|
3368
|
+
(),
|
3369
|
+
],
|
3370
|
+
**{
|
3371
|
+
name: self.__dict__[name]
|
3372
|
+
for name in self._custom_property_list
|
3373
|
+
},
|
3374
|
+
)
|
2621
3375
|
|
2622
3376
|
def __init__(
|
2623
3377
|
self,
|
2624
|
-
|
2625
|
-
|
2626
|
-
|
2627
|
-
|
2628
|
-
|
2629
|
-
|
2630
|
-
|
2631
|
-
|
2632
|
-
|
2633
|
-
reorder: bool = False,
|
2634
|
-
xml_attributes: tuple = None,
|
2635
|
-
name: str = None,
|
2636
|
-
time_resolution: Union[int, None] = None,
|
3378
|
+
*,
|
3379
|
+
successor: dict[int, Sequence] | None = None,
|
3380
|
+
predecessor: dict[int, int | Sequence] | None = None,
|
3381
|
+
time: dict[int, int] | None = None,
|
3382
|
+
starting_time: int | None = None,
|
3383
|
+
pos: dict[int, Iterable] | None = None,
|
3384
|
+
name: str | None = None,
|
3385
|
+
root_leaf_value: Sequence | None = None,
|
3386
|
+
**kwargs,
|
2637
3387
|
):
|
2638
|
-
"""
|
2639
|
-
|
2640
|
-
|
2641
|
-
|
2642
|
-
|
2643
|
-
|
2644
|
-
|
2645
|
-
|
2646
|
-
|
2647
|
-
|
2648
|
-
|
2649
|
-
|
2650
|
-
|
2651
|
-
|
2652
|
-
|
2653
|
-
|
2654
|
-
|
2655
|
-
|
2656
|
-
|
2657
|
-
|
2658
|
-
|
2659
|
-
|
2660
|
-
|
2661
|
-
|
3388
|
+
"""Create a lineageTree object from minimal information, without reading from a file.
|
3389
|
+
Either `successor` or `predecessor` should be specified.
|
3390
|
+
|
3391
|
+
Parameters
|
3392
|
+
----------
|
3393
|
+
successor : dict mapping int to Iterable
|
3394
|
+
Dictionary assigning nodes to their successors.
|
3395
|
+
predecessor : dict mapping int to int or Iterable
|
3396
|
+
Dictionary assigning nodes to their predecessors.
|
3397
|
+
time : dict mapping int to int, optional
|
3398
|
+
Dictionary assigning nodes to the time point they were recorded to.
|
3399
|
+
Defaults to None, in which case all times are set to `starting_time`.
|
3400
|
+
starting_time : int, optional
|
3401
|
+
Starting time of the lineage tree. Defaults to 0.
|
3402
|
+
pos : dict mapping int to Iterable, optional
|
3403
|
+
Dictionary assigning nodes to their positions. Defaults to None.
|
3404
|
+
name : str, optional
|
3405
|
+
Name of the lineage tree. Defaults to None.
|
3406
|
+
root_leaf_value : Iterable, optional
|
3407
|
+
Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
|
3408
|
+
Defaults are `[None, (), [], set()]`.
|
3409
|
+
**kwargs:
|
3410
|
+
Supported keyword arguments are dictionaries assigning nodes to any custom property.
|
3411
|
+
The property must be specified for every node, and named differently from lineageTree's own attributes.
|
3412
|
+
"""
|
3413
|
+
self.__version__ = __version__
|
3414
|
+
self.name = str(name) if name is not None else None
|
3415
|
+
if successor is not None and predecessor is not None:
|
3416
|
+
raise ValueError(
|
3417
|
+
"You cannot have both successors and predecessors."
|
3418
|
+
)
|
2662
3419
|
|
2663
|
-
|
2664
|
-
|
2665
|
-
|
2666
|
-
|
2667
|
-
|
2668
|
-
|
2669
|
-
|
2670
|
-
|
2671
|
-
|
2672
|
-
|
2673
|
-
self.
|
2674
|
-
|
2675
|
-
|
2676
|
-
|
2677
|
-
|
2678
|
-
|
2679
|
-
|
2680
|
-
|
2681
|
-
|
2682
|
-
|
2683
|
-
|
2684
|
-
|
2685
|
-
|
2686
|
-
|
2687
|
-
|
2688
|
-
|
2689
|
-
|
2690
|
-
|
2691
|
-
|
2692
|
-
|
2693
|
-
|
2694
|
-
|
2695
|
-
|
2696
|
-
|
2697
|
-
|
3420
|
+
if root_leaf_value is None:
|
3421
|
+
root_leaf_value = [None, (), [], set()]
|
3422
|
+
elif not isinstance(root_leaf_value, Iterable):
|
3423
|
+
raise TypeError(
|
3424
|
+
f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
|
3425
|
+
)
|
3426
|
+
elif len(root_leaf_value) < 1:
|
3427
|
+
raise ValueError(
|
3428
|
+
"root_leaf_value should have at least one element."
|
3429
|
+
)
|
3430
|
+
self._successor = {}
|
3431
|
+
self._predecessor = {}
|
3432
|
+
if successor is not None:
|
3433
|
+
for pred, succs in successor.items():
|
3434
|
+
if succs in root_leaf_value:
|
3435
|
+
self._successor[pred] = ()
|
3436
|
+
else:
|
3437
|
+
if not isinstance(succs, Iterable):
|
3438
|
+
raise TypeError(
|
3439
|
+
f"Successors should be Iterable, got {type(succs)}."
|
3440
|
+
)
|
3441
|
+
if len(succs) == 0:
|
3442
|
+
raise ValueError(
|
3443
|
+
f"{succs} was not declared as a leaf but was found as a successor.\n"
|
3444
|
+
"Please lift the ambiguity."
|
3445
|
+
)
|
3446
|
+
self._successor[pred] = tuple(succs)
|
3447
|
+
for succ in succs:
|
3448
|
+
if succ in self._predecessor:
|
3449
|
+
raise ValueError(
|
3450
|
+
"Node can have at most one predecessor."
|
3451
|
+
)
|
3452
|
+
self._predecessor[succ] = (pred,)
|
3453
|
+
elif predecessor is not None:
|
3454
|
+
for succ, pred in predecessor.items():
|
3455
|
+
if pred in root_leaf_value:
|
3456
|
+
self._predecessor[succ] = ()
|
3457
|
+
else:
|
3458
|
+
if isinstance(pred, Sequence):
|
3459
|
+
if len(pred) == 0:
|
3460
|
+
raise ValueError(
|
3461
|
+
f"{pred} was not declared as a leaf but was found as a successor.\n"
|
3462
|
+
"Please lift the ambiguity."
|
3463
|
+
)
|
3464
|
+
if 1 < len(pred):
|
3465
|
+
raise ValueError(
|
3466
|
+
"Node can have at most one predecessor."
|
3467
|
+
)
|
3468
|
+
pred = pred[0]
|
3469
|
+
self._predecessor[succ] = (pred,)
|
3470
|
+
self._successor.setdefault(pred, ())
|
3471
|
+
self._successor[pred] += (succ,)
|
3472
|
+
for root in set(self._successor).difference(self._predecessor):
|
3473
|
+
self._predecessor[root] = ()
|
3474
|
+
for leaf in set(self._predecessor).difference(self._successor):
|
3475
|
+
self._successor[leaf] = ()
|
3476
|
+
|
3477
|
+
if self.__check_for_cycles():
|
3478
|
+
raise ValueError(
|
3479
|
+
"Cycles were found in the tree, there should not be any."
|
3480
|
+
)
|
3481
|
+
|
3482
|
+
if pos is None or len(pos) == 0:
|
3483
|
+
self.pos = {}
|
3484
|
+
else:
|
3485
|
+
if self.nodes.difference(pos) != set():
|
3486
|
+
raise ValueError("Please provide the position of all nodes.")
|
3487
|
+
self.pos = {
|
3488
|
+
node: np.array(position) for node, position in pos.items()
|
3489
|
+
}
|
3490
|
+
if "labels" in kwargs:
|
3491
|
+
self._labels = kwargs["labels"]
|
3492
|
+
kwargs.pop("labels")
|
3493
|
+
if time is None:
|
3494
|
+
if starting_time is None:
|
3495
|
+
starting_time = 0
|
3496
|
+
if not isinstance(starting_time, int):
|
3497
|
+
warnings.warn(
|
3498
|
+
f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
|
3499
|
+
stacklevel=2,
|
2698
3500
|
)
|
2699
|
-
|
2700
|
-
|
2701
|
-
|
3501
|
+
self._time = dict.fromkeys(self.roots, starting_time)
|
3502
|
+
queue = list(self.roots)
|
3503
|
+
for node in queue:
|
3504
|
+
for succ in self._successor[node]:
|
3505
|
+
self._time[succ] = self._time[node] + 1
|
3506
|
+
queue.append(succ)
|
3507
|
+
else:
|
3508
|
+
if starting_time is not None:
|
3509
|
+
warnings.warn(
|
3510
|
+
"Both `time` and `starting_time` were provided, `starting_time` was ignored.",
|
3511
|
+
stacklevel=2,
|
3512
|
+
)
|
3513
|
+
self._time = {n: int(time[n]) for n in self.nodes}
|
3514
|
+
if self._time != time:
|
3515
|
+
if len(self._time) != len(time):
|
3516
|
+
warnings.warn(
|
3517
|
+
"The provided `time` dictionary had keys that were not nodes. "
|
3518
|
+
"They have been removed",
|
3519
|
+
stacklevel=2,
|
3520
|
+
)
|
2702
3521
|
else:
|
2703
|
-
|
2704
|
-
|
2705
|
-
|
2706
|
-
|
2707
|
-
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
3522
|
+
warnings.warn(
|
3523
|
+
"The provided `time` dictionary had values that were not `int`. "
|
3524
|
+
"These values have been truncated and converted to `int`",
|
3525
|
+
stacklevel=2,
|
3526
|
+
)
|
3527
|
+
if self.nodes.symmetric_difference(self._time) != set():
|
3528
|
+
raise ValueError(
|
3529
|
+
"Please provide the time of all nodes and only existing nodes."
|
3530
|
+
)
|
3531
|
+
if not all(
|
3532
|
+
self._time[node] < self._time[s]
|
3533
|
+
for node, succ in self._successor.items()
|
3534
|
+
for s in succ
|
3535
|
+
):
|
3536
|
+
raise ValueError(
|
3537
|
+
"Provided times are not strictly increasing. Setting times to default."
|
3538
|
+
)
|
3539
|
+
# custom properties
|
3540
|
+
self._custom_property_list = []
|
3541
|
+
for name, d in kwargs.items():
|
3542
|
+
if name in self.__dict__:
|
3543
|
+
warnings.warn(
|
3544
|
+
f"Attribute name {name} is reserved.", stacklevel=2
|
3545
|
+
)
|
3546
|
+
continue
|
3547
|
+
setattr(self, name, d)
|
3548
|
+
self._custom_property_list.append(name)
|
3549
|
+
if not hasattr(self, "_comparisons"):
|
3550
|
+
self._comparisons = {}
|