LineageTree 1.8.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +27 -2
- LineageTree/legacy/export_csv.py +70 -0
- LineageTree/legacy/to_lineajea.py +30 -0
- LineageTree/legacy/to_motile.py +36 -0
- LineageTree/lineageTree.py +2268 -1467
- LineageTree/lineageTreeManager.py +749 -70
- LineageTree/loaders.py +942 -864
- LineageTree/test/test_lineageTree.py +634 -0
- LineageTree/test/test_uted.py +233 -0
- LineageTree/tree_approximation.py +488 -0
- LineageTree/utils.py +103 -103
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/METADATA +30 -34
- lineagetree-2.0.1.dist-info/RECORD +16 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/WHEEL +1 -1
- LineageTree/tree_styles.py +0 -334
- LineageTree-1.8.0.dist-info/RECORD +0 -11
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info/licenses}/LICENSE +0 -0
- {LineageTree-1.8.0.dist-info → lineagetree-2.0.1.dist-info}/top_level.txt +0 -0
LineageTree/lineageTree.py
CHANGED
@@ -2,643 +2,538 @@
|
|
2
2
|
# This file is subject to the terms and conditions defined in
|
3
3
|
# file 'LICENCE', which is part of this source code package.
|
4
4
|
# Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import importlib.metadata
|
5
9
|
import os
|
6
10
|
import pickle as pkl
|
7
11
|
import struct
|
8
12
|
import warnings
|
9
|
-
from collections.abc import Iterable
|
10
|
-
from functools import partial
|
13
|
+
from collections.abc import Callable, Iterable, Sequence
|
14
|
+
from functools import partial, wraps
|
11
15
|
from itertools import combinations
|
12
16
|
from numbers import Number
|
13
|
-
from
|
14
|
-
from typing import
|
15
|
-
|
16
|
-
|
17
|
-
from .tree_styles import tree_style
|
18
|
-
|
19
|
-
try:
|
20
|
-
from edist import uted
|
21
|
-
except ImportError:
|
22
|
-
warnings.warn(
|
23
|
-
"No edist installed therefore you will not be able to compute the tree edit distance.",
|
24
|
-
stacklevel=2,
|
25
|
-
)
|
17
|
+
from types import MappingProxyType
|
18
|
+
from typing import TYPE_CHECKING, Literal
|
19
|
+
|
20
|
+
import matplotlib.colors as mcolors
|
26
21
|
import matplotlib.pyplot as plt
|
27
22
|
import numpy as np
|
23
|
+
import svgwrite
|
24
|
+
from edist import uted
|
25
|
+
from matplotlib import colormaps
|
26
|
+
from matplotlib.collections import LineCollection
|
27
|
+
from packaging.version import Version
|
28
28
|
from scipy.interpolate import InterpolatedUnivariateSpline
|
29
|
-
from scipy.
|
30
|
-
from scipy.spatial import
|
29
|
+
from scipy.sparse import dok_array
|
30
|
+
from scipy.spatial import Delaunay, KDTree, distance
|
31
31
|
|
32
|
+
from .tree_approximation import TreeApproximationTemplate, tree_style
|
32
33
|
from .utils import (
|
33
|
-
|
34
|
+
convert_style_to_number,
|
35
|
+
create_links_and_chains,
|
34
36
|
hierarchical_pos,
|
35
37
|
)
|
36
38
|
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from edist.alignment import Alignment
|
41
|
+
|
42
|
+
|
43
|
+
class dynamic_property(property):
|
44
|
+
def __init__(
|
45
|
+
self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
|
46
|
+
):
|
47
|
+
super().__init__(fget, fset, fdel, doc)
|
48
|
+
self.protected_name = protected_name
|
49
|
+
|
50
|
+
def __set_name__(self, owner, name):
|
51
|
+
self.name = name
|
52
|
+
if self.protected_name is None:
|
53
|
+
self.protected_name = f"_{name}"
|
54
|
+
if not hasattr(owner, "_protected_dynamic_properties"):
|
55
|
+
owner._protected_dynamic_properties = []
|
56
|
+
owner._protected_dynamic_properties.append(self.protected_name)
|
57
|
+
if not hasattr(owner, "_dynamic_properties"):
|
58
|
+
owner._dynamic_properties = []
|
59
|
+
owner._dynamic_properties += [name, self.protected_name]
|
60
|
+
setattr(owner, self.protected_name, None)
|
61
|
+
|
62
|
+
def __get__(self, instance, owner):
|
63
|
+
if instance is None:
|
64
|
+
return self
|
65
|
+
instance._has_been_reset = False
|
66
|
+
if getattr(instance, self.protected_name) is None:
|
67
|
+
value = super().__get__(instance, owner)
|
68
|
+
setattr(instance, self.protected_name, value)
|
69
|
+
return value
|
70
|
+
else:
|
71
|
+
return getattr(instance, self.protected_name)
|
72
|
+
|
73
|
+
|
74
|
+
class lineageTree:
|
75
|
+
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
|
76
|
+
|
77
|
+
def modifier(wrapped_func):
|
78
|
+
@wraps(wrapped_func)
|
79
|
+
def raising_flag(self, *args, **kwargs):
|
80
|
+
should_reset = (
|
81
|
+
not hasattr(self, "_has_been_reset")
|
82
|
+
or not self._has_been_reset
|
83
|
+
)
|
84
|
+
out_func = wrapped_func(self, *args, **kwargs)
|
85
|
+
if should_reset:
|
86
|
+
for prop in self._protected_dynamic_properties:
|
87
|
+
self.__dict__[prop] = None
|
88
|
+
self._has_been_reset = True
|
89
|
+
return out_func
|
90
|
+
|
91
|
+
return raising_flag
|
92
|
+
|
93
|
+
def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
|
94
|
+
"""Check if the connected component of a given node `n` has a cycle.
|
95
|
+
|
96
|
+
Returns
|
97
|
+
-------
|
98
|
+
bool
|
99
|
+
True if the tree has cycles, False otherwise.
|
100
|
+
set of int
|
101
|
+
The set of nodes that have been checked.
|
102
|
+
"""
|
103
|
+
to_do = [n]
|
104
|
+
no_cycle = True
|
105
|
+
already_done = set()
|
106
|
+
while to_do and no_cycle:
|
107
|
+
current = to_do.pop(-1)
|
108
|
+
if current not in already_done:
|
109
|
+
already_done.add(current)
|
110
|
+
else:
|
111
|
+
no_cycle = False
|
112
|
+
to_do.extend(self._successor[current])
|
113
|
+
to_do = list(self._predecessor[n])
|
114
|
+
while to_do and no_cycle:
|
115
|
+
current = to_do.pop(-1)
|
116
|
+
if current not in already_done:
|
117
|
+
already_done.add(current)
|
118
|
+
else:
|
119
|
+
no_cycle = False
|
120
|
+
to_do.extend(self._predecessor[current])
|
121
|
+
return not no_cycle, already_done
|
37
122
|
|
38
|
-
|
39
|
-
|
123
|
+
def __check_for_cycles(self) -> bool:
|
124
|
+
"""Check if the tree has cycles.
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
bool
|
129
|
+
True if the tree has cycles, False otherwise.
|
130
|
+
"""
|
131
|
+
to_do = set(self.nodes)
|
132
|
+
found_cycle = False
|
133
|
+
while to_do and not found_cycle:
|
134
|
+
current = to_do.pop()
|
135
|
+
found_cycle, done = self.__check_cc_cycles(current)
|
136
|
+
to_do.difference_update(done)
|
137
|
+
return found_cycle
|
138
|
+
|
139
|
+
def __eq__(self, other) -> bool:
|
40
140
|
if isinstance(other, lineageTree):
|
41
|
-
return
|
42
|
-
|
141
|
+
return (
|
142
|
+
other._successor == self._successor
|
143
|
+
and other._predecessor == self._predecessor
|
144
|
+
and other._time == self._time
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
return False
|
43
148
|
|
44
|
-
def get_next_id(self):
|
45
|
-
"""Computes the next authorized id.
|
149
|
+
def get_next_id(self) -> int:
|
150
|
+
"""Computes the next authorized id and assign it.
|
46
151
|
|
47
|
-
Returns
|
48
|
-
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
int
|
155
|
+
next authorized id
|
49
156
|
"""
|
50
|
-
if self.max_id == -1 and self.nodes:
|
51
|
-
self.max_id = max(self.nodes)
|
52
|
-
if self.next_id == []:
|
157
|
+
if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
|
158
|
+
self.max_id = max(self.nodes) if len(self.nodes) else 0
|
159
|
+
if not hasattr(self, "next_id") or self.next_id == []:
|
53
160
|
self.max_id += 1
|
54
161
|
return self.max_id
|
55
162
|
else:
|
56
163
|
return self.next_id.pop()
|
57
164
|
|
58
|
-
def complete_lineage(self, nodes: Union[int, set] = None):
|
59
|
-
"""Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
|
60
|
-
for tree edit distance algorithms.
|
61
|
-
|
62
|
-
Args:
|
63
|
-
nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
|
64
|
-
"""
|
65
|
-
if nodes is None:
|
66
|
-
nodes = set(self.roots)
|
67
|
-
elif isinstance(nodes, int):
|
68
|
-
nodes = {nodes}
|
69
|
-
for node in nodes:
|
70
|
-
sub = set(self.get_sub_tree(node))
|
71
|
-
specific_leaves = sub.intersection(self.leaves)
|
72
|
-
for leaf in specific_leaves:
|
73
|
-
self.add_branch(
|
74
|
-
leaf,
|
75
|
-
(self.t_e - self.time[leaf]),
|
76
|
-
reverse=True,
|
77
|
-
move_timepoints=True,
|
78
|
-
)
|
79
|
-
|
80
165
|
###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
|
81
|
-
|
166
|
+
@modifier
|
167
|
+
def add_chain(
|
82
168
|
self,
|
83
|
-
|
169
|
+
node: int,
|
84
170
|
length: int,
|
85
|
-
|
86
|
-
pos:
|
87
|
-
|
88
|
-
|
89
|
-
"""Adds a branch of specific length to a node either as a successor or as a predecessor.
|
171
|
+
downstream: bool,
|
172
|
+
pos: Callable | None = None,
|
173
|
+
) -> int:
|
174
|
+
"""Adds a chain of specific length to a node either as a successor or as a predecessor.
|
90
175
|
If it is placed on top of a tree all the nodes will move timepoints #length down.
|
91
176
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
177
|
+
Parameters
|
178
|
+
----------
|
179
|
+
node : int
|
180
|
+
Id of the successor (predecessor if `downstream==False`)
|
181
|
+
length : int
|
182
|
+
The length of the new chain.
|
183
|
+
downstream : bool, default=True
|
184
|
+
If `True` will create a chain that goes forwards in time otherwise backwards.
|
185
|
+
pos : np.ndarray, optional
|
186
|
+
The new position of the chain. Defaults to None.
|
187
|
+
|
188
|
+
Returns
|
189
|
+
-------
|
190
|
+
int
|
191
|
+
Id of the first node of the sublineage.
|
100
192
|
"""
|
101
193
|
if length == 0:
|
102
|
-
return
|
103
|
-
if
|
104
|
-
raise
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
new_times = {
|
111
|
-
node: self.time[node] + length for node in nodes_to_move
|
112
|
-
}
|
113
|
-
for node in nodes_to_move:
|
114
|
-
old_time = self.time[node]
|
115
|
-
self.time_nodes[old_time].remove(node)
|
116
|
-
self.time_nodes.setdefault(old_time + length, set()).add(
|
117
|
-
node
|
118
|
-
)
|
119
|
-
self.time.update(new_times)
|
120
|
-
for t in range(length - 1, -1, -1):
|
121
|
-
_next = self.add_node(
|
122
|
-
time + t,
|
123
|
-
succ=pred,
|
124
|
-
pos=self.pos[original],
|
125
|
-
reverse=True,
|
126
|
-
)
|
127
|
-
pred = _next
|
128
|
-
else:
|
129
|
-
for t in range(length):
|
130
|
-
_next = self.add_node(
|
131
|
-
time - t,
|
132
|
-
succ=pred,
|
133
|
-
pos=self.pos[original],
|
134
|
-
reverse=True,
|
135
|
-
)
|
136
|
-
pred = _next
|
194
|
+
return node
|
195
|
+
if length < 1:
|
196
|
+
raise ValueError("Length cannot be <1")
|
197
|
+
if downstream:
|
198
|
+
for _ in range(int(length)):
|
199
|
+
old_node = node
|
200
|
+
node = self._add_node(pred=[old_node])
|
201
|
+
self._time[node] = self._time[old_node] + 1
|
137
202
|
else:
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
reverse=False,
|
203
|
+
if self._predecessor[node]:
|
204
|
+
raise Warning("The node already has a predecessor.")
|
205
|
+
if self._time[node] - length < self.t_b:
|
206
|
+
raise Warning(
|
207
|
+
"A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
|
144
208
|
)
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
"""
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
new_tr = self.add_branch(new_lT, len(cycle), move_timepoints=False)
|
174
|
-
self.roots.add(new_tr)
|
175
|
-
self.labels[new_tr] = f"R-Split {label_of_root}"
|
176
|
-
return new_tr
|
177
|
-
else:
|
178
|
-
raise Warning("No division of the branch")
|
179
|
-
|
180
|
-
def fuse_lineage_tree(
|
181
|
-
self,
|
182
|
-
l1_root: int,
|
183
|
-
l2_root: int,
|
184
|
-
length_l1: int = 0,
|
185
|
-
length_l2: int = 0,
|
186
|
-
length: int = 1,
|
187
|
-
):
|
188
|
-
"""Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
|
189
|
-
first node and the node of the resulting lineage can also be longer.
|
190
|
-
|
191
|
-
Args:
|
192
|
-
l1_root (int): Id of the first root
|
193
|
-
l2_root (int): Id of the second root
|
194
|
-
length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
|
195
|
-
length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
|
196
|
-
length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
|
197
|
-
|
198
|
-
Returns:
|
199
|
-
int: The id of the root of the new lineage.
|
200
|
-
"""
|
201
|
-
if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
|
202
|
-
raise ValueError("Please select 2 roots.")
|
203
|
-
if self.time[l1_root] != self.time[l2_root]:
|
204
|
-
warnings.warn(
|
205
|
-
"Using lineagetrees that do not exist in the same timepoint. The operation will continue"
|
206
|
-
)
|
207
|
-
new_root1 = self.add_branch(l1_root, length_l1)
|
208
|
-
new_root2 = self.add_branch(l2_root, length_l2)
|
209
|
-
next_root1 = self[new_root1][0]
|
210
|
-
self.remove_nodes(new_root1)
|
211
|
-
self.successor[new_root2].append(next_root1)
|
212
|
-
self.predecessor[next_root1] = [new_root2]
|
213
|
-
new_branch = self.add_branch(
|
214
|
-
new_root2,
|
215
|
-
length - 1,
|
216
|
-
)
|
217
|
-
self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
|
218
|
-
return new_branch
|
219
|
-
|
220
|
-
def copy_lineage(self, root):
|
221
|
-
"""
|
222
|
-
Copies the structure of a tree and makes a new with new nodes.
|
223
|
-
Warning does not take into account the predecessor of the root node.
|
224
|
-
|
225
|
-
Args:
|
226
|
-
root (int): The root of the tree to be copied
|
209
|
+
for _ in range(int(length)):
|
210
|
+
old_node = node
|
211
|
+
node = self._add_node(succ=[old_node])
|
212
|
+
self._time[node] = self._time[old_node] - 1
|
213
|
+
return node
|
214
|
+
|
215
|
+
@modifier
|
216
|
+
def add_root(self, t: int, pos: list | None = None) -> int:
|
217
|
+
"""Adds a root to a specific timepoint.
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
----------
|
221
|
+
t :int
|
222
|
+
The timepoint the node is going to be added.
|
223
|
+
pos : list
|
224
|
+
The position of the new node.
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
int
|
228
|
+
The id of the new root.
|
229
|
+
"""
|
230
|
+
C_next = self.get_next_id()
|
231
|
+
self._successor[C_next] = ()
|
232
|
+
self._predecessor[C_next] = ()
|
233
|
+
self._time[C_next] = t
|
234
|
+
self.pos[C_next] = pos if isinstance(pos, list) else []
|
235
|
+
self._changed_roots = True
|
236
|
+
return C_next
|
227
237
|
|
228
|
-
|
229
|
-
int: The root of the new tree.
|
230
|
-
"""
|
231
|
-
new_nodes = {
|
232
|
-
old_node: self.get_next_id()
|
233
|
-
for old_node in self.get_sub_tree(root)
|
234
|
-
}
|
235
|
-
self.nodes.update(new_nodes.values())
|
236
|
-
for old_node, new_node in new_nodes.items():
|
237
|
-
self.time[new_node] = self.time[old_node]
|
238
|
-
succ = self.successor.get(old_node)
|
239
|
-
if succ:
|
240
|
-
self.successor[new_node] = [new_nodes[n] for n in succ]
|
241
|
-
pred = self.predecessor.get(old_node)
|
242
|
-
if pred:
|
243
|
-
self.predecessor[new_node] = [new_nodes[n] for n in pred]
|
244
|
-
self.pos[new_node] = self.pos[old_node] + 0.5
|
245
|
-
self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
|
246
|
-
new_root = new_nodes[root]
|
247
|
-
self.labels[new_root] = f"Copy of {root}"
|
248
|
-
if self.time[new_root] == 0:
|
249
|
-
self.roots.add(new_root)
|
250
|
-
return new_root
|
251
|
-
|
252
|
-
def add_node(
|
238
|
+
def _add_node(
|
253
239
|
self,
|
254
|
-
|
255
|
-
|
256
|
-
pos: np.ndarray = None,
|
257
|
-
nid: int = None,
|
258
|
-
reverse: bool = False,
|
240
|
+
succ: list | None = None,
|
241
|
+
pred: list | None = None,
|
242
|
+
pos: np.ndarray | None = None,
|
243
|
+
nid: int | None = None,
|
259
244
|
) -> int:
|
260
|
-
"""Adds a node to the
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
245
|
+
"""Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
|
246
|
+
Does not handle time! You cannot enter both a successor and a predecessor.
|
247
|
+
|
248
|
+
Parameters
|
249
|
+
----------
|
250
|
+
succ : list
|
251
|
+
list of ids of the nodes the new node is a successor to
|
252
|
+
pred : list
|
253
|
+
list of ids of the nodes the new node is a predecessor to
|
254
|
+
pos : np.ndarray, optional
|
255
|
+
position of the new node
|
256
|
+
nid : int, optional
|
257
|
+
id value of the new node, to be used carefully,
|
258
|
+
if None is provided the new id is automatically computed.
|
259
|
+
|
260
|
+
Returns
|
261
|
+
-------
|
262
|
+
int
|
263
|
+
id of the new node.
|
264
|
+
"""
|
265
|
+
if not succ and not pred:
|
266
|
+
raise Warning(
|
267
|
+
"Please enter a successor or a predecessor, otherwise use the add_roots() function."
|
268
|
+
)
|
275
269
|
C_next = self.get_next_id() if nid is None else nid
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
self.
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
270
|
+
if succ:
|
271
|
+
self._successor[C_next] = succ
|
272
|
+
for suc in succ:
|
273
|
+
self._predecessor[suc] = (C_next,)
|
274
|
+
else:
|
275
|
+
self._successor[C_next] = ()
|
276
|
+
if pred:
|
277
|
+
self._predecessor[C_next] = pred
|
278
|
+
self._successor[pred[0]] = self._successor.setdefault(
|
279
|
+
pred[0], ()
|
280
|
+
) + (C_next,)
|
281
|
+
else:
|
282
|
+
self._predecessor[C_next] = ()
|
283
|
+
if isinstance(pos, list):
|
284
|
+
self.pos[C_next] = pos
|
286
285
|
return C_next
|
287
286
|
|
288
|
-
|
287
|
+
@modifier
|
288
|
+
def remove_nodes(self, group: int | set | list) -> None:
|
289
289
|
"""Removes a group of nodes from the LineageTree
|
290
290
|
|
291
|
-
|
292
|
-
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
group : set of int or list of int or int
|
294
|
+
One or more nodes that are to be removed.
|
293
295
|
"""
|
294
|
-
if isinstance(group, int):
|
296
|
+
if isinstance(group, int | float):
|
295
297
|
group = {group}
|
296
298
|
if isinstance(group, list):
|
297
299
|
group = set(group)
|
298
|
-
group =
|
299
|
-
self.nodes.difference_update(group)
|
300
|
-
times = {self.time.pop(n) for n in group}
|
301
|
-
for t in times:
|
302
|
-
self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
|
300
|
+
group = self.nodes.intersection(group)
|
303
301
|
for node in group:
|
304
|
-
self.
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
self.
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
node (int): Any node of the branch to be modified/
|
325
|
-
new_length (int): The new length of the tree.
|
326
|
-
"""
|
327
|
-
if new_length <= 1:
|
328
|
-
warnings.warn("New length should be more than 1", stacklevel=2)
|
329
|
-
return None
|
330
|
-
cycle = self.get_cycle(node)
|
331
|
-
length = len(cycle)
|
332
|
-
successors = self.successor.get(cycle[-1])
|
333
|
-
if length == 1 and new_length != 1:
|
334
|
-
pred = self.predecessor.pop(node, None)
|
335
|
-
new_node = self.add_branch(
|
336
|
-
node,
|
337
|
-
length=new_length - 1,
|
338
|
-
move_timepoints=True,
|
339
|
-
reverse=False,
|
340
|
-
)
|
341
|
-
if pred:
|
342
|
-
self.successor[pred[0]].remove(node)
|
343
|
-
self.successor[pred[0]].append(new_node)
|
344
|
-
elif self.leaves.intersection(cycle) and new_length < length:
|
345
|
-
self.remove_nodes(cycle[new_length:])
|
346
|
-
elif new_length < length:
|
347
|
-
to_remove = length - new_length
|
348
|
-
last_cell = cycle[new_length - 1]
|
349
|
-
subtree = self.get_sub_tree(cycle[-1])[1:]
|
350
|
-
self.remove_nodes(cycle[new_length:])
|
351
|
-
self.successor[last_cell] = successors
|
352
|
-
if successors:
|
353
|
-
for succ in successors:
|
354
|
-
self.predecessor[succ] = [last_cell]
|
355
|
-
for node in subtree:
|
356
|
-
if node not in cycle[new_length - 1 :]:
|
357
|
-
old_time = self.time[node]
|
358
|
-
self.time[node] = old_time - to_remove
|
359
|
-
self.time_nodes[old_time].remove(node)
|
360
|
-
self.time_nodes.setdefault(
|
361
|
-
old_time - to_remove, set()
|
362
|
-
).add(node)
|
363
|
-
elif length < new_length:
|
364
|
-
to_add = new_length - length
|
365
|
-
last_cell = cycle[-1]
|
366
|
-
self.successor.pop(cycle[-2])
|
367
|
-
self.predecessor.pop(last_cell)
|
368
|
-
succ = self.add_branch(
|
369
|
-
last_cell, length=to_add, move_timepoints=True, reverse=False
|
370
|
-
)
|
371
|
-
self.predecessor[succ] = [cycle[-2]]
|
372
|
-
self.successor[cycle[-2]] = [succ]
|
373
|
-
self.time[last_cell] = (
|
374
|
-
self.time[self.predecessor[last_cell][0]] + 1
|
375
|
-
)
|
376
|
-
else:
|
377
|
-
return None
|
378
|
-
|
379
|
-
@property
|
380
|
-
def time_resolution(self):
|
381
|
-
if not hasattr(self, "_time_resolution"):
|
382
|
-
self.time_resolution = 1
|
383
|
-
return self._time_resolution / 10
|
384
|
-
|
385
|
-
@time_resolution.setter
|
386
|
-
def time_resolution(self, time_resolution):
|
387
|
-
if time_resolution is not None:
|
388
|
-
self._time_resolution = int(time_resolution * 10)
|
389
|
-
else:
|
390
|
-
warnings.warn("Time resolution set to default 0", stacklevel=2)
|
391
|
-
self._time_resolution = 10
|
392
|
-
|
393
|
-
@property
|
394
|
-
def depth(self):
|
395
|
-
if not hasattr(self, "_depth"):
|
396
|
-
self._depth = {}
|
397
|
-
for leaf in self.leaves:
|
398
|
-
self._depth[leaf] = 1
|
399
|
-
while leaf in self.predecessor:
|
400
|
-
parent = self.predecessor[leaf][0]
|
401
|
-
current_depth = self._depth.get(parent, 0)
|
402
|
-
self._depth[parent] = max(
|
403
|
-
self._depth[leaf] + 1, current_depth
|
404
|
-
)
|
405
|
-
leaf = parent
|
406
|
-
for root in self.roots - set(self._depth):
|
407
|
-
self._depth[root] = 1
|
408
|
-
return self._depth
|
302
|
+
for attr in self.__dict__:
|
303
|
+
attr_value = self.__getattribute__(attr)
|
304
|
+
if isinstance(attr_value, dict) and attr not in [
|
305
|
+
"successor",
|
306
|
+
"predecessor",
|
307
|
+
"_successor",
|
308
|
+
"_predecessor",
|
309
|
+
"_time",
|
310
|
+
]:
|
311
|
+
attr_value.pop(node, ())
|
312
|
+
if self._predecessor.get(node):
|
313
|
+
self._successor[self._predecessor[node][0]] = tuple(
|
314
|
+
set(
|
315
|
+
self._successor[self._predecessor[node][0]]
|
316
|
+
).difference(group)
|
317
|
+
)
|
318
|
+
for p_node in self._successor.get(node, []):
|
319
|
+
self._predecessor[p_node] = ()
|
320
|
+
self._predecessor.pop(node, ())
|
321
|
+
self._successor.pop(node, ())
|
409
322
|
|
410
323
|
@property
|
411
|
-
def
|
412
|
-
|
324
|
+
def successor(self) -> MappingProxyType[int, tuple[int]]:
|
325
|
+
"""The successor of the tree."""
|
326
|
+
if not hasattr(self, "_protected_successor"):
|
327
|
+
self._protected_successor = MappingProxyType(self._successor)
|
328
|
+
return self._protected_successor
|
413
329
|
|
414
330
|
@property
|
415
|
-
def
|
416
|
-
|
331
|
+
def predecessor(self) -> MappingProxyType[int, tuple[int]]:
|
332
|
+
"""The predecessor of the tree."""
|
333
|
+
if not hasattr(self, "_protected_predecessor"):
|
334
|
+
self._protected_predecessor = MappingProxyType(self._predecessor)
|
335
|
+
return self._protected_predecessor
|
417
336
|
|
418
337
|
@property
|
419
|
-
def
|
420
|
-
|
338
|
+
def time(self) -> MappingProxyType[int, int]:
|
339
|
+
"""The time of the tree."""
|
340
|
+
if not hasattr(self, "_protected_time"):
|
341
|
+
self._protected_time = MappingProxyType(self._time)
|
342
|
+
return self._protected_time
|
343
|
+
|
344
|
+
@dynamic_property
|
345
|
+
def t_b(self) -> int:
|
346
|
+
"""The first timepoint of the tree."""
|
347
|
+
return min(self._time.values())
|
348
|
+
|
349
|
+
@dynamic_property
|
350
|
+
def t_e(self) -> int:
|
351
|
+
"""The last timepoint of the tree."""
|
352
|
+
return max(self._time.values())
|
353
|
+
|
354
|
+
@dynamic_property
|
355
|
+
def nodes(self) -> frozenset[int]:
|
356
|
+
"""Nodes of the tree"""
|
357
|
+
return frozenset(self._successor.keys())
|
358
|
+
|
359
|
+
@dynamic_property
|
360
|
+
def depth(self) -> dict[int, int]:
|
361
|
+
"""The depth of each node in the tree."""
|
362
|
+
_depth = {}
|
363
|
+
for leaf in self.leaves:
|
364
|
+
_depth[leaf] = 1
|
365
|
+
while leaf in self._predecessor and self._predecessor[leaf]:
|
366
|
+
parent = self._predecessor[leaf][0]
|
367
|
+
current_depth = _depth.get(parent, 0)
|
368
|
+
_depth[parent] = max(_depth[leaf] + 1, current_depth)
|
369
|
+
leaf = parent
|
370
|
+
for root in self.roots - set(_depth):
|
371
|
+
_depth[root] = 1
|
372
|
+
return _depth
|
373
|
+
|
374
|
+
@dynamic_property
|
375
|
+
def roots(self) -> frozenset[int]:
|
376
|
+
"""Set of roots of the tree"""
|
377
|
+
return frozenset({s for s, p in self._predecessor.items() if p == ()})
|
378
|
+
|
379
|
+
@dynamic_property
|
380
|
+
def leaves(self) -> frozenset[int]:
|
381
|
+
"""Set of leaves"""
|
382
|
+
return frozenset({p for p, s in self._successor.items() if s == ()})
|
383
|
+
|
384
|
+
@dynamic_property
|
385
|
+
def edges(self) -> tuple[tuple[int, int]]:
|
386
|
+
"""Set of edges"""
|
387
|
+
return tuple((p, si) for p, s in self._successor.items() for si in s)
|
421
388
|
|
422
389
|
@property
|
423
|
-
def labels(self):
|
390
|
+
def labels(self) -> dict[int, str]:
|
391
|
+
"""The labels of the nodes."""
|
424
392
|
if not hasattr(self, "_labels"):
|
425
|
-
if hasattr(self, "
|
393
|
+
if hasattr(self, "node_name"):
|
426
394
|
self._labels = {
|
427
|
-
i: self.
|
395
|
+
i: self.node_name.get(i, "Unlabeled") for i in self.roots
|
428
396
|
}
|
429
397
|
else:
|
430
398
|
self._labels = {
|
431
399
|
root: "Unlabeled"
|
432
400
|
for root in self.roots
|
433
401
|
for leaf in self.find_leaves(root)
|
434
|
-
if abs(self.
|
402
|
+
if abs(self._time[leaf] - self._time[root])
|
435
403
|
>= abs(self.t_e - self.t_b) / 4
|
436
404
|
}
|
437
405
|
return self._labels
|
438
406
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
f.write("define POINT %d\n" % ((length) * nb_points))
|
445
|
-
f.write("Parameters {\n")
|
446
|
-
f.write('\tContentType "HxSpatialGraph"\n')
|
447
|
-
f.write("}\n")
|
448
|
-
|
449
|
-
f.write("VERTEX { float[3] VertexCoordinates } @1\n")
|
450
|
-
f.write("EDGE { int[2] EdgeConnectivity } @2\n")
|
451
|
-
f.write("EDGE { int NumEdgePoints } @3\n")
|
452
|
-
f.write("POINT { float[3] EdgePointCoordinates } @4\n")
|
453
|
-
f.write("VERTEX { float Vcolor } @5\n")
|
454
|
-
f.write("VERTEX { int Vbool } @6\n")
|
455
|
-
f.write("EDGE { float Ecolor } @7\n")
|
456
|
-
f.write("VERTEX { int Vbool2 } @8\n")
|
457
|
-
|
458
|
-
def write_to_am(
|
459
|
-
self,
|
460
|
-
path_format: str,
|
461
|
-
t_b: int = None,
|
462
|
-
t_e: int = None,
|
463
|
-
length: int = 5,
|
464
|
-
manual_labels: dict = None,
|
465
|
-
default_label: int = 5,
|
466
|
-
new_pos: np.ndarray = None,
|
467
|
-
):
|
468
|
-
"""Writes a lineageTree into an Amira readable data (.am format).
|
469
|
-
|
470
|
-
Args:
|
471
|
-
path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
|
472
|
-
t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
|
473
|
-
t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
|
474
|
-
note, if there is no 'to_take_time' attribute, self.time_nodes
|
475
|
-
is considered instead (historical)
|
476
|
-
length (int): length of the track to print (how many time before).
|
477
|
-
manual_labels ({id: label, }): dictionary that maps cell ids to
|
478
|
-
default_label (int): default value for the manual label
|
479
|
-
new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
|
480
|
-
if new_pos == None (default) then self.pos is considered.
|
481
|
-
"""
|
482
|
-
if not hasattr(self, "to_take_time"):
|
483
|
-
self.to_take_time = self.time_nodes
|
484
|
-
if t_b is None:
|
485
|
-
t_b = min(self.to_take_time.keys())
|
486
|
-
if t_e is None:
|
487
|
-
t_e = max(self.to_take_time.keys())
|
488
|
-
if new_pos is None:
|
489
|
-
new_pos = self.pos
|
490
|
-
|
491
|
-
if manual_labels is None:
|
492
|
-
manual_labels = {}
|
493
|
-
for t in range(t_b, t_e + 1):
|
494
|
-
with open(path_format % t, "w") as f:
|
495
|
-
nb_points = len(self.to_take_time[t])
|
496
|
-
self._write_header_am(f, nb_points, length)
|
497
|
-
points_v = {}
|
498
|
-
for C in self.to_take_time[t]:
|
499
|
-
C_tmp = C
|
500
|
-
positions = []
|
501
|
-
for _ in range(length):
|
502
|
-
C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
|
503
|
-
positions.append(new_pos[C_tmp])
|
504
|
-
points_v[C] = positions
|
505
|
-
|
506
|
-
f.write("@1\n")
|
507
|
-
for C in self.to_take_time[t]:
|
508
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
|
509
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
|
510
|
-
|
511
|
-
f.write("@2\n")
|
512
|
-
for i, _ in enumerate(self.to_take_time[t]):
|
513
|
-
f.write("%d %d\n" % (2 * i, 2 * i + 1))
|
514
|
-
|
515
|
-
f.write("@3\n")
|
516
|
-
for _ in self.to_take_time[t]:
|
517
|
-
f.write("%d\n" % (length))
|
518
|
-
|
519
|
-
f.write("@4\n")
|
520
|
-
for C in self.to_take_time[t]:
|
521
|
-
for p in points_v[C]:
|
522
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
|
523
|
-
|
524
|
-
f.write("@5\n")
|
525
|
-
for C in self.to_take_time[t]:
|
526
|
-
f.write(f"{manual_labels.get(C, default_label):f}\n")
|
527
|
-
f.write(f"{0:f}\n")
|
528
|
-
|
529
|
-
f.write("@6\n")
|
530
|
-
for C in self.to_take_time[t]:
|
531
|
-
f.write(
|
532
|
-
"%d\n"
|
533
|
-
% (
|
534
|
-
int(
|
535
|
-
manual_labels.get(C, default_label)
|
536
|
-
!= default_label
|
537
|
-
)
|
538
|
-
)
|
539
|
-
)
|
540
|
-
f.write("%d\n" % (0))
|
541
|
-
|
542
|
-
f.write("@7\n")
|
543
|
-
for C in self.to_take_time[t]:
|
544
|
-
f.write(
|
545
|
-
f"{np.linalg.norm(points_v[C][0] - points_v[C][-1]):f}\n"
|
546
|
-
)
|
547
|
-
|
548
|
-
f.write("@8\n")
|
549
|
-
for _ in self.to_take_time[t]:
|
550
|
-
f.write("%d\n" % (1))
|
551
|
-
f.write("%d\n" % (0))
|
552
|
-
f.close()
|
407
|
+
@property
|
408
|
+
def time_resolution(self) -> float:
|
409
|
+
if not hasattr(self, "_time_resolution"):
|
410
|
+
self._time_resolution = 0
|
411
|
+
return self._time_resolution / 10
|
553
412
|
|
554
|
-
|
555
|
-
|
413
|
+
@time_resolution.setter
|
414
|
+
def time_resolution(self, time_resolution) -> None:
|
415
|
+
if time_resolution is not None and time_resolution > 0:
|
416
|
+
self._time_resolution = int(time_resolution * 10)
|
417
|
+
else:
|
418
|
+
warnings.warn("Time resolution set to default 0", stacklevel=2)
|
419
|
+
self._time_resolution = 0
|
420
|
+
|
421
|
+
def __setstate__(self, state):
|
422
|
+
if "_successor" not in state:
|
423
|
+
state["_successor"] = state["successor"]
|
424
|
+
if "_predecessor" not in state:
|
425
|
+
state["_predecessor"] = state["predecessor"]
|
426
|
+
if "_time" not in state:
|
427
|
+
state["_time"] = state["time"]
|
428
|
+
self.__dict__.update(state)
|
429
|
+
|
430
|
+
def _get_height(self, c: int, done: dict) -> float:
|
431
|
+
"""Recursively computes the height of a node within a tree times a space factor.
|
556
432
|
This function is specific to the function write_to_svg.
|
557
433
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
434
|
+
Parameters
|
435
|
+
----------
|
436
|
+
c : int
|
437
|
+
id of a node in a lineage tree from which the height will be computed from
|
438
|
+
done : dict mapping int to list of two int
|
439
|
+
a dictionary that maps a node id to its vertical and horizontal position
|
440
|
+
|
441
|
+
Returns
|
442
|
+
-------
|
443
|
+
float
|
444
|
+
the height of the node `c`
|
563
445
|
"""
|
564
446
|
if c in done:
|
565
447
|
return done[c][0]
|
566
448
|
else:
|
567
449
|
P = np.mean(
|
568
|
-
[self._get_height(di, done) for di in self.
|
450
|
+
[self._get_height(di, done) for di in self._successor[c]]
|
569
451
|
)
|
570
|
-
done[c] = [P, self.vert_space_factor * self.
|
452
|
+
done[c] = [P, self.vert_space_factor * self._time[c]]
|
571
453
|
return P
|
572
454
|
|
573
455
|
def write_to_svg(
|
574
456
|
self,
|
575
457
|
file_name: str,
|
576
|
-
roots: list = None,
|
458
|
+
roots: list | None = None,
|
577
459
|
draw_nodes: bool = True,
|
578
460
|
draw_edges: bool = True,
|
579
|
-
order_key:
|
461
|
+
order_key: Callable | None = None,
|
580
462
|
vert_space_factor: float = 0.5,
|
581
463
|
horizontal_space: float = 1,
|
582
|
-
node_size:
|
583
|
-
stroke_width:
|
464
|
+
node_size: Callable | str | None = None,
|
465
|
+
stroke_width: Callable | None = None,
|
584
466
|
factor: float = 1.0,
|
585
|
-
node_color:
|
586
|
-
stroke_color:
|
587
|
-
positions: dict = None,
|
588
|
-
node_color_map:
|
589
|
-
|
590
|
-
):
|
591
|
-
##### remove background? default True background value? default 1
|
592
|
-
|
467
|
+
node_color: Callable | str | None = None,
|
468
|
+
stroke_color: Callable | None = None,
|
469
|
+
positions: dict | None = None,
|
470
|
+
node_color_map: Callable | str | None = None,
|
471
|
+
) -> None:
|
593
472
|
"""Writes the lineage tree to an SVG file.
|
594
473
|
Node and edges coloring and size can be provided.
|
595
474
|
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
475
|
+
Parameters
|
476
|
+
----------
|
477
|
+
file_name : str
|
478
|
+
filesystem filename valid for `open()`
|
479
|
+
roots : list of int, defaults to `self.roots`
|
480
|
+
list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
|
481
|
+
draw_nodes : bool, default True
|
482
|
+
wether to print the nodes or not
|
483
|
+
draw_edges : bool, default True
|
484
|
+
wether to print the edges or not
|
485
|
+
order_key : Callable, optional
|
486
|
+
function that would work for the attribute `key=` for the `sort`/`sorted` function
|
487
|
+
vert_space_factor : float, default=0.5
|
488
|
+
the vertical position of a node is its time. `vert_space_factor` is a
|
489
|
+
multiplier to space more or less nodes in time
|
490
|
+
horizontal_space : float, default=1
|
491
|
+
space between two consecutive nodes
|
492
|
+
node_size : Callable or str, optional
|
493
|
+
a function that maps a node id to a `float` value that will determine the
|
494
|
+
radius of the node. The default function return the constant value `vertical_space_factor/2.1`
|
495
|
+
If a string is given instead and it is a property of the tree,
|
496
|
+
the the size will be mapped according to the property
|
497
|
+
stroke_width : Callable, optional
|
498
|
+
a function that maps a node id to a `float` value that will determine the
|
499
|
+
width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
|
500
|
+
factor : float, default=1.0
|
501
|
+
scaling factor for nodes positions, default 1
|
502
|
+
node_color : Callable or str, optional
|
503
|
+
a function that maps a node id to a triplet between 0 and 255.
|
504
|
+
The triplet will determine the color of the node. If a string is given instead and it is a property
|
505
|
+
of the tree, the the color will be mapped according to the property
|
506
|
+
node_color_map : Callable or str, optional
|
507
|
+
the name of the colormap to use to color the nodes, or a colormap function
|
508
|
+
stroke_color : Callable, optional
|
509
|
+
a function that maps a node id to a triplet between 0 and 255.
|
510
|
+
The triplet will determine the color of the stroke of the inward edge.
|
511
|
+
positions : dict mapping int to list of two float, optional
|
512
|
+
dictionary that maps a node id to a 2D position.
|
513
|
+
Default `None`. If provided it will be used to position the nodes.
|
620
514
|
"""
|
621
|
-
import svgwrite
|
622
515
|
|
623
516
|
def normalize_values(v, nodes, _range, shift, mult):
|
624
517
|
min_ = np.percentile(v, 1)
|
625
518
|
max_ = np.percentile(v, 99)
|
626
519
|
values = _range * ((v - min_) / (max_ - min_)) + shift
|
627
|
-
values_dict_nodes = dict(zip(nodes, values))
|
520
|
+
values_dict_nodes = dict(zip(nodes, values, strict=True))
|
628
521
|
return lambda x: values_dict_nodes[x] * mult
|
629
522
|
|
630
523
|
if roots is None:
|
631
524
|
roots = self.roots
|
632
525
|
if hasattr(self, "image_label"):
|
633
|
-
roots = [
|
526
|
+
roots = [node for node in roots if self.image_label[node] != 1]
|
634
527
|
|
635
528
|
if node_size is None:
|
636
529
|
|
637
530
|
def node_size(x):
|
638
531
|
return vert_space_factor / 2.1
|
639
532
|
|
640
|
-
|
641
|
-
values = np.array(
|
533
|
+
else:
|
534
|
+
values = np.array(
|
535
|
+
[self._successor[node_size][c] for c in self.nodes]
|
536
|
+
)
|
642
537
|
node_size = normalize_values(
|
643
538
|
values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
|
644
539
|
)
|
@@ -653,18 +548,19 @@ class lineageTree(lineageTreeLoaders):
|
|
653
548
|
return 0, 0, 0
|
654
549
|
|
655
550
|
elif isinstance(node_color, str) and node_color in self.__dict__:
|
656
|
-
|
657
|
-
from matplotlib import colormaps
|
551
|
+
from matplotlib import colormaps
|
658
552
|
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
values = np.array(
|
553
|
+
if node_color_map in colormaps:
|
554
|
+
cmap = colormaps[node_color_map]
|
555
|
+
else:
|
556
|
+
cmap = colormaps["viridis"]
|
557
|
+
values = np.array(
|
558
|
+
[self._successor[node_color][c] for c in self.nodes]
|
559
|
+
)
|
664
560
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
665
561
|
|
666
562
|
def node_color(x):
|
667
|
-
return [k * 255 for k in
|
563
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
668
564
|
|
669
565
|
coloring_edges = stroke_color is not None
|
670
566
|
if not coloring_edges:
|
@@ -673,24 +569,25 @@ class lineageTree(lineageTreeLoaders):
|
|
673
569
|
return 0, 0, 0
|
674
570
|
|
675
571
|
elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
|
676
|
-
|
677
|
-
from matplotlib import colormaps
|
572
|
+
from matplotlib import colormaps
|
678
573
|
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
values = np.array(
|
574
|
+
if node_color_map in colormaps:
|
575
|
+
cmap = colormaps[node_color_map]
|
576
|
+
else:
|
577
|
+
cmap = colormaps["viridis"]
|
578
|
+
values = np.array(
|
579
|
+
[self._successor[stroke_color][c] for c in self.nodes]
|
580
|
+
)
|
684
581
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
685
582
|
|
686
583
|
def stroke_color(x):
|
687
|
-
return [k * 255 for k in
|
584
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
688
585
|
|
689
586
|
prev_x = 0
|
690
587
|
self.vert_space_factor = vert_space_factor
|
691
588
|
if order_key is not None:
|
692
589
|
roots.sort(key=order_key)
|
693
|
-
|
590
|
+
treated_nodes = []
|
694
591
|
|
695
592
|
pos_given = positions is not None
|
696
593
|
if not pos_given:
|
@@ -701,25 +598,26 @@ class lineageTree(lineageTreeLoaders):
|
|
701
598
|
[0.0, 0.0],
|
702
599
|
]
|
703
600
|
* len(self.nodes),
|
704
|
-
|
601
|
+
strict=True,
|
602
|
+
),
|
705
603
|
)
|
706
604
|
for _i, r in enumerate(roots):
|
707
605
|
r_leaves = []
|
708
606
|
to_do = [r]
|
709
607
|
while len(to_do) != 0:
|
710
608
|
curr = to_do.pop(0)
|
711
|
-
|
712
|
-
if
|
609
|
+
treated_nodes += [curr]
|
610
|
+
if not self._successor[curr]:
|
713
611
|
if order_key is not None:
|
714
|
-
to_do += sorted(self.
|
612
|
+
to_do += sorted(self._successor[curr], key=order_key)
|
715
613
|
else:
|
716
|
-
to_do += self.
|
614
|
+
to_do += self._successor[curr]
|
717
615
|
else:
|
718
616
|
r_leaves += [curr]
|
719
617
|
r_pos = {
|
720
618
|
leave: [
|
721
619
|
prev_x + horizontal_space * (1 + j),
|
722
|
-
self.vert_space_factor * self.
|
620
|
+
self.vert_space_factor * self._time[leave],
|
723
621
|
]
|
724
622
|
for j, leave in enumerate(r_leaves)
|
725
623
|
}
|
@@ -734,12 +632,12 @@ class lineageTree(lineageTreeLoaders):
|
|
734
632
|
size=factor * np.max(list(positions.values()), axis=0),
|
735
633
|
)
|
736
634
|
if draw_edges and not draw_nodes and not coloring_edges:
|
737
|
-
to_do = set(
|
635
|
+
to_do = set(treated_nodes)
|
738
636
|
while len(to_do) > 0:
|
739
637
|
curr = to_do.pop()
|
740
|
-
|
741
|
-
x1, y1 = positions[
|
742
|
-
x2, y2 = positions[
|
638
|
+
c_chain = self.get_chain_of_node(curr)
|
639
|
+
x1, y1 = positions[c_chain[0]]
|
640
|
+
x2, y2 = positions[c_chain[-1]]
|
743
641
|
dwg.add(
|
744
642
|
dwg.line(
|
745
643
|
(factor * x1, factor * y1),
|
@@ -747,7 +645,7 @@ class lineageTree(lineageTreeLoaders):
|
|
747
645
|
stroke=svgwrite.rgb(0, 0, 0),
|
748
646
|
)
|
749
647
|
)
|
750
|
-
for si in self[
|
648
|
+
for si in self._successor[c_chain[-1]]:
|
751
649
|
x3, y3 = positions[si]
|
752
650
|
dwg.add(
|
753
651
|
dwg.line(
|
@@ -756,11 +654,11 @@ class lineageTree(lineageTreeLoaders):
|
|
756
654
|
stroke=svgwrite.rgb(0, 0, 0),
|
757
655
|
)
|
758
656
|
)
|
759
|
-
to_do.difference_update(
|
657
|
+
to_do.difference_update(c_chain)
|
760
658
|
else:
|
761
|
-
for c in
|
659
|
+
for c in treated_nodes:
|
762
660
|
x1, y1 = positions[c]
|
763
|
-
for si in self[c]:
|
661
|
+
for si in self._successor[c]:
|
764
662
|
x2, y2 = positions[si]
|
765
663
|
if draw_edges:
|
766
664
|
dwg.add(
|
@@ -771,7 +669,7 @@ class lineageTree(lineageTreeLoaders):
|
|
771
669
|
stroke_width=svgwrite.pt(stroke_width(si)),
|
772
670
|
)
|
773
671
|
)
|
774
|
-
for c in
|
672
|
+
for c in treated_nodes:
|
775
673
|
x1, y1 = positions[c]
|
776
674
|
if draw_nodes:
|
777
675
|
dwg.add(
|
@@ -788,49 +686,58 @@ class lineageTree(lineageTreeLoaders):
|
|
788
686
|
fname: str,
|
789
687
|
t_min: int = -1,
|
790
688
|
t_max: int = np.inf,
|
791
|
-
nodes_to_use: list = None,
|
689
|
+
nodes_to_use: list[int] | None = None,
|
792
690
|
temporal: bool = True,
|
793
|
-
spatial: str = None,
|
691
|
+
spatial: str | None = None,
|
794
692
|
write_layout: bool = True,
|
795
|
-
node_properties: dict = None,
|
693
|
+
node_properties: dict | None = None,
|
796
694
|
Names: bool = False,
|
797
|
-
):
|
695
|
+
) -> None:
|
798
696
|
"""Write a lineage tree into an understable tulip file.
|
799
697
|
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
698
|
+
Parameters
|
699
|
+
----------
|
700
|
+
fname : str
|
701
|
+
path to the tulip file to create
|
702
|
+
t_min : int, default=-1
|
703
|
+
minimum time to consider
|
704
|
+
t_max : int, default=np.inf
|
705
|
+
maximum time to consider
|
706
|
+
nodes_to_use : list of int, optional
|
707
|
+
list of nodes to show in the graph,
|
708
|
+
if `None` then self.nodes is used
|
709
|
+
(taking into account `t_min` and `t_max`)
|
710
|
+
temporal : bool, default=True
|
711
|
+
True if the temporal links should be printed
|
712
|
+
spatial : str, optional
|
713
|
+
Build spatial edges from a spatial neighbourhood graph.
|
714
|
+
The graph has to be computed before running this function
|
715
|
+
'ball': neighbours at a given distance,
|
716
|
+
'kn': k-nearest neighbours,
|
717
|
+
'GG': gabriel graph,
|
718
|
+
None: no spatial edges are writen.
|
719
|
+
Default None
|
720
|
+
write_layout : bool, default=True
|
721
|
+
write the spatial position as layout if True
|
722
|
+
do not write spatial position otherwise
|
723
|
+
node_properties : dict mapping str to list of dict of properties and its default value, optional
|
724
|
+
a dictionary of properties to write
|
725
|
+
To a key representing the name of the property is
|
726
|
+
paired a dictionary that maps a node id to a property
|
727
|
+
and a default value for this property
|
728
|
+
Names : bool, default=True
|
729
|
+
Only works with ASTEC outputs, True to sort the nodes by their names
|
823
730
|
"""
|
824
731
|
|
825
732
|
def format_names(names_which_matter):
|
826
|
-
"""Return an ensured formated
|
733
|
+
"""Return an ensured formated node names"""
|
827
734
|
tmp = {}
|
828
735
|
for k, v in names_which_matter.items():
|
829
736
|
tmp[k] = (
|
830
737
|
v.split(".")[0][0]
|
831
|
-
+ "
|
738
|
+
+ "{:02d}".format(int(v.split(".")[0][1:]))
|
832
739
|
+ "."
|
833
|
-
+ "
|
740
|
+
+ "{:04d}".format(int(v.split(".")[1][:-1]))
|
834
741
|
+ v.split(".")[1][-1]
|
835
742
|
)
|
836
743
|
return tmp
|
@@ -857,20 +764,20 @@ class lineageTree(lineageTreeLoaders):
|
|
857
764
|
if not nodes_to_use:
|
858
765
|
if t_max != np.inf or t_min > -1:
|
859
766
|
nodes_to_use = [
|
860
|
-
n for n in self.nodes if t_min < self.
|
767
|
+
n for n in self.nodes if t_min < self._time[n] <= t_max
|
861
768
|
]
|
862
769
|
edges_to_use = []
|
863
770
|
if temporal:
|
864
771
|
edges_to_use += [
|
865
772
|
e
|
866
773
|
for e in self.edges
|
867
|
-
if t_min < self.
|
774
|
+
if t_min < self._time[e[0]] < t_max
|
868
775
|
]
|
869
776
|
if spatial:
|
870
777
|
edges_to_use += [
|
871
778
|
e
|
872
779
|
for e in s_edges
|
873
|
-
if t_min < self.
|
780
|
+
if t_min < self._time[e[0]] < t_max
|
874
781
|
]
|
875
782
|
else:
|
876
783
|
nodes_to_use = list(self.nodes)
|
@@ -884,12 +791,12 @@ class lineageTree(lineageTreeLoaders):
|
|
884
791
|
nodes_to_use = set(nodes_to_use)
|
885
792
|
if temporal:
|
886
793
|
for n in nodes_to_use:
|
887
|
-
for d in self.
|
794
|
+
for d in self._successor[n]:
|
888
795
|
if d in nodes_to_use:
|
889
796
|
edges_to_use.append((n, d))
|
890
797
|
if spatial:
|
891
798
|
edges_to_use += [
|
892
|
-
e for e in s_edges if t_min < self.
|
799
|
+
e for e in s_edges if t_min < self._time[e[0]] < t_max
|
893
800
|
]
|
894
801
|
nodes_to_use = set(nodes_to_use)
|
895
802
|
if Names:
|
@@ -907,12 +814,12 @@ class lineageTree(lineageTreeLoaders):
|
|
907
814
|
for k, v in node_properties[Names][0].items():
|
908
815
|
if (
|
909
816
|
len(
|
910
|
-
self.
|
911
|
-
self.
|
817
|
+
self._successor.get(
|
818
|
+
self._predecessor.get(k, [-1])[0], ()
|
912
819
|
)
|
913
820
|
)
|
914
821
|
!= 1
|
915
|
-
or self.
|
822
|
+
or self._time[k] == t_min + 1
|
916
823
|
):
|
917
824
|
tmp_names[k] = v
|
918
825
|
node_properties[Names][0] = tmp_names
|
@@ -942,7 +849,7 @@ class lineageTree(lineageTreeLoaders):
|
|
942
849
|
f.write('\t(default "0" "0")\n')
|
943
850
|
for n in nodes_to_use:
|
944
851
|
f.write(
|
945
|
-
"\t(node " + str(n) + ' "' + str(self.
|
852
|
+
"\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
|
946
853
|
)
|
947
854
|
f.write(")\n")
|
948
855
|
|
@@ -989,7 +896,9 @@ class lineageTree(lineageTreeLoaders):
|
|
989
896
|
f.write(")")
|
990
897
|
f.close()
|
991
898
|
|
992
|
-
def to_binary(
|
899
|
+
def to_binary(
|
900
|
+
self, fname: str, starting_points: list[int] | None = None
|
901
|
+
) -> None:
|
993
902
|
"""Writes the lineage tree (a forest) as a binary structure
|
994
903
|
(assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
|
995
904
|
The binary file is composed of 3 sequences of numbers and
|
@@ -1004,36 +913,37 @@ class lineageTree(lineageTreeLoaders):
|
|
1004
913
|
The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
|
1005
914
|
The *pos_sequence* is stored as a list of double.
|
1006
915
|
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
916
|
+
Parameters
|
917
|
+
----------
|
918
|
+
fname : str
|
919
|
+
name of the binary file
|
920
|
+
starting_points : list of int, optional
|
921
|
+
list of the roots to be written.
|
922
|
+
If `None`, all roots are written, default value, None
|
1011
923
|
"""
|
1012
924
|
if starting_points is None:
|
1013
|
-
starting_points =
|
1014
|
-
c for c in self.successor if self.predecessor.get(c, []) == []
|
1015
|
-
]
|
925
|
+
starting_points = list(self.roots)
|
1016
926
|
number_sequence = [-1]
|
1017
927
|
pos_sequence = []
|
1018
928
|
time_sequence = []
|
1019
929
|
for c in starting_points:
|
1020
|
-
time_sequence.append(self.
|
930
|
+
time_sequence.append(self._time.get(c, 0))
|
1021
931
|
to_treat = [c]
|
1022
|
-
while to_treat
|
932
|
+
while to_treat:
|
1023
933
|
curr_c = to_treat.pop()
|
1024
934
|
number_sequence.append(curr_c)
|
1025
935
|
pos_sequence += list(self.pos[curr_c])
|
1026
|
-
if self[curr_c] ==
|
936
|
+
if self._successor[curr_c] == ():
|
1027
937
|
number_sequence.append(-1)
|
1028
|
-
elif len(self.
|
1029
|
-
to_treat += self.
|
938
|
+
elif len(self._successor[curr_c]) == 1:
|
939
|
+
to_treat += self._successor[curr_c]
|
1030
940
|
else:
|
1031
941
|
number_sequence.append(-2)
|
1032
|
-
to_treat += self.
|
942
|
+
to_treat += self._successor[curr_c]
|
1033
943
|
remaining_nodes = set(self.nodes) - set(number_sequence)
|
1034
944
|
|
1035
945
|
for c in remaining_nodes:
|
1036
|
-
time_sequence.append(self.
|
946
|
+
time_sequence.append(self._time.get(c, 0))
|
1037
947
|
number_sequence.append(c)
|
1038
948
|
pos_sequence += list(self.pos[c])
|
1039
949
|
number_sequence.append(-1)
|
@@ -1048,50 +958,270 @@ class lineageTree(lineageTreeLoaders):
|
|
1048
958
|
|
1049
959
|
f.close()
|
1050
960
|
|
1051
|
-
def write(self, fname: str):
|
1052
|
-
"""
|
1053
|
-
Write a lineage tree on disk as an .lT file.
|
961
|
+
def write(self, fname: str) -> None:
|
962
|
+
"""Write a lineage tree on disk as an .lT file.
|
1054
963
|
|
1055
|
-
|
1056
|
-
|
964
|
+
Parameters
|
965
|
+
----------
|
966
|
+
fname : str
|
967
|
+
path to and name of the file to save
|
1057
968
|
"""
|
1058
969
|
if os.path.splitext(fname)[-1] != ".lT":
|
1059
970
|
fname = os.path.extsep.join((fname, "lT"))
|
971
|
+
if hasattr(self, "_protected_predecessor"):
|
972
|
+
del self._protected_predecessor
|
973
|
+
if hasattr(self, "_protected_successor"):
|
974
|
+
del self._protected_successor
|
975
|
+
if hasattr(self, "_protected_time"):
|
976
|
+
del self._protected_time
|
1060
977
|
with open(fname, "bw") as f:
|
1061
978
|
pkl.dump(self, f)
|
1062
979
|
f.close()
|
1063
980
|
|
1064
981
|
@classmethod
|
1065
|
-
def load(clf, fname: str
|
1066
|
-
"""
|
1067
|
-
Loading a lineage tree from a ".lT" file.
|
982
|
+
def load(clf, fname: str):
|
983
|
+
"""Loading a lineage tree from a '.lT' file.
|
1068
984
|
|
1069
|
-
|
1070
|
-
|
985
|
+
Parameters
|
986
|
+
----------
|
987
|
+
fname : str
|
988
|
+
path to and name of the file to read
|
1071
989
|
|
1072
|
-
Returns
|
1073
|
-
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
LineageTree
|
993
|
+
loaded file
|
1074
994
|
"""
|
1075
995
|
with open(fname, "br") as f:
|
1076
996
|
lT = pkl.load(f)
|
1077
997
|
f.close()
|
998
|
+
if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
|
999
|
+
"2.0.0"
|
1000
|
+
):
|
1001
|
+
properties = {
|
1002
|
+
prop_name: prop
|
1003
|
+
for prop_name, prop in lT.__dict__.items()
|
1004
|
+
if isinstance(prop, dict)
|
1005
|
+
and prop_name
|
1006
|
+
not in [
|
1007
|
+
"successor",
|
1008
|
+
"predecessor",
|
1009
|
+
"time",
|
1010
|
+
"_successor",
|
1011
|
+
"_predecessor",
|
1012
|
+
"_time",
|
1013
|
+
"pos",
|
1014
|
+
"labels",
|
1015
|
+
]
|
1016
|
+
+ lineageTree._dynamic_properties
|
1017
|
+
+ lineageTree._protected_dynamic_properties
|
1018
|
+
}
|
1019
|
+
print("_comparisons" in properties)
|
1020
|
+
lT = lineageTree(
|
1021
|
+
successor=lT._successor,
|
1022
|
+
time=lT._time,
|
1023
|
+
pos=lT.pos,
|
1024
|
+
name=lT.name if hasattr(lT, "name") else None,
|
1025
|
+
**properties,
|
1026
|
+
)
|
1078
1027
|
if not hasattr(lT, "time_resolution"):
|
1079
|
-
lT.time_resolution =
|
1028
|
+
lT.time_resolution = 1
|
1029
|
+
|
1080
1030
|
return lT
|
1081
1031
|
|
1082
|
-
def
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1032
|
+
def get_predecessors(
|
1033
|
+
self,
|
1034
|
+
x: int,
|
1035
|
+
depth: int | None = None,
|
1036
|
+
start_time: int | None = None,
|
1037
|
+
end_time: int | None = None,
|
1038
|
+
) -> list[int]:
|
1039
|
+
"""Computes the predecessors of the node `x` up to
|
1040
|
+
`depth` predecessors or the begining of the life of `x`.
|
1041
|
+
The ordered list of ids is returned.
|
1042
|
+
|
1043
|
+
Parameters
|
1044
|
+
----------
|
1045
|
+
x : int
|
1046
|
+
id of the node to compute
|
1047
|
+
depth : int
|
1048
|
+
maximum number of predecessors to return
|
1049
|
+
|
1050
|
+
Returns
|
1051
|
+
-------
|
1052
|
+
list of int
|
1053
|
+
list of ids, the last id is `x`
|
1093
1054
|
"""
|
1094
|
-
|
1055
|
+
if not start_time:
|
1056
|
+
start_time = self.t_b
|
1057
|
+
if not end_time:
|
1058
|
+
end_time = self.t_e
|
1059
|
+
unconstrained_chain = [x]
|
1060
|
+
chain = [x] if start_time <= self._time[x] <= end_time else []
|
1061
|
+
acc = 0
|
1062
|
+
while (
|
1063
|
+
acc != depth
|
1064
|
+
and start_time < self._time[unconstrained_chain[0]]
|
1065
|
+
and (
|
1066
|
+
self._predecessor[unconstrained_chain[0]] != ()
|
1067
|
+
and ( # Please dont change very important even if it looks weird.
|
1068
|
+
len(
|
1069
|
+
self._successor[
|
1070
|
+
self._predecessor[unconstrained_chain[0]][0]
|
1071
|
+
]
|
1072
|
+
)
|
1073
|
+
== 1
|
1074
|
+
)
|
1075
|
+
)
|
1076
|
+
):
|
1077
|
+
unconstrained_chain.insert(
|
1078
|
+
0, self._predecessor[unconstrained_chain[0]][0]
|
1079
|
+
)
|
1080
|
+
acc += 1
|
1081
|
+
if start_time <= self._time[unconstrained_chain[0]] <= end_time:
|
1082
|
+
chain.insert(0, unconstrained_chain[0])
|
1083
|
+
|
1084
|
+
return chain
|
1085
|
+
|
1086
|
+
def get_successors(
|
1087
|
+
self, x: int, depth: int | None = None, end_time: int | None = None
|
1088
|
+
) -> list[int]:
|
1089
|
+
"""Computes the successors of the node `x` up to
|
1090
|
+
`depth` successors or the end of the life of `x`.
|
1091
|
+
The ordered list of ids is returned.
|
1092
|
+
|
1093
|
+
Parameters
|
1094
|
+
----------
|
1095
|
+
x : int
|
1096
|
+
id of the node to compute
|
1097
|
+
depth : int, optional
|
1098
|
+
maximum number of predecessors to return
|
1099
|
+
end_time : int, optional
|
1100
|
+
maximum time to consider
|
1101
|
+
|
1102
|
+
Returns
|
1103
|
+
-------
|
1104
|
+
list of int
|
1105
|
+
list of ids, the first id is `x`
|
1106
|
+
"""
|
1107
|
+
if end_time is None:
|
1108
|
+
end_time = self.t_e
|
1109
|
+
chain = [x]
|
1110
|
+
acc = 0
|
1111
|
+
while (
|
1112
|
+
len(self._successor[chain[-1]]) == 1
|
1113
|
+
and acc != depth
|
1114
|
+
and self._time[chain[-1]] < end_time
|
1115
|
+
):
|
1116
|
+
chain += self._successor[chain[-1]]
|
1117
|
+
acc += 1
|
1118
|
+
|
1119
|
+
return chain
|
1120
|
+
|
1121
|
+
def get_chain_of_node(
|
1122
|
+
self,
|
1123
|
+
x: int,
|
1124
|
+
depth: int | None = None,
|
1125
|
+
depth_pred: int | None = None,
|
1126
|
+
depth_succ: int | None = None,
|
1127
|
+
end_time: int | None = None,
|
1128
|
+
) -> list[int]:
|
1129
|
+
"""Computes the predecessors and successors of the node `x` up to
|
1130
|
+
`depth_pred` predecessors plus `depth_succ` successors.
|
1131
|
+
If the value `depth` is provided and not None,
|
1132
|
+
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1133
|
+
The ordered list of ids is returned.
|
1134
|
+
If all `depth` are None, the full chain is returned.
|
1135
|
+
|
1136
|
+
Parameters
|
1137
|
+
----------
|
1138
|
+
x : int
|
1139
|
+
id of the node to compute
|
1140
|
+
depth : int, optional
|
1141
|
+
maximum number of predecessors and successor to return
|
1142
|
+
depth_pred : int, optional
|
1143
|
+
maximum number of predecessors to return
|
1144
|
+
depth_succ : int, optional
|
1145
|
+
maximum number of successors to return
|
1146
|
+
|
1147
|
+
Returns
|
1148
|
+
-------
|
1149
|
+
list of int
|
1150
|
+
list of node ids
|
1151
|
+
"""
|
1152
|
+
if end_time is None:
|
1153
|
+
end_time = self.t_e
|
1154
|
+
if depth is not None:
|
1155
|
+
depth_pred = depth_succ = depth
|
1156
|
+
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1157
|
+
:-1
|
1158
|
+
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1159
|
+
|
1160
|
+
@dynamic_property
|
1161
|
+
def all_chains(self) -> list[list[int]]:
|
1162
|
+
"""List of all chains in the tree, ordered in depth-first search."""
|
1163
|
+
return self._compute_all_chains()
|
1164
|
+
|
1165
|
+
@dynamic_property
|
1166
|
+
def time_nodes(self):
|
1167
|
+
_time_nodes = {}
|
1168
|
+
for c, t in self._time.items():
|
1169
|
+
_time_nodes.setdefault(t, set()).add(c)
|
1170
|
+
return _time_nodes
|
1171
|
+
|
1172
|
+
def m(self, i, j):
|
1173
|
+
if (i, j) not in self._tmp_parenting:
|
1174
|
+
if i == j: # the distance to the node itself is 0
|
1175
|
+
self._tmp_parenting[(i, j)] = 0
|
1176
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1177
|
+
elif not self._predecessor[
|
1178
|
+
j
|
1179
|
+
]: # j and i are note connected so the distance if inf
|
1180
|
+
self._tmp_parenting[(i, j)] = np.inf
|
1181
|
+
else: # the distance between i and j is the distance between i and pred(j) + 1
|
1182
|
+
self._tmp_parenting[(i, j)] = (
|
1183
|
+
self.m(i, self._predecessor[j][0]) + 1
|
1184
|
+
)
|
1185
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1186
|
+
self._parenting[j, i] = -self._tmp_parenting[(i, j)]
|
1187
|
+
return self._tmp_parenting[(i, j)]
|
1188
|
+
|
1189
|
+
@property
|
1190
|
+
def parenting(self):
|
1191
|
+
if not hasattr(self, "_parenting"):
|
1192
|
+
self._parenting = dok_array((max(self.nodes) + 1,) * 2)
|
1193
|
+
self._tmp_parenting = {}
|
1194
|
+
for i, j in combinations(self.nodes, 2):
|
1195
|
+
if self._time[j] < self.time[i]:
|
1196
|
+
i, j = j, i
|
1197
|
+
self._tmp_parenting[(i, j)] = self.m(i, j)
|
1198
|
+
del self._tmp_parenting
|
1199
|
+
return self._parenting
|
1200
|
+
|
1201
|
+
def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
|
1202
|
+
"""Get a 3d kdtree for the dataset at time `t`.
|
1203
|
+
The kdtree is stored in `self.kdtrees[t]` and returned.
|
1204
|
+
The correspondancy list is also returned.
|
1205
|
+
|
1206
|
+
Parameters
|
1207
|
+
----------
|
1208
|
+
t : int
|
1209
|
+
time
|
1210
|
+
|
1211
|
+
Returns
|
1212
|
+
-------
|
1213
|
+
KDTree
|
1214
|
+
The KDTree corresponding to the lineage tree at time `t`
|
1215
|
+
np.ndarray
|
1216
|
+
The correspondancy list in the KDTree.
|
1217
|
+
If the query in the kdtree gives you the value `i`,
|
1218
|
+
then it corresponds to the id in the tree `to_check_self[i]`
|
1219
|
+
"""
|
1220
|
+
to_check_self = list(self.nodes_at_t(t=t))
|
1221
|
+
|
1222
|
+
if not hasattr(self, "kdtrees"):
|
1223
|
+
self.kdtrees = {}
|
1224
|
+
|
1095
1225
|
if t not in self.kdtrees:
|
1096
1226
|
data_corres = {}
|
1097
1227
|
data = []
|
@@ -1104,16 +1234,21 @@ class lineageTree(lineageTreeLoaders):
|
|
1104
1234
|
idx3d = self.kdtrees[t]
|
1105
1235
|
return idx3d, np.array(to_check_self)
|
1106
1236
|
|
1107
|
-
def get_gabriel_graph(self, t: int) -> dict:
|
1108
|
-
"""Build the Gabriel graph of the given graph for time point `t
|
1109
|
-
The Garbiel graph is then stored in self.Gabriel_graph and returned
|
1110
|
-
|
1237
|
+
def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
|
1238
|
+
"""Build the Gabriel graph of the given graph for time point `t`.
|
1239
|
+
The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
|
1240
|
+
|
1241
|
+
.. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
|
1242
|
+
|
1243
|
+
Parameters
|
1244
|
+
----------
|
1245
|
+
t : int
|
1246
|
+
time
|
1111
1247
|
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
the set of its neighbors
|
1248
|
+
Returns
|
1249
|
+
-------
|
1250
|
+
dict of int to set of int
|
1251
|
+
A dictionary that maps a node to the set of its neighbors
|
1117
1252
|
"""
|
1118
1253
|
if not hasattr(self, "Gabriel_graph"):
|
1119
1254
|
self.Gabriel_graph = {}
|
@@ -1158,178 +1293,110 @@ class lineageTree(lineageTreeLoaders):
|
|
1158
1293
|
|
1159
1294
|
return self.Gabriel_graph[t]
|
1160
1295
|
|
1161
|
-
def
|
1162
|
-
self,
|
1163
|
-
) -> list:
|
1164
|
-
"""Computes the
|
1165
|
-
|
1166
|
-
The ordered list of ids is returned.
|
1167
|
-
|
1168
|
-
Args:
|
1169
|
-
x (int): id of the node to compute
|
1170
|
-
depth (int): maximum number of predecessors to return
|
1171
|
-
Returns:
|
1172
|
-
[int, ]: list of ids, the last id is `x`
|
1173
|
-
"""
|
1174
|
-
if not start_time:
|
1175
|
-
start_time = self.t_b
|
1176
|
-
if not end_time:
|
1177
|
-
end_time = self.t_e
|
1178
|
-
unconstrained_cycle = [x]
|
1179
|
-
cycle = [x] if start_time <= self.time[x] <= end_time else []
|
1180
|
-
acc = 0
|
1181
|
-
while (
|
1182
|
-
len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
|
1183
|
-
== 1
|
1184
|
-
and acc != depth
|
1185
|
-
and start_time
|
1186
|
-
<= self.time.get(
|
1187
|
-
self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
|
1188
|
-
)
|
1189
|
-
):
|
1190
|
-
unconstrained_cycle.insert(
|
1191
|
-
0, self.predecessor[unconstrained_cycle[0]][0]
|
1192
|
-
)
|
1193
|
-
acc += 1
|
1194
|
-
if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
|
1195
|
-
cycle.insert(0, unconstrained_cycle[0])
|
1196
|
-
|
1197
|
-
return cycle
|
1198
|
-
|
1199
|
-
def get_successors(
|
1200
|
-
self, x: int, depth: int = None, end_time: int = None
|
1201
|
-
) -> list:
|
1202
|
-
"""Computes the successors of the node `x` up to
|
1203
|
-
`depth` successors or the end of the life of `x`.
|
1204
|
-
The ordered list of ids is returned.
|
1205
|
-
|
1206
|
-
Args:
|
1207
|
-
x (int): id of the node to compute
|
1208
|
-
depth (int): maximum number of predecessors to return
|
1209
|
-
Returns:
|
1210
|
-
[int, ]: list of ids, the first id is `x`
|
1211
|
-
"""
|
1212
|
-
if end_time is None:
|
1213
|
-
end_time = self.t_e
|
1214
|
-
cycle = [x]
|
1215
|
-
acc = 0
|
1216
|
-
while (
|
1217
|
-
len(self[cycle[-1]]) == 1
|
1218
|
-
and acc != depth
|
1219
|
-
and self.time[cycle[-1]] < end_time
|
1220
|
-
):
|
1221
|
-
cycle += self.successor[cycle[-1]]
|
1222
|
-
acc += 1
|
1223
|
-
|
1224
|
-
return cycle
|
1225
|
-
|
1226
|
-
def get_cycle(
|
1227
|
-
self,
|
1228
|
-
x: int,
|
1229
|
-
depth: int = None,
|
1230
|
-
depth_pred: int = None,
|
1231
|
-
depth_succ: int = None,
|
1232
|
-
end_time: int = None,
|
1233
|
-
) -> list:
|
1234
|
-
"""Computes the predecessors and successors of the node `x` up to
|
1235
|
-
`depth_pred` predecessors plus `depth_succ` successors.
|
1236
|
-
If the value `depth` is provided and not None,
|
1237
|
-
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1238
|
-
The ordered list of ids is returned.
|
1239
|
-
If all `depth` are None, the full cycle is returned.
|
1240
|
-
|
1241
|
-
Args:
|
1242
|
-
x (int): id of the node to compute
|
1243
|
-
depth (int): maximum number of predecessors and successor to return
|
1244
|
-
depth_pred (int): maximum number of predecessors to return
|
1245
|
-
depth_succ (int): maximum number of successors to return
|
1246
|
-
Returns:
|
1247
|
-
[int, ]: list of ids
|
1248
|
-
"""
|
1249
|
-
if end_time is None:
|
1250
|
-
end_time = self.t_e
|
1251
|
-
if depth is not None:
|
1252
|
-
depth_pred = depth_succ = depth
|
1253
|
-
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1254
|
-
:-1
|
1255
|
-
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1256
|
-
|
1257
|
-
@property
|
1258
|
-
def all_tracks(self):
|
1259
|
-
if not hasattr(self, "_all_tracks"):
|
1260
|
-
self._all_tracks = self.get_all_tracks()
|
1261
|
-
return self._all_tracks
|
1262
|
-
|
1263
|
-
def get_all_branches_of_node(
|
1264
|
-
self, node: int, end_time: int = None
|
1265
|
-
) -> list:
|
1266
|
-
"""Computes all the tracks of the subtree spawn by a given node.
|
1267
|
-
Similar to get_all_tracks().
|
1296
|
+
def get_all_chains_of_subtree(
|
1297
|
+
self, node: int, end_time: int | None = None
|
1298
|
+
) -> list[list[int]]:
|
1299
|
+
"""Computes all the chains of the subtree spawn by a given node.
|
1300
|
+
Similar to get_all_chains().
|
1268
1301
|
|
1269
|
-
|
1270
|
-
|
1302
|
+
Parameters
|
1303
|
+
----------
|
1304
|
+
node : int
|
1305
|
+
The node from which we want to get its chains.
|
1306
|
+
end_time : int, optional
|
1307
|
+
The time at which we want to stop the chains.
|
1271
1308
|
|
1272
|
-
Returns
|
1273
|
-
|
1309
|
+
Returns
|
1310
|
+
-------
|
1311
|
+
list of list of int
|
1312
|
+
list of chains
|
1274
1313
|
"""
|
1275
1314
|
if not end_time:
|
1276
1315
|
end_time = self.t_e
|
1277
|
-
|
1278
|
-
to_do = list(self[
|
1316
|
+
chains = [self.get_successors(node)]
|
1317
|
+
to_do = list(self._successor[chains[0][-1]])
|
1279
1318
|
while to_do:
|
1280
1319
|
current = to_do.pop()
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
return
|
1304
|
-
|
1305
|
-
def
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1320
|
+
chain = self.get_successors(current, end_time=end_time)
|
1321
|
+
if self._time[chain[-1]] <= end_time:
|
1322
|
+
chains += [chain]
|
1323
|
+
to_do += self._successor[chain[-1]]
|
1324
|
+
return chains
|
1325
|
+
|
1326
|
+
def _compute_all_chains(self) -> list[list[int]]:
|
1327
|
+
"""Computes all the chains of a given lineage tree,
|
1328
|
+
stores it in `self.all_chains` and returns it.
|
1329
|
+
|
1330
|
+
Returns
|
1331
|
+
-------
|
1332
|
+
list of list of int
|
1333
|
+
list of chains
|
1334
|
+
"""
|
1335
|
+
all_chains = []
|
1336
|
+
to_do = sorted(self.roots, key=self.time.get, reverse=True)
|
1337
|
+
while len(to_do) != 0:
|
1338
|
+
current = to_do.pop()
|
1339
|
+
chain = self.get_chain_of_node(current)
|
1340
|
+
all_chains += [chain]
|
1341
|
+
to_do.extend(self._successor[chain[-1]])
|
1342
|
+
return all_chains
|
1343
|
+
|
1344
|
+
def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
|
1345
|
+
self, nodes: Iterable | int | None = None
|
1346
|
+
) -> dict[int, list[list[int]]]:
|
1347
|
+
"""Returns all the chains in the subtrees spawned by each of the given nodes.
|
1348
|
+
|
1349
|
+
Parameters
|
1350
|
+
----------
|
1351
|
+
nodes : Iterable or int, optional
|
1352
|
+
id or Iterable of ids of the nodes to be computed, if `None` all roots are used
|
1353
|
+
|
1354
|
+
Returns
|
1355
|
+
-------
|
1356
|
+
dict mapping int to list of Chain
|
1357
|
+
dictionary mapping the node ids to a list of chains
|
1358
|
+
"""
|
1359
|
+
all_chains = self.all_chains
|
1360
|
+
if nodes is None:
|
1361
|
+
nodes = self.roots
|
1362
|
+
if not isinstance(nodes, Iterable):
|
1363
|
+
nodes = [nodes]
|
1364
|
+
output_chains = {}
|
1365
|
+
for n in nodes:
|
1366
|
+
starting_node = self.get_predecessors(n)[0]
|
1367
|
+
found = False
|
1368
|
+
done = False
|
1369
|
+
starting_time = self.time[n]
|
1370
|
+
i = 0
|
1371
|
+
current_chain = []
|
1372
|
+
while not done and i < len(all_chains):
|
1373
|
+
curr_found = all_chains[i][0] == starting_node
|
1374
|
+
found = found or curr_found
|
1375
|
+
if found:
|
1376
|
+
done = (
|
1377
|
+
self.time[all_chains[i][0]] <= starting_time
|
1378
|
+
) and not curr_found
|
1379
|
+
if not done:
|
1380
|
+
if curr_found:
|
1381
|
+
current_chain.append(self.get_successors(n))
|
1382
|
+
else:
|
1383
|
+
current_chain.append(all_chains[i])
|
1384
|
+
i += 1
|
1385
|
+
output_chains[n] = current_chain
|
1386
|
+
return output_chains
|
1387
|
+
|
1388
|
+
def find_leaves(self, roots: int | Iterable) -> set[int]:
|
1326
1389
|
"""Finds the leaves of a tree spawned by one or more nodes.
|
1327
1390
|
|
1328
|
-
|
1329
|
-
|
1391
|
+
Parameters
|
1392
|
+
----------
|
1393
|
+
roots : int or Iterable
|
1394
|
+
The roots of the trees spawning the leaves
|
1330
1395
|
|
1331
|
-
Returns
|
1332
|
-
|
1396
|
+
Returns
|
1397
|
+
-------
|
1398
|
+
set
|
1399
|
+
The leaves of one or more trees.
|
1333
1400
|
"""
|
1334
1401
|
if not isinstance(roots, Iterable):
|
1335
1402
|
to_do = [roots]
|
@@ -1338,28 +1405,34 @@ class lineageTree(lineageTreeLoaders):
|
|
1338
1405
|
leaves = set()
|
1339
1406
|
while to_do:
|
1340
1407
|
curr = to_do.pop()
|
1341
|
-
succ = self.
|
1408
|
+
succ = self._successor[curr]
|
1342
1409
|
if not succ:
|
1343
1410
|
leaves.add(curr)
|
1344
1411
|
to_do += succ
|
1345
1412
|
return leaves
|
1346
1413
|
|
1347
|
-
def
|
1414
|
+
def get_subtree_nodes(
|
1348
1415
|
self,
|
1349
|
-
x:
|
1350
|
-
end_time:
|
1416
|
+
x: int | Iterable,
|
1417
|
+
end_time: int | None = None,
|
1351
1418
|
preorder: bool = False,
|
1352
|
-
) -> list:
|
1353
|
-
"""Computes the list of
|
1354
|
-
The default output order is
|
1419
|
+
) -> list[int]:
|
1420
|
+
"""Computes the list of nodes from the subtree spawned by *x*
|
1421
|
+
The default output order is Breadth First Traversal.
|
1355
1422
|
Unless preorder is `True` in that case the order is
|
1356
|
-
Depth
|
1423
|
+
Depth First Traversal (DFT) preordered.
|
1357
1424
|
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1425
|
+
Parameters
|
1426
|
+
----------
|
1427
|
+
x : int
|
1428
|
+
id of root node
|
1429
|
+
preorder : bool, default=False
|
1430
|
+
if True the output preorder is DFT
|
1431
|
+
|
1432
|
+
Returns
|
1433
|
+
-------
|
1434
|
+
list of int
|
1435
|
+
the ordered list of node ids
|
1363
1436
|
"""
|
1364
1437
|
if not end_time:
|
1365
1438
|
end_time = self.t_e
|
@@ -1367,233 +1440,258 @@ class lineageTree(lineageTreeLoaders):
|
|
1367
1440
|
to_do = [x]
|
1368
1441
|
elif isinstance(x, Iterable):
|
1369
1442
|
to_do = list(x)
|
1370
|
-
|
1443
|
+
subtree = []
|
1371
1444
|
while to_do:
|
1372
1445
|
curr = to_do.pop()
|
1373
|
-
succ = self.
|
1374
|
-
if succ and end_time < self.
|
1446
|
+
succ = self._successor[curr]
|
1447
|
+
if succ and end_time < self._time.get(curr, end_time):
|
1375
1448
|
succ = []
|
1376
1449
|
continue
|
1377
1450
|
if preorder:
|
1378
1451
|
to_do = succ + to_do
|
1379
1452
|
else:
|
1380
1453
|
to_do += succ
|
1381
|
-
|
1382
|
-
return
|
1454
|
+
subtree += [curr]
|
1455
|
+
return subtree
|
1383
1456
|
|
1384
1457
|
def compute_spatial_density(
|
1385
|
-
self, t_b: int = None, t_e: int = None, th: float = 50
|
1386
|
-
) -> dict:
|
1387
|
-
"""Computes the spatial density of
|
1388
|
-
The
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1458
|
+
self, t_b: int | None = None, t_e: int | None = None, th: float = 50
|
1459
|
+
) -> dict[int, float]:
|
1460
|
+
"""Computes the spatial density of nodes between `t_b` and `t_e`.
|
1461
|
+
The results is stored in `self.spatial_density` and returned.
|
1462
|
+
|
1463
|
+
Parameters
|
1464
|
+
----------
|
1465
|
+
t_b : int, optional
|
1466
|
+
starting time to look at, default first time point
|
1467
|
+
t_e : int, optional
|
1468
|
+
ending time to look at, default last time point
|
1469
|
+
th : float, default=50
|
1470
|
+
size of the neighbourhood
|
1471
|
+
|
1472
|
+
Returns
|
1473
|
+
-------
|
1474
|
+
dict mapping int to float
|
1475
|
+
dictionary that maps a node id to its spatial density
|
1476
|
+
"""
|
1477
|
+
if not hasattr(self, "spatial_density"):
|
1478
|
+
self.spatial_density = {}
|
1399
1479
|
s_vol = 4 / 3.0 * np.pi * th**3
|
1400
|
-
|
1480
|
+
if t_b is None:
|
1481
|
+
t_b = self.t_b
|
1482
|
+
if t_e is None:
|
1483
|
+
t_e = self.t_e
|
1484
|
+
time_range = set(range(t_b, t_e)).intersection(self._time.values())
|
1401
1485
|
for t in time_range:
|
1402
1486
|
idx3d, nodes = self.get_idx3d(t)
|
1403
1487
|
nb_ni = [
|
1404
1488
|
(len(ni) - 1) / s_vol
|
1405
1489
|
for ni in idx3d.query_ball_tree(idx3d, th)
|
1406
1490
|
]
|
1407
|
-
self.spatial_density.update(dict(zip(nodes, nb_ni)))
|
1491
|
+
self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
|
1408
1492
|
return self.spatial_density
|
1409
1493
|
|
1410
|
-
def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
|
1494
|
+
def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
|
1411
1495
|
"""Computes the k-nearest neighbors
|
1412
1496
|
Writes the output in the attribute `kn_graph`
|
1413
1497
|
and returns it.
|
1414
1498
|
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1499
|
+
Parameters
|
1500
|
+
----------
|
1501
|
+
k : float
|
1502
|
+
number of nearest neighours
|
1503
|
+
|
1504
|
+
Returns
|
1505
|
+
-------
|
1506
|
+
dict mapping int to set of int
|
1507
|
+
dictionary that maps
|
1508
|
+
a node id to its `k` nearest neighbors
|
1420
1509
|
"""
|
1421
1510
|
self.kn_graph = {}
|
1422
|
-
for t
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1511
|
+
for t in set(self._time.values()):
|
1512
|
+
nodes = self.nodes_at_t(t)
|
1513
|
+
if 1 < len(nodes):
|
1514
|
+
use_k = k if k < len(nodes) else len(nodes)
|
1515
|
+
idx3d, nodes = self.get_idx3d(t)
|
1516
|
+
pos = [self.pos[c] for c in nodes]
|
1517
|
+
_, neighbs = idx3d.query(pos, use_k)
|
1518
|
+
out = dict(
|
1519
|
+
zip(
|
1520
|
+
nodes,
|
1521
|
+
map(set, nodes[neighbs]),
|
1522
|
+
strict=True,
|
1523
|
+
)
|
1524
|
+
)
|
1525
|
+
self.kn_graph.update(out)
|
1526
|
+
else:
|
1527
|
+
n = nodes.pop
|
1528
|
+
self.kn_graph.update({n: {n}})
|
1429
1529
|
return self.kn_graph
|
1430
1530
|
|
1431
|
-
def compute_spatial_edges(self, th: int = 50) -> dict:
|
1531
|
+
def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
|
1432
1532
|
"""Computes the neighbors at a distance `th`
|
1433
1533
|
Writes the output in the attribute `th_edge`
|
1434
1534
|
and returns it.
|
1435
1535
|
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1536
|
+
Parameters
|
1537
|
+
----------
|
1538
|
+
th : float, default=50
|
1539
|
+
distance to consider neighbors
|
1540
|
+
|
1541
|
+
Returns
|
1542
|
+
-------
|
1543
|
+
dict mapping int to set of int
|
1544
|
+
dictionary that maps a node id to its neighbors at a distance `th`
|
1441
1545
|
"""
|
1442
1546
|
self.th_edges = {}
|
1443
|
-
for t
|
1547
|
+
for t in set(self._time.values()):
|
1548
|
+
nodes = self.nodes_at_t(t)
|
1444
1549
|
idx3d, nodes = self.get_idx3d(t)
|
1445
1550
|
neighbs = idx3d.query_ball_tree(idx3d, th)
|
1446
|
-
out = dict(
|
1551
|
+
out = dict(
|
1552
|
+
zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
|
1553
|
+
)
|
1447
1554
|
self.th_edges.update(
|
1448
1555
|
{k: v.difference([k]) for k, v in out.items()}
|
1449
1556
|
)
|
1450
1557
|
return self.th_edges
|
1451
1558
|
|
1452
|
-
def
|
1453
|
-
"""
|
1454
|
-
If none will select the timepoint with the highest amound of cells.
|
1455
|
-
|
1456
|
-
Args:
|
1457
|
-
time (int, optional): The timepoint to find the main axes.
|
1458
|
-
If None will find the timepoint
|
1459
|
-
with the largest number of cells.
|
1460
|
-
|
1461
|
-
Returns:
|
1462
|
-
list: A list that contains the array of eigenvalues and eigenvectors.
|
1463
|
-
"""
|
1464
|
-
if time is None:
|
1465
|
-
time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
|
1466
|
-
pos = np.array([self.pos[node] for node in self.time_nodes[time]])
|
1467
|
-
pos = pos - np.mean(pos, axis=0)
|
1468
|
-
cov = np.cov(np.array(pos).T)
|
1469
|
-
eig_val, eig_vec = np.linalg.eig(cov)
|
1470
|
-
srt = np.argsort(eig_val)[::-1]
|
1471
|
-
self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
|
1472
|
-
return eig_val[srt], eig_vec[:, srt]
|
1473
|
-
|
1474
|
-
def scale_embryo(self, scale=1000):
|
1475
|
-
"""Scale the embryo using their eigenvalues.
|
1476
|
-
|
1477
|
-
Args:
|
1478
|
-
scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
|
1479
|
-
|
1480
|
-
Returns:
|
1481
|
-
float: The scale factor.
|
1482
|
-
"""
|
1483
|
-
eig = self.main_axes()[0]
|
1484
|
-
return scale / (np.sqrt(eig[0]))
|
1485
|
-
|
1486
|
-
@staticmethod
|
1487
|
-
def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
|
1488
|
-
"""Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
|
1489
|
-
Uses the Rodrigues rotation formula.
|
1490
|
-
|
1491
|
-
Args:
|
1492
|
-
vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
|
1493
|
-
vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
|
1494
|
-
|
1495
|
-
Returns:
|
1496
|
-
np.array: The rotation matrix.
|
1497
|
-
"""
|
1498
|
-
vector1 = vector1 / np.linalg.norm(vector1)
|
1499
|
-
vector2 = vector2 / np.linalg.norm(vector2)
|
1500
|
-
if vector1 @ vector2 == 1:
|
1501
|
-
return np.eye(3)
|
1502
|
-
angle = np.arccos(vector1 @ vector2)
|
1503
|
-
axis = np.cross(vector1, vector2)
|
1504
|
-
axis = axis / np.linalg.norm(axis)
|
1505
|
-
K = np.array(
|
1506
|
-
[
|
1507
|
-
[0, -axis[2], axis[1]],
|
1508
|
-
[axis[2], 0, -axis[0]],
|
1509
|
-
[-axis[1], axis[0], 0],
|
1510
|
-
]
|
1511
|
-
)
|
1512
|
-
return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
|
1513
|
-
|
1514
|
-
def get_ancestor_at_t(self, n: int, time: int = None):
|
1515
|
-
"""
|
1516
|
-
Find the id of the ancestor of a give node `n`
|
1559
|
+
def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
|
1560
|
+
"""Find the id of the ancestor of a give node `n`
|
1517
1561
|
at a given time `time`.
|
1518
1562
|
|
1519
|
-
If there is no ancestor, returns
|
1520
|
-
If time is None return the root of the
|
1563
|
+
If there is no ancestor, returns `None`
|
1564
|
+
If time is None return the root of the subtree that spawns
|
1521
1565
|
the node n.
|
1522
1566
|
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1567
|
+
Parameters
|
1568
|
+
----------
|
1569
|
+
n : int
|
1570
|
+
node for which to look the ancestor
|
1571
|
+
time : int, optional
|
1572
|
+
time at which the ancestor has to be found.
|
1573
|
+
If `None` the ancestor at the first time point
|
1574
|
+
will be found.
|
1575
|
+
|
1576
|
+
Returns
|
1577
|
+
-------
|
1578
|
+
int
|
1579
|
+
the id of the ancestor at time `time`,
|
1580
|
+
`-1` if there is no ancestor.
|
1532
1581
|
"""
|
1533
1582
|
if n not in self.nodes:
|
1534
|
-
return
|
1583
|
+
return -1
|
1535
1584
|
if time is None:
|
1536
1585
|
time = self.t_b
|
1537
1586
|
ancestor = n
|
1538
1587
|
while (
|
1539
|
-
time < self.
|
1588
|
+
time < self._time.get(ancestor, self.t_b - 1)
|
1589
|
+
and self._predecessor[ancestor]
|
1540
1590
|
):
|
1541
|
-
ancestor = self.
|
1542
|
-
|
1591
|
+
ancestor = self._predecessor[ancestor][0]
|
1592
|
+
if self._time.get(ancestor, self.t_b - 1) == time:
|
1593
|
+
return ancestor
|
1594
|
+
else:
|
1595
|
+
return -1
|
1543
1596
|
|
1544
|
-
def get_labelled_ancestor(self, node: int):
|
1545
|
-
"""Finds the first labelled ancestor and returns its ID otherwise returns
|
1597
|
+
def get_labelled_ancestor(self, node: int) -> int:
|
1598
|
+
"""Finds the first labelled ancestor and returns its ID otherwise returns -1
|
1546
1599
|
|
1547
|
-
|
1548
|
-
|
1600
|
+
Parameters
|
1601
|
+
----------
|
1602
|
+
node : int
|
1603
|
+
The id of the node
|
1549
1604
|
|
1550
|
-
Returns
|
1551
|
-
|
1552
|
-
|
1605
|
+
Returns
|
1606
|
+
-------
|
1607
|
+
int
|
1608
|
+
Returns the first ancestor found that has a label otherwise `-1`.
|
1553
1609
|
"""
|
1554
1610
|
if node not in self.nodes:
|
1555
|
-
return
|
1611
|
+
return -1
|
1556
1612
|
ancestor = node
|
1557
1613
|
while (
|
1558
|
-
self.t_b <= self.
|
1614
|
+
self.t_b <= self._time.get(ancestor, self.t_b - 1)
|
1559
1615
|
and ancestor != -1
|
1560
1616
|
):
|
1561
1617
|
if ancestor in self.labels:
|
1562
1618
|
return ancestor
|
1563
|
-
ancestor = self.
|
1564
|
-
return
|
1619
|
+
ancestor = self._predecessor.get(ancestor, [-1])[0]
|
1620
|
+
return -1
|
1621
|
+
|
1622
|
+
def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
|
1623
|
+
"""General purpose function to help with searching the first ancestor that has an attribute.
|
1624
|
+
Similar to get_labeled_ancestor and may make it redundant.
|
1625
|
+
|
1626
|
+
Parameters
|
1627
|
+
----------
|
1628
|
+
node : int
|
1629
|
+
The id of the node
|
1630
|
+
|
1631
|
+
Returns
|
1632
|
+
-------
|
1633
|
+
int
|
1634
|
+
Returns the first ancestor found that has an attribute otherwise `-1`.
|
1635
|
+
"""
|
1636
|
+
attr_dict = self.__getattribute__(attribute)
|
1637
|
+
if not isinstance(attr_dict, dict):
|
1638
|
+
raise ValueError("Please select a dict attribute")
|
1639
|
+
if node not in self.nodes:
|
1640
|
+
return -1
|
1641
|
+
if node in attr_dict:
|
1642
|
+
return node
|
1643
|
+
if node in self.roots:
|
1644
|
+
return -1
|
1645
|
+
ancestor = (node,)
|
1646
|
+
while ancestor and ancestor != [-1]:
|
1647
|
+
ancestor = ancestor[0]
|
1648
|
+
if ancestor in attr_dict:
|
1649
|
+
return ancestor
|
1650
|
+
ancestor = self._predecessor.get(ancestor, [-1])
|
1651
|
+
return -1
|
1565
1652
|
|
1566
1653
|
def unordered_tree_edit_distances_at_time_t(
|
1567
1654
|
self,
|
1568
1655
|
t: int,
|
1569
|
-
end_time: int = None,
|
1570
|
-
style
|
1656
|
+
end_time: int | None = None,
|
1657
|
+
style: (
|
1658
|
+
Literal["simple", "full", "downsampled", "normalized_simple"]
|
1659
|
+
| type[TreeApproximationTemplate]
|
1660
|
+
) = "simple",
|
1571
1661
|
downsample: int = 2,
|
1572
|
-
|
1662
|
+
norm: Literal["max", "sum", None] = "max",
|
1573
1663
|
recompute: bool = False,
|
1574
|
-
) -> dict:
|
1575
|
-
"""
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1664
|
+
) -> dict[tuple[int, int], float]:
|
1665
|
+
"""Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
|
1666
|
+
|
1667
|
+
Parameters
|
1668
|
+
----------
|
1669
|
+
t : int
|
1670
|
+
time to look at
|
1671
|
+
end_time : int
|
1672
|
+
The final time point the comparison algorithm will take into account.
|
1673
|
+
If None all nodes will be taken into account.
|
1674
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1675
|
+
Which tree approximation is going to be used for the comparisons.
|
1676
|
+
downsample : int, default=2
|
1677
|
+
The downsample factor for the downsampled tree approximation.
|
1678
|
+
Used only when `style="downsampled"`.
|
1679
|
+
norm : {"max", "sum"}, default="max"
|
1680
|
+
The normalization method to use.
|
1681
|
+
recompute : bool, default=False
|
1682
|
+
If True, forces to recompute the distances
|
1683
|
+
|
1684
|
+
Returns
|
1685
|
+
-------
|
1686
|
+
dict mapping a tuple of tuple that contains 2 ints to float
|
1687
|
+
a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
|
1590
1688
|
"""
|
1591
1689
|
if not hasattr(self, "uted"):
|
1592
1690
|
self.uted = {}
|
1593
1691
|
elif t in self.uted and not recompute:
|
1594
1692
|
return self.uted[t]
|
1595
1693
|
self.uted[t] = {}
|
1596
|
-
roots = self.
|
1694
|
+
roots = self.nodes_at_t(t=t)
|
1597
1695
|
for n1, n2 in combinations(roots, 2):
|
1598
1696
|
key = tuple(sorted((n1, n2)))
|
1599
1697
|
self.uted[t][key] = self.unordered_tree_edit_distance(
|
@@ -1602,37 +1700,132 @@ class lineageTree(lineageTreeLoaders):
|
|
1602
1700
|
end_time=end_time,
|
1603
1701
|
style=style,
|
1604
1702
|
downsample=downsample,
|
1605
|
-
|
1703
|
+
norm=norm,
|
1606
1704
|
)
|
1607
1705
|
return self.uted[t]
|
1608
1706
|
|
1609
|
-
def
|
1707
|
+
def __calculate_distance_of_sub_tree(
|
1708
|
+
self,
|
1709
|
+
node1: int,
|
1710
|
+
node2: int,
|
1711
|
+
alignment: Alignment,
|
1712
|
+
corres1: dict[int, int],
|
1713
|
+
corres2: dict[int, int],
|
1714
|
+
delta_tmp: Callable,
|
1715
|
+
norm: Callable,
|
1716
|
+
norm1: int | float,
|
1717
|
+
norm2: int | float,
|
1718
|
+
) -> float:
|
1719
|
+
"""Calculates the distance of the subtree of each node matched in a comparison.
|
1720
|
+
DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
|
1721
|
+
TODO ITS BOUND TO CHANGE
|
1722
|
+
Parameters
|
1723
|
+
----------
|
1724
|
+
node1 : int
|
1725
|
+
The root of the first subtree
|
1726
|
+
node2 : int
|
1727
|
+
The root of the second subtree
|
1728
|
+
alignment : Alignment
|
1729
|
+
The alignment of the subtree
|
1730
|
+
corres1 : dict
|
1731
|
+
The correspndance dictionary of the first lineage
|
1732
|
+
corres2 : dict
|
1733
|
+
The correspondance dictionary of the second lineage
|
1734
|
+
delta_tmp : Callable
|
1735
|
+
The delta function for the comparisons
|
1736
|
+
norm : Callable
|
1737
|
+
How should the lineages be normalized
|
1738
|
+
norm1 : int or float
|
1739
|
+
The result of the normalization of the first tree
|
1740
|
+
norm2 : int or float
|
1741
|
+
The result of the normalization of the second tree
|
1742
|
+
|
1743
|
+
Returns
|
1744
|
+
-------
|
1745
|
+
float
|
1746
|
+
The result of the comparison of the subtree
|
1747
|
+
"""
|
1748
|
+
sub_tree_1 = set(self.get_subtree_nodes(node1))
|
1749
|
+
sub_tree_2 = set(self.get_subtree_nodes(node2))
|
1750
|
+
res = 0
|
1751
|
+
for m in alignment:
|
1752
|
+
if (
|
1753
|
+
corres1.get(m._left, -1) in sub_tree_1
|
1754
|
+
or corres2.get(m._right, -1) in sub_tree_2
|
1755
|
+
):
|
1756
|
+
res += delta_tmp(
|
1757
|
+
m._left if m._left != -1 else None,
|
1758
|
+
m._right if m._right != -1 else None,
|
1759
|
+
)
|
1760
|
+
return res / norm([norm1, norm2])
|
1761
|
+
|
1762
|
+
def clear_comparisons(self):
|
1763
|
+
self._comparisons.clear()
|
1764
|
+
|
1765
|
+
def __unordereded_backtrace(
|
1610
1766
|
self,
|
1611
1767
|
n1: int,
|
1612
1768
|
n2: int,
|
1613
|
-
end_time: int = None,
|
1614
|
-
norm:
|
1615
|
-
style
|
1769
|
+
end_time: int | None = None,
|
1770
|
+
norm: Literal["max", "sum", None] = "max",
|
1771
|
+
style: (
|
1772
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1773
|
+
| type[TreeApproximationTemplate]
|
1774
|
+
) = "simple",
|
1616
1775
|
downsample: int = 2,
|
1617
|
-
) ->
|
1776
|
+
) -> dict[
|
1777
|
+
str,
|
1778
|
+
Alignment
|
1779
|
+
| tuple[TreeApproximationTemplate, TreeApproximationTemplate],
|
1780
|
+
]:
|
1618
1781
|
"""
|
1619
|
-
Compute the unordered tree edit
|
1782
|
+
Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
|
1620
1783
|
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
1621
1784
|
cost is given by the function delta (see edist doc for more information).
|
1622
|
-
The distance is normed by the function norm that takes the two list of nodes
|
1623
|
-
spawned by the trees `n1` and `n2`.
|
1624
|
-
|
1625
|
-
Args:
|
1626
|
-
n1 (int): id of the first node to compare
|
1627
|
-
n2 (int): id of the second node to compare
|
1628
|
-
tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
|
1629
|
-
Defaults to "fragmented".
|
1630
|
-
|
1631
|
-
Returns:
|
1632
|
-
(float) The normed unordered tree edit distance
|
1633
|
-
"""
|
1634
1785
|
|
1635
|
-
|
1786
|
+
Parameters
|
1787
|
+
----------
|
1788
|
+
n1 : int
|
1789
|
+
id of the first node to compare
|
1790
|
+
n2 : int
|
1791
|
+
id of the second node to compare
|
1792
|
+
end_time : int
|
1793
|
+
The final time point the comparison algorithm will take into account.
|
1794
|
+
If None all nodes will be taken into account.
|
1795
|
+
norm : {"max", "sum"}, default="max"
|
1796
|
+
The normalization method to use.
|
1797
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1798
|
+
Which tree approximation is going to be used for the comparisons.
|
1799
|
+
downsample : int, default=2
|
1800
|
+
The downsample factor for the downsampled tree approximation.
|
1801
|
+
Used only when `style="downsampled"`.
|
1802
|
+
|
1803
|
+
Returns
|
1804
|
+
-------
|
1805
|
+
dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
|
1806
|
+
- 'alignment'
|
1807
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
|
1808
|
+
- 'trees'
|
1809
|
+
A list of the two trees that have been mapped to each other.
|
1810
|
+
"""
|
1811
|
+
|
1812
|
+
parameters = (
|
1813
|
+
end_time,
|
1814
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1815
|
+
)
|
1816
|
+
n1, n2 = sorted([n1, n2])
|
1817
|
+
self._comparisons.setdefault(parameters, {})
|
1818
|
+
if len(self._comparisons) > 100:
|
1819
|
+
warnings.warn(
|
1820
|
+
"More than 100 comparisons are saved, use clear_comparisons() to delete them.",
|
1821
|
+
stacklevel=2,
|
1822
|
+
)
|
1823
|
+
if isinstance(style, str):
|
1824
|
+
tree = tree_style[style].value
|
1825
|
+
elif issubclass(style, TreeApproximationTemplate):
|
1826
|
+
tree = style
|
1827
|
+
else:
|
1828
|
+
raise ValueError("Please use a valid approximation.")
|
1636
1829
|
tree1 = tree(
|
1637
1830
|
lT=self,
|
1638
1831
|
downsample=downsample,
|
@@ -1661,7 +1854,11 @@ class lineageTree(lineageTreeLoaders):
|
|
1661
1854
|
corres2,
|
1662
1855
|
) = tree2.edist
|
1663
1856
|
if len(nodes1) == len(nodes2) == 0:
|
1664
|
-
|
1857
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1858
|
+
"alignment": (),
|
1859
|
+
"trees": (),
|
1860
|
+
}
|
1861
|
+
return self._comparisons[parameters][(n1, n2)]
|
1665
1862
|
delta_tmp = partial(
|
1666
1863
|
delta,
|
1667
1864
|
corres1=corres1,
|
@@ -1669,126 +1866,538 @@ class lineageTree(lineageTreeLoaders):
|
|
1669
1866
|
times1=times1,
|
1670
1867
|
times2=times2,
|
1671
1868
|
)
|
1672
|
-
|
1673
|
-
|
1674
|
-
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1869
|
+
btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
|
1870
|
+
|
1871
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1872
|
+
"alignment": btrc,
|
1873
|
+
"trees": (tree1, tree2),
|
1874
|
+
}
|
1875
|
+
return self._comparisons[parameters][(n1, n2)]
|
1876
|
+
|
1877
|
+
def plot_tree_distance_graphs(
|
1878
|
+
self,
|
1879
|
+
n1: int,
|
1880
|
+
n2: int,
|
1881
|
+
end_time: int | None = None,
|
1882
|
+
norm: Literal["max", "sum", None] = "max",
|
1883
|
+
style: (
|
1884
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1885
|
+
| type[TreeApproximationTemplate]
|
1886
|
+
) = "simple",
|
1887
|
+
downsample: int = 2,
|
1888
|
+
colormap: str = "cool",
|
1889
|
+
default_color: str = "black",
|
1890
|
+
size: float = 10,
|
1891
|
+
lw: float = 0.3,
|
1892
|
+
ax: list[plt.Axes] | None = None,
|
1893
|
+
) -> tuple[plt.figure, plt.Axes]:
|
1894
|
+
"""
|
1895
|
+
Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
|
1896
|
+
|
1897
|
+
Parameters
|
1898
|
+
----------
|
1899
|
+
n1 : int
|
1900
|
+
id of the first node to compare
|
1901
|
+
n2 : int
|
1902
|
+
id of the second node to compare
|
1903
|
+
end_time : int
|
1904
|
+
The final time point the comparison algorithm will take into account.
|
1905
|
+
If None all nodes will be taken into account.
|
1906
|
+
norm : {"max", "sum"}, default="max"
|
1907
|
+
The normalization method to use.
|
1908
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
1909
|
+
Which tree approximation is going to be used for the comparisons.
|
1910
|
+
downsample : int, default=2
|
1911
|
+
The downsample factor for the downsampled tree approximation.
|
1912
|
+
Used only when `style="downsampled"`.
|
1913
|
+
colormap : str, default="cool"
|
1914
|
+
The colormap used for matched nodes, defaults to "cool"
|
1915
|
+
default_color : str
|
1916
|
+
The color of the unmatched nodes, defaults to "black"
|
1917
|
+
size : float
|
1918
|
+
The size of the nodes, defaults to 10
|
1919
|
+
lw : float
|
1920
|
+
The width of the edges, defaults to 0.3
|
1921
|
+
ax : np.ndarray, optional
|
1922
|
+
The axes used, if not provided another set of axes is produced, defaults to None
|
1923
|
+
|
1924
|
+
Returns
|
1925
|
+
-------
|
1926
|
+
plt.Figure
|
1927
|
+
The figure of the plot
|
1928
|
+
plt.Axes
|
1929
|
+
The axes of the plot
|
1930
|
+
"""
|
1931
|
+
parameters = (
|
1932
|
+
end_time,
|
1933
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1934
|
+
)
|
1935
|
+
n1, n2 = sorted([n1, n2])
|
1936
|
+
self._comparisons.setdefault(parameters, {})
|
1937
|
+
if self._comparisons[parameters].get((n1, n2)):
|
1938
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
1939
|
+
else:
|
1940
|
+
tmp = self.__unordereded_backtrace(
|
1941
|
+
n1, n2, end_time, norm, style, downsample
|
1942
|
+
)
|
1943
|
+
btrc: Alignment = tmp["alignment"]
|
1944
|
+
tree1, tree2 = tmp["trees"]
|
1945
|
+
_, times1 = tree1.tree
|
1946
|
+
_, times2 = tree2.tree
|
1947
|
+
(
|
1948
|
+
*_,
|
1949
|
+
corres1,
|
1950
|
+
) = tree1.edist
|
1951
|
+
(
|
1952
|
+
*_,
|
1953
|
+
corres2,
|
1954
|
+
) = tree2.edist
|
1955
|
+
delta_tmp = partial(
|
1956
|
+
tree1.delta,
|
1957
|
+
corres1=corres1,
|
1958
|
+
corres2=corres2,
|
1959
|
+
times1=times1,
|
1960
|
+
times2=times2,
|
1961
|
+
)
|
1962
|
+
|
1963
|
+
if norm not in self.norm_dict:
|
1964
|
+
raise Warning(
|
1965
|
+
"Select a viable normalization method (max, sum, None)"
|
1966
|
+
)
|
1967
|
+
matched_right = []
|
1968
|
+
matched_left = []
|
1969
|
+
colors = {}
|
1970
|
+
if style not in ("full", "downsampled"):
|
1971
|
+
for m in btrc:
|
1972
|
+
if m._left != -1 and m._right != -1:
|
1973
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
1974
|
+
if len(cyc1) > 1:
|
1975
|
+
node_1, *_, l_node_1 = cyc1
|
1976
|
+
matched_left.append(node_1)
|
1977
|
+
matched_left.append(l_node_1)
|
1978
|
+
elif len(cyc1) == 1:
|
1979
|
+
node_1 = l_node_1 = cyc1.pop()
|
1980
|
+
matched_left.append(node_1)
|
1981
|
+
|
1982
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
1983
|
+
if len(cyc2) > 1:
|
1984
|
+
node_2, *_, l_node_2 = cyc2
|
1985
|
+
matched_right.append(node_2)
|
1986
|
+
matched_right.append(l_node_2)
|
1987
|
+
|
1988
|
+
elif len(cyc2) == 1:
|
1989
|
+
node_2 = l_node_2 = cyc2.pop()
|
1990
|
+
matched_right.append(node_2)
|
1991
|
+
|
1992
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
1993
|
+
node_1,
|
1994
|
+
node_2,
|
1995
|
+
btrc,
|
1996
|
+
corres1,
|
1997
|
+
corres2,
|
1998
|
+
delta_tmp,
|
1999
|
+
self.norm_dict[norm],
|
2000
|
+
tree1.get_norm(node_1),
|
2001
|
+
tree2.get_norm(node_2),
|
2002
|
+
)
|
2003
|
+
colors[node_2] = colors[node_1]
|
2004
|
+
colors[l_node_1] = colors[node_1]
|
2005
|
+
colors[l_node_2] = colors[node_2]
|
2006
|
+
else:
|
2007
|
+
for m in btrc:
|
2008
|
+
if m._left != -1 and m._right != -1:
|
2009
|
+
node_1 = corres1[m._left]
|
2010
|
+
node_2 = corres2[m._right]
|
2011
|
+
|
2012
|
+
if (
|
2013
|
+
self.get_chain_of_node(node_1)[0] == node_1
|
2014
|
+
or self.get_chain_of_node(node_2)[0] == node_2
|
2015
|
+
and (node_1 not in colors or node_2 not in colors)
|
2016
|
+
):
|
2017
|
+
matched_left.append(node_1)
|
2018
|
+
l_node_1 = self.get_chain_of_node(node_1)[-1]
|
2019
|
+
matched_left.append(l_node_1)
|
2020
|
+
matched_right.append(node_2)
|
2021
|
+
l_node_2 = self.get_chain_of_node(node_2)[-1]
|
2022
|
+
matched_right.append(l_node_2)
|
2023
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
2024
|
+
node_1,
|
2025
|
+
node_2,
|
2026
|
+
btrc,
|
2027
|
+
corres1,
|
2028
|
+
corres2,
|
2029
|
+
delta_tmp,
|
2030
|
+
self.norm_dict[norm],
|
2031
|
+
tree1.get_norm(node_1),
|
2032
|
+
tree2.get_norm(node_2),
|
2033
|
+
)
|
2034
|
+
colors[l_node_1] = colors[node_1]
|
2035
|
+
colors[node_2] = colors[node_1]
|
2036
|
+
colors[l_node_2] = colors[node_1]
|
2037
|
+
if ax is None:
|
2038
|
+
fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
|
2039
|
+
cmap = colormaps[colormap]
|
2040
|
+
c_norm = mcolors.Normalize(0, 1)
|
2041
|
+
colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
|
2042
|
+
self.plot_subtree(
|
2043
|
+
self.get_ancestor_at_t(n1),
|
2044
|
+
end_time=end_time,
|
2045
|
+
size=size,
|
2046
|
+
selected_nodes=matched_left,
|
2047
|
+
color_of_nodes=colors,
|
2048
|
+
selected_edges=matched_left,
|
2049
|
+
color_of_edges=colors,
|
2050
|
+
default_color=default_color,
|
2051
|
+
lw=lw,
|
2052
|
+
ax=ax[0],
|
2053
|
+
)
|
2054
|
+
self.plot_subtree(
|
2055
|
+
self.get_ancestor_at_t(n2),
|
2056
|
+
end_time=end_time,
|
2057
|
+
size=size,
|
2058
|
+
selected_nodes=matched_right,
|
2059
|
+
color_of_nodes=colors,
|
2060
|
+
selected_edges=matched_right,
|
2061
|
+
color_of_edges=colors,
|
2062
|
+
default_color=default_color,
|
2063
|
+
lw=lw,
|
2064
|
+
ax=ax[1],
|
2065
|
+
)
|
2066
|
+
return ax[0].get_figure(), ax
|
2067
|
+
|
2068
|
+
def labelled_mappings(
|
2069
|
+
self,
|
2070
|
+
n1: int,
|
2071
|
+
n2: int,
|
2072
|
+
end_time: int | None = None,
|
2073
|
+
norm: Literal["max", "sum", None] = "max",
|
2074
|
+
style: (
|
2075
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2076
|
+
| type[TreeApproximationTemplate]
|
2077
|
+
) = "simple",
|
2078
|
+
downsample: int = 2,
|
2079
|
+
) -> dict[str, list[str]]:
|
2080
|
+
"""
|
2081
|
+
Returns the labels or IDs of all the nodes in the subtrees compared.
|
2082
|
+
|
2083
|
+
|
2084
|
+
Parameters
|
2085
|
+
----------
|
2086
|
+
n1 : int
|
2087
|
+
id of the first node to compare
|
2088
|
+
n2 : int
|
2089
|
+
id of the second node to compare
|
2090
|
+
end_time : int, optional
|
2091
|
+
The final time point the comparison algorithm will take into account.
|
2092
|
+
If None or not provided all nodes will be taken into account.
|
2093
|
+
norm : {"max", "sum"}, default="max"
|
2094
|
+
The normalization method to use, defaults to 'max'.
|
2095
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
2096
|
+
Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
|
2097
|
+
downsample : int, default=2
|
2098
|
+
The downsample factor for the downsampled tree approximation.
|
2099
|
+
Used only when `style="downsampled"`.
|
2100
|
+
|
2101
|
+
Returns
|
2102
|
+
-------
|
2103
|
+
dict mapping str to list of str
|
2104
|
+
- 'matched' The labels of the matched nodes of the alignment.
|
2105
|
+
- 'unmatched' The labels of the unmatched nodes of the alginment.
|
2106
|
+
"""
|
2107
|
+
parameters = (
|
2108
|
+
end_time,
|
2109
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2110
|
+
)
|
2111
|
+
n1, n2 = sorted([n1, n2])
|
2112
|
+
self._comparisons.setdefault(parameters, {})
|
2113
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2114
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2115
|
+
else:
|
2116
|
+
tmp = self.__unordereded_backtrace(
|
2117
|
+
n1, n2, end_time, norm, style, downsample
|
2118
|
+
)
|
2119
|
+
btrc = tmp["alignment"]
|
2120
|
+
tree1, tree2 = tmp["trees"]
|
2121
|
+
|
2122
|
+
(
|
2123
|
+
*_,
|
2124
|
+
corres1,
|
2125
|
+
) = tree1.edist
|
2126
|
+
(
|
2127
|
+
*_,
|
2128
|
+
corres2,
|
2129
|
+
) = tree2.edist
|
2130
|
+
|
2131
|
+
if norm not in self.norm_dict:
|
1678
2132
|
raise Warning(
|
1679
2133
|
"Select a viable normalization method (max, sum, None)"
|
1680
2134
|
)
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
2135
|
+
matched = []
|
2136
|
+
unmatched = []
|
2137
|
+
if style not in ("full", "downsampled"):
|
2138
|
+
for m in btrc:
|
2139
|
+
if m._left != -1 and m._right != -1:
|
2140
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
2141
|
+
if len(cyc1) > 1:
|
2142
|
+
node_1, *_ = cyc1
|
2143
|
+
elif len(cyc1) == 1:
|
2144
|
+
node_1 = cyc1.pop()
|
2145
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
2146
|
+
if len(cyc2) > 1:
|
2147
|
+
node_2, *_ = cyc2
|
2148
|
+
elif len(cyc2) == 1:
|
2149
|
+
node_2 = cyc2.pop()
|
2150
|
+
matched.append(
|
2151
|
+
(
|
2152
|
+
self.labels.get(node_1, node_1),
|
2153
|
+
self.labels.get(node_2, node_2),
|
2154
|
+
)
|
2155
|
+
)
|
2156
|
+
|
2157
|
+
else:
|
2158
|
+
if m._left != -1:
|
2159
|
+
node_1 = self.get_chain_of_node(
|
2160
|
+
corres1.get(m._left, "-")
|
2161
|
+
)[0]
|
2162
|
+
else:
|
2163
|
+
node_1 = self.get_chain_of_node(
|
2164
|
+
corres2.get(m._right, "-")
|
2165
|
+
)[0]
|
2166
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2167
|
+
else:
|
2168
|
+
for m in btrc:
|
2169
|
+
if m._left != -1 and m._right != -1:
|
2170
|
+
node_1 = corres1[m._left]
|
2171
|
+
node_2 = corres2[m._right]
|
2172
|
+
matched.append(
|
2173
|
+
(
|
2174
|
+
self.labels.get(node_1, node_1),
|
2175
|
+
self.labels.get(node_2, node_2),
|
2176
|
+
)
|
2177
|
+
)
|
2178
|
+
else:
|
2179
|
+
if m._left != -1:
|
2180
|
+
node_1 = corres1[m._left]
|
2181
|
+
else:
|
2182
|
+
node_1 = corres2[m._right]
|
2183
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2184
|
+
return {"matched": matched, "unmatched": unmatched}
|
2185
|
+
|
2186
|
+
def unordered_tree_edit_distance(
|
2187
|
+
self,
|
2188
|
+
n1: int,
|
2189
|
+
n2: int,
|
2190
|
+
end_time: int | None = None,
|
2191
|
+
norm: Literal["max", "sum", None] = "max",
|
2192
|
+
style: (
|
2193
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2194
|
+
| type[TreeApproximationTemplate]
|
2195
|
+
) = "simple",
|
2196
|
+
downsample: int = 2,
|
2197
|
+
return_norms: bool = False,
|
2198
|
+
) -> float | tuple[float, tuple[float, float]]:
|
2199
|
+
"""
|
2200
|
+
Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
2201
|
+
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
2202
|
+
cost is given by the function delta (see edist doc for more information).
|
2203
|
+
The distance is normed by the function norm that takes the two list of nodes
|
2204
|
+
spawned by the trees `n1` and `n2`.
|
2205
|
+
|
2206
|
+
Parameters
|
2207
|
+
----------
|
2208
|
+
n1 : int
|
2209
|
+
id of the first node to compare
|
2210
|
+
n2 : int
|
2211
|
+
id of the second node to compare
|
2212
|
+
end_time : int, optional
|
2213
|
+
The final time point the comparison algorithm will take into account.
|
2214
|
+
If None or not provided all nodes will be taken into account.
|
2215
|
+
norm : {"max", "sum"}, default="max"
|
2216
|
+
The normalization method to use, defaults to 'max'.
|
2217
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
2218
|
+
Which tree approximation is going to be used for the comparisons.
|
2219
|
+
downsample : int, default=2
|
2220
|
+
The downsample factor for the downsampled tree approximation.
|
2221
|
+
Used only when `style="downsampled"`.
|
2222
|
+
|
2223
|
+
Returns
|
2224
|
+
-------
|
2225
|
+
float
|
2226
|
+
The normalized unordered tree edit distance between `n1` and `n2`
|
2227
|
+
"""
|
2228
|
+
parameters = (
|
2229
|
+
end_time,
|
2230
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2231
|
+
)
|
2232
|
+
n1, n2 = sorted([n1, n2])
|
2233
|
+
self._comparisons.setdefault(parameters, {})
|
2234
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2235
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2236
|
+
else:
|
2237
|
+
tmp = self.__unordereded_backtrace(
|
2238
|
+
n1, n2, end_time, norm, style, downsample
|
2239
|
+
)
|
2240
|
+
btrc = tmp["alignment"]
|
2241
|
+
tree1, tree2 = tmp["trees"]
|
2242
|
+
_, times1 = tree1.tree
|
2243
|
+
_, times2 = tree2.tree
|
2244
|
+
(
|
2245
|
+
nodes1,
|
2246
|
+
adj1,
|
2247
|
+
corres1,
|
2248
|
+
) = tree1.edist
|
2249
|
+
(
|
2250
|
+
nodes2,
|
2251
|
+
adj2,
|
2252
|
+
corres2,
|
2253
|
+
) = tree2.edist
|
2254
|
+
delta_tmp = partial(
|
2255
|
+
tree1.delta,
|
2256
|
+
corres1=corres1,
|
2257
|
+
corres2=corres2,
|
2258
|
+
times1=times1,
|
2259
|
+
times2=times2,
|
2260
|
+
)
|
2261
|
+
|
2262
|
+
if norm not in self.norm_dict:
|
2263
|
+
raise ValueError(
|
2264
|
+
"Select a viable normalization method (max, sum, None)"
|
2265
|
+
)
|
2266
|
+
cost = btrc.cost(nodes1, nodes2, delta_tmp)
|
2267
|
+
norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
|
2268
|
+
if return_norms:
|
2269
|
+
return cost, norm_values
|
2270
|
+
return cost / self.norm_dict[norm](norm_values)
|
1684
2271
|
|
1685
2272
|
@staticmethod
|
1686
2273
|
def __plot_nodes(
|
1687
|
-
hier
|
1688
|
-
|
2274
|
+
hier: dict,
|
2275
|
+
selected_nodes: set,
|
2276
|
+
color: str | dict | list,
|
2277
|
+
size: int | float,
|
2278
|
+
ax: plt.Axes,
|
2279
|
+
default_color: str = "black",
|
2280
|
+
**kwargs,
|
2281
|
+
) -> None:
|
1689
2282
|
"""
|
1690
2283
|
Private method that plots the nodes of the tree.
|
1691
2284
|
"""
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
)
|
1703
|
-
if selected_nodes.intersection(hier.keys()):
|
1704
|
-
hier_selected = np.array(
|
1705
|
-
[v for k, v in hier.items() if k in selected_nodes]
|
1706
|
-
)
|
1707
|
-
ax.scatter(
|
1708
|
-
*hier_selected.T, s=size, zorder=10, color=color, **kwargs
|
1709
|
-
)
|
2285
|
+
|
2286
|
+
if isinstance(color, dict):
|
2287
|
+
color = [color.get(k, default_color) for k in hier]
|
2288
|
+
elif isinstance(color, str | list):
|
2289
|
+
color = [
|
2290
|
+
color if node in selected_nodes else default_color
|
2291
|
+
for node in hier
|
2292
|
+
]
|
2293
|
+
hier_pos = np.array(list(hier.values()))
|
2294
|
+
ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
|
1710
2295
|
|
1711
2296
|
@staticmethod
|
1712
2297
|
def __plot_edges(
|
1713
|
-
hier,
|
1714
|
-
lnks_tms,
|
1715
|
-
selected_edges,
|
1716
|
-
color,
|
1717
|
-
|
1718
|
-
|
2298
|
+
hier: dict,
|
2299
|
+
lnks_tms: dict,
|
2300
|
+
selected_edges: Iterable,
|
2301
|
+
color: str | dict | list,
|
2302
|
+
lw: float,
|
2303
|
+
ax: plt.Axes,
|
2304
|
+
default_color: str = "black",
|
1719
2305
|
**kwargs,
|
1720
|
-
):
|
2306
|
+
) -> None:
|
1721
2307
|
"""
|
1722
2308
|
Private method that plots the edges of the tree.
|
1723
2309
|
"""
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
1728
|
-
x.extend((hier[succ][0], hier[pred][0], None))
|
1729
|
-
y.extend((hier[succ][1], hier[pred][1], None))
|
1730
|
-
ax.plot(x, y, linewidth=0.3, zorder=0.1, c=default_color, **kwargs)
|
1731
|
-
x, y = [], []
|
2310
|
+
if isinstance(color, dict):
|
2311
|
+
selected_edges = color.keys()
|
2312
|
+
lines = []
|
2313
|
+
c = []
|
1732
2314
|
for pred, succs in lnks_tms["links"].items():
|
1733
|
-
for
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
2315
|
+
for suc in succs:
|
2316
|
+
lines.append(
|
2317
|
+
[
|
2318
|
+
[hier[suc][0], hier[suc][1]],
|
2319
|
+
[hier[pred][0], hier[pred][1]],
|
2320
|
+
]
|
2321
|
+
)
|
2322
|
+
if pred in selected_edges:
|
2323
|
+
if isinstance(color, str | list):
|
2324
|
+
c.append(color)
|
2325
|
+
elif isinstance(color, dict):
|
2326
|
+
c.append(color[pred])
|
2327
|
+
else:
|
2328
|
+
c.append(default_color)
|
2329
|
+
lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
|
2330
|
+
ax.add_collection(lc)
|
1738
2331
|
|
1739
2332
|
def draw_tree_graph(
|
1740
2333
|
self,
|
1741
|
-
hier,
|
1742
|
-
lnks_tms,
|
1743
|
-
selected_nodes=None,
|
1744
|
-
selected_edges=None,
|
1745
|
-
color_of_nodes="magenta",
|
1746
|
-
color_of_edges=
|
1747
|
-
size=10,
|
1748
|
-
|
1749
|
-
|
2334
|
+
hier: dict[int, tuple[int, int]],
|
2335
|
+
lnks_tms: dict[str, dict[int, list | int]],
|
2336
|
+
selected_nodes: list | set | None = None,
|
2337
|
+
selected_edges: list | set | None = None,
|
2338
|
+
color_of_nodes: str | dict = "magenta",
|
2339
|
+
color_of_edges: str | dict = "magenta",
|
2340
|
+
size: int | float = 10,
|
2341
|
+
lw: float = 0.3,
|
2342
|
+
ax: plt.Axes | None = None,
|
2343
|
+
default_color: str = "black",
|
1750
2344
|
**kwargs,
|
1751
|
-
):
|
2345
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
1752
2346
|
"""Function to plot the tree graph.
|
1753
2347
|
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
2348
|
+
Parameters
|
2349
|
+
----------
|
2350
|
+
hier : dict mapping int to tuple of int
|
2351
|
+
Dictionary that contains the positions of all nodes.
|
2352
|
+
lnks_tms : dict mapping string to dictionaries mapping int to list or int
|
2353
|
+
- 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
|
2354
|
+
- 'times' : contains the distance between the start and the end of each chain.
|
2355
|
+
selected_nodes : list or set, optional
|
2356
|
+
Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
|
2357
|
+
selected_edges : list or set, optional
|
2358
|
+
Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
|
2359
|
+
color_of_nodes : str, default="magenta"
|
2360
|
+
Color of selected nodes
|
2361
|
+
color_of_edges : str, default="magenta"
|
2362
|
+
Color of selected edges
|
2363
|
+
size : int, default=10
|
2364
|
+
Size of the nodes, defaults to 10
|
2365
|
+
lw : float, default=0.3
|
2366
|
+
The width of the edges of the tree graph, defaults to 0.3
|
2367
|
+
ax : plt.Axes, optional
|
2368
|
+
Plot the graph on existing ax. If not provided or None a new ax is going to be created.
|
2369
|
+
default_color : str, default="black"
|
2370
|
+
Default color of nodes
|
2371
|
+
|
2372
|
+
Returns
|
2373
|
+
-------
|
2374
|
+
plt.Figure
|
2375
|
+
The matplotlib figure
|
2376
|
+
plt.Axes
|
2377
|
+
The matplotlib ax
|
1770
2378
|
"""
|
1771
2379
|
if selected_nodes is None:
|
1772
2380
|
selected_nodes = []
|
1773
2381
|
if selected_edges is None:
|
1774
2382
|
selected_edges = []
|
1775
2383
|
if ax is None:
|
1776
|
-
|
2384
|
+
_, ax = plt.subplots()
|
1777
2385
|
else:
|
1778
2386
|
ax.clear()
|
1779
2387
|
if not isinstance(selected_nodes, set):
|
1780
2388
|
selected_nodes = set(selected_nodes)
|
1781
2389
|
if not isinstance(selected_edges, set):
|
1782
2390
|
selected_edges = set(selected_edges)
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
2391
|
+
if 0 < size:
|
2392
|
+
self.__plot_nodes(
|
2393
|
+
hier,
|
2394
|
+
selected_nodes,
|
2395
|
+
color_of_nodes,
|
2396
|
+
size=size,
|
2397
|
+
ax=ax,
|
2398
|
+
default_color=default_color,
|
2399
|
+
**kwargs,
|
2400
|
+
)
|
1792
2401
|
if not color_of_edges:
|
1793
2402
|
color_of_edges = color_of_nodes
|
1794
2403
|
self.__plot_edges(
|
@@ -1796,62 +2405,107 @@ class lineageTree(lineageTreeLoaders):
|
|
1796
2405
|
lnks_tms,
|
1797
2406
|
selected_edges,
|
1798
2407
|
color_of_edges,
|
2408
|
+
lw,
|
1799
2409
|
ax,
|
1800
2410
|
default_color=default_color,
|
1801
2411
|
**kwargs,
|
1802
2412
|
)
|
2413
|
+
ax.autoscale()
|
2414
|
+
plt.draw()
|
1803
2415
|
ax.get_yaxis().set_visible(False)
|
1804
2416
|
ax.get_xaxis().set_visible(False)
|
1805
2417
|
return ax.get_figure(), ax
|
1806
2418
|
|
1807
|
-
def
|
2419
|
+
def _create_dict_of_plots(
|
2420
|
+
self,
|
2421
|
+
node: int | Iterable[int] | None = None,
|
2422
|
+
start_time: int | None = None,
|
2423
|
+
end_time: int | None = None,
|
2424
|
+
) -> dict[int, dict]:
|
1808
2425
|
"""Generates a dictionary of graphs where the keys are the index of the graph and
|
1809
|
-
the values are the graphs themselves which are produced by
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
2426
|
+
the values are the graphs themselves which are produced by `create_links_and_chains`
|
2427
|
+
|
2428
|
+
Parameters
|
2429
|
+
----------
|
2430
|
+
node : int or Iterable of int, optional
|
2431
|
+
The id of the node/nodes to produce the simple graphs, if not provided or None will
|
2432
|
+
calculate the dicts for every root that starts before 'start_time'
|
2433
|
+
start_time : int, optional
|
2434
|
+
Important only if there are no nodes it will produce the graph of every
|
2435
|
+
root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
|
2436
|
+
end_time : int, optional
|
2437
|
+
The last timepoint to be considered, if not provided or None the last timepoint of the
|
2438
|
+
dataset (t_e) is considered.
|
2439
|
+
|
2440
|
+
Returns
|
2441
|
+
-------
|
2442
|
+
dict mapping int to dict
|
2443
|
+
The keys are just index values 0-n and the values are the graphs produced.
|
1818
2444
|
"""
|
1819
2445
|
if start_time is None:
|
1820
2446
|
start_time = self.t_b
|
2447
|
+
if end_time is None:
|
2448
|
+
end_time = self.t_e
|
1821
2449
|
if node is None:
|
1822
2450
|
mothers = [
|
1823
|
-
root for root in self.roots if self.
|
2451
|
+
root for root in self.roots if self._time[root] <= start_time
|
1824
2452
|
]
|
2453
|
+
elif isinstance(node, Iterable):
|
2454
|
+
mothers = node
|
1825
2455
|
else:
|
1826
|
-
mothers =
|
2456
|
+
mothers = [node]
|
1827
2457
|
return {
|
1828
|
-
i:
|
2458
|
+
i: create_links_and_chains(self, mother, end_time=end_time)
|
1829
2459
|
for i, mother in enumerate(mothers)
|
1830
2460
|
}
|
1831
2461
|
|
1832
2462
|
def plot_all_lineages(
|
1833
2463
|
self,
|
1834
|
-
nodes: list = None,
|
1835
|
-
last_time_point_to_consider: int = None,
|
1836
|
-
nrows=2,
|
1837
|
-
figsize=(10, 15),
|
1838
|
-
dpi=100,
|
1839
|
-
fontsize=15,
|
1840
|
-
axes=None,
|
1841
|
-
vert_gap=1,
|
2464
|
+
nodes: list | None = None,
|
2465
|
+
last_time_point_to_consider: int | None = None,
|
2466
|
+
nrows: int = 2,
|
2467
|
+
figsize: tuple[int, int] = (10, 15),
|
2468
|
+
dpi: int = 100,
|
2469
|
+
fontsize: int = 15,
|
2470
|
+
axes: plt.Axes | None = None,
|
2471
|
+
vert_gap: int = 1,
|
1842
2472
|
**kwargs,
|
1843
|
-
):
|
2473
|
+
) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
|
1844
2474
|
"""Plots all lineages.
|
1845
2475
|
|
1846
|
-
|
1847
|
-
|
1848
|
-
|
1849
|
-
|
1850
|
-
|
1851
|
-
|
1852
|
-
|
2476
|
+
Parameters
|
2477
|
+
----------
|
2478
|
+
nodes : list, optional
|
2479
|
+
The nodes spawning the graphs to be plotted.
|
2480
|
+
last_time_point_to_consider : int, optional
|
2481
|
+
Which timepoints and upwards are the graphs to be plotted.
|
2482
|
+
For example if start_time is 10, then all trees that begin
|
2483
|
+
on tp 10 or before are calculated. Defaults to None, where
|
2484
|
+
it will plot all the roots that exist on `self.t_b`.
|
2485
|
+
nrows : int, default=2
|
2486
|
+
How many rows of plots should be printed.
|
2487
|
+
figsize : tuple, default=(10, 15)
|
2488
|
+
The size of the figure.
|
2489
|
+
dpi : int, default=100
|
2490
|
+
The dpi of the figure.
|
2491
|
+
fontsize : int, default=15
|
2492
|
+
The fontsize of the labels.
|
2493
|
+
axes : plt.Axes, optional
|
2494
|
+
The axes to plot the graphs on. If None or not provided new axes are going to be created.
|
2495
|
+
vert_gap : int, default=1
|
2496
|
+
space between the nodes, defaults to 1
|
2497
|
+
**kwargs:
|
2498
|
+
kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
|
2499
|
+
|
2500
|
+
Returns
|
2501
|
+
-------
|
2502
|
+
plt.Figure
|
2503
|
+
The figure
|
2504
|
+
plt.Axes
|
2505
|
+
The axes
|
2506
|
+
dict of plt.Axes to int
|
2507
|
+
A dictionary that maps the axes to the root of the tree.
|
1853
2508
|
"""
|
1854
|
-
|
1855
2509
|
nrows = int(nrows)
|
1856
2510
|
if last_time_point_to_consider is None:
|
1857
2511
|
last_time_point_to_consider = self.t_b
|
@@ -1860,17 +2514,18 @@ class lineageTree(lineageTreeLoaders):
|
|
1860
2514
|
raise Warning("Number of rows has to be at least 1")
|
1861
2515
|
if nodes:
|
1862
2516
|
graphs = {
|
1863
|
-
i: self.
|
2517
|
+
i: self._create_dict_of_plots(node)
|
2518
|
+
for i, node in enumerate(nodes)
|
1864
2519
|
}
|
1865
2520
|
else:
|
1866
|
-
graphs = self.
|
2521
|
+
graphs = self._create_dict_of_plots(
|
1867
2522
|
start_time=last_time_point_to_consider
|
1868
2523
|
)
|
1869
2524
|
pos = {
|
1870
2525
|
i: hierarchical_pos(
|
1871
2526
|
g,
|
1872
2527
|
g["root"],
|
1873
|
-
ycenter=-int(self.
|
2528
|
+
ycenter=-int(self._time[g["root"]]),
|
1874
2529
|
vert_gap=vert_gap,
|
1875
2530
|
)
|
1876
2531
|
for i, g in graphs.items()
|
@@ -1925,135 +2580,147 @@ class lineageTree(lineageTreeLoaders):
|
|
1925
2580
|
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
|
1926
2581
|
return axes.flatten()[0].get_figure(), axes, ax2root
|
1927
2582
|
|
1928
|
-
def
|
2583
|
+
def plot_subtree(
|
1929
2584
|
self,
|
1930
|
-
node,
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
2585
|
+
node: int,
|
2586
|
+
end_time: int | None = None,
|
2587
|
+
figsize: tuple[int, int] = (4, 7),
|
2588
|
+
dpi: int = 150,
|
2589
|
+
vert_gap: int = 2,
|
2590
|
+
selected_nodes: list | None = None,
|
2591
|
+
selected_edges: list | None = None,
|
2592
|
+
color_of_nodes: str | dict = "magenta",
|
2593
|
+
color_of_edges: str | dict = "magenta",
|
2594
|
+
size: int | float = 10,
|
2595
|
+
lw: float = 0.1,
|
2596
|
+
default_color: str = "black",
|
2597
|
+
ax: plt.Axes | None = None,
|
2598
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
1937
2599
|
"""Plots the subtree spawn by a node.
|
1938
2600
|
|
1939
|
-
|
1940
|
-
|
1941
|
-
|
1942
|
-
|
1943
|
-
|
2601
|
+
Parameters
|
2602
|
+
----------
|
2603
|
+
node : int
|
2604
|
+
The id of the node that is going to be plotted.
|
2605
|
+
end_time : int, optional
|
2606
|
+
The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
|
2607
|
+
figsize : tuple of 2 ints, default=(4,7)
|
2608
|
+
The size of the figure, deafults to (4,7)
|
2609
|
+
vert_gap : int, default=2
|
2610
|
+
The verical gap of a node when it divides, defaults to 2.
|
2611
|
+
dpi : int, default=150
|
2612
|
+
The dpi of the figure, defaults to 150
|
2613
|
+
selected_nodes : list, optional
|
2614
|
+
The nodes that are selected by the user to be colored in a different color, defaults to None
|
2615
|
+
selected_edges : list, optional
|
2616
|
+
The edges that are selected by the user to be colored in a different color, defaults to None
|
2617
|
+
color_of_nodes : str, default="magenta"
|
2618
|
+
The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
|
2619
|
+
color_of_edges : str, default="magenta"
|
2620
|
+
The color of the edges to be colored, except the default colored ones, defaults to "magenta"
|
2621
|
+
size : int, default=10
|
2622
|
+
The size of the nodes, defaults to 10
|
2623
|
+
lw : float, default=0.1
|
2624
|
+
The widthe of the edges of the tree graph, defaults to 0.1
|
2625
|
+
default_color : str, default="black"
|
2626
|
+
The default color of nodes and edges, defaults to "black"
|
2627
|
+
ax : plt.Axes, optional
|
2628
|
+
The ax where the plot is going to be applied, if not provided or None new axes will be created.
|
2629
|
+
|
2630
|
+
Returns
|
2631
|
+
-------
|
2632
|
+
plt.Figure
|
2633
|
+
The matplotlib figure
|
2634
|
+
plt.Axes
|
2635
|
+
The matplotlib axes
|
2636
|
+
|
2637
|
+
Raises
|
2638
|
+
------
|
2639
|
+
Warning
|
2640
|
+
If more than one nodes are received
|
2641
|
+
"""
|
2642
|
+
graph = self._create_dict_of_plots(node, end_time=end_time)
|
1944
2643
|
if len(graph) > 1:
|
1945
|
-
raise Warning(
|
2644
|
+
raise Warning(
|
2645
|
+
"Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
|
2646
|
+
)
|
1946
2647
|
graph = graph[0]
|
1947
2648
|
if not ax:
|
1948
|
-
|
1949
|
-
nrows=1, ncols=1, figsize=figsize, dpi=dpi
|
1950
|
-
)
|
2649
|
+
_, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
|
1951
2650
|
self.draw_tree_graph(
|
1952
2651
|
hier=hierarchical_pos(
|
1953
2652
|
graph,
|
1954
2653
|
graph["root"],
|
1955
2654
|
vert_gap=vert_gap,
|
1956
|
-
ycenter=-int(self.
|
2655
|
+
ycenter=-int(self._time[node]),
|
1957
2656
|
),
|
2657
|
+
selected_edges=selected_edges,
|
2658
|
+
selected_nodes=selected_nodes,
|
2659
|
+
color_of_edges=color_of_edges,
|
2660
|
+
color_of_nodes=color_of_nodes,
|
2661
|
+
default_color=default_color,
|
2662
|
+
size=size,
|
2663
|
+
lw=lw,
|
1958
2664
|
lnks_tms=graph,
|
1959
2665
|
ax=ax,
|
1960
2666
|
)
|
1961
2667
|
return ax.get_figure(), ax
|
1962
2668
|
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1966
|
-
|
1967
|
-
|
1968
|
-
# t1 ([int, ]): list of node ids for the first track
|
1969
|
-
# t2 ([int, ]): list of node ids for the second track
|
1970
|
-
# w (int): maximum wapring allowed (default infinite),
|
1971
|
-
# if w=1 then the DTW is the distance between t1 and t2
|
1972
|
-
# start_delay (int): maximum number of time points that can be
|
1973
|
-
# skipped at the beginning of the track
|
1974
|
-
# end_delay (int): minimum number of time points that can be
|
1975
|
-
# skipped at the beginning of the track
|
1976
|
-
# metric (str): str or callable, optional The distance metric to use.
|
1977
|
-
# Default='euclidean'. Refer to the documentation for
|
1978
|
-
# scipy.spatial.distance.cdist. Some examples:
|
1979
|
-
# 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
|
1980
|
-
# 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
|
1981
|
-
# 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
|
1982
|
-
# 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
|
1983
|
-
# 'wminkowski', 'yule'
|
1984
|
-
# **kwargs (dict): Extra arguments to `metric`: refer to each metric
|
1985
|
-
# documentation in scipy.spatial.distance (optional)
|
1986
|
-
|
1987
|
-
# Returns:
|
1988
|
-
# float: the dynamic time warping distance between the two tracks
|
1989
|
-
# """
|
1990
|
-
# from scipy.sparse import
|
1991
|
-
# pos_t1 = [self.pos[ti] for ti in t1]
|
1992
|
-
# pos_t2 = [self.pos[ti] for ti in t2]
|
1993
|
-
# distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
|
1994
|
-
|
1995
|
-
# c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
|
1996
|
-
|
1997
|
-
# d = np.zeros(c.shape)
|
1998
|
-
# d[0, 0] = c[0, 0]
|
1999
|
-
# n, m = c.shape
|
2000
|
-
# for i in range(1, n):
|
2001
|
-
# d[i, 0] = d[i-1, 0] + c[i, 0]
|
2002
|
-
# for j in range(1, m):
|
2003
|
-
# d[0, j] = d[0, j-1] + c[0, j]
|
2004
|
-
# for i in range(1, n):
|
2005
|
-
# for j in range(1, m):
|
2006
|
-
# d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
|
2007
|
-
# return d[-1, -1], d
|
2008
|
-
|
2009
|
-
def __getitem__(self, item):
|
2010
|
-
if isinstance(item, str):
|
2011
|
-
return self.__dict__[item]
|
2012
|
-
elif np.issubdtype(type(item), np.integer):
|
2013
|
-
return self.successor.get(item, [])
|
2014
|
-
else:
|
2015
|
-
raise KeyError(
|
2016
|
-
"Only integer or string are valid key for lineageTree"
|
2017
|
-
)
|
2018
|
-
|
2019
|
-
def get_cells_at_t_from_root(self, r: [int, list], t: int = None) -> list:
|
2669
|
+
def nodes_at_t(
|
2670
|
+
self,
|
2671
|
+
t: int,
|
2672
|
+
r: int | Iterable[int] | None = None,
|
2673
|
+
) -> list:
|
2020
2674
|
"""
|
2021
|
-
Returns the list of
|
2675
|
+
Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
|
2022
2676
|
|
2023
|
-
|
2024
|
-
|
2025
|
-
|
2026
|
-
|
2677
|
+
Parameters
|
2678
|
+
----------
|
2679
|
+
t : int
|
2680
|
+
target time, if `None` goes as far as possible
|
2681
|
+
r : int or Iterable of int, optional
|
2682
|
+
id or list of ids of the spawning node
|
2027
2683
|
|
2028
|
-
|
2029
|
-
|
2684
|
+
Returns
|
2685
|
+
-------
|
2686
|
+
list
|
2687
|
+
list of nodes at time `t` spawned by `r`
|
2030
2688
|
"""
|
2031
|
-
if not
|
2689
|
+
if not r and r != 0:
|
2690
|
+
r = {root for root in self.roots if self.time[root] <= t}
|
2691
|
+
if isinstance(r, int):
|
2032
2692
|
r = [r]
|
2693
|
+
if t is None:
|
2694
|
+
t = self.t_e
|
2033
2695
|
to_do = list(r)
|
2034
2696
|
final_nodes = []
|
2035
2697
|
while len(to_do) > 0:
|
2036
2698
|
curr = to_do.pop()
|
2037
|
-
for _next in self[curr]:
|
2038
|
-
if self.
|
2699
|
+
for _next in self._successor[curr]:
|
2700
|
+
if self._time[_next] < t:
|
2039
2701
|
to_do.append(_next)
|
2040
|
-
elif self.
|
2702
|
+
elif self._time[_next] == t:
|
2041
2703
|
final_nodes.append(_next)
|
2042
2704
|
if not final_nodes:
|
2043
2705
|
return list(r)
|
2044
2706
|
return final_nodes
|
2045
2707
|
|
2046
2708
|
@staticmethod
|
2047
|
-
def __calculate_diag_line(dist_mat: np.ndarray) ->
|
2709
|
+
def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
|
2048
2710
|
"""
|
2049
2711
|
Calculate the line that centers the band w.
|
2050
2712
|
|
2051
|
-
|
2052
|
-
|
2713
|
+
Parameters
|
2714
|
+
----------
|
2715
|
+
dist_mat : np.ndarray
|
2716
|
+
distance matrix obtained by the function calculate_dtw
|
2053
2717
|
|
2054
|
-
|
2055
|
-
|
2056
|
-
|
2718
|
+
Returns
|
2719
|
+
-------
|
2720
|
+
float
|
2721
|
+
The slope of the curve
|
2722
|
+
float
|
2723
|
+
The intercept of the curve
|
2057
2724
|
"""
|
2058
2725
|
i, j = dist_mat.shape
|
2059
2726
|
x1 = max(0, i - j) / 2
|
@@ -2073,22 +2740,33 @@ class lineageTree(lineageTreeLoaders):
|
|
2073
2740
|
fast: bool = False,
|
2074
2741
|
w: int = 0,
|
2075
2742
|
centered_band: bool = True,
|
2076
|
-
) ->
|
2743
|
+
) -> tuple[list[int], np.ndarray, float]:
|
2077
2744
|
"""
|
2078
2745
|
Find DTW minimum cost between two series using dynamic programming.
|
2079
2746
|
|
2080
|
-
|
2081
|
-
|
2082
|
-
|
2083
|
-
|
2084
|
-
|
2085
|
-
|
2086
|
-
|
2087
|
-
|
2088
|
-
|
2089
|
-
|
2090
|
-
|
2091
|
-
|
2747
|
+
Parameters
|
2748
|
+
----------
|
2749
|
+
dist_mat : np.ndarray
|
2750
|
+
distance matrix obtained by the function calculate_dtw
|
2751
|
+
start_d : int, default=0
|
2752
|
+
start delay
|
2753
|
+
back_d : int, default=0
|
2754
|
+
end delay
|
2755
|
+
fast : bool, default=False
|
2756
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2757
|
+
w : int, default=0
|
2758
|
+
window constrain
|
2759
|
+
centered_band : bool, default=True
|
2760
|
+
if `True`, the band will be centered around the diagonal
|
2761
|
+
|
2762
|
+
Returns
|
2763
|
+
-------
|
2764
|
+
tuple of tuples of int
|
2765
|
+
Aligment path
|
2766
|
+
np.ndarray
|
2767
|
+
cost matrix
|
2768
|
+
float
|
2769
|
+
optimal cost
|
2092
2770
|
"""
|
2093
2771
|
N, M = dist_mat.shape
|
2094
2772
|
w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
|
@@ -2209,7 +2887,6 @@ class lineageTree(lineageTreeLoaders):
|
|
2209
2887
|
|
2210
2888
|
# special reflection case
|
2211
2889
|
if np.linalg.det(R) < 0:
|
2212
|
-
# print("det(R) < R, reflection detected!, correcting for it ...")
|
2213
2890
|
Vt[2, :] *= -1
|
2214
2891
|
R = Vt.T @ U.T
|
2215
2892
|
|
@@ -2218,52 +2895,59 @@ class lineageTree(lineageTreeLoaders):
|
|
2218
2895
|
return R, t
|
2219
2896
|
|
2220
2897
|
def __interpolate(
|
2221
|
-
self,
|
2222
|
-
) ->
|
2898
|
+
self, chain1: list, chain2: list, threshold: int
|
2899
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
2223
2900
|
"""
|
2224
2901
|
Interpolate two series that have different lengths
|
2225
2902
|
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2229
|
-
|
2230
|
-
|
2231
|
-
|
2232
|
-
|
2233
|
-
|
2903
|
+
Parameters
|
2904
|
+
----------
|
2905
|
+
chain1 : list of int
|
2906
|
+
list of nodes of the first chain to compare
|
2907
|
+
chain2 : list of int
|
2908
|
+
list of nodes of the second chain to compare
|
2909
|
+
threshold : int
|
2910
|
+
set a maximum number of points a chain can have
|
2911
|
+
|
2912
|
+
Returns
|
2913
|
+
-------
|
2914
|
+
list of np.ndarray
|
2915
|
+
`x`, `y`, `z` postions for `chain1`
|
2916
|
+
list of np.ndarray
|
2917
|
+
`x`, `y`, `z` postions for `chain2`
|
2234
2918
|
"""
|
2235
2919
|
inter1_pos = []
|
2236
2920
|
inter2_pos = []
|
2237
2921
|
|
2238
|
-
|
2239
|
-
|
2922
|
+
chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
|
2923
|
+
chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
|
2240
2924
|
|
2241
|
-
# Both
|
2242
|
-
if len(
|
2243
|
-
len(
|
2925
|
+
# Both chains have the same length and size below the threshold - nothing is done
|
2926
|
+
if len(chain1) == len(chain2) and (
|
2927
|
+
len(chain1) <= threshold or len(chain2) <= threshold
|
2244
2928
|
):
|
2245
|
-
return
|
2246
|
-
# Both
|
2247
|
-
elif len(
|
2929
|
+
return chain1_pos, chain2_pos
|
2930
|
+
# Both chains have the same length but one or more sizes are above the threshold
|
2931
|
+
elif len(chain1) > threshold or len(chain2) > threshold:
|
2248
2932
|
sampling = threshold
|
2249
|
-
#
|
2933
|
+
# chains have different lengths and the sizes are below the threshold
|
2250
2934
|
else:
|
2251
|
-
sampling = max(len(
|
2935
|
+
sampling = max(len(chain1), len(chain2))
|
2252
2936
|
|
2253
2937
|
for pos in range(3):
|
2254
|
-
|
2255
|
-
np.linspace(0, 1, len(
|
2256
|
-
|
2938
|
+
chain1_interp = InterpolatedUnivariateSpline(
|
2939
|
+
np.linspace(0, 1, len(chain1_pos[:, pos])),
|
2940
|
+
chain1_pos[:, pos],
|
2257
2941
|
k=1,
|
2258
2942
|
)
|
2259
|
-
inter1_pos.append(
|
2943
|
+
inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
|
2260
2944
|
|
2261
|
-
|
2262
|
-
np.linspace(0, 1, len(
|
2263
|
-
|
2945
|
+
chain2_interp = InterpolatedUnivariateSpline(
|
2946
|
+
np.linspace(0, 1, len(chain2_pos[:, pos])),
|
2947
|
+
chain2_pos[:, pos],
|
2264
2948
|
k=1,
|
2265
2949
|
)
|
2266
|
-
inter2_pos.append(
|
2950
|
+
inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
|
2267
2951
|
|
2268
2952
|
return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
|
2269
2953
|
|
@@ -2279,46 +2963,66 @@ class lineageTree(lineageTreeLoaders):
|
|
2279
2963
|
w: int = 0,
|
2280
2964
|
centered_band: bool = True,
|
2281
2965
|
cost_mat_p: bool = False,
|
2282
|
-
) -> (
|
2283
|
-
|
2284
|
-
|
2285
|
-
|
2286
|
-
Args:
|
2287
|
-
nodes1 (int): node to compare distance
|
2288
|
-
nodes2 (int): node to compare distance
|
2289
|
-
threshold: set a maximum number of points a track can have
|
2290
|
-
regist (boolean): Rotate and translate trajectories
|
2291
|
-
start_d (int): start delay
|
2292
|
-
back_d (int): end delay
|
2293
|
-
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
2294
|
-
w (int): window size
|
2295
|
-
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
2296
|
-
cost_mat_p (boolean): True if print the not normalized cost matrix
|
2297
|
-
|
2298
|
-
Returns:
|
2299
|
-
(float) DTW distance
|
2300
|
-
(tuple of tuples) Aligment path
|
2301
|
-
(matrix) Cost matrix
|
2302
|
-
(list of lists) pos_cycle1: rotated and translated trajectories positions
|
2303
|
-
(list of lists) pos_cycle2: rotated and translated trajectories positions
|
2966
|
+
) -> (
|
2967
|
+
tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
|
2968
|
+
| tuple[float, tuple]
|
2969
|
+
):
|
2304
2970
|
"""
|
2305
|
-
|
2306
|
-
|
2307
|
-
|
2308
|
-
|
2309
|
-
|
2971
|
+
Calculate DTW distance between two chains
|
2972
|
+
|
2973
|
+
Parameters
|
2974
|
+
----------
|
2975
|
+
nodes1 : int
|
2976
|
+
node to compare distance
|
2977
|
+
nodes2 : int
|
2978
|
+
node to compare distance
|
2979
|
+
threshold : int, default=1000
|
2980
|
+
set a maximum number of points a chain can have
|
2981
|
+
regist : bool, default=True
|
2982
|
+
Rotate and translate trajectories
|
2983
|
+
start_d : int, default=0
|
2984
|
+
start delay
|
2985
|
+
back_d : int, default=0
|
2986
|
+
end delay
|
2987
|
+
fast : bool, default=False
|
2988
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2989
|
+
w : int, default=0
|
2990
|
+
window size
|
2991
|
+
centered_band : bool, default=True
|
2992
|
+
when running the fast algorithm, `True` if the windown is centered
|
2993
|
+
cost_mat_p : bool, default=False
|
2994
|
+
True if print the not normalized cost matrix
|
2995
|
+
|
2996
|
+
Returns
|
2997
|
+
-------
|
2998
|
+
float
|
2999
|
+
DTW distance
|
3000
|
+
tuple of tuples
|
3001
|
+
Aligment path
|
3002
|
+
matrix
|
3003
|
+
Cost matrix
|
3004
|
+
list of lists
|
3005
|
+
rotated and translated trajectories positions
|
3006
|
+
list of lists
|
3007
|
+
rotated and translated trajectories positions
|
3008
|
+
"""
|
3009
|
+
nodes1_chain = self.get_chain_of_node(nodes1)
|
3010
|
+
nodes2_chain = self.get_chain_of_node(nodes2)
|
3011
|
+
|
3012
|
+
interp_chain1, interp_chain2 = self.__interpolate(
|
3013
|
+
nodes1_chain, nodes2_chain, threshold
|
2310
3014
|
)
|
2311
3015
|
|
2312
|
-
|
2313
|
-
|
3016
|
+
pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
|
3017
|
+
pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
|
2314
3018
|
|
2315
3019
|
if regist:
|
2316
3020
|
R, t = self.__rigid_transform_3D(
|
2317
|
-
np.transpose(
|
3021
|
+
np.transpose(interp_chain1), np.transpose(interp_chain2)
|
2318
3022
|
)
|
2319
|
-
|
3023
|
+
pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
|
2320
3024
|
|
2321
|
-
dist_mat = distance.cdist(
|
3025
|
+
dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
|
2322
3026
|
|
2323
3027
|
path, cost_mat, final_cost = self.__dp(
|
2324
3028
|
dist_mat,
|
@@ -2331,7 +3035,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2331
3035
|
cost = final_cost / len(path)
|
2332
3036
|
|
2333
3037
|
if cost_mat_p:
|
2334
|
-
return cost, path, cost_mat,
|
3038
|
+
return cost, path, cost_mat, pos_chain1, pos_chain2
|
2335
3039
|
else:
|
2336
3040
|
return cost, path
|
2337
3041
|
|
@@ -2346,24 +3050,39 @@ class lineageTree(lineageTreeLoaders):
|
|
2346
3050
|
fast: bool = False,
|
2347
3051
|
w: int = 0,
|
2348
3052
|
centered_band: bool = True,
|
2349
|
-
) ->
|
2350
|
-
"""
|
2351
|
-
Plot DTW cost matrix between two
|
2352
|
-
|
2353
|
-
|
2354
|
-
|
2355
|
-
|
2356
|
-
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2360
|
-
|
2361
|
-
|
2362
|
-
|
2363
|
-
|
2364
|
-
|
2365
|
-
|
2366
|
-
|
3053
|
+
) -> tuple[float, plt.Figure]:
|
3054
|
+
"""
|
3055
|
+
Plot DTW cost matrix between two chains in heatmap format
|
3056
|
+
|
3057
|
+
Parameters
|
3058
|
+
----------
|
3059
|
+
nodes1 : int
|
3060
|
+
node to compare distance
|
3061
|
+
nodes2 : int
|
3062
|
+
node to compare distance
|
3063
|
+
threshold : int, default=1000
|
3064
|
+
set a maximum number of points a chain can have
|
3065
|
+
regist : bool, default=True
|
3066
|
+
Rotate and translate trajectories
|
3067
|
+
start_d : int, default=0
|
3068
|
+
start delay
|
3069
|
+
back_d : int, default=0
|
3070
|
+
end delay
|
3071
|
+
fast : bool, default=False
|
3072
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
3073
|
+
w : int, default=0
|
3074
|
+
window size
|
3075
|
+
centered_band : bool, default=True
|
3076
|
+
when running the fast algorithm, `True` if the windown is centered
|
3077
|
+
|
3078
|
+
Returns
|
3079
|
+
-------
|
3080
|
+
float
|
3081
|
+
DTW distance
|
3082
|
+
plt.Figure
|
3083
|
+
Heatmap of cost matrix with opitimal path
|
3084
|
+
"""
|
3085
|
+
cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
|
2367
3086
|
nodes1,
|
2368
3087
|
nodes2,
|
2369
3088
|
threshold,
|
@@ -2385,32 +3104,32 @@ class lineageTree(lineageTreeLoaders):
|
|
2385
3104
|
ax.set_title("Heatmap of DTW Cost Matrix")
|
2386
3105
|
ax.set_xlabel("Tree 1")
|
2387
3106
|
ax.set_ylabel("tree 2")
|
2388
|
-
x_path, y_path = zip(*path)
|
3107
|
+
x_path, y_path = zip(*path, strict=True)
|
2389
3108
|
ax.plot(y_path, x_path, color="black")
|
2390
3109
|
|
2391
3110
|
return cost, fig
|
2392
3111
|
|
2393
3112
|
@staticmethod
|
2394
3113
|
def __plot_2d(
|
2395
|
-
|
2396
|
-
|
2397
|
-
nodes1,
|
2398
|
-
nodes2,
|
2399
|
-
ax,
|
2400
|
-
x_idx,
|
2401
|
-
y_idx,
|
2402
|
-
x_label,
|
2403
|
-
y_label,
|
2404
|
-
):
|
3114
|
+
pos_chain1: np.ndarray,
|
3115
|
+
pos_chain2: np.ndarray,
|
3116
|
+
nodes1: list[int],
|
3117
|
+
nodes2: list[int],
|
3118
|
+
ax: plt.Axes,
|
3119
|
+
x_idx: list[int],
|
3120
|
+
y_idx: list[int],
|
3121
|
+
x_label: str,
|
3122
|
+
y_label: str,
|
3123
|
+
) -> None:
|
2405
3124
|
ax.plot(
|
2406
|
-
|
2407
|
-
|
3125
|
+
pos_chain1[:, x_idx],
|
3126
|
+
pos_chain1[:, y_idx],
|
2408
3127
|
"-",
|
2409
3128
|
label=f"root = {nodes1}",
|
2410
3129
|
)
|
2411
3130
|
ax.plot(
|
2412
|
-
|
2413
|
-
|
3131
|
+
pos_chain2[:, x_idx],
|
3132
|
+
pos_chain2[:, y_idx],
|
2414
3133
|
"-",
|
2415
3134
|
label=f"root = {nodes2}",
|
2416
3135
|
)
|
@@ -2428,40 +3147,55 @@ class lineageTree(lineageTreeLoaders):
|
|
2428
3147
|
fast: bool = False,
|
2429
3148
|
w: int = 0,
|
2430
3149
|
centered_band: bool = True,
|
2431
|
-
projection:
|
3150
|
+
projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
|
2432
3151
|
alig: bool = False,
|
2433
|
-
) ->
|
2434
|
-
"""
|
2435
|
-
Plots DTW trajectories aligment between two
|
2436
|
-
|
2437
|
-
|
2438
|
-
|
2439
|
-
|
2440
|
-
|
2441
|
-
|
2442
|
-
|
2443
|
-
|
2444
|
-
|
2445
|
-
|
2446
|
-
|
2447
|
-
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
|
2453
|
-
|
2454
|
-
|
2455
|
-
|
2456
|
-
|
2457
|
-
|
3152
|
+
) -> tuple[float, plt.Figure]:
|
3153
|
+
"""
|
3154
|
+
Plots DTW trajectories aligment between two chains in 2D or 3D
|
3155
|
+
|
3156
|
+
Parameters
|
3157
|
+
----------
|
3158
|
+
nodes1 : int
|
3159
|
+
node to compare distance
|
3160
|
+
nodes2 : int
|
3161
|
+
node to compare distance
|
3162
|
+
threshold : int, default=1000
|
3163
|
+
set a maximum number of points a chain can have
|
3164
|
+
regist : bool, default=True
|
3165
|
+
Rotate and translate trajectories
|
3166
|
+
start_d : int, default=0
|
3167
|
+
start delay
|
3168
|
+
back_d : int, default=0
|
3169
|
+
end delay
|
3170
|
+
w : int, default=0
|
3171
|
+
window size
|
3172
|
+
fast : bool, default=False
|
3173
|
+
True if the user wants to run the fast algorithm with window restrains
|
3174
|
+
centered_band : bool, default=True
|
3175
|
+
if running the fast algorithm, True if the windown is centered
|
3176
|
+
projection : {"3d", "xy", "xz", "yz", "pca"}, optional
|
3177
|
+
specify which 2D to plot ->
|
3178
|
+
"3d" : for the 3d visualization
|
3179
|
+
"xy" or None (default) : 2D projection of axis x and y
|
3180
|
+
"xz" : 2D projection of axis x and z
|
3181
|
+
"yz" : 2D projection of axis y and z
|
3182
|
+
"pca" : PCA projection
|
3183
|
+
alig : bool
|
3184
|
+
True to show alignment on plot
|
3185
|
+
|
3186
|
+
Returns
|
3187
|
+
-------
|
3188
|
+
float
|
3189
|
+
DTW distance
|
3190
|
+
figure
|
3191
|
+
Trajectories Plot
|
2458
3192
|
"""
|
2459
3193
|
(
|
2460
3194
|
distance,
|
2461
3195
|
alignment,
|
2462
3196
|
cost_mat,
|
2463
|
-
|
2464
|
-
|
3197
|
+
pos_chain1,
|
3198
|
+
pos_chain2,
|
2465
3199
|
) = self.calculate_dtw(
|
2466
3200
|
nodes1,
|
2467
3201
|
nodes2,
|
@@ -2484,16 +3218,16 @@ class lineageTree(lineageTreeLoaders):
|
|
2484
3218
|
|
2485
3219
|
if projection == "3d":
|
2486
3220
|
ax.plot(
|
2487
|
-
|
2488
|
-
|
2489
|
-
|
3221
|
+
pos_chain1[:, 0],
|
3222
|
+
pos_chain1[:, 1],
|
3223
|
+
pos_chain1[:, 2],
|
2490
3224
|
"-",
|
2491
3225
|
label=f"root = {nodes1}",
|
2492
3226
|
)
|
2493
3227
|
ax.plot(
|
2494
|
-
|
2495
|
-
|
2496
|
-
|
3228
|
+
pos_chain2[:, 0],
|
3229
|
+
pos_chain2[:, 1],
|
3230
|
+
pos_chain2[:, 2],
|
2497
3231
|
"-",
|
2498
3232
|
label=f"root = {nodes2}",
|
2499
3233
|
)
|
@@ -2503,8 +3237,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2503
3237
|
else:
|
2504
3238
|
if projection == "xy" or projection == "yx" or projection is None:
|
2505
3239
|
self.__plot_2d(
|
2506
|
-
|
2507
|
-
|
3240
|
+
pos_chain1,
|
3241
|
+
pos_chain2,
|
2508
3242
|
nodes1,
|
2509
3243
|
nodes2,
|
2510
3244
|
ax,
|
@@ -2515,8 +3249,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2515
3249
|
)
|
2516
3250
|
elif projection == "xz" or projection == "zx":
|
2517
3251
|
self.__plot_2d(
|
2518
|
-
|
2519
|
-
|
3252
|
+
pos_chain1,
|
3253
|
+
pos_chain2,
|
2520
3254
|
nodes1,
|
2521
3255
|
nodes2,
|
2522
3256
|
ax,
|
@@ -2527,8 +3261,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2527
3261
|
)
|
2528
3262
|
elif projection == "yz" or projection == "zy":
|
2529
3263
|
self.__plot_2d(
|
2530
|
-
|
2531
|
-
|
3264
|
+
pos_chain1,
|
3265
|
+
pos_chain2,
|
2532
3266
|
nodes1,
|
2533
3267
|
nodes2,
|
2534
3268
|
ax,
|
@@ -2542,24 +3276,25 @@ class lineageTree(lineageTreeLoaders):
|
|
2542
3276
|
from sklearn.decomposition import PCA
|
2543
3277
|
except ImportError:
|
2544
3278
|
Warning(
|
2545
|
-
"scikit-learn is not installed, the PCA orientation cannot be used.
|
3279
|
+
"scikit-learn is not installed, the PCA orientation cannot be used."
|
3280
|
+
"You can install scikit-learn with pip install"
|
2546
3281
|
)
|
2547
3282
|
|
2548
3283
|
# Apply PCA
|
2549
3284
|
pca = PCA(n_components=2)
|
2550
|
-
pca.fit(np.vstack([
|
2551
|
-
|
2552
|
-
|
3285
|
+
pca.fit(np.vstack([pos_chain1, pos_chain2]))
|
3286
|
+
pos_chain1_2d = pca.transform(pos_chain1)
|
3287
|
+
pos_chain2_2d = pca.transform(pos_chain2)
|
2553
3288
|
|
2554
3289
|
ax.plot(
|
2555
|
-
|
2556
|
-
|
3290
|
+
pos_chain1_2d[:, 0],
|
3291
|
+
pos_chain1_2d[:, 1],
|
2557
3292
|
"-",
|
2558
3293
|
label=f"root = {nodes1}",
|
2559
3294
|
)
|
2560
3295
|
ax.plot(
|
2561
|
-
|
2562
|
-
|
3296
|
+
pos_chain2_2d[:, 0],
|
3297
|
+
pos_chain2_2d[:, 1],
|
2563
3298
|
"-",
|
2564
3299
|
label=f"root = {nodes2}",
|
2565
3300
|
)
|
@@ -2588,7 +3323,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2588
3323
|
'pca' : PCA projection"""
|
2589
3324
|
)
|
2590
3325
|
|
2591
|
-
connections = [[
|
3326
|
+
connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
|
2592
3327
|
|
2593
3328
|
for connection in connections:
|
2594
3329
|
xyz1 = connection[0]
|
@@ -2616,108 +3351,174 @@ class lineageTree(lineageTreeLoaders):
|
|
2616
3351
|
|
2617
3352
|
return distance, fig
|
2618
3353
|
|
2619
|
-
def first_labelling(self):
|
2620
|
-
self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
|
2621
|
-
|
2622
3354
|
def __init__(
|
2623
3355
|
self,
|
2624
|
-
|
2625
|
-
|
2626
|
-
|
2627
|
-
|
2628
|
-
|
2629
|
-
|
2630
|
-
|
2631
|
-
|
2632
|
-
|
2633
|
-
reorder: bool = False,
|
2634
|
-
xml_attributes: tuple = None,
|
2635
|
-
name: str = None,
|
2636
|
-
time_resolution: Union[int, None] = None,
|
3356
|
+
*,
|
3357
|
+
successor: dict[int, Sequence] | None = None,
|
3358
|
+
predecessor: dict[int, int | Sequence] | None = None,
|
3359
|
+
time: dict[int, int] | None = None,
|
3360
|
+
starting_time: int | None = None,
|
3361
|
+
pos: dict[int, Iterable] | None = None,
|
3362
|
+
name: str | None = None,
|
3363
|
+
root_leaf_value: Sequence | None = None,
|
3364
|
+
**kwargs,
|
2637
3365
|
):
|
2638
|
-
"""
|
2639
|
-
|
2640
|
-
|
2641
|
-
|
2642
|
-
|
2643
|
-
|
2644
|
-
|
2645
|
-
|
2646
|
-
|
2647
|
-
|
2648
|
-
|
2649
|
-
|
2650
|
-
|
2651
|
-
|
2652
|
-
|
2653
|
-
|
2654
|
-
|
2655
|
-
|
2656
|
-
|
2657
|
-
|
2658
|
-
|
2659
|
-
|
2660
|
-
|
2661
|
-
|
3366
|
+
"""Create a lineageTree object from minimal information, without reading from a file.
|
3367
|
+
Either `successor` or `predecessor` should be specified.
|
3368
|
+
|
3369
|
+
Parameters
|
3370
|
+
----------
|
3371
|
+
successor : dict mapping int to Iterable
|
3372
|
+
Dictionary assigning nodes to their successors.
|
3373
|
+
predecessor : dict mapping int to int or Iterable
|
3374
|
+
Dictionary assigning nodes to their predecessors.
|
3375
|
+
time : dict mapping int to int, optional
|
3376
|
+
Dictionary assigning nodes to the time point they were recorded to.
|
3377
|
+
Defaults to None, in which case all times are set to `starting_time`.
|
3378
|
+
starting_time : int, optional
|
3379
|
+
Starting time of the lineage tree. Defaults to 0.
|
3380
|
+
pos : dict mapping int to Iterable, optional
|
3381
|
+
Dictionary assigning nodes to their positions. Defaults to None.
|
3382
|
+
name : str, optional
|
3383
|
+
Name of the lineage tree. Defaults to None.
|
3384
|
+
root_leaf_value : Iterable, optional
|
3385
|
+
Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
|
3386
|
+
Defaults are `[None, (), [], set()]`.
|
3387
|
+
**kwargs:
|
3388
|
+
Supported keyword arguments are dictionaries assigning nodes to any custom property.
|
3389
|
+
The property must be specified for every node, and named differently from lineageTree's own attributes.
|
3390
|
+
"""
|
3391
|
+
self.__version__ = importlib.metadata.version("LineageTree")
|
3392
|
+
self.name = str(name) if name is not None else None
|
3393
|
+
if successor is not None and predecessor is not None:
|
3394
|
+
raise ValueError(
|
3395
|
+
"You cannot have both successors and predecessors."
|
3396
|
+
)
|
2662
3397
|
|
2663
|
-
|
2664
|
-
|
2665
|
-
|
2666
|
-
|
2667
|
-
|
2668
|
-
|
2669
|
-
|
2670
|
-
|
2671
|
-
|
2672
|
-
|
2673
|
-
self.
|
2674
|
-
|
2675
|
-
|
2676
|
-
|
2677
|
-
|
2678
|
-
|
2679
|
-
|
2680
|
-
|
2681
|
-
|
2682
|
-
|
2683
|
-
|
2684
|
-
|
2685
|
-
|
2686
|
-
|
2687
|
-
|
2688
|
-
|
2689
|
-
|
2690
|
-
|
2691
|
-
|
2692
|
-
|
2693
|
-
|
2694
|
-
|
2695
|
-
|
2696
|
-
|
2697
|
-
|
3398
|
+
if root_leaf_value is None:
|
3399
|
+
root_leaf_value = [None, (), [], set()]
|
3400
|
+
elif not isinstance(root_leaf_value, Iterable):
|
3401
|
+
raise TypeError(
|
3402
|
+
f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
|
3403
|
+
)
|
3404
|
+
elif len(root_leaf_value) < 1:
|
3405
|
+
raise ValueError(
|
3406
|
+
"root_leaf_value should have at least one element."
|
3407
|
+
)
|
3408
|
+
self._successor = {}
|
3409
|
+
self._predecessor = {}
|
3410
|
+
if successor is not None:
|
3411
|
+
for pred, succs in successor.items():
|
3412
|
+
if succs in root_leaf_value:
|
3413
|
+
self._successor[pred] = ()
|
3414
|
+
else:
|
3415
|
+
if not isinstance(succs, Iterable):
|
3416
|
+
raise TypeError(
|
3417
|
+
f"Successors should be Iterable, got {type(succs)}."
|
3418
|
+
)
|
3419
|
+
if len(succs) == 0:
|
3420
|
+
raise ValueError(
|
3421
|
+
f"{succs} was not declared as a leaf but was found as a successor.\n"
|
3422
|
+
"Please lift the ambiguity."
|
3423
|
+
)
|
3424
|
+
self._successor[pred] = tuple(succs)
|
3425
|
+
for succ in succs:
|
3426
|
+
if succ in self._predecessor:
|
3427
|
+
raise ValueError(
|
3428
|
+
"Node can have at most one predecessor."
|
3429
|
+
)
|
3430
|
+
self._predecessor[succ] = (pred,)
|
3431
|
+
elif predecessor is not None:
|
3432
|
+
for succ, pred in predecessor.items():
|
3433
|
+
if pred in root_leaf_value:
|
3434
|
+
self._predecessor[succ] = ()
|
3435
|
+
else:
|
3436
|
+
if isinstance(pred, Sequence):
|
3437
|
+
if len(pred) == 0:
|
3438
|
+
raise ValueError(
|
3439
|
+
f"{pred} was not declared as a leaf but was found as a successor.\n"
|
3440
|
+
"Please lift the ambiguity."
|
3441
|
+
)
|
3442
|
+
if 1 < len(pred):
|
3443
|
+
raise ValueError(
|
3444
|
+
"Node can have at most one predecessor."
|
3445
|
+
)
|
3446
|
+
pred = pred[0]
|
3447
|
+
self._predecessor[succ] = (pred,)
|
3448
|
+
self._successor.setdefault(pred, ())
|
3449
|
+
self._successor[pred] += (succ,)
|
3450
|
+
for root in set(self._successor).difference(self._predecessor):
|
3451
|
+
self._predecessor[root] = ()
|
3452
|
+
for leaf in set(self._predecessor).difference(self._successor):
|
3453
|
+
self._successor[leaf] = ()
|
3454
|
+
|
3455
|
+
if self.__check_for_cycles():
|
3456
|
+
raise ValueError(
|
3457
|
+
"Cycles were found in the tree, there should not be any."
|
3458
|
+
)
|
3459
|
+
|
3460
|
+
if pos is None:
|
3461
|
+
self.pos = {}
|
3462
|
+
else:
|
3463
|
+
if self.nodes.difference(pos) != set():
|
3464
|
+
raise ValueError("Please provide the position of all nodes.")
|
3465
|
+
self.pos = {
|
3466
|
+
node: np.array(position) for node, position in pos.items()
|
3467
|
+
}
|
3468
|
+
|
3469
|
+
if time is None:
|
3470
|
+
if starting_time is None:
|
3471
|
+
starting_time = 0
|
3472
|
+
if not isinstance(starting_time, int):
|
3473
|
+
warnings.warn(
|
3474
|
+
f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
|
3475
|
+
stacklevel=2,
|
2698
3476
|
)
|
2699
|
-
|
2700
|
-
|
2701
|
-
|
3477
|
+
self._time = {node: starting_time for node in self.roots}
|
3478
|
+
queue = list(self.roots)
|
3479
|
+
for node in queue:
|
3480
|
+
for succ in self._successor[node]:
|
3481
|
+
self._time[succ] = self._time[node] + 1
|
3482
|
+
queue.append(succ)
|
3483
|
+
else:
|
3484
|
+
if starting_time is not None:
|
3485
|
+
warnings.warn(
|
3486
|
+
"Both `time` and `starting_time` were provided, `starting_time` was ignored.",
|
3487
|
+
stacklevel=2,
|
3488
|
+
)
|
3489
|
+
self._time = {n: int(time[n]) for n in self.nodes}
|
3490
|
+
if self._time != time:
|
3491
|
+
if len(self._time) != len(time):
|
3492
|
+
warnings.warn(
|
3493
|
+
"The provided `time` dictionary had keys that were not nodes. "
|
3494
|
+
"They have been removed",
|
3495
|
+
stacklevel=2,
|
3496
|
+
)
|
2702
3497
|
else:
|
2703
|
-
|
2704
|
-
|
2705
|
-
|
2706
|
-
|
2707
|
-
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
3498
|
+
warnings.warn(
|
3499
|
+
"The provided `time` dictionary had values that were not `int`. "
|
3500
|
+
"These values have been truncated and converted to `int`",
|
3501
|
+
stacklevel=2,
|
3502
|
+
)
|
3503
|
+
if self.nodes.symmetric_difference(self._time) != set():
|
3504
|
+
raise ValueError(
|
3505
|
+
"Please provide the time of all nodes and only existing nodes."
|
3506
|
+
)
|
3507
|
+
if not all(
|
3508
|
+
self._time[node] < self._time[s]
|
3509
|
+
for node, succ in self._successor.items()
|
3510
|
+
for s in succ
|
3511
|
+
):
|
3512
|
+
raise ValueError(
|
3513
|
+
"Provided times are not strictly increasing. Setting times to default."
|
3514
|
+
)
|
3515
|
+
# custom properties
|
3516
|
+
for name, d in kwargs.items():
|
3517
|
+
if name in self.__dict__:
|
3518
|
+
warnings.warn(
|
3519
|
+
f"Attribute name {name} is reserved.", stacklevel=2
|
3520
|
+
)
|
3521
|
+
continue
|
3522
|
+
setattr(self, name, d)
|
3523
|
+
if not hasattr(self, "_comparisons"):
|
3524
|
+
self._comparisons = {}
|