LineageTree 1.7.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- LineageTree/__init__.py +27 -2
- LineageTree/legacy/export_csv.py +70 -0
- LineageTree/legacy/to_lineajea.py +30 -0
- LineageTree/legacy/to_motile.py +36 -0
- LineageTree/lineageTree.py +2294 -1618
- LineageTree/lineageTreeManager.py +759 -55
- LineageTree/loaders.py +947 -695
- LineageTree/test/test_lineageTree.py +634 -0
- LineageTree/test/test_uted.py +233 -0
- LineageTree/tree_approximation.py +488 -0
- LineageTree/utils.py +106 -108
- {LineageTree-1.7.0.dist-info → lineagetree-2.0.1.dist-info}/METADATA +31 -34
- lineagetree-2.0.1.dist-info/RECORD +16 -0
- {LineageTree-1.7.0.dist-info → lineagetree-2.0.1.dist-info}/WHEEL +1 -1
- LineageTree/tree_styles.py +0 -322
- LineageTree-1.7.0.dist-info/RECORD +0 -11
- {LineageTree-1.7.0.dist-info → lineagetree-2.0.1.dist-info/licenses}/LICENSE +0 -0
- {LineageTree-1.7.0.dist-info → lineagetree-2.0.1.dist-info}/top_level.txt +0 -0
LineageTree/lineageTree.py
CHANGED
@@ -2,625 +2,538 @@
|
|
2
2
|
# This file is subject to the terms and conditions defined in
|
3
3
|
# file 'LICENCE', which is part of this source code package.
|
4
4
|
# Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import importlib.metadata
|
5
9
|
import os
|
6
10
|
import pickle as pkl
|
7
11
|
import struct
|
8
12
|
import warnings
|
9
|
-
from collections.abc import Iterable
|
10
|
-
from functools import partial
|
13
|
+
from collections.abc import Callable, Iterable, Sequence
|
14
|
+
from functools import partial, wraps
|
11
15
|
from itertools import combinations
|
12
16
|
from numbers import Number
|
13
|
-
from
|
14
|
-
from typing import
|
15
|
-
|
16
|
-
|
17
|
-
from .tree_styles import tree_style
|
18
|
-
|
19
|
-
try:
|
20
|
-
from edist import uted
|
21
|
-
except ImportError:
|
22
|
-
warnings.warn(
|
23
|
-
"No edist installed therefore you will not be able to compute the tree edit distance.",
|
24
|
-
stacklevel=2,
|
25
|
-
)
|
17
|
+
from types import MappingProxyType
|
18
|
+
from typing import TYPE_CHECKING, Literal
|
19
|
+
|
20
|
+
import matplotlib.colors as mcolors
|
26
21
|
import matplotlib.pyplot as plt
|
27
22
|
import numpy as np
|
23
|
+
import svgwrite
|
24
|
+
from edist import uted
|
25
|
+
from matplotlib import colormaps
|
26
|
+
from matplotlib.collections import LineCollection
|
27
|
+
from packaging.version import Version
|
28
28
|
from scipy.interpolate import InterpolatedUnivariateSpline
|
29
|
-
from scipy.
|
30
|
-
from scipy.spatial import
|
29
|
+
from scipy.sparse import dok_array
|
30
|
+
from scipy.spatial import Delaunay, KDTree, distance
|
31
31
|
|
32
|
+
from .tree_approximation import TreeApproximationTemplate, tree_style
|
32
33
|
from .utils import (
|
33
|
-
|
34
|
+
convert_style_to_number,
|
35
|
+
create_links_and_chains,
|
34
36
|
hierarchical_pos,
|
35
37
|
)
|
36
38
|
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from edist.alignment import Alignment
|
41
|
+
|
42
|
+
|
43
|
+
class dynamic_property(property):
|
44
|
+
def __init__(
|
45
|
+
self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
|
46
|
+
):
|
47
|
+
super().__init__(fget, fset, fdel, doc)
|
48
|
+
self.protected_name = protected_name
|
49
|
+
|
50
|
+
def __set_name__(self, owner, name):
|
51
|
+
self.name = name
|
52
|
+
if self.protected_name is None:
|
53
|
+
self.protected_name = f"_{name}"
|
54
|
+
if not hasattr(owner, "_protected_dynamic_properties"):
|
55
|
+
owner._protected_dynamic_properties = []
|
56
|
+
owner._protected_dynamic_properties.append(self.protected_name)
|
57
|
+
if not hasattr(owner, "_dynamic_properties"):
|
58
|
+
owner._dynamic_properties = []
|
59
|
+
owner._dynamic_properties += [name, self.protected_name]
|
60
|
+
setattr(owner, self.protected_name, None)
|
61
|
+
|
62
|
+
def __get__(self, instance, owner):
|
63
|
+
if instance is None:
|
64
|
+
return self
|
65
|
+
instance._has_been_reset = False
|
66
|
+
if getattr(instance, self.protected_name) is None:
|
67
|
+
value = super().__get__(instance, owner)
|
68
|
+
setattr(instance, self.protected_name, value)
|
69
|
+
return value
|
70
|
+
else:
|
71
|
+
return getattr(instance, self.protected_name)
|
72
|
+
|
73
|
+
|
74
|
+
class lineageTree:
|
75
|
+
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
|
76
|
+
|
77
|
+
def modifier(wrapped_func):
|
78
|
+
@wraps(wrapped_func)
|
79
|
+
def raising_flag(self, *args, **kwargs):
|
80
|
+
should_reset = (
|
81
|
+
not hasattr(self, "_has_been_reset")
|
82
|
+
or not self._has_been_reset
|
83
|
+
)
|
84
|
+
out_func = wrapped_func(self, *args, **kwargs)
|
85
|
+
if should_reset:
|
86
|
+
for prop in self._protected_dynamic_properties:
|
87
|
+
self.__dict__[prop] = None
|
88
|
+
self._has_been_reset = True
|
89
|
+
return out_func
|
90
|
+
|
91
|
+
return raising_flag
|
92
|
+
|
93
|
+
def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
|
94
|
+
"""Check if the connected component of a given node `n` has a cycle.
|
95
|
+
|
96
|
+
Returns
|
97
|
+
-------
|
98
|
+
bool
|
99
|
+
True if the tree has cycles, False otherwise.
|
100
|
+
set of int
|
101
|
+
The set of nodes that have been checked.
|
102
|
+
"""
|
103
|
+
to_do = [n]
|
104
|
+
no_cycle = True
|
105
|
+
already_done = set()
|
106
|
+
while to_do and no_cycle:
|
107
|
+
current = to_do.pop(-1)
|
108
|
+
if current not in already_done:
|
109
|
+
already_done.add(current)
|
110
|
+
else:
|
111
|
+
no_cycle = False
|
112
|
+
to_do.extend(self._successor[current])
|
113
|
+
to_do = list(self._predecessor[n])
|
114
|
+
while to_do and no_cycle:
|
115
|
+
current = to_do.pop(-1)
|
116
|
+
if current not in already_done:
|
117
|
+
already_done.add(current)
|
118
|
+
else:
|
119
|
+
no_cycle = False
|
120
|
+
to_do.extend(self._predecessor[current])
|
121
|
+
return not no_cycle, already_done
|
122
|
+
|
123
|
+
def __check_for_cycles(self) -> bool:
|
124
|
+
"""Check if the tree has cycles.
|
125
|
+
|
126
|
+
Returns
|
127
|
+
-------
|
128
|
+
bool
|
129
|
+
True if the tree has cycles, False otherwise.
|
130
|
+
"""
|
131
|
+
to_do = set(self.nodes)
|
132
|
+
found_cycle = False
|
133
|
+
while to_do and not found_cycle:
|
134
|
+
current = to_do.pop()
|
135
|
+
found_cycle, done = self.__check_cc_cycles(current)
|
136
|
+
to_do.difference_update(done)
|
137
|
+
return found_cycle
|
37
138
|
|
38
|
-
|
39
|
-
def __eq__(self, other):
|
139
|
+
def __eq__(self, other) -> bool:
|
40
140
|
if isinstance(other, lineageTree):
|
41
|
-
return
|
42
|
-
|
141
|
+
return (
|
142
|
+
other._successor == self._successor
|
143
|
+
and other._predecessor == self._predecessor
|
144
|
+
and other._time == self._time
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
return False
|
43
148
|
|
44
|
-
def get_next_id(self):
|
45
|
-
"""Computes the next authorized id.
|
149
|
+
def get_next_id(self) -> int:
|
150
|
+
"""Computes the next authorized id and assign it.
|
46
151
|
|
47
|
-
Returns
|
48
|
-
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
int
|
155
|
+
next authorized id
|
49
156
|
"""
|
50
|
-
if self.max_id == -1 and self.nodes:
|
51
|
-
self.max_id = max(self.nodes)
|
52
|
-
if self.next_id == []:
|
157
|
+
if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
|
158
|
+
self.max_id = max(self.nodes) if len(self.nodes) else 0
|
159
|
+
if not hasattr(self, "next_id") or self.next_id == []:
|
53
160
|
self.max_id += 1
|
54
161
|
return self.max_id
|
55
162
|
else:
|
56
163
|
return self.next_id.pop()
|
57
164
|
|
58
|
-
def complete_lineage(self, nodes: Union[int, set] = None):
|
59
|
-
"""Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
|
60
|
-
for tree edit distance algorithms.
|
61
|
-
|
62
|
-
Args:
|
63
|
-
nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
|
64
|
-
"""
|
65
|
-
if nodes is None:
|
66
|
-
nodes = set(self.roots)
|
67
|
-
elif isinstance(nodes, int):
|
68
|
-
nodes = {nodes}
|
69
|
-
for node in nodes:
|
70
|
-
sub = set(self.get_sub_tree(node))
|
71
|
-
specific_leaves = sub.intersection(self.leaves)
|
72
|
-
for leaf in specific_leaves:
|
73
|
-
self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
|
74
|
-
|
75
165
|
###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
|
76
|
-
|
166
|
+
@modifier
|
167
|
+
def add_chain(
|
77
168
|
self,
|
78
|
-
|
169
|
+
node: int,
|
79
170
|
length: int,
|
80
|
-
|
81
|
-
pos:
|
82
|
-
|
83
|
-
|
84
|
-
"""Adds a branch of specific length to a node either as a successor or as a predecessor.
|
171
|
+
downstream: bool,
|
172
|
+
pos: Callable | None = None,
|
173
|
+
) -> int:
|
174
|
+
"""Adds a chain of specific length to a node either as a successor or as a predecessor.
|
85
175
|
If it is placed on top of a tree all the nodes will move timepoints #length down.
|
86
176
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
177
|
+
Parameters
|
178
|
+
----------
|
179
|
+
node : int
|
180
|
+
Id of the successor (predecessor if `downstream==False`)
|
181
|
+
length : int
|
182
|
+
The length of the new chain.
|
183
|
+
downstream : bool, default=True
|
184
|
+
If `True` will create a chain that goes forwards in time otherwise backwards.
|
185
|
+
pos : np.ndarray, optional
|
186
|
+
The new position of the chain. Defaults to None.
|
187
|
+
|
188
|
+
Returns
|
189
|
+
-------
|
190
|
+
int
|
191
|
+
Id of the first node of the sublineage.
|
95
192
|
"""
|
96
193
|
if length == 0:
|
97
|
-
return
|
98
|
-
if
|
99
|
-
raise
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
new_times = {
|
106
|
-
node: self.time[node] + length for node in nodes_to_move
|
107
|
-
}
|
108
|
-
for node in nodes_to_move:
|
109
|
-
old_time = self.time[node]
|
110
|
-
self.time_nodes[old_time].remove(node)
|
111
|
-
self.time_nodes.setdefault(old_time + length, set()).add(
|
112
|
-
node
|
113
|
-
)
|
114
|
-
self.time.update(new_times)
|
115
|
-
for t in range(length - 1, -1, -1):
|
116
|
-
_next = self.add_node(
|
117
|
-
time + t,
|
118
|
-
succ=pred,
|
119
|
-
pos=self.pos[original],
|
120
|
-
reverse=True,
|
121
|
-
)
|
122
|
-
pred = _next
|
123
|
-
else:
|
124
|
-
for t in range(length):
|
125
|
-
_next = self.add_node(
|
126
|
-
time - t,
|
127
|
-
succ=pred,
|
128
|
-
pos=self.pos[original],
|
129
|
-
reverse=True,
|
130
|
-
)
|
131
|
-
pred = _next
|
194
|
+
return node
|
195
|
+
if length < 1:
|
196
|
+
raise ValueError("Length cannot be <1")
|
197
|
+
if downstream:
|
198
|
+
for _ in range(int(length)):
|
199
|
+
old_node = node
|
200
|
+
node = self._add_node(pred=[old_node])
|
201
|
+
self._time[node] = self._time[old_node] + 1
|
132
202
|
else:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
reverse=False,
|
203
|
+
if self._predecessor[node]:
|
204
|
+
raise Warning("The node already has a predecessor.")
|
205
|
+
if self._time[node] - length < self.t_b:
|
206
|
+
raise Warning(
|
207
|
+
"A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
|
139
208
|
)
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
if
|
166
|
-
|
167
|
-
|
168
|
-
label_of_root = self.labels.get(cycle[0], cycle[0])
|
169
|
-
self.labels[cycle[0]] = f"L-Split {label_of_root}"
|
170
|
-
new_tr = self.add_branch(
|
171
|
-
new_lT, len(cycle) + 1, move_timepoints=False
|
172
|
-
)
|
173
|
-
self.roots.add(new_tr)
|
174
|
-
self.labels[new_tr] = f"R-Split {label_of_root}"
|
175
|
-
return new_tr
|
176
|
-
else:
|
177
|
-
raise Warning("No division of the branch")
|
178
|
-
|
179
|
-
def fuse_lineage_tree(
|
180
|
-
self,
|
181
|
-
l1_root: int,
|
182
|
-
l2_root: int,
|
183
|
-
length_l1: int = 0,
|
184
|
-
length_l2: int = 0,
|
185
|
-
length: int = 1,
|
186
|
-
):
|
187
|
-
"""Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
|
188
|
-
first node and the node of the resulting lineage can also be longer.
|
189
|
-
|
190
|
-
Args:
|
191
|
-
l1_root (int): Id of the first root
|
192
|
-
l2_root (int): Id of the second root
|
193
|
-
length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
|
194
|
-
length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
|
195
|
-
length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
|
196
|
-
|
197
|
-
Returns:
|
198
|
-
int: The id of the root of the new lineage.
|
199
|
-
"""
|
200
|
-
if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
|
201
|
-
raise ValueError("Please select 2 roots.")
|
202
|
-
if self.time[l1_root] != self.time[l2_root]:
|
203
|
-
warnings.warn(
|
204
|
-
"Using lineagetrees that do not exist in the same timepoint. The operation will continue"
|
205
|
-
)
|
206
|
-
new_root1 = self.add_branch(l1_root, length_l1)
|
207
|
-
new_root2 = self.add_branch(l2_root, length_l2)
|
208
|
-
next_root1 = self[new_root1][0]
|
209
|
-
self.remove_nodes(new_root1)
|
210
|
-
self.successor[new_root2].append(next_root1)
|
211
|
-
self.predecessor[next_root1] = [new_root2]
|
212
|
-
new_branch = self.add_branch(new_root2, length)
|
213
|
-
self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
|
214
|
-
return new_branch
|
215
|
-
|
216
|
-
def copy_lineage(self, root):
|
217
|
-
"""
|
218
|
-
Copies the structure of a tree and makes a new with new nodes.
|
219
|
-
Warning does not take into account the predecessor of the root node.
|
220
|
-
|
221
|
-
Args:
|
222
|
-
root (int): The root of the tree to be copied
|
209
|
+
for _ in range(int(length)):
|
210
|
+
old_node = node
|
211
|
+
node = self._add_node(succ=[old_node])
|
212
|
+
self._time[node] = self._time[old_node] - 1
|
213
|
+
return node
|
214
|
+
|
215
|
+
@modifier
|
216
|
+
def add_root(self, t: int, pos: list | None = None) -> int:
|
217
|
+
"""Adds a root to a specific timepoint.
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
----------
|
221
|
+
t :int
|
222
|
+
The timepoint the node is going to be added.
|
223
|
+
pos : list
|
224
|
+
The position of the new node.
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
int
|
228
|
+
The id of the new root.
|
229
|
+
"""
|
230
|
+
C_next = self.get_next_id()
|
231
|
+
self._successor[C_next] = ()
|
232
|
+
self._predecessor[C_next] = ()
|
233
|
+
self._time[C_next] = t
|
234
|
+
self.pos[C_next] = pos if isinstance(pos, list) else []
|
235
|
+
self._changed_roots = True
|
236
|
+
return C_next
|
223
237
|
|
224
|
-
|
225
|
-
int: The root of the new tree.
|
226
|
-
"""
|
227
|
-
new_nodes = {
|
228
|
-
old_node: self.get_next_id()
|
229
|
-
for old_node in self.get_sub_tree(root)
|
230
|
-
}
|
231
|
-
self.nodes.update(new_nodes.values())
|
232
|
-
for old_node, new_node in new_nodes.items():
|
233
|
-
self.time[new_node] = self.time[old_node]
|
234
|
-
succ = self.successor.get(old_node)
|
235
|
-
if succ:
|
236
|
-
self.successor[new_node] = [new_nodes[n] for n in succ]
|
237
|
-
pred = self.predecessor.get(old_node)
|
238
|
-
if pred:
|
239
|
-
self.predecessor[new_node] = [new_nodes[n] for n in pred]
|
240
|
-
self.pos[new_node] = self.pos[old_node] + 0.5
|
241
|
-
self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
|
242
|
-
new_root = new_nodes[root]
|
243
|
-
self.labels[new_root] = f"Copy of {root}"
|
244
|
-
if self.time[new_root] == 0:
|
245
|
-
self.roots.add(new_root)
|
246
|
-
return new_root
|
247
|
-
|
248
|
-
def add_node(
|
238
|
+
def _add_node(
|
249
239
|
self,
|
250
|
-
|
251
|
-
|
252
|
-
pos: np.ndarray = None,
|
253
|
-
nid: int = None,
|
254
|
-
reverse: bool = False,
|
240
|
+
succ: list | None = None,
|
241
|
+
pred: list | None = None,
|
242
|
+
pos: np.ndarray | None = None,
|
243
|
+
nid: int | None = None,
|
255
244
|
) -> int:
|
256
|
-
"""Adds a node to the
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
245
|
+
"""Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
|
246
|
+
Does not handle time! You cannot enter both a successor and a predecessor.
|
247
|
+
|
248
|
+
Parameters
|
249
|
+
----------
|
250
|
+
succ : list
|
251
|
+
list of ids of the nodes the new node is a successor to
|
252
|
+
pred : list
|
253
|
+
list of ids of the nodes the new node is a predecessor to
|
254
|
+
pos : np.ndarray, optional
|
255
|
+
position of the new node
|
256
|
+
nid : int, optional
|
257
|
+
id value of the new node, to be used carefully,
|
258
|
+
if None is provided the new id is automatically computed.
|
259
|
+
|
260
|
+
Returns
|
261
|
+
-------
|
262
|
+
int
|
263
|
+
id of the new node.
|
264
|
+
"""
|
265
|
+
if not succ and not pred:
|
266
|
+
raise Warning(
|
267
|
+
"Please enter a successor or a predecessor, otherwise use the add_roots() function."
|
268
|
+
)
|
271
269
|
C_next = self.get_next_id() if nid is None else nid
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
self.
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
270
|
+
if succ:
|
271
|
+
self._successor[C_next] = succ
|
272
|
+
for suc in succ:
|
273
|
+
self._predecessor[suc] = (C_next,)
|
274
|
+
else:
|
275
|
+
self._successor[C_next] = ()
|
276
|
+
if pred:
|
277
|
+
self._predecessor[C_next] = pred
|
278
|
+
self._successor[pred[0]] = self._successor.setdefault(
|
279
|
+
pred[0], ()
|
280
|
+
) + (C_next,)
|
281
|
+
else:
|
282
|
+
self._predecessor[C_next] = ()
|
283
|
+
if isinstance(pos, list):
|
284
|
+
self.pos[C_next] = pos
|
282
285
|
return C_next
|
283
286
|
|
284
|
-
|
287
|
+
@modifier
|
288
|
+
def remove_nodes(self, group: int | set | list) -> None:
|
285
289
|
"""Removes a group of nodes from the LineageTree
|
286
290
|
|
287
|
-
|
288
|
-
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
group : set of int or list of int or int
|
294
|
+
One or more nodes that are to be removed.
|
289
295
|
"""
|
290
|
-
if isinstance(group, int):
|
296
|
+
if isinstance(group, int | float):
|
291
297
|
group = {group}
|
292
298
|
if isinstance(group, list):
|
293
299
|
group = set(group)
|
294
|
-
group =
|
295
|
-
self.nodes.difference_update(group)
|
296
|
-
times = {self.time.pop(n) for n in group}
|
297
|
-
for t in times:
|
298
|
-
self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
|
300
|
+
group = self.nodes.intersection(group)
|
299
301
|
for node in group:
|
300
|
-
self.
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
self.
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
node (int): Any node of the branch to be modified/
|
321
|
-
new_length (int): The new length of the tree.
|
322
|
-
"""
|
323
|
-
if new_length <= 1:
|
324
|
-
warnings.warn("New length should be more than 1")
|
325
|
-
return None
|
326
|
-
cycle = self.get_cycle(node)
|
327
|
-
length = len(cycle)
|
328
|
-
successors = self.successor.get(cycle[-1])
|
329
|
-
if length == 1 and new_length != 1:
|
330
|
-
pred = self.predecessor.pop(node, None)
|
331
|
-
new_node = self.add_branch(
|
332
|
-
node,
|
333
|
-
length=new_length - 1,
|
334
|
-
move_timepoints=True,
|
335
|
-
reverse=False,
|
336
|
-
)
|
337
|
-
if pred:
|
338
|
-
self.successor[pred[0]].remove(node)
|
339
|
-
self.successor[pred[0]].append(new_node)
|
340
|
-
elif self.leaves.intersection(cycle) and new_length < length:
|
341
|
-
self.remove_nodes(cycle[new_length:])
|
342
|
-
elif new_length < length:
|
343
|
-
to_remove = length - new_length
|
344
|
-
last_cell = cycle[new_length - 1]
|
345
|
-
subtree = self.get_sub_tree(cycle[-1])[1:]
|
346
|
-
self.remove_nodes(cycle[new_length:])
|
347
|
-
self.successor[last_cell] = successors
|
348
|
-
if successors:
|
349
|
-
for succ in successors:
|
350
|
-
self.predecessor[succ] = [last_cell]
|
351
|
-
for node in subtree:
|
352
|
-
if node not in cycle[new_length - 1 :]:
|
353
|
-
old_time = self.time[node]
|
354
|
-
self.time[node] = old_time - to_remove
|
355
|
-
self.time_nodes[old_time].remove(node)
|
356
|
-
self.time_nodes.setdefault(
|
357
|
-
old_time - to_remove, set()
|
358
|
-
).add(node)
|
359
|
-
elif length < new_length:
|
360
|
-
to_add = new_length - length
|
361
|
-
last_cell = cycle[-1]
|
362
|
-
self.successor.pop(cycle[-2])
|
363
|
-
self.predecessor.pop(last_cell)
|
364
|
-
succ = self.add_branch(
|
365
|
-
last_cell, length=to_add, move_timepoints=True, reverse=False
|
366
|
-
)
|
367
|
-
self.predecessor[succ] = [cycle[-2]]
|
368
|
-
self.successor[cycle[-2]] = [succ]
|
369
|
-
self.time[last_cell] = (
|
370
|
-
self.time[self.predecessor[last_cell][0]] + 1
|
371
|
-
)
|
372
|
-
else:
|
373
|
-
return None
|
374
|
-
|
375
|
-
@property
|
376
|
-
def time_resolution(self):
|
377
|
-
if not hasattr(self, "_time_resolution"):
|
378
|
-
self.time_resolution = 1
|
379
|
-
return self._time_resolution / 10
|
380
|
-
|
381
|
-
@time_resolution.setter
|
382
|
-
def time_resolution(self, time_resolution):
|
383
|
-
if time_resolution is not None:
|
384
|
-
self._time_resolution = int(time_resolution * 10)
|
385
|
-
else:
|
386
|
-
warnings.warn("Time resolution set to default 1", stacklevel=2)
|
387
|
-
self._time_resolution = 10
|
302
|
+
for attr in self.__dict__:
|
303
|
+
attr_value = self.__getattribute__(attr)
|
304
|
+
if isinstance(attr_value, dict) and attr not in [
|
305
|
+
"successor",
|
306
|
+
"predecessor",
|
307
|
+
"_successor",
|
308
|
+
"_predecessor",
|
309
|
+
"_time",
|
310
|
+
]:
|
311
|
+
attr_value.pop(node, ())
|
312
|
+
if self._predecessor.get(node):
|
313
|
+
self._successor[self._predecessor[node][0]] = tuple(
|
314
|
+
set(
|
315
|
+
self._successor[self._predecessor[node][0]]
|
316
|
+
).difference(group)
|
317
|
+
)
|
318
|
+
for p_node in self._successor.get(node, []):
|
319
|
+
self._predecessor[p_node] = ()
|
320
|
+
self._predecessor.pop(node, ())
|
321
|
+
self._successor.pop(node, ())
|
388
322
|
|
389
323
|
@property
|
390
|
-
def
|
391
|
-
|
392
|
-
|
393
|
-
|
324
|
+
def successor(self) -> MappingProxyType[int, tuple[int]]:
|
325
|
+
"""The successor of the tree."""
|
326
|
+
if not hasattr(self, "_protected_successor"):
|
327
|
+
self._protected_successor = MappingProxyType(self._successor)
|
328
|
+
return self._protected_successor
|
394
329
|
|
395
330
|
@property
|
396
|
-
def
|
397
|
-
|
331
|
+
def predecessor(self) -> MappingProxyType[int, tuple[int]]:
|
332
|
+
"""The predecessor of the tree."""
|
333
|
+
if not hasattr(self, "_protected_predecessor"):
|
334
|
+
self._protected_predecessor = MappingProxyType(self._predecessor)
|
335
|
+
return self._protected_predecessor
|
398
336
|
|
399
337
|
@property
|
400
|
-
def
|
401
|
-
|
338
|
+
def time(self) -> MappingProxyType[int, int]:
|
339
|
+
"""The time of the tree."""
|
340
|
+
if not hasattr(self, "_protected_time"):
|
341
|
+
self._protected_time = MappingProxyType(self._time)
|
342
|
+
return self._protected_time
|
343
|
+
|
344
|
+
@dynamic_property
|
345
|
+
def t_b(self) -> int:
|
346
|
+
"""The first timepoint of the tree."""
|
347
|
+
return min(self._time.values())
|
348
|
+
|
349
|
+
@dynamic_property
|
350
|
+
def t_e(self) -> int:
|
351
|
+
"""The last timepoint of the tree."""
|
352
|
+
return max(self._time.values())
|
353
|
+
|
354
|
+
@dynamic_property
|
355
|
+
def nodes(self) -> frozenset[int]:
|
356
|
+
"""Nodes of the tree"""
|
357
|
+
return frozenset(self._successor.keys())
|
358
|
+
|
359
|
+
@dynamic_property
|
360
|
+
def depth(self) -> dict[int, int]:
|
361
|
+
"""The depth of each node in the tree."""
|
362
|
+
_depth = {}
|
363
|
+
for leaf in self.leaves:
|
364
|
+
_depth[leaf] = 1
|
365
|
+
while leaf in self._predecessor and self._predecessor[leaf]:
|
366
|
+
parent = self._predecessor[leaf][0]
|
367
|
+
current_depth = _depth.get(parent, 0)
|
368
|
+
_depth[parent] = max(_depth[leaf] + 1, current_depth)
|
369
|
+
leaf = parent
|
370
|
+
for root in self.roots - set(_depth):
|
371
|
+
_depth[root] = 1
|
372
|
+
return _depth
|
373
|
+
|
374
|
+
@dynamic_property
|
375
|
+
def roots(self) -> frozenset[int]:
|
376
|
+
"""Set of roots of the tree"""
|
377
|
+
return frozenset({s for s, p in self._predecessor.items() if p == ()})
|
378
|
+
|
379
|
+
@dynamic_property
|
380
|
+
def leaves(self) -> frozenset[int]:
|
381
|
+
"""Set of leaves"""
|
382
|
+
return frozenset({p for p, s in self._successor.items() if s == ()})
|
383
|
+
|
384
|
+
@dynamic_property
|
385
|
+
def edges(self) -> tuple[tuple[int, int]]:
|
386
|
+
"""Set of edges"""
|
387
|
+
return tuple((p, si) for p, s in self._successor.items() for si in s)
|
402
388
|
|
403
389
|
@property
|
404
|
-
def labels(self):
|
390
|
+
def labels(self) -> dict[int, str]:
|
391
|
+
"""The labels of the nodes."""
|
405
392
|
if not hasattr(self, "_labels"):
|
406
|
-
if hasattr(self, "
|
393
|
+
if hasattr(self, "node_name"):
|
407
394
|
self._labels = {
|
408
|
-
i: self.
|
395
|
+
i: self.node_name.get(i, "Unlabeled") for i in self.roots
|
409
396
|
}
|
410
397
|
else:
|
411
398
|
self._labels = {
|
412
|
-
|
413
|
-
for
|
414
|
-
for
|
415
|
-
if abs(self.
|
399
|
+
root: "Unlabeled"
|
400
|
+
for root in self.roots
|
401
|
+
for leaf in self.find_leaves(root)
|
402
|
+
if abs(self._time[leaf] - self._time[root])
|
416
403
|
>= abs(self.t_e - self.t_b) / 4
|
417
404
|
}
|
418
405
|
return self._labels
|
419
406
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
f.write("define POINT %d\n" % ((length) * nb_points))
|
426
|
-
f.write("Parameters {\n")
|
427
|
-
f.write('\tContentType "HxSpatialGraph"\n')
|
428
|
-
f.write("}\n")
|
429
|
-
|
430
|
-
f.write("VERTEX { float[3] VertexCoordinates } @1\n")
|
431
|
-
f.write("EDGE { int[2] EdgeConnectivity } @2\n")
|
432
|
-
f.write("EDGE { int NumEdgePoints } @3\n")
|
433
|
-
f.write("POINT { float[3] EdgePointCoordinates } @4\n")
|
434
|
-
f.write("VERTEX { float Vcolor } @5\n")
|
435
|
-
f.write("VERTEX { int Vbool } @6\n")
|
436
|
-
f.write("EDGE { float Ecolor } @7\n")
|
437
|
-
f.write("VERTEX { int Vbool2 } @8\n")
|
438
|
-
|
439
|
-
def write_to_am(
|
440
|
-
self,
|
441
|
-
path_format: str,
|
442
|
-
t_b: int = None,
|
443
|
-
t_e: int = None,
|
444
|
-
length: int = 5,
|
445
|
-
manual_labels: dict = None,
|
446
|
-
default_label: int = 5,
|
447
|
-
new_pos: np.ndarray = None,
|
448
|
-
):
|
449
|
-
"""Writes a lineageTree into an Amira readable data (.am format).
|
450
|
-
|
451
|
-
Args:
|
452
|
-
path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
|
453
|
-
t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
|
454
|
-
t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
|
455
|
-
note, if there is no 'to_take_time' attribute, self.time_nodes
|
456
|
-
is considered instead (historical)
|
457
|
-
length (int): length of the track to print (how many time before).
|
458
|
-
manual_labels ({id: label, }): dictionary that maps cell ids to
|
459
|
-
default_label (int): default value for the manual label
|
460
|
-
new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
|
461
|
-
if new_pos == None (default) then self.pos is considered.
|
462
|
-
"""
|
463
|
-
if not hasattr(self, "to_take_time"):
|
464
|
-
self.to_take_time = self.time_nodes
|
465
|
-
if t_b is None:
|
466
|
-
t_b = min(self.to_take_time.keys())
|
467
|
-
if t_e is None:
|
468
|
-
t_e = max(self.to_take_time.keys())
|
469
|
-
if new_pos is None:
|
470
|
-
new_pos = self.pos
|
471
|
-
|
472
|
-
if manual_labels is None:
|
473
|
-
manual_labels = {}
|
474
|
-
for t in range(t_b, t_e + 1):
|
475
|
-
with open(path_format % t, "w") as f:
|
476
|
-
nb_points = len(self.to_take_time[t])
|
477
|
-
self._write_header_am(f, nb_points, length)
|
478
|
-
points_v = {}
|
479
|
-
for C in self.to_take_time[t]:
|
480
|
-
C_tmp = C
|
481
|
-
positions = []
|
482
|
-
for _ in range(length):
|
483
|
-
C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
|
484
|
-
positions.append(new_pos[C_tmp])
|
485
|
-
points_v[C] = positions
|
486
|
-
|
487
|
-
f.write("@1\n")
|
488
|
-
for C in self.to_take_time[t]:
|
489
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
|
490
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
|
491
|
-
|
492
|
-
f.write("@2\n")
|
493
|
-
for i, _ in enumerate(self.to_take_time[t]):
|
494
|
-
f.write("%d %d\n" % (2 * i, 2 * i + 1))
|
495
|
-
|
496
|
-
f.write("@3\n")
|
497
|
-
for _ in self.to_take_time[t]:
|
498
|
-
f.write("%d\n" % (length))
|
499
|
-
|
500
|
-
f.write("@4\n")
|
501
|
-
for C in self.to_take_time[t]:
|
502
|
-
for p in points_v[C]:
|
503
|
-
f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
|
504
|
-
|
505
|
-
f.write("@5\n")
|
506
|
-
for C in self.to_take_time[t]:
|
507
|
-
f.write("%f\n" % (manual_labels.get(C, default_label)))
|
508
|
-
f.write(f"{0:f}\n")
|
509
|
-
|
510
|
-
f.write("@6\n")
|
511
|
-
for C in self.to_take_time[t]:
|
512
|
-
f.write(
|
513
|
-
"%d\n"
|
514
|
-
% (
|
515
|
-
int(
|
516
|
-
manual_labels.get(C, default_label)
|
517
|
-
!= default_label
|
518
|
-
)
|
519
|
-
)
|
520
|
-
)
|
521
|
-
f.write("%d\n" % (0))
|
522
|
-
|
523
|
-
f.write("@7\n")
|
524
|
-
for C in self.to_take_time[t]:
|
525
|
-
f.write(
|
526
|
-
"%f\n"
|
527
|
-
% (np.linalg.norm(points_v[C][0] - points_v[C][-1]))
|
528
|
-
)
|
529
|
-
|
530
|
-
f.write("@8\n")
|
531
|
-
for _ in self.to_take_time[t]:
|
532
|
-
f.write("%d\n" % (1))
|
533
|
-
f.write("%d\n" % (0))
|
534
|
-
f.close()
|
407
|
+
@property
|
408
|
+
def time_resolution(self) -> float:
|
409
|
+
if not hasattr(self, "_time_resolution"):
|
410
|
+
self._time_resolution = 0
|
411
|
+
return self._time_resolution / 10
|
535
412
|
|
536
|
-
|
537
|
-
|
413
|
+
@time_resolution.setter
|
414
|
+
def time_resolution(self, time_resolution) -> None:
|
415
|
+
if time_resolution is not None and time_resolution > 0:
|
416
|
+
self._time_resolution = int(time_resolution * 10)
|
417
|
+
else:
|
418
|
+
warnings.warn("Time resolution set to default 0", stacklevel=2)
|
419
|
+
self._time_resolution = 0
|
420
|
+
|
421
|
+
def __setstate__(self, state):
|
422
|
+
if "_successor" not in state:
|
423
|
+
state["_successor"] = state["successor"]
|
424
|
+
if "_predecessor" not in state:
|
425
|
+
state["_predecessor"] = state["predecessor"]
|
426
|
+
if "_time" not in state:
|
427
|
+
state["_time"] = state["time"]
|
428
|
+
self.__dict__.update(state)
|
429
|
+
|
430
|
+
def _get_height(self, c: int, done: dict) -> float:
|
431
|
+
"""Recursively computes the height of a node within a tree times a space factor.
|
538
432
|
This function is specific to the function write_to_svg.
|
539
433
|
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
434
|
+
Parameters
|
435
|
+
----------
|
436
|
+
c : int
|
437
|
+
id of a node in a lineage tree from which the height will be computed from
|
438
|
+
done : dict mapping int to list of two int
|
439
|
+
a dictionary that maps a node id to its vertical and horizontal position
|
440
|
+
|
441
|
+
Returns
|
442
|
+
-------
|
443
|
+
float
|
444
|
+
the height of the node `c`
|
545
445
|
"""
|
546
446
|
if c in done:
|
547
447
|
return done[c][0]
|
548
448
|
else:
|
549
449
|
P = np.mean(
|
550
|
-
[self._get_height(di, done) for di in self.
|
450
|
+
[self._get_height(di, done) for di in self._successor[c]]
|
551
451
|
)
|
552
|
-
done[c] = [P, self.vert_space_factor * self.
|
452
|
+
done[c] = [P, self.vert_space_factor * self._time[c]]
|
553
453
|
return P
|
554
454
|
|
555
455
|
def write_to_svg(
|
556
456
|
self,
|
557
457
|
file_name: str,
|
558
|
-
roots: list = None,
|
458
|
+
roots: list | None = None,
|
559
459
|
draw_nodes: bool = True,
|
560
460
|
draw_edges: bool = True,
|
561
|
-
order_key:
|
461
|
+
order_key: Callable | None = None,
|
562
462
|
vert_space_factor: float = 0.5,
|
563
463
|
horizontal_space: float = 1,
|
564
|
-
node_size:
|
565
|
-
stroke_width:
|
464
|
+
node_size: Callable | str | None = None,
|
465
|
+
stroke_width: Callable | None = None,
|
566
466
|
factor: float = 1.0,
|
567
|
-
node_color:
|
568
|
-
stroke_color:
|
569
|
-
positions: dict = None,
|
570
|
-
node_color_map:
|
571
|
-
|
572
|
-
):
|
573
|
-
##### remove background? default True background value? default 1
|
574
|
-
|
467
|
+
node_color: Callable | str | None = None,
|
468
|
+
stroke_color: Callable | None = None,
|
469
|
+
positions: dict | None = None,
|
470
|
+
node_color_map: Callable | str | None = None,
|
471
|
+
) -> None:
|
575
472
|
"""Writes the lineage tree to an SVG file.
|
576
473
|
Node and edges coloring and size can be provided.
|
577
474
|
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
475
|
+
Parameters
|
476
|
+
----------
|
477
|
+
file_name : str
|
478
|
+
filesystem filename valid for `open()`
|
479
|
+
roots : list of int, defaults to `self.roots`
|
480
|
+
list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
|
481
|
+
draw_nodes : bool, default True
|
482
|
+
wether to print the nodes or not
|
483
|
+
draw_edges : bool, default True
|
484
|
+
wether to print the edges or not
|
485
|
+
order_key : Callable, optional
|
486
|
+
function that would work for the attribute `key=` for the `sort`/`sorted` function
|
487
|
+
vert_space_factor : float, default=0.5
|
488
|
+
the vertical position of a node is its time. `vert_space_factor` is a
|
489
|
+
multiplier to space more or less nodes in time
|
490
|
+
horizontal_space : float, default=1
|
491
|
+
space between two consecutive nodes
|
492
|
+
node_size : Callable or str, optional
|
493
|
+
a function that maps a node id to a `float` value that will determine the
|
494
|
+
radius of the node. The default function return the constant value `vertical_space_factor/2.1`
|
495
|
+
If a string is given instead and it is a property of the tree,
|
496
|
+
the the size will be mapped according to the property
|
497
|
+
stroke_width : Callable, optional
|
498
|
+
a function that maps a node id to a `float` value that will determine the
|
499
|
+
width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
|
500
|
+
factor : float, default=1.0
|
501
|
+
scaling factor for nodes positions, default 1
|
502
|
+
node_color : Callable or str, optional
|
503
|
+
a function that maps a node id to a triplet between 0 and 255.
|
504
|
+
The triplet will determine the color of the node. If a string is given instead and it is a property
|
505
|
+
of the tree, the the color will be mapped according to the property
|
506
|
+
node_color_map : Callable or str, optional
|
507
|
+
the name of the colormap to use to color the nodes, or a colormap function
|
508
|
+
stroke_color : Callable, optional
|
509
|
+
a function that maps a node id to a triplet between 0 and 255.
|
510
|
+
The triplet will determine the color of the stroke of the inward edge.
|
511
|
+
positions : dict mapping int to list of two float, optional
|
512
|
+
dictionary that maps a node id to a 2D position.
|
513
|
+
Default `None`. If provided it will be used to position the nodes.
|
602
514
|
"""
|
603
|
-
import svgwrite
|
604
515
|
|
605
516
|
def normalize_values(v, nodes, _range, shift, mult):
|
606
517
|
min_ = np.percentile(v, 1)
|
607
518
|
max_ = np.percentile(v, 99)
|
608
519
|
values = _range * ((v - min_) / (max_ - min_)) + shift
|
609
|
-
values_dict_nodes = dict(zip(nodes, values))
|
520
|
+
values_dict_nodes = dict(zip(nodes, values, strict=True))
|
610
521
|
return lambda x: values_dict_nodes[x] * mult
|
611
522
|
|
612
523
|
if roots is None:
|
613
524
|
roots = self.roots
|
614
525
|
if hasattr(self, "image_label"):
|
615
|
-
roots = [
|
526
|
+
roots = [node for node in roots if self.image_label[node] != 1]
|
616
527
|
|
617
528
|
if node_size is None:
|
618
529
|
|
619
530
|
def node_size(x):
|
620
531
|
return vert_space_factor / 2.1
|
621
532
|
|
622
|
-
|
623
|
-
values = np.array(
|
533
|
+
else:
|
534
|
+
values = np.array(
|
535
|
+
[self._successor[node_size][c] for c in self.nodes]
|
536
|
+
)
|
624
537
|
node_size = normalize_values(
|
625
538
|
values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
|
626
539
|
)
|
@@ -635,18 +548,19 @@ class lineageTree(lineageTreeLoaders):
|
|
635
548
|
return 0, 0, 0
|
636
549
|
|
637
550
|
elif isinstance(node_color, str) and node_color in self.__dict__:
|
638
|
-
|
639
|
-
from matplotlib import colormaps
|
551
|
+
from matplotlib import colormaps
|
640
552
|
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
values = np.array(
|
553
|
+
if node_color_map in colormaps:
|
554
|
+
cmap = colormaps[node_color_map]
|
555
|
+
else:
|
556
|
+
cmap = colormaps["viridis"]
|
557
|
+
values = np.array(
|
558
|
+
[self._successor[node_color][c] for c in self.nodes]
|
559
|
+
)
|
646
560
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
647
561
|
|
648
562
|
def node_color(x):
|
649
|
-
return [k * 255 for k in
|
563
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
650
564
|
|
651
565
|
coloring_edges = stroke_color is not None
|
652
566
|
if not coloring_edges:
|
@@ -655,24 +569,25 @@ class lineageTree(lineageTreeLoaders):
|
|
655
569
|
return 0, 0, 0
|
656
570
|
|
657
571
|
elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
|
658
|
-
|
659
|
-
from matplotlib import colormaps
|
572
|
+
from matplotlib import colormaps
|
660
573
|
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
values = np.array(
|
574
|
+
if node_color_map in colormaps:
|
575
|
+
cmap = colormaps[node_color_map]
|
576
|
+
else:
|
577
|
+
cmap = colormaps["viridis"]
|
578
|
+
values = np.array(
|
579
|
+
[self._successor[stroke_color][c] for c in self.nodes]
|
580
|
+
)
|
666
581
|
normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
|
667
582
|
|
668
583
|
def stroke_color(x):
|
669
|
-
return [k * 255 for k in
|
584
|
+
return [k * 255 for k in cmap(normed_vals(x))[:-1]]
|
670
585
|
|
671
586
|
prev_x = 0
|
672
587
|
self.vert_space_factor = vert_space_factor
|
673
588
|
if order_key is not None:
|
674
589
|
roots.sort(key=order_key)
|
675
|
-
|
590
|
+
treated_nodes = []
|
676
591
|
|
677
592
|
pos_given = positions is not None
|
678
593
|
if not pos_given:
|
@@ -683,25 +598,26 @@ class lineageTree(lineageTreeLoaders):
|
|
683
598
|
[0.0, 0.0],
|
684
599
|
]
|
685
600
|
* len(self.nodes),
|
686
|
-
|
601
|
+
strict=True,
|
602
|
+
),
|
687
603
|
)
|
688
604
|
for _i, r in enumerate(roots):
|
689
605
|
r_leaves = []
|
690
606
|
to_do = [r]
|
691
607
|
while len(to_do) != 0:
|
692
608
|
curr = to_do.pop(0)
|
693
|
-
|
694
|
-
if
|
609
|
+
treated_nodes += [curr]
|
610
|
+
if not self._successor[curr]:
|
695
611
|
if order_key is not None:
|
696
|
-
to_do += sorted(self.
|
612
|
+
to_do += sorted(self._successor[curr], key=order_key)
|
697
613
|
else:
|
698
|
-
to_do += self.
|
614
|
+
to_do += self._successor[curr]
|
699
615
|
else:
|
700
616
|
r_leaves += [curr]
|
701
617
|
r_pos = {
|
702
618
|
leave: [
|
703
619
|
prev_x + horizontal_space * (1 + j),
|
704
|
-
self.vert_space_factor * self.
|
620
|
+
self.vert_space_factor * self._time[leave],
|
705
621
|
]
|
706
622
|
for j, leave in enumerate(r_leaves)
|
707
623
|
}
|
@@ -716,12 +632,12 @@ class lineageTree(lineageTreeLoaders):
|
|
716
632
|
size=factor * np.max(list(positions.values()), axis=0),
|
717
633
|
)
|
718
634
|
if draw_edges and not draw_nodes and not coloring_edges:
|
719
|
-
to_do = set(
|
635
|
+
to_do = set(treated_nodes)
|
720
636
|
while len(to_do) > 0:
|
721
637
|
curr = to_do.pop()
|
722
|
-
|
723
|
-
x1, y1 = positions[
|
724
|
-
x2, y2 = positions[
|
638
|
+
c_chain = self.get_chain_of_node(curr)
|
639
|
+
x1, y1 = positions[c_chain[0]]
|
640
|
+
x2, y2 = positions[c_chain[-1]]
|
725
641
|
dwg.add(
|
726
642
|
dwg.line(
|
727
643
|
(factor * x1, factor * y1),
|
@@ -729,7 +645,7 @@ class lineageTree(lineageTreeLoaders):
|
|
729
645
|
stroke=svgwrite.rgb(0, 0, 0),
|
730
646
|
)
|
731
647
|
)
|
732
|
-
for si in self[
|
648
|
+
for si in self._successor[c_chain[-1]]:
|
733
649
|
x3, y3 = positions[si]
|
734
650
|
dwg.add(
|
735
651
|
dwg.line(
|
@@ -738,11 +654,11 @@ class lineageTree(lineageTreeLoaders):
|
|
738
654
|
stroke=svgwrite.rgb(0, 0, 0),
|
739
655
|
)
|
740
656
|
)
|
741
|
-
to_do.difference_update(
|
657
|
+
to_do.difference_update(c_chain)
|
742
658
|
else:
|
743
|
-
for c in
|
659
|
+
for c in treated_nodes:
|
744
660
|
x1, y1 = positions[c]
|
745
|
-
for si in self[c]:
|
661
|
+
for si in self._successor[c]:
|
746
662
|
x2, y2 = positions[si]
|
747
663
|
if draw_edges:
|
748
664
|
dwg.add(
|
@@ -753,7 +669,7 @@ class lineageTree(lineageTreeLoaders):
|
|
753
669
|
stroke_width=svgwrite.pt(stroke_width(si)),
|
754
670
|
)
|
755
671
|
)
|
756
|
-
for c in
|
672
|
+
for c in treated_nodes:
|
757
673
|
x1, y1 = positions[c]
|
758
674
|
if draw_nodes:
|
759
675
|
dwg.add(
|
@@ -770,49 +686,58 @@ class lineageTree(lineageTreeLoaders):
|
|
770
686
|
fname: str,
|
771
687
|
t_min: int = -1,
|
772
688
|
t_max: int = np.inf,
|
773
|
-
nodes_to_use: list = None,
|
689
|
+
nodes_to_use: list[int] | None = None,
|
774
690
|
temporal: bool = True,
|
775
|
-
spatial: str = None,
|
691
|
+
spatial: str | None = None,
|
776
692
|
write_layout: bool = True,
|
777
|
-
node_properties: dict = None,
|
693
|
+
node_properties: dict | None = None,
|
778
694
|
Names: bool = False,
|
779
|
-
):
|
695
|
+
) -> None:
|
780
696
|
"""Write a lineage tree into an understable tulip file.
|
781
697
|
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
698
|
+
Parameters
|
699
|
+
----------
|
700
|
+
fname : str
|
701
|
+
path to the tulip file to create
|
702
|
+
t_min : int, default=-1
|
703
|
+
minimum time to consider
|
704
|
+
t_max : int, default=np.inf
|
705
|
+
maximum time to consider
|
706
|
+
nodes_to_use : list of int, optional
|
707
|
+
list of nodes to show in the graph,
|
708
|
+
if `None` then self.nodes is used
|
709
|
+
(taking into account `t_min` and `t_max`)
|
710
|
+
temporal : bool, default=True
|
711
|
+
True if the temporal links should be printed
|
712
|
+
spatial : str, optional
|
713
|
+
Build spatial edges from a spatial neighbourhood graph.
|
714
|
+
The graph has to be computed before running this function
|
715
|
+
'ball': neighbours at a given distance,
|
716
|
+
'kn': k-nearest neighbours,
|
717
|
+
'GG': gabriel graph,
|
718
|
+
None: no spatial edges are writen.
|
719
|
+
Default None
|
720
|
+
write_layout : bool, default=True
|
721
|
+
write the spatial position as layout if True
|
722
|
+
do not write spatial position otherwise
|
723
|
+
node_properties : dict mapping str to list of dict of properties and its default value, optional
|
724
|
+
a dictionary of properties to write
|
725
|
+
To a key representing the name of the property is
|
726
|
+
paired a dictionary that maps a node id to a property
|
727
|
+
and a default value for this property
|
728
|
+
Names : bool, default=True
|
729
|
+
Only works with ASTEC outputs, True to sort the nodes by their names
|
805
730
|
"""
|
806
731
|
|
807
732
|
def format_names(names_which_matter):
|
808
|
-
"""Return an ensured formated
|
733
|
+
"""Return an ensured formated node names"""
|
809
734
|
tmp = {}
|
810
735
|
for k, v in names_which_matter.items():
|
811
736
|
tmp[k] = (
|
812
737
|
v.split(".")[0][0]
|
813
|
-
+ "
|
738
|
+
+ "{:02d}".format(int(v.split(".")[0][1:]))
|
814
739
|
+ "."
|
815
|
-
+ "
|
740
|
+
+ "{:04d}".format(int(v.split(".")[1][:-1]))
|
816
741
|
+ v.split(".")[1][-1]
|
817
742
|
)
|
818
743
|
return tmp
|
@@ -839,20 +764,20 @@ class lineageTree(lineageTreeLoaders):
|
|
839
764
|
if not nodes_to_use:
|
840
765
|
if t_max != np.inf or t_min > -1:
|
841
766
|
nodes_to_use = [
|
842
|
-
n for n in self.nodes if t_min < self.
|
767
|
+
n for n in self.nodes if t_min < self._time[n] <= t_max
|
843
768
|
]
|
844
769
|
edges_to_use = []
|
845
770
|
if temporal:
|
846
771
|
edges_to_use += [
|
847
772
|
e
|
848
773
|
for e in self.edges
|
849
|
-
if t_min < self.
|
774
|
+
if t_min < self._time[e[0]] < t_max
|
850
775
|
]
|
851
776
|
if spatial:
|
852
777
|
edges_to_use += [
|
853
778
|
e
|
854
779
|
for e in s_edges
|
855
|
-
if t_min < self.
|
780
|
+
if t_min < self._time[e[0]] < t_max
|
856
781
|
]
|
857
782
|
else:
|
858
783
|
nodes_to_use = list(self.nodes)
|
@@ -866,12 +791,12 @@ class lineageTree(lineageTreeLoaders):
|
|
866
791
|
nodes_to_use = set(nodes_to_use)
|
867
792
|
if temporal:
|
868
793
|
for n in nodes_to_use:
|
869
|
-
for d in self.
|
794
|
+
for d in self._successor[n]:
|
870
795
|
if d in nodes_to_use:
|
871
796
|
edges_to_use.append((n, d))
|
872
797
|
if spatial:
|
873
798
|
edges_to_use += [
|
874
|
-
e for e in s_edges if t_min < self.
|
799
|
+
e for e in s_edges if t_min < self._time[e[0]] < t_max
|
875
800
|
]
|
876
801
|
nodes_to_use = set(nodes_to_use)
|
877
802
|
if Names:
|
@@ -889,12 +814,12 @@ class lineageTree(lineageTreeLoaders):
|
|
889
814
|
for k, v in node_properties[Names][0].items():
|
890
815
|
if (
|
891
816
|
len(
|
892
|
-
self.
|
893
|
-
self.
|
817
|
+
self._successor.get(
|
818
|
+
self._predecessor.get(k, [-1])[0], ()
|
894
819
|
)
|
895
820
|
)
|
896
821
|
!= 1
|
897
|
-
or self.
|
822
|
+
or self._time[k] == t_min + 1
|
898
823
|
):
|
899
824
|
tmp_names[k] = v
|
900
825
|
node_properties[Names][0] = tmp_names
|
@@ -924,7 +849,7 @@ class lineageTree(lineageTreeLoaders):
|
|
924
849
|
f.write('\t(default "0" "0")\n')
|
925
850
|
for n in nodes_to_use:
|
926
851
|
f.write(
|
927
|
-
"\t(node " + str(n) + ' "' + str(self.
|
852
|
+
"\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
|
928
853
|
)
|
929
854
|
f.write(")\n")
|
930
855
|
|
@@ -953,10 +878,10 @@ class lineageTree(lineageTreeLoaders):
|
|
953
878
|
if node_properties:
|
954
879
|
for p_name, (p_dict, default) in node_properties.items():
|
955
880
|
if isinstance(list(p_dict.values())[0], str):
|
956
|
-
f.write('(property 0 string "
|
881
|
+
f.write(f'(property 0 string "{p_name}"\n')
|
957
882
|
f.write(f"\t(default {default} {default})\n")
|
958
883
|
elif isinstance(list(p_dict.values())[0], Number):
|
959
|
-
f.write('(property 0 double "
|
884
|
+
f.write(f'(property 0 double "{p_name}"\n')
|
960
885
|
f.write('\t(default "0" "0")\n')
|
961
886
|
for n in nodes_to_use:
|
962
887
|
f.write(
|
@@ -971,7 +896,9 @@ class lineageTree(lineageTreeLoaders):
|
|
971
896
|
f.write(")")
|
972
897
|
f.close()
|
973
898
|
|
974
|
-
def to_binary(
|
899
|
+
def to_binary(
|
900
|
+
self, fname: str, starting_points: list[int] | None = None
|
901
|
+
) -> None:
|
975
902
|
"""Writes the lineage tree (a forest) as a binary structure
|
976
903
|
(assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
|
977
904
|
The binary file is composed of 3 sequences of numbers and
|
@@ -986,36 +913,37 @@ class lineageTree(lineageTreeLoaders):
|
|
986
913
|
The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
|
987
914
|
The *pos_sequence* is stored as a list of double.
|
988
915
|
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
916
|
+
Parameters
|
917
|
+
----------
|
918
|
+
fname : str
|
919
|
+
name of the binary file
|
920
|
+
starting_points : list of int, optional
|
921
|
+
list of the roots to be written.
|
922
|
+
If `None`, all roots are written, default value, None
|
993
923
|
"""
|
994
924
|
if starting_points is None:
|
995
|
-
starting_points =
|
996
|
-
c for c in self.successor if self.predecessor.get(c, []) == []
|
997
|
-
]
|
925
|
+
starting_points = list(self.roots)
|
998
926
|
number_sequence = [-1]
|
999
927
|
pos_sequence = []
|
1000
928
|
time_sequence = []
|
1001
929
|
for c in starting_points:
|
1002
|
-
time_sequence.append(self.
|
930
|
+
time_sequence.append(self._time.get(c, 0))
|
1003
931
|
to_treat = [c]
|
1004
|
-
while to_treat
|
932
|
+
while to_treat:
|
1005
933
|
curr_c = to_treat.pop()
|
1006
934
|
number_sequence.append(curr_c)
|
1007
935
|
pos_sequence += list(self.pos[curr_c])
|
1008
|
-
if self[curr_c] ==
|
936
|
+
if self._successor[curr_c] == ():
|
1009
937
|
number_sequence.append(-1)
|
1010
|
-
elif len(self.
|
1011
|
-
to_treat += self.
|
938
|
+
elif len(self._successor[curr_c]) == 1:
|
939
|
+
to_treat += self._successor[curr_c]
|
1012
940
|
else:
|
1013
941
|
number_sequence.append(-2)
|
1014
|
-
to_treat += self.
|
942
|
+
to_treat += self._successor[curr_c]
|
1015
943
|
remaining_nodes = set(self.nodes) - set(number_sequence)
|
1016
944
|
|
1017
945
|
for c in remaining_nodes:
|
1018
|
-
time_sequence.append(self.
|
946
|
+
time_sequence.append(self._time.get(c, 0))
|
1019
947
|
number_sequence.append(c)
|
1020
948
|
pos_sequence += list(self.pos[c])
|
1021
949
|
number_sequence.append(-1)
|
@@ -1030,204 +958,270 @@ class lineageTree(lineageTreeLoaders):
|
|
1030
958
|
|
1031
959
|
f.close()
|
1032
960
|
|
1033
|
-
def
|
1034
|
-
"""
|
1035
|
-
Reads a binary lineageTree file name.
|
1036
|
-
Format description: see self.to_binary
|
961
|
+
def write(self, fname: str) -> None:
|
962
|
+
"""Write a lineage tree on disk as an .lT file.
|
1037
963
|
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
q_size = struct.calcsize("q")
|
1043
|
-
H_size = struct.calcsize("H")
|
1044
|
-
d_size = struct.calcsize("d")
|
1045
|
-
|
1046
|
-
with open(fname, "rb") as f:
|
1047
|
-
len_tree = struct.unpack("q", f.read(q_size))[0]
|
1048
|
-
len_time = struct.unpack("q", f.read(q_size))[0]
|
1049
|
-
len_pos = struct.unpack("q", f.read(q_size))[0]
|
1050
|
-
number_sequence = list(
|
1051
|
-
struct.unpack("q" * len_tree, f.read(q_size * len_tree))
|
1052
|
-
)
|
1053
|
-
time_sequence = list(
|
1054
|
-
struct.unpack("H" * len_time, f.read(H_size * len_time))
|
1055
|
-
)
|
1056
|
-
pos_sequence = np.array(
|
1057
|
-
struct.unpack("d" * len_pos, f.read(d_size * len_pos))
|
1058
|
-
)
|
1059
|
-
|
1060
|
-
f.close()
|
1061
|
-
|
1062
|
-
successor = {}
|
1063
|
-
predecessor = {}
|
1064
|
-
time = {}
|
1065
|
-
time_nodes = {}
|
1066
|
-
time_edges = {}
|
1067
|
-
pos = {}
|
1068
|
-
is_root = {}
|
1069
|
-
nodes = []
|
1070
|
-
edges = []
|
1071
|
-
waiting_list = []
|
1072
|
-
print(number_sequence[0])
|
1073
|
-
i = 0
|
1074
|
-
done = False
|
1075
|
-
if max(number_sequence[::2]) == -1:
|
1076
|
-
tmp = number_sequence[1::2]
|
1077
|
-
if len(tmp) * 3 == len(pos_sequence) == len(time_sequence) * 3:
|
1078
|
-
time = dict(list(zip(tmp, time_sequence)))
|
1079
|
-
for c, t in time.items():
|
1080
|
-
time_nodes.setdefault(t, set()).add(c)
|
1081
|
-
pos = dict(
|
1082
|
-
list(zip(tmp, np.reshape(pos_sequence, (len_time, 3))))
|
1083
|
-
)
|
1084
|
-
is_root = {c: True for c in tmp}
|
1085
|
-
nodes = tmp
|
1086
|
-
done = True
|
1087
|
-
while (
|
1088
|
-
i < len(number_sequence) and not done
|
1089
|
-
): # , c in enumerate(number_sequence[:-1]):
|
1090
|
-
c = number_sequence[i]
|
1091
|
-
if c == -1:
|
1092
|
-
if waiting_list != []:
|
1093
|
-
prev_mother = waiting_list.pop()
|
1094
|
-
successor[prev_mother].insert(0, number_sequence[i + 1])
|
1095
|
-
edges.append((prev_mother, number_sequence[i + 1]))
|
1096
|
-
time_edges.setdefault(t, set()).add(
|
1097
|
-
(prev_mother, number_sequence[i + 1])
|
1098
|
-
)
|
1099
|
-
is_root[number_sequence[i + 1]] = False
|
1100
|
-
t = time[prev_mother] + 1
|
1101
|
-
else:
|
1102
|
-
t = time_sequence.pop(0)
|
1103
|
-
is_root[number_sequence[i + 1]] = True
|
1104
|
-
|
1105
|
-
elif c == -2:
|
1106
|
-
successor[waiting_list[-1]] = [number_sequence[i + 1]]
|
1107
|
-
edges.append((waiting_list[-1], number_sequence[i + 1]))
|
1108
|
-
time_edges.setdefault(t, set()).add(
|
1109
|
-
(waiting_list[-1], number_sequence[i + 1])
|
1110
|
-
)
|
1111
|
-
is_root[number_sequence[i + 1]] = False
|
1112
|
-
pos[waiting_list[-1]] = pos_sequence[:3]
|
1113
|
-
pos_sequence = pos_sequence[3:]
|
1114
|
-
nodes.append(waiting_list[-1])
|
1115
|
-
time[waiting_list[-1]] = t
|
1116
|
-
time_nodes.setdefault(t, set()).add(waiting_list[-1])
|
1117
|
-
t += 1
|
1118
|
-
|
1119
|
-
elif number_sequence[i + 1] >= 0:
|
1120
|
-
successor[c] = [number_sequence[i + 1]]
|
1121
|
-
edges.append((c, number_sequence[i + 1]))
|
1122
|
-
time_edges.setdefault(t, set()).add(
|
1123
|
-
(c, number_sequence[i + 1])
|
1124
|
-
)
|
1125
|
-
is_root[number_sequence[i + 1]] = False
|
1126
|
-
pos[c] = pos_sequence[:3]
|
1127
|
-
pos_sequence = pos_sequence[3:]
|
1128
|
-
nodes.append(c)
|
1129
|
-
time[c] = t
|
1130
|
-
time_nodes.setdefault(t, set()).add(c)
|
1131
|
-
t += 1
|
1132
|
-
|
1133
|
-
elif number_sequence[i + 1] == -2:
|
1134
|
-
waiting_list += [c]
|
1135
|
-
|
1136
|
-
elif number_sequence[i + 1] == -1:
|
1137
|
-
pos[c] = pos_sequence[:3]
|
1138
|
-
pos_sequence = pos_sequence[3:]
|
1139
|
-
nodes.append(c)
|
1140
|
-
time[c] = t
|
1141
|
-
time_nodes.setdefault(t, set()).add(c)
|
1142
|
-
t += 1
|
1143
|
-
i += 1
|
1144
|
-
if waiting_list != []:
|
1145
|
-
prev_mother = waiting_list.pop()
|
1146
|
-
successor[prev_mother].insert(0, number_sequence[i + 1])
|
1147
|
-
edges.append((prev_mother, number_sequence[i + 1]))
|
1148
|
-
time_edges.setdefault(t, set()).add(
|
1149
|
-
(prev_mother, number_sequence[i + 1])
|
1150
|
-
)
|
1151
|
-
if i + 1 < len(number_sequence):
|
1152
|
-
is_root[number_sequence[i + 1]] = False
|
1153
|
-
t = time[prev_mother] + 1
|
1154
|
-
else:
|
1155
|
-
if len(time_sequence) > 0:
|
1156
|
-
t = time_sequence.pop(0)
|
1157
|
-
if i + 1 < len(number_sequence):
|
1158
|
-
is_root[number_sequence[i + 1]] = True
|
1159
|
-
i += 1
|
1160
|
-
|
1161
|
-
predecessor = {vi: [k] for k, v in successor.items() for vi in v}
|
1162
|
-
|
1163
|
-
self.successor = successor
|
1164
|
-
self.predecessor = predecessor
|
1165
|
-
self.time = time
|
1166
|
-
self.time_nodes = time_nodes
|
1167
|
-
self.time_edges = time_edges
|
1168
|
-
self.pos = pos
|
1169
|
-
self.nodes = set(nodes)
|
1170
|
-
self.t_b = min(time_nodes.keys())
|
1171
|
-
self.t_e = max(time_nodes.keys())
|
1172
|
-
self.is_root = is_root
|
1173
|
-
self.max_id = max(self.nodes)
|
1174
|
-
|
1175
|
-
def write(self, fname: str):
|
1176
|
-
"""
|
1177
|
-
Write a lineage tree on disk as an .lT file.
|
1178
|
-
|
1179
|
-
Args:
|
1180
|
-
fname (str): path to and name of the file to save
|
964
|
+
Parameters
|
965
|
+
----------
|
966
|
+
fname : str
|
967
|
+
path to and name of the file to save
|
1181
968
|
"""
|
1182
969
|
if os.path.splitext(fname)[-1] != ".lT":
|
1183
970
|
fname = os.path.extsep.join((fname, "lT"))
|
971
|
+
if hasattr(self, "_protected_predecessor"):
|
972
|
+
del self._protected_predecessor
|
973
|
+
if hasattr(self, "_protected_successor"):
|
974
|
+
del self._protected_successor
|
975
|
+
if hasattr(self, "_protected_time"):
|
976
|
+
del self._protected_time
|
1184
977
|
with open(fname, "bw") as f:
|
1185
978
|
pkl.dump(self, f)
|
1186
979
|
f.close()
|
1187
980
|
|
1188
981
|
@classmethod
|
1189
|
-
def load(clf, fname: str
|
1190
|
-
"""
|
1191
|
-
Loading a lineage tree from a ".lT" file.
|
982
|
+
def load(clf, fname: str):
|
983
|
+
"""Loading a lineage tree from a '.lT' file.
|
1192
984
|
|
1193
|
-
|
1194
|
-
|
985
|
+
Parameters
|
986
|
+
----------
|
987
|
+
fname : str
|
988
|
+
path to and name of the file to read
|
1195
989
|
|
1196
|
-
Returns
|
1197
|
-
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
LineageTree
|
993
|
+
loaded file
|
1198
994
|
"""
|
1199
995
|
with open(fname, "br") as f:
|
1200
996
|
lT = pkl.load(f)
|
1201
997
|
f.close()
|
998
|
+
if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
|
999
|
+
"2.0.0"
|
1000
|
+
):
|
1001
|
+
properties = {
|
1002
|
+
prop_name: prop
|
1003
|
+
for prop_name, prop in lT.__dict__.items()
|
1004
|
+
if isinstance(prop, dict)
|
1005
|
+
and prop_name
|
1006
|
+
not in [
|
1007
|
+
"successor",
|
1008
|
+
"predecessor",
|
1009
|
+
"time",
|
1010
|
+
"_successor",
|
1011
|
+
"_predecessor",
|
1012
|
+
"_time",
|
1013
|
+
"pos",
|
1014
|
+
"labels",
|
1015
|
+
]
|
1016
|
+
+ lineageTree._dynamic_properties
|
1017
|
+
+ lineageTree._protected_dynamic_properties
|
1018
|
+
}
|
1019
|
+
print("_comparisons" in properties)
|
1020
|
+
lT = lineageTree(
|
1021
|
+
successor=lT._successor,
|
1022
|
+
time=lT._time,
|
1023
|
+
pos=lT.pos,
|
1024
|
+
name=lT.name if hasattr(lT, "name") else None,
|
1025
|
+
**properties,
|
1026
|
+
)
|
1202
1027
|
if not hasattr(lT, "time_resolution"):
|
1203
|
-
lT.time_resolution =
|
1204
|
-
|
1205
|
-
if [] in lT.successor.values():
|
1206
|
-
for node, succ in lT.successor.items():
|
1207
|
-
if succ == []:
|
1208
|
-
lT.successor.pop(node)
|
1209
|
-
if [] in lT.predecessor.values():
|
1210
|
-
for node, succ in lT.predecessor.items():
|
1211
|
-
if succ == []:
|
1212
|
-
lT.predecessor.pop(node)
|
1213
|
-
lT.t_e = max(lT.time_nodes)
|
1214
|
-
lT.t_b = min(lT.time_nodes)
|
1215
|
-
warnings.warn("Empty lists have been removed")
|
1028
|
+
lT.time_resolution = 1
|
1029
|
+
|
1216
1030
|
return lT
|
1217
1031
|
|
1218
|
-
def
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1032
|
+
def get_predecessors(
|
1033
|
+
self,
|
1034
|
+
x: int,
|
1035
|
+
depth: int | None = None,
|
1036
|
+
start_time: int | None = None,
|
1037
|
+
end_time: int | None = None,
|
1038
|
+
) -> list[int]:
|
1039
|
+
"""Computes the predecessors of the node `x` up to
|
1040
|
+
`depth` predecessors or the begining of the life of `x`.
|
1041
|
+
The ordered list of ids is returned.
|
1042
|
+
|
1043
|
+
Parameters
|
1044
|
+
----------
|
1045
|
+
x : int
|
1046
|
+
id of the node to compute
|
1047
|
+
depth : int
|
1048
|
+
maximum number of predecessors to return
|
1049
|
+
|
1050
|
+
Returns
|
1051
|
+
-------
|
1052
|
+
list of int
|
1053
|
+
list of ids, the last id is `x`
|
1054
|
+
"""
|
1055
|
+
if not start_time:
|
1056
|
+
start_time = self.t_b
|
1057
|
+
if not end_time:
|
1058
|
+
end_time = self.t_e
|
1059
|
+
unconstrained_chain = [x]
|
1060
|
+
chain = [x] if start_time <= self._time[x] <= end_time else []
|
1061
|
+
acc = 0
|
1062
|
+
while (
|
1063
|
+
acc != depth
|
1064
|
+
and start_time < self._time[unconstrained_chain[0]]
|
1065
|
+
and (
|
1066
|
+
self._predecessor[unconstrained_chain[0]] != ()
|
1067
|
+
and ( # Please dont change very important even if it looks weird.
|
1068
|
+
len(
|
1069
|
+
self._successor[
|
1070
|
+
self._predecessor[unconstrained_chain[0]][0]
|
1071
|
+
]
|
1072
|
+
)
|
1073
|
+
== 1
|
1074
|
+
)
|
1075
|
+
)
|
1076
|
+
):
|
1077
|
+
unconstrained_chain.insert(
|
1078
|
+
0, self._predecessor[unconstrained_chain[0]][0]
|
1079
|
+
)
|
1080
|
+
acc += 1
|
1081
|
+
if start_time <= self._time[unconstrained_chain[0]] <= end_time:
|
1082
|
+
chain.insert(0, unconstrained_chain[0])
|
1083
|
+
|
1084
|
+
return chain
|
1085
|
+
|
1086
|
+
def get_successors(
|
1087
|
+
self, x: int, depth: int | None = None, end_time: int | None = None
|
1088
|
+
) -> list[int]:
|
1089
|
+
"""Computes the successors of the node `x` up to
|
1090
|
+
`depth` successors or the end of the life of `x`.
|
1091
|
+
The ordered list of ids is returned.
|
1092
|
+
|
1093
|
+
Parameters
|
1094
|
+
----------
|
1095
|
+
x : int
|
1096
|
+
id of the node to compute
|
1097
|
+
depth : int, optional
|
1098
|
+
maximum number of predecessors to return
|
1099
|
+
end_time : int, optional
|
1100
|
+
maximum time to consider
|
1101
|
+
|
1102
|
+
Returns
|
1103
|
+
-------
|
1104
|
+
list of int
|
1105
|
+
list of ids, the first id is `x`
|
1106
|
+
"""
|
1107
|
+
if end_time is None:
|
1108
|
+
end_time = self.t_e
|
1109
|
+
chain = [x]
|
1110
|
+
acc = 0
|
1111
|
+
while (
|
1112
|
+
len(self._successor[chain[-1]]) == 1
|
1113
|
+
and acc != depth
|
1114
|
+
and self._time[chain[-1]] < end_time
|
1115
|
+
):
|
1116
|
+
chain += self._successor[chain[-1]]
|
1117
|
+
acc += 1
|
1118
|
+
|
1119
|
+
return chain
|
1120
|
+
|
1121
|
+
def get_chain_of_node(
|
1122
|
+
self,
|
1123
|
+
x: int,
|
1124
|
+
depth: int | None = None,
|
1125
|
+
depth_pred: int | None = None,
|
1126
|
+
depth_succ: int | None = None,
|
1127
|
+
end_time: int | None = None,
|
1128
|
+
) -> list[int]:
|
1129
|
+
"""Computes the predecessors and successors of the node `x` up to
|
1130
|
+
`depth_pred` predecessors plus `depth_succ` successors.
|
1131
|
+
If the value `depth` is provided and not None,
|
1132
|
+
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1133
|
+
The ordered list of ids is returned.
|
1134
|
+
If all `depth` are None, the full chain is returned.
|
1135
|
+
|
1136
|
+
Parameters
|
1137
|
+
----------
|
1138
|
+
x : int
|
1139
|
+
id of the node to compute
|
1140
|
+
depth : int, optional
|
1141
|
+
maximum number of predecessors and successor to return
|
1142
|
+
depth_pred : int, optional
|
1143
|
+
maximum number of predecessors to return
|
1144
|
+
depth_succ : int, optional
|
1145
|
+
maximum number of successors to return
|
1146
|
+
|
1147
|
+
Returns
|
1148
|
+
-------
|
1149
|
+
list of int
|
1150
|
+
list of node ids
|
1229
1151
|
"""
|
1230
|
-
|
1152
|
+
if end_time is None:
|
1153
|
+
end_time = self.t_e
|
1154
|
+
if depth is not None:
|
1155
|
+
depth_pred = depth_succ = depth
|
1156
|
+
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1157
|
+
:-1
|
1158
|
+
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1159
|
+
|
1160
|
+
@dynamic_property
|
1161
|
+
def all_chains(self) -> list[list[int]]:
|
1162
|
+
"""List of all chains in the tree, ordered in depth-first search."""
|
1163
|
+
return self._compute_all_chains()
|
1164
|
+
|
1165
|
+
@dynamic_property
|
1166
|
+
def time_nodes(self):
|
1167
|
+
_time_nodes = {}
|
1168
|
+
for c, t in self._time.items():
|
1169
|
+
_time_nodes.setdefault(t, set()).add(c)
|
1170
|
+
return _time_nodes
|
1171
|
+
|
1172
|
+
def m(self, i, j):
|
1173
|
+
if (i, j) not in self._tmp_parenting:
|
1174
|
+
if i == j: # the distance to the node itself is 0
|
1175
|
+
self._tmp_parenting[(i, j)] = 0
|
1176
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1177
|
+
elif not self._predecessor[
|
1178
|
+
j
|
1179
|
+
]: # j and i are note connected so the distance if inf
|
1180
|
+
self._tmp_parenting[(i, j)] = np.inf
|
1181
|
+
else: # the distance between i and j is the distance between i and pred(j) + 1
|
1182
|
+
self._tmp_parenting[(i, j)] = (
|
1183
|
+
self.m(i, self._predecessor[j][0]) + 1
|
1184
|
+
)
|
1185
|
+
self._parenting[i, j] = self._tmp_parenting[(i, j)]
|
1186
|
+
self._parenting[j, i] = -self._tmp_parenting[(i, j)]
|
1187
|
+
return self._tmp_parenting[(i, j)]
|
1188
|
+
|
1189
|
+
@property
|
1190
|
+
def parenting(self):
|
1191
|
+
if not hasattr(self, "_parenting"):
|
1192
|
+
self._parenting = dok_array((max(self.nodes) + 1,) * 2)
|
1193
|
+
self._tmp_parenting = {}
|
1194
|
+
for i, j in combinations(self.nodes, 2):
|
1195
|
+
if self._time[j] < self.time[i]:
|
1196
|
+
i, j = j, i
|
1197
|
+
self._tmp_parenting[(i, j)] = self.m(i, j)
|
1198
|
+
del self._tmp_parenting
|
1199
|
+
return self._parenting
|
1200
|
+
|
1201
|
+
def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
|
1202
|
+
"""Get a 3d kdtree for the dataset at time `t`.
|
1203
|
+
The kdtree is stored in `self.kdtrees[t]` and returned.
|
1204
|
+
The correspondancy list is also returned.
|
1205
|
+
|
1206
|
+
Parameters
|
1207
|
+
----------
|
1208
|
+
t : int
|
1209
|
+
time
|
1210
|
+
|
1211
|
+
Returns
|
1212
|
+
-------
|
1213
|
+
KDTree
|
1214
|
+
The KDTree corresponding to the lineage tree at time `t`
|
1215
|
+
np.ndarray
|
1216
|
+
The correspondancy list in the KDTree.
|
1217
|
+
If the query in the kdtree gives you the value `i`,
|
1218
|
+
then it corresponds to the id in the tree `to_check_self[i]`
|
1219
|
+
"""
|
1220
|
+
to_check_self = list(self.nodes_at_t(t=t))
|
1221
|
+
|
1222
|
+
if not hasattr(self, "kdtrees"):
|
1223
|
+
self.kdtrees = {}
|
1224
|
+
|
1231
1225
|
if t not in self.kdtrees:
|
1232
1226
|
data_corres = {}
|
1233
1227
|
data = []
|
@@ -1240,16 +1234,21 @@ class lineageTree(lineageTreeLoaders):
|
|
1240
1234
|
idx3d = self.kdtrees[t]
|
1241
1235
|
return idx3d, np.array(to_check_self)
|
1242
1236
|
|
1243
|
-
def get_gabriel_graph(self, t: int) -> dict:
|
1244
|
-
"""Build the Gabriel graph of the given graph for time point `t
|
1245
|
-
The Garbiel graph is then stored in self.Gabriel_graph and returned
|
1246
|
-
|
1237
|
+
def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
|
1238
|
+
"""Build the Gabriel graph of the given graph for time point `t`.
|
1239
|
+
The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
|
1240
|
+
|
1241
|
+
.. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
|
1247
1242
|
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1243
|
+
Parameters
|
1244
|
+
----------
|
1245
|
+
t : int
|
1246
|
+
time
|
1247
|
+
|
1248
|
+
Returns
|
1249
|
+
-------
|
1250
|
+
dict of int to set of int
|
1251
|
+
A dictionary that maps a node to the set of its neighbors
|
1253
1252
|
"""
|
1254
1253
|
if not hasattr(self, "Gabriel_graph"):
|
1255
1254
|
self.Gabriel_graph = {}
|
@@ -1294,178 +1293,110 @@ class lineageTree(lineageTreeLoaders):
|
|
1294
1293
|
|
1295
1294
|
return self.Gabriel_graph[t]
|
1296
1295
|
|
1297
|
-
def
|
1298
|
-
self,
|
1299
|
-
) -> list:
|
1300
|
-
"""Computes the
|
1301
|
-
|
1302
|
-
The ordered list of ids is returned.
|
1303
|
-
|
1304
|
-
Args:
|
1305
|
-
x (int): id of the node to compute
|
1306
|
-
depth (int): maximum number of predecessors to return
|
1307
|
-
Returns:
|
1308
|
-
[int, ]: list of ids, the last id is `x`
|
1309
|
-
"""
|
1310
|
-
if not start_time:
|
1311
|
-
start_time = self.t_b
|
1312
|
-
if not end_time:
|
1313
|
-
end_time = self.t_e
|
1314
|
-
unconstrained_cycle = [x]
|
1315
|
-
cycle = [x] if start_time <= self.time[x] <= end_time else []
|
1316
|
-
acc = 0
|
1317
|
-
while (
|
1318
|
-
len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
|
1319
|
-
== 1
|
1320
|
-
and acc != depth
|
1321
|
-
and start_time
|
1322
|
-
<= self.time.get(
|
1323
|
-
self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
|
1324
|
-
)
|
1325
|
-
):
|
1326
|
-
unconstrained_cycle.insert(
|
1327
|
-
0, self.predecessor[unconstrained_cycle[0]][0]
|
1328
|
-
)
|
1329
|
-
acc += 1
|
1330
|
-
if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
|
1331
|
-
cycle.insert(0, unconstrained_cycle[0])
|
1332
|
-
|
1333
|
-
return cycle
|
1334
|
-
|
1335
|
-
def get_successors(
|
1336
|
-
self, x: int, depth: int = None, end_time: int = None
|
1337
|
-
) -> list:
|
1338
|
-
"""Computes the successors of the node `x` up to
|
1339
|
-
`depth` successors or the end of the life of `x`.
|
1340
|
-
The ordered list of ids is returned.
|
1341
|
-
|
1342
|
-
Args:
|
1343
|
-
x (int): id of the node to compute
|
1344
|
-
depth (int): maximum number of predecessors to return
|
1345
|
-
Returns:
|
1346
|
-
[int, ]: list of ids, the first id is `x`
|
1347
|
-
"""
|
1348
|
-
if end_time is None:
|
1349
|
-
end_time = self.t_e
|
1350
|
-
cycle = [x]
|
1351
|
-
acc = 0
|
1352
|
-
while (
|
1353
|
-
len(self[cycle[-1]]) == 1
|
1354
|
-
and acc != depth
|
1355
|
-
and self.time[cycle[-1]] < end_time
|
1356
|
-
):
|
1357
|
-
cycle += self.successor[cycle[-1]]
|
1358
|
-
acc += 1
|
1359
|
-
|
1360
|
-
return cycle
|
1296
|
+
def get_all_chains_of_subtree(
|
1297
|
+
self, node: int, end_time: int | None = None
|
1298
|
+
) -> list[list[int]]:
|
1299
|
+
"""Computes all the chains of the subtree spawn by a given node.
|
1300
|
+
Similar to get_all_chains().
|
1361
1301
|
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
end_time: int = None,
|
1369
|
-
) -> list:
|
1370
|
-
"""Computes the predecessors and successors of the node `x` up to
|
1371
|
-
`depth_pred` predecessors plus `depth_succ` successors.
|
1372
|
-
If the value `depth` is provided and not None,
|
1373
|
-
`depth_pred` and `depth_succ` are overwriten by `depth`.
|
1374
|
-
The ordered list of ids is returned.
|
1375
|
-
If all `depth` are None, the full cycle is returned.
|
1376
|
-
|
1377
|
-
Args:
|
1378
|
-
x (int): id of the node to compute
|
1379
|
-
depth (int): maximum number of predecessors and successor to return
|
1380
|
-
depth_pred (int): maximum number of predecessors to return
|
1381
|
-
depth_succ (int): maximum number of successors to return
|
1382
|
-
Returns:
|
1383
|
-
[int, ]: list of ids
|
1384
|
-
"""
|
1385
|
-
if end_time is None:
|
1386
|
-
end_time = self.t_e
|
1387
|
-
if depth is not None:
|
1388
|
-
depth_pred = depth_succ = depth
|
1389
|
-
return self.get_predecessors(x, depth_pred, end_time=end_time)[
|
1390
|
-
:-1
|
1391
|
-
] + self.get_successors(x, depth_succ, end_time=end_time)
|
1392
|
-
|
1393
|
-
@property
|
1394
|
-
def all_tracks(self):
|
1395
|
-
if not hasattr(self, "_all_tracks"):
|
1396
|
-
self._all_tracks = self.get_all_tracks()
|
1397
|
-
return self._all_tracks
|
1302
|
+
Parameters
|
1303
|
+
----------
|
1304
|
+
node : int
|
1305
|
+
The node from which we want to get its chains.
|
1306
|
+
end_time : int, optional
|
1307
|
+
The time at which we want to stop the chains.
|
1398
1308
|
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
Similar to get_all_tracks().
|
1404
|
-
|
1405
|
-
Args:
|
1406
|
-
node (int, optional): The node that we want to get its branches.
|
1407
|
-
|
1408
|
-
Returns:
|
1409
|
-
([[int, ...], ...]): list of lists containing track cell ids
|
1309
|
+
Returns
|
1310
|
+
-------
|
1311
|
+
list of list of int
|
1312
|
+
list of chains
|
1410
1313
|
"""
|
1411
1314
|
if not end_time:
|
1412
1315
|
end_time = self.t_e
|
1413
|
-
|
1414
|
-
to_do = list(self[
|
1316
|
+
chains = [self.get_successors(node)]
|
1317
|
+
to_do = list(self._successor[chains[0][-1]])
|
1415
1318
|
while to_do:
|
1416
1319
|
current = to_do.pop()
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
return
|
1440
|
-
|
1441
|
-
def
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1320
|
+
chain = self.get_successors(current, end_time=end_time)
|
1321
|
+
if self._time[chain[-1]] <= end_time:
|
1322
|
+
chains += [chain]
|
1323
|
+
to_do += self._successor[chain[-1]]
|
1324
|
+
return chains
|
1325
|
+
|
1326
|
+
def _compute_all_chains(self) -> list[list[int]]:
|
1327
|
+
"""Computes all the chains of a given lineage tree,
|
1328
|
+
stores it in `self.all_chains` and returns it.
|
1329
|
+
|
1330
|
+
Returns
|
1331
|
+
-------
|
1332
|
+
list of list of int
|
1333
|
+
list of chains
|
1334
|
+
"""
|
1335
|
+
all_chains = []
|
1336
|
+
to_do = sorted(self.roots, key=self.time.get, reverse=True)
|
1337
|
+
while len(to_do) != 0:
|
1338
|
+
current = to_do.pop()
|
1339
|
+
chain = self.get_chain_of_node(current)
|
1340
|
+
all_chains += [chain]
|
1341
|
+
to_do.extend(self._successor[chain[-1]])
|
1342
|
+
return all_chains
|
1343
|
+
|
1344
|
+
def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
|
1345
|
+
self, nodes: Iterable | int | None = None
|
1346
|
+
) -> dict[int, list[list[int]]]:
|
1347
|
+
"""Returns all the chains in the subtrees spawned by each of the given nodes.
|
1348
|
+
|
1349
|
+
Parameters
|
1350
|
+
----------
|
1351
|
+
nodes : Iterable or int, optional
|
1352
|
+
id or Iterable of ids of the nodes to be computed, if `None` all roots are used
|
1353
|
+
|
1354
|
+
Returns
|
1355
|
+
-------
|
1356
|
+
dict mapping int to list of Chain
|
1357
|
+
dictionary mapping the node ids to a list of chains
|
1358
|
+
"""
|
1359
|
+
all_chains = self.all_chains
|
1360
|
+
if nodes is None:
|
1361
|
+
nodes = self.roots
|
1362
|
+
if not isinstance(nodes, Iterable):
|
1363
|
+
nodes = [nodes]
|
1364
|
+
output_chains = {}
|
1365
|
+
for n in nodes:
|
1366
|
+
starting_node = self.get_predecessors(n)[0]
|
1367
|
+
found = False
|
1368
|
+
done = False
|
1369
|
+
starting_time = self.time[n]
|
1370
|
+
i = 0
|
1371
|
+
current_chain = []
|
1372
|
+
while not done and i < len(all_chains):
|
1373
|
+
curr_found = all_chains[i][0] == starting_node
|
1374
|
+
found = found or curr_found
|
1375
|
+
if found:
|
1376
|
+
done = (
|
1377
|
+
self.time[all_chains[i][0]] <= starting_time
|
1378
|
+
) and not curr_found
|
1379
|
+
if not done:
|
1380
|
+
if curr_found:
|
1381
|
+
current_chain.append(self.get_successors(n))
|
1382
|
+
else:
|
1383
|
+
current_chain.append(all_chains[i])
|
1384
|
+
i += 1
|
1385
|
+
output_chains[n] = current_chain
|
1386
|
+
return output_chains
|
1460
1387
|
|
1461
|
-
def find_leaves(self, roots:
|
1388
|
+
def find_leaves(self, roots: int | Iterable) -> set[int]:
|
1462
1389
|
"""Finds the leaves of a tree spawned by one or more nodes.
|
1463
1390
|
|
1464
|
-
|
1465
|
-
|
1391
|
+
Parameters
|
1392
|
+
----------
|
1393
|
+
roots : int or Iterable
|
1394
|
+
The roots of the trees spawning the leaves
|
1466
1395
|
|
1467
|
-
Returns
|
1468
|
-
|
1396
|
+
Returns
|
1397
|
+
-------
|
1398
|
+
set
|
1399
|
+
The leaves of one or more trees.
|
1469
1400
|
"""
|
1470
1401
|
if not isinstance(roots, Iterable):
|
1471
1402
|
to_do = [roots]
|
@@ -1474,29 +1405,34 @@ class lineageTree(lineageTreeLoaders):
|
|
1474
1405
|
leaves = set()
|
1475
1406
|
while to_do:
|
1476
1407
|
curr = to_do.pop()
|
1477
|
-
succ = self.
|
1478
|
-
if
|
1408
|
+
succ = self._successor[curr]
|
1409
|
+
if not succ:
|
1479
1410
|
leaves.add(curr)
|
1480
|
-
|
1481
|
-
to_do += succ
|
1411
|
+
to_do += succ
|
1482
1412
|
return leaves
|
1483
1413
|
|
1484
|
-
def
|
1414
|
+
def get_subtree_nodes(
|
1485
1415
|
self,
|
1486
|
-
x:
|
1487
|
-
end_time:
|
1416
|
+
x: int | Iterable,
|
1417
|
+
end_time: int | None = None,
|
1488
1418
|
preorder: bool = False,
|
1489
|
-
) -> list:
|
1490
|
-
"""Computes the list of
|
1491
|
-
The default output order is
|
1419
|
+
) -> list[int]:
|
1420
|
+
"""Computes the list of nodes from the subtree spawned by *x*
|
1421
|
+
The default output order is Breadth First Traversal.
|
1492
1422
|
Unless preorder is `True` in that case the order is
|
1493
|
-
Depth
|
1423
|
+
Depth First Traversal (DFT) preordered.
|
1424
|
+
|
1425
|
+
Parameters
|
1426
|
+
----------
|
1427
|
+
x : int
|
1428
|
+
id of root node
|
1429
|
+
preorder : bool, default=False
|
1430
|
+
if True the output preorder is DFT
|
1494
1431
|
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
([int, ...]): the ordered list of node ids
|
1432
|
+
Returns
|
1433
|
+
-------
|
1434
|
+
list of int
|
1435
|
+
the ordered list of node ids
|
1500
1436
|
"""
|
1501
1437
|
if not end_time:
|
1502
1438
|
end_time = self.t_e
|
@@ -1504,233 +1440,258 @@ class lineageTree(lineageTreeLoaders):
|
|
1504
1440
|
to_do = [x]
|
1505
1441
|
elif isinstance(x, Iterable):
|
1506
1442
|
to_do = list(x)
|
1507
|
-
|
1443
|
+
subtree = []
|
1508
1444
|
while to_do:
|
1509
1445
|
curr = to_do.pop()
|
1510
|
-
succ = self.
|
1511
|
-
if succ and end_time < self.
|
1446
|
+
succ = self._successor[curr]
|
1447
|
+
if succ and end_time < self._time.get(curr, end_time):
|
1512
1448
|
succ = []
|
1513
1449
|
continue
|
1514
1450
|
if preorder:
|
1515
1451
|
to_do = succ + to_do
|
1516
1452
|
else:
|
1517
1453
|
to_do += succ
|
1518
|
-
|
1519
|
-
return
|
1454
|
+
subtree += [curr]
|
1455
|
+
return subtree
|
1520
1456
|
|
1521
1457
|
def compute_spatial_density(
|
1522
|
-
self, t_b: int = None, t_e: int = None, th: float = 50
|
1523
|
-
) -> dict:
|
1524
|
-
"""Computes the spatial density of
|
1525
|
-
The
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1458
|
+
self, t_b: int | None = None, t_e: int | None = None, th: float = 50
|
1459
|
+
) -> dict[int, float]:
|
1460
|
+
"""Computes the spatial density of nodes between `t_b` and `t_e`.
|
1461
|
+
The results is stored in `self.spatial_density` and returned.
|
1462
|
+
|
1463
|
+
Parameters
|
1464
|
+
----------
|
1465
|
+
t_b : int, optional
|
1466
|
+
starting time to look at, default first time point
|
1467
|
+
t_e : int, optional
|
1468
|
+
ending time to look at, default last time point
|
1469
|
+
th : float, default=50
|
1470
|
+
size of the neighbourhood
|
1471
|
+
|
1472
|
+
Returns
|
1473
|
+
-------
|
1474
|
+
dict mapping int to float
|
1475
|
+
dictionary that maps a node id to its spatial density
|
1476
|
+
"""
|
1477
|
+
if not hasattr(self, "spatial_density"):
|
1478
|
+
self.spatial_density = {}
|
1536
1479
|
s_vol = 4 / 3.0 * np.pi * th**3
|
1537
|
-
|
1480
|
+
if t_b is None:
|
1481
|
+
t_b = self.t_b
|
1482
|
+
if t_e is None:
|
1483
|
+
t_e = self.t_e
|
1484
|
+
time_range = set(range(t_b, t_e)).intersection(self._time.values())
|
1538
1485
|
for t in time_range:
|
1539
1486
|
idx3d, nodes = self.get_idx3d(t)
|
1540
1487
|
nb_ni = [
|
1541
1488
|
(len(ni) - 1) / s_vol
|
1542
1489
|
for ni in idx3d.query_ball_tree(idx3d, th)
|
1543
1490
|
]
|
1544
|
-
self.spatial_density.update(dict(zip(nodes, nb_ni)))
|
1491
|
+
self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
|
1545
1492
|
return self.spatial_density
|
1546
1493
|
|
1547
|
-
def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
|
1494
|
+
def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
|
1548
1495
|
"""Computes the k-nearest neighbors
|
1549
1496
|
Writes the output in the attribute `kn_graph`
|
1550
1497
|
and returns it.
|
1551
1498
|
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1499
|
+
Parameters
|
1500
|
+
----------
|
1501
|
+
k : float
|
1502
|
+
number of nearest neighours
|
1503
|
+
|
1504
|
+
Returns
|
1505
|
+
-------
|
1506
|
+
dict mapping int to set of int
|
1507
|
+
dictionary that maps
|
1508
|
+
a node id to its `k` nearest neighbors
|
1557
1509
|
"""
|
1558
1510
|
self.kn_graph = {}
|
1559
|
-
for t
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1511
|
+
for t in set(self._time.values()):
|
1512
|
+
nodes = self.nodes_at_t(t)
|
1513
|
+
if 1 < len(nodes):
|
1514
|
+
use_k = k if k < len(nodes) else len(nodes)
|
1515
|
+
idx3d, nodes = self.get_idx3d(t)
|
1516
|
+
pos = [self.pos[c] for c in nodes]
|
1517
|
+
_, neighbs = idx3d.query(pos, use_k)
|
1518
|
+
out = dict(
|
1519
|
+
zip(
|
1520
|
+
nodes,
|
1521
|
+
map(set, nodes[neighbs]),
|
1522
|
+
strict=True,
|
1523
|
+
)
|
1524
|
+
)
|
1525
|
+
self.kn_graph.update(out)
|
1526
|
+
else:
|
1527
|
+
n = nodes.pop
|
1528
|
+
self.kn_graph.update({n: {n}})
|
1566
1529
|
return self.kn_graph
|
1567
1530
|
|
1568
|
-
def compute_spatial_edges(self, th: int = 50) -> dict:
|
1531
|
+
def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
|
1569
1532
|
"""Computes the neighbors at a distance `th`
|
1570
1533
|
Writes the output in the attribute `th_edge`
|
1571
1534
|
and returns it.
|
1572
1535
|
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1536
|
+
Parameters
|
1537
|
+
----------
|
1538
|
+
th : float, default=50
|
1539
|
+
distance to consider neighbors
|
1540
|
+
|
1541
|
+
Returns
|
1542
|
+
-------
|
1543
|
+
dict mapping int to set of int
|
1544
|
+
dictionary that maps a node id to its neighbors at a distance `th`
|
1578
1545
|
"""
|
1579
1546
|
self.th_edges = {}
|
1580
|
-
for t
|
1547
|
+
for t in set(self._time.values()):
|
1548
|
+
nodes = self.nodes_at_t(t)
|
1581
1549
|
idx3d, nodes = self.get_idx3d(t)
|
1582
1550
|
neighbs = idx3d.query_ball_tree(idx3d, th)
|
1583
|
-
out = dict(
|
1551
|
+
out = dict(
|
1552
|
+
zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
|
1553
|
+
)
|
1584
1554
|
self.th_edges.update(
|
1585
1555
|
{k: v.difference([k]) for k, v in out.items()}
|
1586
1556
|
)
|
1587
1557
|
return self.th_edges
|
1588
1558
|
|
1589
|
-
def
|
1590
|
-
"""
|
1591
|
-
If none will select the timepoint with the highest amound of cells.
|
1592
|
-
|
1593
|
-
Args:
|
1594
|
-
time (int, optional): The timepoint to find the main axes.
|
1595
|
-
If None will find the timepoint
|
1596
|
-
with the largest number of cells.
|
1597
|
-
|
1598
|
-
Returns:
|
1599
|
-
list: A list that contains the array of eigenvalues and eigenvectors.
|
1600
|
-
"""
|
1601
|
-
if time is None:
|
1602
|
-
time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
|
1603
|
-
pos = np.array([self.pos[node] for node in self.time_nodes[time]])
|
1604
|
-
pos = pos - np.mean(pos, axis=0)
|
1605
|
-
cov = np.cov(np.array(pos).T)
|
1606
|
-
eig_val, eig_vec = np.linalg.eig(cov)
|
1607
|
-
srt = np.argsort(eig_val)[::-1]
|
1608
|
-
self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
|
1609
|
-
return eig_val[srt], eig_vec[:, srt]
|
1610
|
-
|
1611
|
-
def scale_embryo(self, scale=1000):
|
1612
|
-
"""Scale the embryo using their eigenvalues.
|
1613
|
-
|
1614
|
-
Args:
|
1615
|
-
scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
|
1616
|
-
|
1617
|
-
Returns:
|
1618
|
-
float: The scale factor.
|
1619
|
-
"""
|
1620
|
-
eig = self.main_axes()[0]
|
1621
|
-
return scale / (np.sqrt(eig[0]))
|
1622
|
-
|
1623
|
-
@staticmethod
|
1624
|
-
def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
|
1625
|
-
"""Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
|
1626
|
-
Uses the Rodrigues rotation formula.
|
1627
|
-
|
1628
|
-
Args:
|
1629
|
-
vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
|
1630
|
-
vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
|
1631
|
-
|
1632
|
-
Returns:
|
1633
|
-
np.array: The rotation matrix.
|
1634
|
-
"""
|
1635
|
-
vector1 = vector1 / np.linalg.norm(vector1)
|
1636
|
-
vector2 = vector2 / np.linalg.norm(vector2)
|
1637
|
-
if vector1 @ vector2 == 1:
|
1638
|
-
return np.eye(3)
|
1639
|
-
angle = np.arccos(vector1 @ vector2)
|
1640
|
-
axis = np.cross(vector1, vector2)
|
1641
|
-
axis = axis / np.linalg.norm(axis)
|
1642
|
-
K = np.array(
|
1643
|
-
[
|
1644
|
-
[0, -axis[2], axis[1]],
|
1645
|
-
[axis[2], 0, -axis[0]],
|
1646
|
-
[-axis[1], axis[0], 0],
|
1647
|
-
]
|
1648
|
-
)
|
1649
|
-
return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
|
1650
|
-
|
1651
|
-
def get_ancestor_at_t(self, n: int, time: int = None):
|
1652
|
-
"""
|
1653
|
-
Find the id of the ancestor of a give node `n`
|
1559
|
+
def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
|
1560
|
+
"""Find the id of the ancestor of a give node `n`
|
1654
1561
|
at a given time `time`.
|
1655
1562
|
|
1656
|
-
If there is no ancestor, returns
|
1657
|
-
If time is None return the root of the
|
1563
|
+
If there is no ancestor, returns `None`
|
1564
|
+
If time is None return the root of the subtree that spawns
|
1658
1565
|
the node n.
|
1659
1566
|
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1567
|
+
Parameters
|
1568
|
+
----------
|
1569
|
+
n : int
|
1570
|
+
node for which to look the ancestor
|
1571
|
+
time : int, optional
|
1572
|
+
time at which the ancestor has to be found.
|
1573
|
+
If `None` the ancestor at the first time point
|
1574
|
+
will be found.
|
1575
|
+
|
1576
|
+
Returns
|
1577
|
+
-------
|
1578
|
+
int
|
1579
|
+
the id of the ancestor at time `time`,
|
1580
|
+
`-1` if there is no ancestor.
|
1669
1581
|
"""
|
1670
1582
|
if n not in self.nodes:
|
1671
|
-
return
|
1583
|
+
return -1
|
1672
1584
|
if time is None:
|
1673
1585
|
time = self.t_b
|
1674
1586
|
ancestor = n
|
1675
1587
|
while (
|
1676
|
-
time < self.
|
1588
|
+
time < self._time.get(ancestor, self.t_b - 1)
|
1589
|
+
and self._predecessor[ancestor]
|
1677
1590
|
):
|
1678
|
-
ancestor = self.
|
1679
|
-
|
1591
|
+
ancestor = self._predecessor[ancestor][0]
|
1592
|
+
if self._time.get(ancestor, self.t_b - 1) == time:
|
1593
|
+
return ancestor
|
1594
|
+
else:
|
1595
|
+
return -1
|
1680
1596
|
|
1681
|
-
def get_labelled_ancestor(self, node: int):
|
1682
|
-
"""Finds the first labelled ancestor and returns its ID otherwise returns
|
1597
|
+
def get_labelled_ancestor(self, node: int) -> int:
|
1598
|
+
"""Finds the first labelled ancestor and returns its ID otherwise returns -1
|
1683
1599
|
|
1684
|
-
|
1685
|
-
|
1600
|
+
Parameters
|
1601
|
+
----------
|
1602
|
+
node : int
|
1603
|
+
The id of the node
|
1686
1604
|
|
1687
|
-
Returns
|
1688
|
-
|
1689
|
-
|
1605
|
+
Returns
|
1606
|
+
-------
|
1607
|
+
int
|
1608
|
+
Returns the first ancestor found that has a label otherwise `-1`.
|
1690
1609
|
"""
|
1691
1610
|
if node not in self.nodes:
|
1692
|
-
return
|
1611
|
+
return -1
|
1693
1612
|
ancestor = node
|
1694
1613
|
while (
|
1695
|
-
self.t_b <= self.
|
1614
|
+
self.t_b <= self._time.get(ancestor, self.t_b - 1)
|
1696
1615
|
and ancestor != -1
|
1697
1616
|
):
|
1698
1617
|
if ancestor in self.labels:
|
1699
1618
|
return ancestor
|
1700
|
-
ancestor = self.
|
1701
|
-
return
|
1619
|
+
ancestor = self._predecessor.get(ancestor, [-1])[0]
|
1620
|
+
return -1
|
1621
|
+
|
1622
|
+
def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
|
1623
|
+
"""General purpose function to help with searching the first ancestor that has an attribute.
|
1624
|
+
Similar to get_labeled_ancestor and may make it redundant.
|
1625
|
+
|
1626
|
+
Parameters
|
1627
|
+
----------
|
1628
|
+
node : int
|
1629
|
+
The id of the node
|
1630
|
+
|
1631
|
+
Returns
|
1632
|
+
-------
|
1633
|
+
int
|
1634
|
+
Returns the first ancestor found that has an attribute otherwise `-1`.
|
1635
|
+
"""
|
1636
|
+
attr_dict = self.__getattribute__(attribute)
|
1637
|
+
if not isinstance(attr_dict, dict):
|
1638
|
+
raise ValueError("Please select a dict attribute")
|
1639
|
+
if node not in self.nodes:
|
1640
|
+
return -1
|
1641
|
+
if node in attr_dict:
|
1642
|
+
return node
|
1643
|
+
if node in self.roots:
|
1644
|
+
return -1
|
1645
|
+
ancestor = (node,)
|
1646
|
+
while ancestor and ancestor != [-1]:
|
1647
|
+
ancestor = ancestor[0]
|
1648
|
+
if ancestor in attr_dict:
|
1649
|
+
return ancestor
|
1650
|
+
ancestor = self._predecessor.get(ancestor, [-1])
|
1651
|
+
return -1
|
1702
1652
|
|
1703
1653
|
def unordered_tree_edit_distances_at_time_t(
|
1704
1654
|
self,
|
1705
1655
|
t: int,
|
1706
|
-
end_time: int = None,
|
1707
|
-
style
|
1656
|
+
end_time: int | None = None,
|
1657
|
+
style: (
|
1658
|
+
Literal["simple", "full", "downsampled", "normalized_simple"]
|
1659
|
+
| type[TreeApproximationTemplate]
|
1660
|
+
) = "simple",
|
1708
1661
|
downsample: int = 2,
|
1709
|
-
|
1662
|
+
norm: Literal["max", "sum", None] = "max",
|
1710
1663
|
recompute: bool = False,
|
1711
|
-
) -> dict:
|
1712
|
-
"""
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1664
|
+
) -> dict[tuple[int, int], float]:
|
1665
|
+
"""Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
|
1666
|
+
|
1667
|
+
Parameters
|
1668
|
+
----------
|
1669
|
+
t : int
|
1670
|
+
time to look at
|
1671
|
+
end_time : int
|
1672
|
+
The final time point the comparison algorithm will take into account.
|
1673
|
+
If None all nodes will be taken into account.
|
1674
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1675
|
+
Which tree approximation is going to be used for the comparisons.
|
1676
|
+
downsample : int, default=2
|
1677
|
+
The downsample factor for the downsampled tree approximation.
|
1678
|
+
Used only when `style="downsampled"`.
|
1679
|
+
norm : {"max", "sum"}, default="max"
|
1680
|
+
The normalization method to use.
|
1681
|
+
recompute : bool, default=False
|
1682
|
+
If True, forces to recompute the distances
|
1683
|
+
|
1684
|
+
Returns
|
1685
|
+
-------
|
1686
|
+
dict mapping a tuple of tuple that contains 2 ints to float
|
1687
|
+
a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
|
1727
1688
|
"""
|
1728
1689
|
if not hasattr(self, "uted"):
|
1729
1690
|
self.uted = {}
|
1730
1691
|
elif t in self.uted and not recompute:
|
1731
1692
|
return self.uted[t]
|
1732
1693
|
self.uted[t] = {}
|
1733
|
-
roots = self.
|
1694
|
+
roots = self.nodes_at_t(t=t)
|
1734
1695
|
for n1, n2 in combinations(roots, 2):
|
1735
1696
|
key = tuple(sorted((n1, n2)))
|
1736
1697
|
self.uted[t][key] = self.unordered_tree_edit_distance(
|
@@ -1739,37 +1700,132 @@ class lineageTree(lineageTreeLoaders):
|
|
1739
1700
|
end_time=end_time,
|
1740
1701
|
style=style,
|
1741
1702
|
downsample=downsample,
|
1742
|
-
|
1703
|
+
norm=norm,
|
1743
1704
|
)
|
1744
1705
|
return self.uted[t]
|
1745
1706
|
|
1746
|
-
def
|
1707
|
+
def __calculate_distance_of_sub_tree(
|
1708
|
+
self,
|
1709
|
+
node1: int,
|
1710
|
+
node2: int,
|
1711
|
+
alignment: Alignment,
|
1712
|
+
corres1: dict[int, int],
|
1713
|
+
corres2: dict[int, int],
|
1714
|
+
delta_tmp: Callable,
|
1715
|
+
norm: Callable,
|
1716
|
+
norm1: int | float,
|
1717
|
+
norm2: int | float,
|
1718
|
+
) -> float:
|
1719
|
+
"""Calculates the distance of the subtree of each node matched in a comparison.
|
1720
|
+
DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
|
1721
|
+
TODO ITS BOUND TO CHANGE
|
1722
|
+
Parameters
|
1723
|
+
----------
|
1724
|
+
node1 : int
|
1725
|
+
The root of the first subtree
|
1726
|
+
node2 : int
|
1727
|
+
The root of the second subtree
|
1728
|
+
alignment : Alignment
|
1729
|
+
The alignment of the subtree
|
1730
|
+
corres1 : dict
|
1731
|
+
The correspndance dictionary of the first lineage
|
1732
|
+
corres2 : dict
|
1733
|
+
The correspondance dictionary of the second lineage
|
1734
|
+
delta_tmp : Callable
|
1735
|
+
The delta function for the comparisons
|
1736
|
+
norm : Callable
|
1737
|
+
How should the lineages be normalized
|
1738
|
+
norm1 : int or float
|
1739
|
+
The result of the normalization of the first tree
|
1740
|
+
norm2 : int or float
|
1741
|
+
The result of the normalization of the second tree
|
1742
|
+
|
1743
|
+
Returns
|
1744
|
+
-------
|
1745
|
+
float
|
1746
|
+
The result of the comparison of the subtree
|
1747
|
+
"""
|
1748
|
+
sub_tree_1 = set(self.get_subtree_nodes(node1))
|
1749
|
+
sub_tree_2 = set(self.get_subtree_nodes(node2))
|
1750
|
+
res = 0
|
1751
|
+
for m in alignment:
|
1752
|
+
if (
|
1753
|
+
corres1.get(m._left, -1) in sub_tree_1
|
1754
|
+
or corres2.get(m._right, -1) in sub_tree_2
|
1755
|
+
):
|
1756
|
+
res += delta_tmp(
|
1757
|
+
m._left if m._left != -1 else None,
|
1758
|
+
m._right if m._right != -1 else None,
|
1759
|
+
)
|
1760
|
+
return res / norm([norm1, norm2])
|
1761
|
+
|
1762
|
+
def clear_comparisons(self):
|
1763
|
+
self._comparisons.clear()
|
1764
|
+
|
1765
|
+
def __unordereded_backtrace(
|
1747
1766
|
self,
|
1748
1767
|
n1: int,
|
1749
1768
|
n2: int,
|
1750
|
-
end_time: int = None,
|
1751
|
-
norm:
|
1752
|
-
style
|
1769
|
+
end_time: int | None = None,
|
1770
|
+
norm: Literal["max", "sum", None] = "max",
|
1771
|
+
style: (
|
1772
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1773
|
+
| type[TreeApproximationTemplate]
|
1774
|
+
) = "simple",
|
1753
1775
|
downsample: int = 2,
|
1754
|
-
) ->
|
1776
|
+
) -> dict[
|
1777
|
+
str,
|
1778
|
+
Alignment
|
1779
|
+
| tuple[TreeApproximationTemplate, TreeApproximationTemplate],
|
1780
|
+
]:
|
1755
1781
|
"""
|
1756
|
-
Compute the unordered tree edit
|
1782
|
+
Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
|
1757
1783
|
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
1758
1784
|
cost is given by the function delta (see edist doc for more information).
|
1759
|
-
The distance is normed by the function norm that takes the two list of nodes
|
1760
|
-
spawned by the trees `n1` and `n2`.
|
1761
1785
|
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1786
|
+
Parameters
|
1787
|
+
----------
|
1788
|
+
n1 : int
|
1789
|
+
id of the first node to compare
|
1790
|
+
n2 : int
|
1791
|
+
id of the second node to compare
|
1792
|
+
end_time : int
|
1793
|
+
The final time point the comparison algorithm will take into account.
|
1794
|
+
If None all nodes will be taken into account.
|
1795
|
+
norm : {"max", "sum"}, default="max"
|
1796
|
+
The normalization method to use.
|
1797
|
+
style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
|
1798
|
+
Which tree approximation is going to be used for the comparisons.
|
1799
|
+
downsample : int, default=2
|
1800
|
+
The downsample factor for the downsampled tree approximation.
|
1801
|
+
Used only when `style="downsampled"`.
|
1802
|
+
|
1803
|
+
Returns
|
1804
|
+
-------
|
1805
|
+
dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
|
1806
|
+
- 'alignment'
|
1807
|
+
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
|
1808
|
+
- 'trees'
|
1809
|
+
A list of the two trees that have been mapped to each other.
|
1810
|
+
"""
|
1811
|
+
|
1812
|
+
parameters = (
|
1813
|
+
end_time,
|
1814
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1815
|
+
)
|
1816
|
+
n1, n2 = sorted([n1, n2])
|
1817
|
+
self._comparisons.setdefault(parameters, {})
|
1818
|
+
if len(self._comparisons) > 100:
|
1819
|
+
warnings.warn(
|
1820
|
+
"More than 100 comparisons are saved, use clear_comparisons() to delete them.",
|
1821
|
+
stacklevel=2,
|
1822
|
+
)
|
1823
|
+
if isinstance(style, str):
|
1824
|
+
tree = tree_style[style].value
|
1825
|
+
elif issubclass(style, TreeApproximationTemplate):
|
1826
|
+
tree = style
|
1827
|
+
else:
|
1828
|
+
raise ValueError("Please use a valid approximation.")
|
1773
1829
|
tree1 = tree(
|
1774
1830
|
lT=self,
|
1775
1831
|
downsample=downsample,
|
@@ -1798,7 +1854,11 @@ class lineageTree(lineageTreeLoaders):
|
|
1798
1854
|
corres2,
|
1799
1855
|
) = tree2.edist
|
1800
1856
|
if len(nodes1) == len(nodes2) == 0:
|
1801
|
-
|
1857
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1858
|
+
"alignment": (),
|
1859
|
+
"trees": (),
|
1860
|
+
}
|
1861
|
+
return self._comparisons[parameters][(n1, n2)]
|
1802
1862
|
delta_tmp = partial(
|
1803
1863
|
delta,
|
1804
1864
|
corres1=corres1,
|
@@ -1806,127 +1866,538 @@ class lineageTree(lineageTreeLoaders):
|
|
1806
1866
|
times1=times1,
|
1807
1867
|
times2=times2,
|
1808
1868
|
)
|
1809
|
-
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1869
|
+
btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
|
1870
|
+
|
1871
|
+
self._comparisons[parameters][(n1, n2)] = {
|
1872
|
+
"alignment": btrc,
|
1873
|
+
"trees": (tree1, tree2),
|
1874
|
+
}
|
1875
|
+
return self._comparisons[parameters][(n1, n2)]
|
1876
|
+
|
1877
|
+
def plot_tree_distance_graphs(
|
1878
|
+
self,
|
1879
|
+
n1: int,
|
1880
|
+
n2: int,
|
1881
|
+
end_time: int | None = None,
|
1882
|
+
norm: Literal["max", "sum", None] = "max",
|
1883
|
+
style: (
|
1884
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
1885
|
+
| type[TreeApproximationTemplate]
|
1886
|
+
) = "simple",
|
1887
|
+
downsample: int = 2,
|
1888
|
+
colormap: str = "cool",
|
1889
|
+
default_color: str = "black",
|
1890
|
+
size: float = 10,
|
1891
|
+
lw: float = 0.3,
|
1892
|
+
ax: list[plt.Axes] | None = None,
|
1893
|
+
) -> tuple[plt.figure, plt.Axes]:
|
1894
|
+
"""
|
1895
|
+
Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
|
1896
|
+
|
1897
|
+
Parameters
|
1898
|
+
----------
|
1899
|
+
n1 : int
|
1900
|
+
id of the first node to compare
|
1901
|
+
n2 : int
|
1902
|
+
id of the second node to compare
|
1903
|
+
end_time : int
|
1904
|
+
The final time point the comparison algorithm will take into account.
|
1905
|
+
If None all nodes will be taken into account.
|
1906
|
+
norm : {"max", "sum"}, default="max"
|
1907
|
+
The normalization method to use.
|
1908
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
1909
|
+
Which tree approximation is going to be used for the comparisons.
|
1910
|
+
downsample : int, default=2
|
1911
|
+
The downsample factor for the downsampled tree approximation.
|
1912
|
+
Used only when `style="downsampled"`.
|
1913
|
+
colormap : str, default="cool"
|
1914
|
+
The colormap used for matched nodes, defaults to "cool"
|
1915
|
+
default_color : str
|
1916
|
+
The color of the unmatched nodes, defaults to "black"
|
1917
|
+
size : float
|
1918
|
+
The size of the nodes, defaults to 10
|
1919
|
+
lw : float
|
1920
|
+
The width of the edges, defaults to 0.3
|
1921
|
+
ax : np.ndarray, optional
|
1922
|
+
The axes used, if not provided another set of axes is produced, defaults to None
|
1923
|
+
|
1924
|
+
Returns
|
1925
|
+
-------
|
1926
|
+
plt.Figure
|
1927
|
+
The figure of the plot
|
1928
|
+
plt.Axes
|
1929
|
+
The axes of the plot
|
1930
|
+
"""
|
1931
|
+
parameters = (
|
1932
|
+
end_time,
|
1933
|
+
convert_style_to_number(style=style, downsample=downsample),
|
1934
|
+
)
|
1935
|
+
n1, n2 = sorted([n1, n2])
|
1936
|
+
self._comparisons.setdefault(parameters, {})
|
1937
|
+
if self._comparisons[parameters].get((n1, n2)):
|
1938
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
1939
|
+
else:
|
1940
|
+
tmp = self.__unordereded_backtrace(
|
1941
|
+
n1, n2, end_time, norm, style, downsample
|
1942
|
+
)
|
1943
|
+
btrc: Alignment = tmp["alignment"]
|
1944
|
+
tree1, tree2 = tmp["trees"]
|
1945
|
+
_, times1 = tree1.tree
|
1946
|
+
_, times2 = tree2.tree
|
1947
|
+
(
|
1948
|
+
*_,
|
1949
|
+
corres1,
|
1950
|
+
) = tree1.edist
|
1951
|
+
(
|
1952
|
+
*_,
|
1953
|
+
corres2,
|
1954
|
+
) = tree2.edist
|
1955
|
+
delta_tmp = partial(
|
1956
|
+
tree1.delta,
|
1957
|
+
corres1=corres1,
|
1958
|
+
corres2=corres2,
|
1959
|
+
times1=times1,
|
1960
|
+
times2=times2,
|
1961
|
+
)
|
1962
|
+
|
1963
|
+
if norm not in self.norm_dict:
|
1964
|
+
raise Warning(
|
1965
|
+
"Select a viable normalization method (max, sum, None)"
|
1966
|
+
)
|
1967
|
+
matched_right = []
|
1968
|
+
matched_left = []
|
1969
|
+
colors = {}
|
1970
|
+
if style not in ("full", "downsampled"):
|
1971
|
+
for m in btrc:
|
1972
|
+
if m._left != -1 and m._right != -1:
|
1973
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
1974
|
+
if len(cyc1) > 1:
|
1975
|
+
node_1, *_, l_node_1 = cyc1
|
1976
|
+
matched_left.append(node_1)
|
1977
|
+
matched_left.append(l_node_1)
|
1978
|
+
elif len(cyc1) == 1:
|
1979
|
+
node_1 = l_node_1 = cyc1.pop()
|
1980
|
+
matched_left.append(node_1)
|
1981
|
+
|
1982
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
1983
|
+
if len(cyc2) > 1:
|
1984
|
+
node_2, *_, l_node_2 = cyc2
|
1985
|
+
matched_right.append(node_2)
|
1986
|
+
matched_right.append(l_node_2)
|
1987
|
+
|
1988
|
+
elif len(cyc2) == 1:
|
1989
|
+
node_2 = l_node_2 = cyc2.pop()
|
1990
|
+
matched_right.append(node_2)
|
1991
|
+
|
1992
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
1993
|
+
node_1,
|
1994
|
+
node_2,
|
1995
|
+
btrc,
|
1996
|
+
corres1,
|
1997
|
+
corres2,
|
1998
|
+
delta_tmp,
|
1999
|
+
self.norm_dict[norm],
|
2000
|
+
tree1.get_norm(node_1),
|
2001
|
+
tree2.get_norm(node_2),
|
2002
|
+
)
|
2003
|
+
colors[node_2] = colors[node_1]
|
2004
|
+
colors[l_node_1] = colors[node_1]
|
2005
|
+
colors[l_node_2] = colors[node_2]
|
2006
|
+
else:
|
2007
|
+
for m in btrc:
|
2008
|
+
if m._left != -1 and m._right != -1:
|
2009
|
+
node_1 = corres1[m._left]
|
2010
|
+
node_2 = corres2[m._right]
|
2011
|
+
|
2012
|
+
if (
|
2013
|
+
self.get_chain_of_node(node_1)[0] == node_1
|
2014
|
+
or self.get_chain_of_node(node_2)[0] == node_2
|
2015
|
+
and (node_1 not in colors or node_2 not in colors)
|
2016
|
+
):
|
2017
|
+
matched_left.append(node_1)
|
2018
|
+
l_node_1 = self.get_chain_of_node(node_1)[-1]
|
2019
|
+
matched_left.append(l_node_1)
|
2020
|
+
matched_right.append(node_2)
|
2021
|
+
l_node_2 = self.get_chain_of_node(node_2)[-1]
|
2022
|
+
matched_right.append(l_node_2)
|
2023
|
+
colors[node_1] = self.__calculate_distance_of_sub_tree(
|
2024
|
+
node_1,
|
2025
|
+
node_2,
|
2026
|
+
btrc,
|
2027
|
+
corres1,
|
2028
|
+
corres2,
|
2029
|
+
delta_tmp,
|
2030
|
+
self.norm_dict[norm],
|
2031
|
+
tree1.get_norm(node_1),
|
2032
|
+
tree2.get_norm(node_2),
|
2033
|
+
)
|
2034
|
+
colors[l_node_1] = colors[node_1]
|
2035
|
+
colors[node_2] = colors[node_1]
|
2036
|
+
colors[l_node_2] = colors[node_1]
|
2037
|
+
if ax is None:
|
2038
|
+
fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
|
2039
|
+
cmap = colormaps[colormap]
|
2040
|
+
c_norm = mcolors.Normalize(0, 1)
|
2041
|
+
colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
|
2042
|
+
self.plot_subtree(
|
2043
|
+
self.get_ancestor_at_t(n1),
|
2044
|
+
end_time=end_time,
|
2045
|
+
size=size,
|
2046
|
+
selected_nodes=matched_left,
|
2047
|
+
color_of_nodes=colors,
|
2048
|
+
selected_edges=matched_left,
|
2049
|
+
color_of_edges=colors,
|
2050
|
+
default_color=default_color,
|
2051
|
+
lw=lw,
|
2052
|
+
ax=ax[0],
|
2053
|
+
)
|
2054
|
+
self.plot_subtree(
|
2055
|
+
self.get_ancestor_at_t(n2),
|
2056
|
+
end_time=end_time,
|
2057
|
+
size=size,
|
2058
|
+
selected_nodes=matched_right,
|
2059
|
+
color_of_nodes=colors,
|
2060
|
+
selected_edges=matched_right,
|
2061
|
+
color_of_edges=colors,
|
2062
|
+
default_color=default_color,
|
2063
|
+
lw=lw,
|
2064
|
+
ax=ax[1],
|
2065
|
+
)
|
2066
|
+
return ax[0].get_figure(), ax
|
2067
|
+
|
2068
|
+
def labelled_mappings(
|
2069
|
+
self,
|
2070
|
+
n1: int,
|
2071
|
+
n2: int,
|
2072
|
+
end_time: int | None = None,
|
2073
|
+
norm: Literal["max", "sum", None] = "max",
|
2074
|
+
style: (
|
2075
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2076
|
+
| type[TreeApproximationTemplate]
|
2077
|
+
) = "simple",
|
2078
|
+
downsample: int = 2,
|
2079
|
+
) -> dict[str, list[str]]:
|
2080
|
+
"""
|
2081
|
+
Returns the labels or IDs of all the nodes in the subtrees compared.
|
2082
|
+
|
2083
|
+
|
2084
|
+
Parameters
|
2085
|
+
----------
|
2086
|
+
n1 : int
|
2087
|
+
id of the first node to compare
|
2088
|
+
n2 : int
|
2089
|
+
id of the second node to compare
|
2090
|
+
end_time : int, optional
|
2091
|
+
The final time point the comparison algorithm will take into account.
|
2092
|
+
If None or not provided all nodes will be taken into account.
|
2093
|
+
norm : {"max", "sum"}, default="max"
|
2094
|
+
The normalization method to use, defaults to 'max'.
|
2095
|
+
style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
|
2096
|
+
Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
|
2097
|
+
downsample : int, default=2
|
2098
|
+
The downsample factor for the downsampled tree approximation.
|
2099
|
+
Used only when `style="downsampled"`.
|
2100
|
+
|
2101
|
+
Returns
|
2102
|
+
-------
|
2103
|
+
dict mapping str to list of str
|
2104
|
+
- 'matched' The labels of the matched nodes of the alignment.
|
2105
|
+
- 'unmatched' The labels of the unmatched nodes of the alginment.
|
2106
|
+
"""
|
2107
|
+
parameters = (
|
2108
|
+
end_time,
|
2109
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2110
|
+
)
|
2111
|
+
n1, n2 = sorted([n1, n2])
|
2112
|
+
self._comparisons.setdefault(parameters, {})
|
2113
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2114
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2115
|
+
else:
|
2116
|
+
tmp = self.__unordereded_backtrace(
|
2117
|
+
n1, n2, end_time, norm, style, downsample
|
2118
|
+
)
|
2119
|
+
btrc = tmp["alignment"]
|
2120
|
+
tree1, tree2 = tmp["trees"]
|
2121
|
+
|
2122
|
+
(
|
2123
|
+
*_,
|
2124
|
+
corres1,
|
2125
|
+
) = tree1.edist
|
2126
|
+
(
|
2127
|
+
*_,
|
2128
|
+
corres2,
|
2129
|
+
) = tree2.edist
|
2130
|
+
|
2131
|
+
if norm not in self.norm_dict:
|
1815
2132
|
raise Warning(
|
1816
2133
|
"Select a viable normalization method (max, sum, None)"
|
1817
2134
|
)
|
1818
|
-
|
1819
|
-
|
1820
|
-
|
2135
|
+
matched = []
|
2136
|
+
unmatched = []
|
2137
|
+
if style not in ("full", "downsampled"):
|
2138
|
+
for m in btrc:
|
2139
|
+
if m._left != -1 and m._right != -1:
|
2140
|
+
cyc1 = self.get_chain_of_node(corres1[m._left])
|
2141
|
+
if len(cyc1) > 1:
|
2142
|
+
node_1, *_ = cyc1
|
2143
|
+
elif len(cyc1) == 1:
|
2144
|
+
node_1 = cyc1.pop()
|
2145
|
+
cyc2 = self.get_chain_of_node(corres2[m._right])
|
2146
|
+
if len(cyc2) > 1:
|
2147
|
+
node_2, *_ = cyc2
|
2148
|
+
elif len(cyc2) == 1:
|
2149
|
+
node_2 = cyc2.pop()
|
2150
|
+
matched.append(
|
2151
|
+
(
|
2152
|
+
self.labels.get(node_1, node_1),
|
2153
|
+
self.labels.get(node_2, node_2),
|
2154
|
+
)
|
2155
|
+
)
|
2156
|
+
|
2157
|
+
else:
|
2158
|
+
if m._left != -1:
|
2159
|
+
node_1 = self.get_chain_of_node(
|
2160
|
+
corres1.get(m._left, "-")
|
2161
|
+
)[0]
|
2162
|
+
else:
|
2163
|
+
node_1 = self.get_chain_of_node(
|
2164
|
+
corres2.get(m._right, "-")
|
2165
|
+
)[0]
|
2166
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2167
|
+
else:
|
2168
|
+
for m in btrc:
|
2169
|
+
if m._left != -1 and m._right != -1:
|
2170
|
+
node_1 = corres1[m._left]
|
2171
|
+
node_2 = corres2[m._right]
|
2172
|
+
matched.append(
|
2173
|
+
(
|
2174
|
+
self.labels.get(node_1, node_1),
|
2175
|
+
self.labels.get(node_2, node_2),
|
2176
|
+
)
|
2177
|
+
)
|
2178
|
+
else:
|
2179
|
+
if m._left != -1:
|
2180
|
+
node_1 = corres1[m._left]
|
2181
|
+
else:
|
2182
|
+
node_1 = corres2[m._right]
|
2183
|
+
unmatched.append(self.labels.get(node_1, node_1))
|
2184
|
+
return {"matched": matched, "unmatched": unmatched}
|
2185
|
+
|
2186
|
+
def unordered_tree_edit_distance(
|
2187
|
+
self,
|
2188
|
+
n1: int,
|
2189
|
+
n2: int,
|
2190
|
+
end_time: int | None = None,
|
2191
|
+
norm: Literal["max", "sum", None] = "max",
|
2192
|
+
style: (
|
2193
|
+
Literal["simple", "normalized_simple", "full", "downsampled"]
|
2194
|
+
| type[TreeApproximationTemplate]
|
2195
|
+
) = "simple",
|
2196
|
+
downsample: int = 2,
|
2197
|
+
return_norms: bool = False,
|
2198
|
+
) -> float | tuple[float, tuple[float, float]]:
|
2199
|
+
"""
|
2200
|
+
Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
|
2201
|
+
by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
|
2202
|
+
cost is given by the function delta (see edist doc for more information).
|
2203
|
+
The distance is normed by the function norm that takes the two list of nodes
|
2204
|
+
spawned by the trees `n1` and `n2`.
|
2205
|
+
|
2206
|
+
Parameters
|
2207
|
+
----------
|
2208
|
+
n1 : int
|
2209
|
+
id of the first node to compare
|
2210
|
+
n2 : int
|
2211
|
+
id of the second node to compare
|
2212
|
+
end_time : int, optional
|
2213
|
+
The final time point the comparison algorithm will take into account.
|
2214
|
+
If None or not provided all nodes will be taken into account.
|
2215
|
+
norm : {"max", "sum"}, default="max"
|
2216
|
+
The normalization method to use, defaults to 'max'.
|
2217
|
+
style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
|
2218
|
+
Which tree approximation is going to be used for the comparisons.
|
2219
|
+
downsample : int, default=2
|
2220
|
+
The downsample factor for the downsampled tree approximation.
|
2221
|
+
Used only when `style="downsampled"`.
|
2222
|
+
|
2223
|
+
Returns
|
2224
|
+
-------
|
2225
|
+
float
|
2226
|
+
The normalized unordered tree edit distance between `n1` and `n2`
|
2227
|
+
"""
|
2228
|
+
parameters = (
|
2229
|
+
end_time,
|
2230
|
+
convert_style_to_number(style=style, downsample=downsample),
|
2231
|
+
)
|
2232
|
+
n1, n2 = sorted([n1, n2])
|
2233
|
+
self._comparisons.setdefault(parameters, {})
|
2234
|
+
if self._comparisons[parameters].get((n1, n2)):
|
2235
|
+
tmp = self._comparisons[parameters][(n1, n2)]
|
2236
|
+
else:
|
2237
|
+
tmp = self.__unordereded_backtrace(
|
2238
|
+
n1, n2, end_time, norm, style, downsample
|
2239
|
+
)
|
2240
|
+
btrc = tmp["alignment"]
|
2241
|
+
tree1, tree2 = tmp["trees"]
|
2242
|
+
_, times1 = tree1.tree
|
2243
|
+
_, times2 = tree2.tree
|
2244
|
+
(
|
2245
|
+
nodes1,
|
2246
|
+
adj1,
|
2247
|
+
corres1,
|
2248
|
+
) = tree1.edist
|
2249
|
+
(
|
2250
|
+
nodes2,
|
2251
|
+
adj2,
|
2252
|
+
corres2,
|
2253
|
+
) = tree2.edist
|
2254
|
+
delta_tmp = partial(
|
2255
|
+
tree1.delta,
|
2256
|
+
corres1=corres1,
|
2257
|
+
corres2=corres2,
|
2258
|
+
times1=times1,
|
2259
|
+
times2=times2,
|
2260
|
+
)
|
2261
|
+
|
2262
|
+
if norm not in self.norm_dict:
|
2263
|
+
raise ValueError(
|
2264
|
+
"Select a viable normalization method (max, sum, None)"
|
2265
|
+
)
|
2266
|
+
cost = btrc.cost(nodes1, nodes2, delta_tmp)
|
2267
|
+
norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
|
2268
|
+
if return_norms:
|
2269
|
+
return cost, norm_values
|
2270
|
+
return cost / self.norm_dict[norm](norm_values)
|
1821
2271
|
|
1822
2272
|
@staticmethod
|
1823
2273
|
def __plot_nodes(
|
1824
|
-
hier
|
1825
|
-
|
2274
|
+
hier: dict,
|
2275
|
+
selected_nodes: set,
|
2276
|
+
color: str | dict | list,
|
2277
|
+
size: int | float,
|
2278
|
+
ax: plt.Axes,
|
2279
|
+
default_color: str = "black",
|
2280
|
+
**kwargs,
|
2281
|
+
) -> None:
|
1826
2282
|
"""
|
1827
2283
|
Private method that plots the nodes of the tree.
|
1828
2284
|
"""
|
1829
|
-
|
1830
|
-
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
|
1835
|
-
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
)
|
1840
|
-
if selected_nodes.intersection(hier.keys()):
|
1841
|
-
hier_selected = np.array(
|
1842
|
-
[v for k, v in hier.items() if k in selected_nodes]
|
1843
|
-
)
|
1844
|
-
ax.scatter(
|
1845
|
-
*hier_selected.T, s=size, zorder=10, color=color, **kwargs
|
1846
|
-
)
|
2285
|
+
|
2286
|
+
if isinstance(color, dict):
|
2287
|
+
color = [color.get(k, default_color) for k in hier]
|
2288
|
+
elif isinstance(color, str | list):
|
2289
|
+
color = [
|
2290
|
+
color if node in selected_nodes else default_color
|
2291
|
+
for node in hier
|
2292
|
+
]
|
2293
|
+
hier_pos = np.array(list(hier.values()))
|
2294
|
+
ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
|
1847
2295
|
|
1848
2296
|
@staticmethod
|
1849
2297
|
def __plot_edges(
|
1850
|
-
hier,
|
1851
|
-
lnks_tms,
|
1852
|
-
selected_edges,
|
1853
|
-
color,
|
1854
|
-
|
1855
|
-
|
2298
|
+
hier: dict,
|
2299
|
+
lnks_tms: dict,
|
2300
|
+
selected_edges: Iterable,
|
2301
|
+
color: str | dict | list,
|
2302
|
+
lw: float,
|
2303
|
+
ax: plt.Axes,
|
2304
|
+
default_color: str = "black",
|
1856
2305
|
**kwargs,
|
1857
|
-
):
|
2306
|
+
) -> None:
|
1858
2307
|
"""
|
1859
2308
|
Private method that plots the edges of the tree.
|
1860
2309
|
"""
|
1861
|
-
|
2310
|
+
if isinstance(color, dict):
|
2311
|
+
selected_edges = color.keys()
|
2312
|
+
lines = []
|
2313
|
+
c = []
|
1862
2314
|
for pred, succs in lnks_tms["links"].items():
|
1863
|
-
for
|
1864
|
-
|
1865
|
-
|
1866
|
-
|
1867
|
-
|
1868
|
-
|
1869
|
-
|
1870
|
-
|
1871
|
-
|
1872
|
-
|
1873
|
-
|
1874
|
-
|
2315
|
+
for suc in succs:
|
2316
|
+
lines.append(
|
2317
|
+
[
|
2318
|
+
[hier[suc][0], hier[suc][1]],
|
2319
|
+
[hier[pred][0], hier[pred][1]],
|
2320
|
+
]
|
2321
|
+
)
|
2322
|
+
if pred in selected_edges:
|
2323
|
+
if isinstance(color, str | list):
|
2324
|
+
c.append(color)
|
2325
|
+
elif isinstance(color, dict):
|
2326
|
+
c.append(color[pred])
|
2327
|
+
else:
|
2328
|
+
c.append(default_color)
|
2329
|
+
lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
|
2330
|
+
ax.add_collection(lc)
|
1875
2331
|
|
1876
2332
|
def draw_tree_graph(
|
1877
2333
|
self,
|
1878
|
-
hier,
|
1879
|
-
lnks_tms,
|
1880
|
-
selected_nodes=None,
|
1881
|
-
selected_edges=None,
|
1882
|
-
color_of_nodes="magenta",
|
1883
|
-
color_of_edges=
|
1884
|
-
size=10,
|
1885
|
-
|
1886
|
-
|
1887
|
-
default_color="black",
|
2334
|
+
hier: dict[int, tuple[int, int]],
|
2335
|
+
lnks_tms: dict[str, dict[int, list | int]],
|
2336
|
+
selected_nodes: list | set | None = None,
|
2337
|
+
selected_edges: list | set | None = None,
|
2338
|
+
color_of_nodes: str | dict = "magenta",
|
2339
|
+
color_of_edges: str | dict = "magenta",
|
2340
|
+
size: int | float = 10,
|
2341
|
+
lw: float = 0.3,
|
2342
|
+
ax: plt.Axes | None = None,
|
2343
|
+
default_color: str = "black",
|
1888
2344
|
**kwargs,
|
1889
|
-
):
|
2345
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
1890
2346
|
"""Function to plot the tree graph.
|
1891
2347
|
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
|
1899
|
-
|
1900
|
-
|
1901
|
-
|
1902
|
-
|
1903
|
-
|
1904
|
-
|
1905
|
-
|
1906
|
-
|
1907
|
-
|
2348
|
+
Parameters
|
2349
|
+
----------
|
2350
|
+
hier : dict mapping int to tuple of int
|
2351
|
+
Dictionary that contains the positions of all nodes.
|
2352
|
+
lnks_tms : dict mapping string to dictionaries mapping int to list or int
|
2353
|
+
- 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
|
2354
|
+
- 'times' : contains the distance between the start and the end of each chain.
|
2355
|
+
selected_nodes : list or set, optional
|
2356
|
+
Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
|
2357
|
+
selected_edges : list or set, optional
|
2358
|
+
Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
|
2359
|
+
color_of_nodes : str, default="magenta"
|
2360
|
+
Color of selected nodes
|
2361
|
+
color_of_edges : str, default="magenta"
|
2362
|
+
Color of selected edges
|
2363
|
+
size : int, default=10
|
2364
|
+
Size of the nodes, defaults to 10
|
2365
|
+
lw : float, default=0.3
|
2366
|
+
The width of the edges of the tree graph, defaults to 0.3
|
2367
|
+
ax : plt.Axes, optional
|
2368
|
+
Plot the graph on existing ax. If not provided or None a new ax is going to be created.
|
2369
|
+
default_color : str, default="black"
|
2370
|
+
Default color of nodes
|
2371
|
+
|
2372
|
+
Returns
|
2373
|
+
-------
|
2374
|
+
plt.Figure
|
2375
|
+
The matplotlib figure
|
2376
|
+
plt.Axes
|
2377
|
+
The matplotlib ax
|
1908
2378
|
"""
|
1909
2379
|
if selected_nodes is None:
|
1910
2380
|
selected_nodes = []
|
1911
2381
|
if selected_edges is None:
|
1912
2382
|
selected_edges = []
|
1913
2383
|
if ax is None:
|
1914
|
-
|
2384
|
+
_, ax = plt.subplots()
|
1915
2385
|
else:
|
1916
2386
|
ax.clear()
|
1917
2387
|
if not isinstance(selected_nodes, set):
|
1918
2388
|
selected_nodes = set(selected_nodes)
|
1919
2389
|
if not isinstance(selected_edges, set):
|
1920
2390
|
selected_edges = set(selected_edges)
|
1921
|
-
|
1922
|
-
|
1923
|
-
|
1924
|
-
|
1925
|
-
|
1926
|
-
|
1927
|
-
|
1928
|
-
|
1929
|
-
|
2391
|
+
if 0 < size:
|
2392
|
+
self.__plot_nodes(
|
2393
|
+
hier,
|
2394
|
+
selected_nodes,
|
2395
|
+
color_of_nodes,
|
2396
|
+
size=size,
|
2397
|
+
ax=ax,
|
2398
|
+
default_color=default_color,
|
2399
|
+
**kwargs,
|
2400
|
+
)
|
1930
2401
|
if not color_of_edges:
|
1931
2402
|
color_of_edges = color_of_nodes
|
1932
2403
|
self.__plot_edges(
|
@@ -1934,62 +2405,107 @@ class lineageTree(lineageTreeLoaders):
|
|
1934
2405
|
lnks_tms,
|
1935
2406
|
selected_edges,
|
1936
2407
|
color_of_edges,
|
2408
|
+
lw,
|
1937
2409
|
ax,
|
1938
2410
|
default_color=default_color,
|
1939
2411
|
**kwargs,
|
1940
2412
|
)
|
2413
|
+
ax.autoscale()
|
2414
|
+
plt.draw()
|
1941
2415
|
ax.get_yaxis().set_visible(False)
|
1942
2416
|
ax.get_xaxis().set_visible(False)
|
1943
|
-
return
|
2417
|
+
return ax.get_figure(), ax
|
1944
2418
|
|
1945
|
-
def
|
2419
|
+
def _create_dict_of_plots(
|
2420
|
+
self,
|
2421
|
+
node: int | Iterable[int] | None = None,
|
2422
|
+
start_time: int | None = None,
|
2423
|
+
end_time: int | None = None,
|
2424
|
+
) -> dict[int, dict]:
|
1946
2425
|
"""Generates a dictionary of graphs where the keys are the index of the graph and
|
1947
|
-
the values are the graphs themselves which are produced by
|
1948
|
-
|
1949
|
-
|
1950
|
-
|
1951
|
-
|
1952
|
-
|
1953
|
-
|
1954
|
-
|
1955
|
-
|
2426
|
+
the values are the graphs themselves which are produced by `create_links_and_chains`
|
2427
|
+
|
2428
|
+
Parameters
|
2429
|
+
----------
|
2430
|
+
node : int or Iterable of int, optional
|
2431
|
+
The id of the node/nodes to produce the simple graphs, if not provided or None will
|
2432
|
+
calculate the dicts for every root that starts before 'start_time'
|
2433
|
+
start_time : int, optional
|
2434
|
+
Important only if there are no nodes it will produce the graph of every
|
2435
|
+
root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
|
2436
|
+
end_time : int, optional
|
2437
|
+
The last timepoint to be considered, if not provided or None the last timepoint of the
|
2438
|
+
dataset (t_e) is considered.
|
2439
|
+
|
2440
|
+
Returns
|
2441
|
+
-------
|
2442
|
+
dict mapping int to dict
|
2443
|
+
The keys are just index values 0-n and the values are the graphs produced.
|
1956
2444
|
"""
|
1957
2445
|
if start_time is None:
|
1958
2446
|
start_time = self.t_b
|
2447
|
+
if end_time is None:
|
2448
|
+
end_time = self.t_e
|
1959
2449
|
if node is None:
|
1960
2450
|
mothers = [
|
1961
|
-
root for root in self.roots if self.
|
2451
|
+
root for root in self.roots if self._time[root] <= start_time
|
1962
2452
|
]
|
2453
|
+
elif isinstance(node, Iterable):
|
2454
|
+
mothers = node
|
1963
2455
|
else:
|
1964
|
-
mothers =
|
2456
|
+
mothers = [node]
|
1965
2457
|
return {
|
1966
|
-
i:
|
2458
|
+
i: create_links_and_chains(self, mother, end_time=end_time)
|
1967
2459
|
for i, mother in enumerate(mothers)
|
1968
2460
|
}
|
1969
2461
|
|
1970
2462
|
def plot_all_lineages(
|
1971
2463
|
self,
|
1972
|
-
nodes: list = None,
|
1973
|
-
last_time_point_to_consider: int = None,
|
1974
|
-
nrows=2,
|
1975
|
-
figsize=(10, 15),
|
1976
|
-
dpi=100,
|
1977
|
-
fontsize=15,
|
1978
|
-
|
1979
|
-
|
2464
|
+
nodes: list | None = None,
|
2465
|
+
last_time_point_to_consider: int | None = None,
|
2466
|
+
nrows: int = 2,
|
2467
|
+
figsize: tuple[int, int] = (10, 15),
|
2468
|
+
dpi: int = 100,
|
2469
|
+
fontsize: int = 15,
|
2470
|
+
axes: plt.Axes | None = None,
|
2471
|
+
vert_gap: int = 1,
|
1980
2472
|
**kwargs,
|
1981
|
-
):
|
2473
|
+
) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
|
1982
2474
|
"""Plots all lineages.
|
1983
2475
|
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
2476
|
+
Parameters
|
2477
|
+
----------
|
2478
|
+
nodes : list, optional
|
2479
|
+
The nodes spawning the graphs to be plotted.
|
2480
|
+
last_time_point_to_consider : int, optional
|
2481
|
+
Which timepoints and upwards are the graphs to be plotted.
|
2482
|
+
For example if start_time is 10, then all trees that begin
|
2483
|
+
on tp 10 or before are calculated. Defaults to None, where
|
2484
|
+
it will plot all the roots that exist on `self.t_b`.
|
2485
|
+
nrows : int, default=2
|
2486
|
+
How many rows of plots should be printed.
|
2487
|
+
figsize : tuple, default=(10, 15)
|
2488
|
+
The size of the figure.
|
2489
|
+
dpi : int, default=100
|
2490
|
+
The dpi of the figure.
|
2491
|
+
fontsize : int, default=15
|
2492
|
+
The fontsize of the labels.
|
2493
|
+
axes : plt.Axes, optional
|
2494
|
+
The axes to plot the graphs on. If None or not provided new axes are going to be created.
|
2495
|
+
vert_gap : int, default=1
|
2496
|
+
space between the nodes, defaults to 1
|
2497
|
+
**kwargs:
|
2498
|
+
kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
|
2499
|
+
|
2500
|
+
Returns
|
2501
|
+
-------
|
2502
|
+
plt.Figure
|
2503
|
+
The figure
|
2504
|
+
plt.Axes
|
2505
|
+
The axes
|
2506
|
+
dict of plt.Axes to int
|
2507
|
+
A dictionary that maps the axes to the root of the tree.
|
1991
2508
|
"""
|
1992
|
-
|
1993
2509
|
nrows = int(nrows)
|
1994
2510
|
if last_time_point_to_consider is None:
|
1995
2511
|
last_time_point_to_consider = self.t_b
|
@@ -1998,22 +2514,33 @@ class lineageTree(lineageTreeLoaders):
|
|
1998
2514
|
raise Warning("Number of rows has to be at least 1")
|
1999
2515
|
if nodes:
|
2000
2516
|
graphs = {
|
2001
|
-
i: self.
|
2517
|
+
i: self._create_dict_of_plots(node)
|
2518
|
+
for i, node in enumerate(nodes)
|
2002
2519
|
}
|
2003
2520
|
else:
|
2004
|
-
graphs = self.
|
2521
|
+
graphs = self._create_dict_of_plots(
|
2005
2522
|
start_time=last_time_point_to_consider
|
2006
2523
|
)
|
2007
2524
|
pos = {
|
2008
2525
|
i: hierarchical_pos(
|
2009
|
-
g,
|
2526
|
+
g,
|
2527
|
+
g["root"],
|
2528
|
+
ycenter=-int(self._time[g["root"]]),
|
2529
|
+
vert_gap=vert_gap,
|
2010
2530
|
)
|
2011
2531
|
for i, g in graphs.items()
|
2012
2532
|
}
|
2013
|
-
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2533
|
+
if axes is None:
|
2534
|
+
ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
|
2535
|
+
figure, axes = plt.subplots(
|
2536
|
+
figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
|
2537
|
+
)
|
2538
|
+
else:
|
2539
|
+
figure, axes = axes.flatten()[0].get_figure(), axes
|
2540
|
+
if len(axes.flatten()) < len(graphs):
|
2541
|
+
raise Exception(
|
2542
|
+
f"Not enough axes, they should be at least {len(graphs)}."
|
2543
|
+
)
|
2017
2544
|
flat_axes = axes.flatten()
|
2018
2545
|
ax2root = {}
|
2019
2546
|
min_width, min_height = float("inf"), float("inf")
|
@@ -2051,127 +2578,149 @@ class lineageTree(lineageTreeLoaders):
|
|
2051
2578
|
},
|
2052
2579
|
)
|
2053
2580
|
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
|
2054
|
-
return
|
2581
|
+
return axes.flatten()[0].get_figure(), axes, ax2root
|
2055
2582
|
|
2056
|
-
def
|
2583
|
+
def plot_subtree(
|
2584
|
+
self,
|
2585
|
+
node: int,
|
2586
|
+
end_time: int | None = None,
|
2587
|
+
figsize: tuple[int, int] = (4, 7),
|
2588
|
+
dpi: int = 150,
|
2589
|
+
vert_gap: int = 2,
|
2590
|
+
selected_nodes: list | None = None,
|
2591
|
+
selected_edges: list | None = None,
|
2592
|
+
color_of_nodes: str | dict = "magenta",
|
2593
|
+
color_of_edges: str | dict = "magenta",
|
2594
|
+
size: int | float = 10,
|
2595
|
+
lw: float = 0.1,
|
2596
|
+
default_color: str = "black",
|
2597
|
+
ax: plt.Axes | None = None,
|
2598
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
2057
2599
|
"""Plots the subtree spawn by a node.
|
2058
2600
|
|
2059
|
-
|
2060
|
-
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
2601
|
+
Parameters
|
2602
|
+
----------
|
2603
|
+
node : int
|
2604
|
+
The id of the node that is going to be plotted.
|
2605
|
+
end_time : int, optional
|
2606
|
+
The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
|
2607
|
+
figsize : tuple of 2 ints, default=(4,7)
|
2608
|
+
The size of the figure, deafults to (4,7)
|
2609
|
+
vert_gap : int, default=2
|
2610
|
+
The verical gap of a node when it divides, defaults to 2.
|
2611
|
+
dpi : int, default=150
|
2612
|
+
The dpi of the figure, defaults to 150
|
2613
|
+
selected_nodes : list, optional
|
2614
|
+
The nodes that are selected by the user to be colored in a different color, defaults to None
|
2615
|
+
selected_edges : list, optional
|
2616
|
+
The edges that are selected by the user to be colored in a different color, defaults to None
|
2617
|
+
color_of_nodes : str, default="magenta"
|
2618
|
+
The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
|
2619
|
+
color_of_edges : str, default="magenta"
|
2620
|
+
The color of the edges to be colored, except the default colored ones, defaults to "magenta"
|
2621
|
+
size : int, default=10
|
2622
|
+
The size of the nodes, defaults to 10
|
2623
|
+
lw : float, default=0.1
|
2624
|
+
The widthe of the edges of the tree graph, defaults to 0.1
|
2625
|
+
default_color : str, default="black"
|
2626
|
+
The default color of nodes and edges, defaults to "black"
|
2627
|
+
ax : plt.Axes, optional
|
2628
|
+
The ax where the plot is going to be applied, if not provided or None new axes will be created.
|
2629
|
+
|
2630
|
+
Returns
|
2631
|
+
-------
|
2632
|
+
plt.Figure
|
2633
|
+
The matplotlib figure
|
2634
|
+
plt.Axes
|
2635
|
+
The matplotlib axes
|
2636
|
+
|
2637
|
+
Raises
|
2638
|
+
------
|
2639
|
+
Warning
|
2640
|
+
If more than one nodes are received
|
2641
|
+
"""
|
2642
|
+
graph = self._create_dict_of_plots(node, end_time=end_time)
|
2064
2643
|
if len(graph) > 1:
|
2065
|
-
raise Warning(
|
2644
|
+
raise Warning(
|
2645
|
+
"Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
|
2646
|
+
)
|
2066
2647
|
graph = graph[0]
|
2067
|
-
|
2648
|
+
if not ax:
|
2649
|
+
_, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
|
2068
2650
|
self.draw_tree_graph(
|
2069
2651
|
hier=hierarchical_pos(
|
2070
2652
|
graph,
|
2071
2653
|
graph["root"],
|
2072
2654
|
vert_gap=vert_gap,
|
2073
|
-
ycenter=-int(self.
|
2655
|
+
ycenter=-int(self._time[node]),
|
2074
2656
|
),
|
2657
|
+
selected_edges=selected_edges,
|
2658
|
+
selected_nodes=selected_nodes,
|
2659
|
+
color_of_edges=color_of_edges,
|
2660
|
+
color_of_nodes=color_of_nodes,
|
2661
|
+
default_color=default_color,
|
2662
|
+
size=size,
|
2663
|
+
lw=lw,
|
2075
2664
|
lnks_tms=graph,
|
2076
2665
|
ax=ax,
|
2077
|
-
**kwargs,
|
2078
2666
|
)
|
2079
|
-
return
|
2080
|
-
|
2081
|
-
# def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
|
2082
|
-
# metric='euclidian', **kwargs):
|
2083
|
-
# """ Computes the dynamic time warping distance between the tracks t1 and t2
|
2084
|
-
|
2085
|
-
# Args:
|
2086
|
-
# t1 ([int, ]): list of node ids for the first track
|
2087
|
-
# t2 ([int, ]): list of node ids for the second track
|
2088
|
-
# w (int): maximum wapring allowed (default infinite),
|
2089
|
-
# if w=1 then the DTW is the distance between t1 and t2
|
2090
|
-
# start_delay (int): maximum number of time points that can be
|
2091
|
-
# skipped at the beginning of the track
|
2092
|
-
# end_delay (int): minimum number of time points that can be
|
2093
|
-
# skipped at the beginning of the track
|
2094
|
-
# metric (str): str or callable, optional The distance metric to use.
|
2095
|
-
# Default='euclidean'. Refer to the documentation for
|
2096
|
-
# scipy.spatial.distance.cdist. Some examples:
|
2097
|
-
# 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
|
2098
|
-
# 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
|
2099
|
-
# 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
|
2100
|
-
# 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
|
2101
|
-
# 'wminkowski', 'yule'
|
2102
|
-
# **kwargs (dict): Extra arguments to `metric`: refer to each metric
|
2103
|
-
# documentation in scipy.spatial.distance (optional)
|
2104
|
-
|
2105
|
-
# Returns:
|
2106
|
-
# float: the dynamic time warping distance between the two tracks
|
2107
|
-
# """
|
2108
|
-
# from scipy.sparse import
|
2109
|
-
# pos_t1 = [self.pos[ti] for ti in t1]
|
2110
|
-
# pos_t2 = [self.pos[ti] for ti in t2]
|
2111
|
-
# distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
|
2112
|
-
|
2113
|
-
# c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
|
2114
|
-
|
2115
|
-
# d = np.zeros(c.shape)
|
2116
|
-
# d[0, 0] = c[0, 0]
|
2117
|
-
# n, m = c.shape
|
2118
|
-
# for i in range(1, n):
|
2119
|
-
# d[i, 0] = d[i-1, 0] + c[i, 0]
|
2120
|
-
# for j in range(1, m):
|
2121
|
-
# d[0, j] = d[0, j-1] + c[0, j]
|
2122
|
-
# for i in range(1, n):
|
2123
|
-
# for j in range(1, m):
|
2124
|
-
# d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
|
2125
|
-
# return d[-1, -1], d
|
2126
|
-
|
2127
|
-
def __getitem__(self, item):
|
2128
|
-
if isinstance(item, str):
|
2129
|
-
return self.__dict__[item]
|
2130
|
-
elif np.issubdtype(type(item), np.integer):
|
2131
|
-
return self.successor.get(item, [])
|
2132
|
-
else:
|
2133
|
-
raise KeyError(
|
2134
|
-
"Only integer or string are valid key for lineageTree"
|
2135
|
-
)
|
2667
|
+
return ax.get_figure(), ax
|
2136
2668
|
|
2137
|
-
def
|
2669
|
+
def nodes_at_t(
|
2670
|
+
self,
|
2671
|
+
t: int,
|
2672
|
+
r: int | Iterable[int] | None = None,
|
2673
|
+
) -> list:
|
2138
2674
|
"""
|
2139
|
-
Returns the list of
|
2675
|
+
Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
|
2140
2676
|
|
2141
|
-
|
2142
|
-
|
2143
|
-
|
2144
|
-
|
2677
|
+
Parameters
|
2678
|
+
----------
|
2679
|
+
t : int
|
2680
|
+
target time, if `None` goes as far as possible
|
2681
|
+
r : int or Iterable of int, optional
|
2682
|
+
id or list of ids of the spawning node
|
2145
2683
|
|
2146
|
-
|
2147
|
-
|
2684
|
+
Returns
|
2685
|
+
-------
|
2686
|
+
list
|
2687
|
+
list of nodes at time `t` spawned by `r`
|
2148
2688
|
"""
|
2149
|
-
if not
|
2689
|
+
if not r and r != 0:
|
2690
|
+
r = {root for root in self.roots if self.time[root] <= t}
|
2691
|
+
if isinstance(r, int):
|
2150
2692
|
r = [r]
|
2693
|
+
if t is None:
|
2694
|
+
t = self.t_e
|
2151
2695
|
to_do = list(r)
|
2152
2696
|
final_nodes = []
|
2153
2697
|
while len(to_do) > 0:
|
2154
2698
|
curr = to_do.pop()
|
2155
|
-
for _next in self[curr]:
|
2156
|
-
if self.
|
2699
|
+
for _next in self._successor[curr]:
|
2700
|
+
if self._time[_next] < t:
|
2157
2701
|
to_do.append(_next)
|
2158
|
-
elif self.
|
2702
|
+
elif self._time[_next] == t:
|
2159
2703
|
final_nodes.append(_next)
|
2160
2704
|
if not final_nodes:
|
2161
2705
|
return list(r)
|
2162
2706
|
return final_nodes
|
2163
2707
|
|
2164
2708
|
@staticmethod
|
2165
|
-
def __calculate_diag_line(dist_mat: np.ndarray) ->
|
2709
|
+
def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
|
2166
2710
|
"""
|
2167
2711
|
Calculate the line that centers the band w.
|
2168
2712
|
|
2169
|
-
|
2170
|
-
|
2713
|
+
Parameters
|
2714
|
+
----------
|
2715
|
+
dist_mat : np.ndarray
|
2716
|
+
distance matrix obtained by the function calculate_dtw
|
2171
2717
|
|
2172
|
-
|
2173
|
-
|
2174
|
-
|
2718
|
+
Returns
|
2719
|
+
-------
|
2720
|
+
float
|
2721
|
+
The slope of the curve
|
2722
|
+
float
|
2723
|
+
The intercept of the curve
|
2175
2724
|
"""
|
2176
2725
|
i, j = dist_mat.shape
|
2177
2726
|
x1 = max(0, i - j) / 2
|
@@ -2191,22 +2740,33 @@ class lineageTree(lineageTreeLoaders):
|
|
2191
2740
|
fast: bool = False,
|
2192
2741
|
w: int = 0,
|
2193
2742
|
centered_band: bool = True,
|
2194
|
-
) ->
|
2743
|
+
) -> tuple[list[int], np.ndarray, float]:
|
2195
2744
|
"""
|
2196
2745
|
Find DTW minimum cost between two series using dynamic programming.
|
2197
2746
|
|
2198
|
-
|
2199
|
-
|
2200
|
-
|
2201
|
-
|
2202
|
-
|
2203
|
-
|
2204
|
-
|
2205
|
-
|
2206
|
-
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2747
|
+
Parameters
|
2748
|
+
----------
|
2749
|
+
dist_mat : np.ndarray
|
2750
|
+
distance matrix obtained by the function calculate_dtw
|
2751
|
+
start_d : int, default=0
|
2752
|
+
start delay
|
2753
|
+
back_d : int, default=0
|
2754
|
+
end delay
|
2755
|
+
fast : bool, default=False
|
2756
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2757
|
+
w : int, default=0
|
2758
|
+
window constrain
|
2759
|
+
centered_band : bool, default=True
|
2760
|
+
if `True`, the band will be centered around the diagonal
|
2761
|
+
|
2762
|
+
Returns
|
2763
|
+
-------
|
2764
|
+
tuple of tuples of int
|
2765
|
+
Aligment path
|
2766
|
+
np.ndarray
|
2767
|
+
cost matrix
|
2768
|
+
float
|
2769
|
+
optimal cost
|
2210
2770
|
"""
|
2211
2771
|
N, M = dist_mat.shape
|
2212
2772
|
w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
|
@@ -2327,7 +2887,6 @@ class lineageTree(lineageTreeLoaders):
|
|
2327
2887
|
|
2328
2888
|
# special reflection case
|
2329
2889
|
if np.linalg.det(R) < 0:
|
2330
|
-
# print("det(R) < R, reflection detected!, correcting for it ...")
|
2331
2890
|
Vt[2, :] *= -1
|
2332
2891
|
R = Vt.T @ U.T
|
2333
2892
|
|
@@ -2336,52 +2895,59 @@ class lineageTree(lineageTreeLoaders):
|
|
2336
2895
|
return R, t
|
2337
2896
|
|
2338
2897
|
def __interpolate(
|
2339
|
-
self,
|
2340
|
-
) ->
|
2898
|
+
self, chain1: list, chain2: list, threshold: int
|
2899
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
2341
2900
|
"""
|
2342
2901
|
Interpolate two series that have different lengths
|
2343
2902
|
|
2344
|
-
|
2345
|
-
|
2346
|
-
|
2347
|
-
|
2348
|
-
|
2349
|
-
|
2350
|
-
|
2351
|
-
|
2903
|
+
Parameters
|
2904
|
+
----------
|
2905
|
+
chain1 : list of int
|
2906
|
+
list of nodes of the first chain to compare
|
2907
|
+
chain2 : list of int
|
2908
|
+
list of nodes of the second chain to compare
|
2909
|
+
threshold : int
|
2910
|
+
set a maximum number of points a chain can have
|
2911
|
+
|
2912
|
+
Returns
|
2913
|
+
-------
|
2914
|
+
list of np.ndarray
|
2915
|
+
`x`, `y`, `z` postions for `chain1`
|
2916
|
+
list of np.ndarray
|
2917
|
+
`x`, `y`, `z` postions for `chain2`
|
2352
2918
|
"""
|
2353
2919
|
inter1_pos = []
|
2354
2920
|
inter2_pos = []
|
2355
2921
|
|
2356
|
-
|
2357
|
-
|
2922
|
+
chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
|
2923
|
+
chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
|
2358
2924
|
|
2359
|
-
# Both
|
2360
|
-
if len(
|
2361
|
-
len(
|
2925
|
+
# Both chains have the same length and size below the threshold - nothing is done
|
2926
|
+
if len(chain1) == len(chain2) and (
|
2927
|
+
len(chain1) <= threshold or len(chain2) <= threshold
|
2362
2928
|
):
|
2363
|
-
return
|
2364
|
-
# Both
|
2365
|
-
elif len(
|
2929
|
+
return chain1_pos, chain2_pos
|
2930
|
+
# Both chains have the same length but one or more sizes are above the threshold
|
2931
|
+
elif len(chain1) > threshold or len(chain2) > threshold:
|
2366
2932
|
sampling = threshold
|
2367
|
-
#
|
2933
|
+
# chains have different lengths and the sizes are below the threshold
|
2368
2934
|
else:
|
2369
|
-
sampling = max(len(
|
2935
|
+
sampling = max(len(chain1), len(chain2))
|
2370
2936
|
|
2371
2937
|
for pos in range(3):
|
2372
|
-
|
2373
|
-
np.linspace(0, 1, len(
|
2374
|
-
|
2938
|
+
chain1_interp = InterpolatedUnivariateSpline(
|
2939
|
+
np.linspace(0, 1, len(chain1_pos[:, pos])),
|
2940
|
+
chain1_pos[:, pos],
|
2375
2941
|
k=1,
|
2376
2942
|
)
|
2377
|
-
inter1_pos.append(
|
2943
|
+
inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
|
2378
2944
|
|
2379
|
-
|
2380
|
-
np.linspace(0, 1, len(
|
2381
|
-
|
2945
|
+
chain2_interp = InterpolatedUnivariateSpline(
|
2946
|
+
np.linspace(0, 1, len(chain2_pos[:, pos])),
|
2947
|
+
chain2_pos[:, pos],
|
2382
2948
|
k=1,
|
2383
2949
|
)
|
2384
|
-
inter2_pos.append(
|
2950
|
+
inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
|
2385
2951
|
|
2386
2952
|
return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
|
2387
2953
|
|
@@ -2397,46 +2963,66 @@ class lineageTree(lineageTreeLoaders):
|
|
2397
2963
|
w: int = 0,
|
2398
2964
|
centered_band: bool = True,
|
2399
2965
|
cost_mat_p: bool = False,
|
2400
|
-
) -> (
|
2401
|
-
|
2402
|
-
|
2403
|
-
|
2404
|
-
Args:
|
2405
|
-
nodes1 (int): node to compare distance
|
2406
|
-
nodes2 (int): node to compare distance
|
2407
|
-
threshold: set a maximum number of points a track can have
|
2408
|
-
regist (boolean): Rotate and translate trajectories
|
2409
|
-
start_d (int): start delay
|
2410
|
-
back_d (int): end delay
|
2411
|
-
fast (boolean): True if the user wants to run the fast algorithm with window restrains
|
2412
|
-
w (int): window size
|
2413
|
-
centered_band (boolean): if running the fast algorithm, True if the windown is centered
|
2414
|
-
cost_mat_p (boolean): True if print the not normalized cost matrix
|
2415
|
-
|
2416
|
-
Returns:
|
2417
|
-
(float) DTW distance
|
2418
|
-
(tuple of tuples) Aligment path
|
2419
|
-
(matrix) Cost matrix
|
2420
|
-
(list of lists) pos_cycle1: rotated and translated trajectories positions
|
2421
|
-
(list of lists) pos_cycle2: rotated and translated trajectories positions
|
2966
|
+
) -> (
|
2967
|
+
tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
|
2968
|
+
| tuple[float, tuple]
|
2969
|
+
):
|
2422
2970
|
"""
|
2423
|
-
|
2424
|
-
|
2425
|
-
|
2426
|
-
|
2427
|
-
|
2971
|
+
Calculate DTW distance between two chains
|
2972
|
+
|
2973
|
+
Parameters
|
2974
|
+
----------
|
2975
|
+
nodes1 : int
|
2976
|
+
node to compare distance
|
2977
|
+
nodes2 : int
|
2978
|
+
node to compare distance
|
2979
|
+
threshold : int, default=1000
|
2980
|
+
set a maximum number of points a chain can have
|
2981
|
+
regist : bool, default=True
|
2982
|
+
Rotate and translate trajectories
|
2983
|
+
start_d : int, default=0
|
2984
|
+
start delay
|
2985
|
+
back_d : int, default=0
|
2986
|
+
end delay
|
2987
|
+
fast : bool, default=False
|
2988
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
2989
|
+
w : int, default=0
|
2990
|
+
window size
|
2991
|
+
centered_band : bool, default=True
|
2992
|
+
when running the fast algorithm, `True` if the windown is centered
|
2993
|
+
cost_mat_p : bool, default=False
|
2994
|
+
True if print the not normalized cost matrix
|
2995
|
+
|
2996
|
+
Returns
|
2997
|
+
-------
|
2998
|
+
float
|
2999
|
+
DTW distance
|
3000
|
+
tuple of tuples
|
3001
|
+
Aligment path
|
3002
|
+
matrix
|
3003
|
+
Cost matrix
|
3004
|
+
list of lists
|
3005
|
+
rotated and translated trajectories positions
|
3006
|
+
list of lists
|
3007
|
+
rotated and translated trajectories positions
|
3008
|
+
"""
|
3009
|
+
nodes1_chain = self.get_chain_of_node(nodes1)
|
3010
|
+
nodes2_chain = self.get_chain_of_node(nodes2)
|
3011
|
+
|
3012
|
+
interp_chain1, interp_chain2 = self.__interpolate(
|
3013
|
+
nodes1_chain, nodes2_chain, threshold
|
2428
3014
|
)
|
2429
3015
|
|
2430
|
-
|
2431
|
-
|
3016
|
+
pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
|
3017
|
+
pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
|
2432
3018
|
|
2433
3019
|
if regist:
|
2434
3020
|
R, t = self.__rigid_transform_3D(
|
2435
|
-
np.transpose(
|
3021
|
+
np.transpose(interp_chain1), np.transpose(interp_chain2)
|
2436
3022
|
)
|
2437
|
-
|
3023
|
+
pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
|
2438
3024
|
|
2439
|
-
dist_mat = distance.cdist(
|
3025
|
+
dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
|
2440
3026
|
|
2441
3027
|
path, cost_mat, final_cost = self.__dp(
|
2442
3028
|
dist_mat,
|
@@ -2449,7 +3035,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2449
3035
|
cost = final_cost / len(path)
|
2450
3036
|
|
2451
3037
|
if cost_mat_p:
|
2452
|
-
return cost, path, cost_mat,
|
3038
|
+
return cost, path, cost_mat, pos_chain1, pos_chain2
|
2453
3039
|
else:
|
2454
3040
|
return cost, path
|
2455
3041
|
|
@@ -2464,24 +3050,39 @@ class lineageTree(lineageTreeLoaders):
|
|
2464
3050
|
fast: bool = False,
|
2465
3051
|
w: int = 0,
|
2466
3052
|
centered_band: bool = True,
|
2467
|
-
) ->
|
2468
|
-
"""
|
2469
|
-
Plot DTW cost matrix between two
|
2470
|
-
|
2471
|
-
|
2472
|
-
|
2473
|
-
|
2474
|
-
|
2475
|
-
|
2476
|
-
|
2477
|
-
|
2478
|
-
|
2479
|
-
|
2480
|
-
|
2481
|
-
|
2482
|
-
|
2483
|
-
|
2484
|
-
|
3053
|
+
) -> tuple[float, plt.Figure]:
|
3054
|
+
"""
|
3055
|
+
Plot DTW cost matrix between two chains in heatmap format
|
3056
|
+
|
3057
|
+
Parameters
|
3058
|
+
----------
|
3059
|
+
nodes1 : int
|
3060
|
+
node to compare distance
|
3061
|
+
nodes2 : int
|
3062
|
+
node to compare distance
|
3063
|
+
threshold : int, default=1000
|
3064
|
+
set a maximum number of points a chain can have
|
3065
|
+
regist : bool, default=True
|
3066
|
+
Rotate and translate trajectories
|
3067
|
+
start_d : int, default=0
|
3068
|
+
start delay
|
3069
|
+
back_d : int, default=0
|
3070
|
+
end delay
|
3071
|
+
fast : bool, default=False
|
3072
|
+
if `True`, the algorithm will use a faster version but might not find the optimal alignment
|
3073
|
+
w : int, default=0
|
3074
|
+
window size
|
3075
|
+
centered_band : bool, default=True
|
3076
|
+
when running the fast algorithm, `True` if the windown is centered
|
3077
|
+
|
3078
|
+
Returns
|
3079
|
+
-------
|
3080
|
+
float
|
3081
|
+
DTW distance
|
3082
|
+
plt.Figure
|
3083
|
+
Heatmap of cost matrix with opitimal path
|
3084
|
+
"""
|
3085
|
+
cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
|
2485
3086
|
nodes1,
|
2486
3087
|
nodes2,
|
2487
3088
|
threshold,
|
@@ -2503,32 +3104,32 @@ class lineageTree(lineageTreeLoaders):
|
|
2503
3104
|
ax.set_title("Heatmap of DTW Cost Matrix")
|
2504
3105
|
ax.set_xlabel("Tree 1")
|
2505
3106
|
ax.set_ylabel("tree 2")
|
2506
|
-
x_path, y_path = zip(*path)
|
3107
|
+
x_path, y_path = zip(*path, strict=True)
|
2507
3108
|
ax.plot(y_path, x_path, color="black")
|
2508
3109
|
|
2509
3110
|
return cost, fig
|
2510
3111
|
|
2511
3112
|
@staticmethod
|
2512
3113
|
def __plot_2d(
|
2513
|
-
|
2514
|
-
|
2515
|
-
nodes1,
|
2516
|
-
nodes2,
|
2517
|
-
ax,
|
2518
|
-
x_idx,
|
2519
|
-
y_idx,
|
2520
|
-
x_label,
|
2521
|
-
y_label,
|
2522
|
-
):
|
3114
|
+
pos_chain1: np.ndarray,
|
3115
|
+
pos_chain2: np.ndarray,
|
3116
|
+
nodes1: list[int],
|
3117
|
+
nodes2: list[int],
|
3118
|
+
ax: plt.Axes,
|
3119
|
+
x_idx: list[int],
|
3120
|
+
y_idx: list[int],
|
3121
|
+
x_label: str,
|
3122
|
+
y_label: str,
|
3123
|
+
) -> None:
|
2523
3124
|
ax.plot(
|
2524
|
-
|
2525
|
-
|
3125
|
+
pos_chain1[:, x_idx],
|
3126
|
+
pos_chain1[:, y_idx],
|
2526
3127
|
"-",
|
2527
3128
|
label=f"root = {nodes1}",
|
2528
3129
|
)
|
2529
3130
|
ax.plot(
|
2530
|
-
|
2531
|
-
|
3131
|
+
pos_chain2[:, x_idx],
|
3132
|
+
pos_chain2[:, y_idx],
|
2532
3133
|
"-",
|
2533
3134
|
label=f"root = {nodes2}",
|
2534
3135
|
)
|
@@ -2546,40 +3147,55 @@ class lineageTree(lineageTreeLoaders):
|
|
2546
3147
|
fast: bool = False,
|
2547
3148
|
w: int = 0,
|
2548
3149
|
centered_band: bool = True,
|
2549
|
-
projection:
|
3150
|
+
projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
|
2550
3151
|
alig: bool = False,
|
2551
|
-
) ->
|
2552
|
-
"""
|
2553
|
-
Plots DTW trajectories aligment between two
|
2554
|
-
|
2555
|
-
|
2556
|
-
|
2557
|
-
|
2558
|
-
|
2559
|
-
|
2560
|
-
|
2561
|
-
|
2562
|
-
|
2563
|
-
|
2564
|
-
|
2565
|
-
|
2566
|
-
|
2567
|
-
|
2568
|
-
|
2569
|
-
|
2570
|
-
|
2571
|
-
|
2572
|
-
|
2573
|
-
|
2574
|
-
|
2575
|
-
|
3152
|
+
) -> tuple[float, plt.Figure]:
|
3153
|
+
"""
|
3154
|
+
Plots DTW trajectories aligment between two chains in 2D or 3D
|
3155
|
+
|
3156
|
+
Parameters
|
3157
|
+
----------
|
3158
|
+
nodes1 : int
|
3159
|
+
node to compare distance
|
3160
|
+
nodes2 : int
|
3161
|
+
node to compare distance
|
3162
|
+
threshold : int, default=1000
|
3163
|
+
set a maximum number of points a chain can have
|
3164
|
+
regist : bool, default=True
|
3165
|
+
Rotate and translate trajectories
|
3166
|
+
start_d : int, default=0
|
3167
|
+
start delay
|
3168
|
+
back_d : int, default=0
|
3169
|
+
end delay
|
3170
|
+
w : int, default=0
|
3171
|
+
window size
|
3172
|
+
fast : bool, default=False
|
3173
|
+
True if the user wants to run the fast algorithm with window restrains
|
3174
|
+
centered_band : bool, default=True
|
3175
|
+
if running the fast algorithm, True if the windown is centered
|
3176
|
+
projection : {"3d", "xy", "xz", "yz", "pca"}, optional
|
3177
|
+
specify which 2D to plot ->
|
3178
|
+
"3d" : for the 3d visualization
|
3179
|
+
"xy" or None (default) : 2D projection of axis x and y
|
3180
|
+
"xz" : 2D projection of axis x and z
|
3181
|
+
"yz" : 2D projection of axis y and z
|
3182
|
+
"pca" : PCA projection
|
3183
|
+
alig : bool
|
3184
|
+
True to show alignment on plot
|
3185
|
+
|
3186
|
+
Returns
|
3187
|
+
-------
|
3188
|
+
float
|
3189
|
+
DTW distance
|
3190
|
+
figure
|
3191
|
+
Trajectories Plot
|
2576
3192
|
"""
|
2577
3193
|
(
|
2578
3194
|
distance,
|
2579
3195
|
alignment,
|
2580
3196
|
cost_mat,
|
2581
|
-
|
2582
|
-
|
3197
|
+
pos_chain1,
|
3198
|
+
pos_chain2,
|
2583
3199
|
) = self.calculate_dtw(
|
2584
3200
|
nodes1,
|
2585
3201
|
nodes2,
|
@@ -2602,16 +3218,16 @@ class lineageTree(lineageTreeLoaders):
|
|
2602
3218
|
|
2603
3219
|
if projection == "3d":
|
2604
3220
|
ax.plot(
|
2605
|
-
|
2606
|
-
|
2607
|
-
|
3221
|
+
pos_chain1[:, 0],
|
3222
|
+
pos_chain1[:, 1],
|
3223
|
+
pos_chain1[:, 2],
|
2608
3224
|
"-",
|
2609
3225
|
label=f"root = {nodes1}",
|
2610
3226
|
)
|
2611
3227
|
ax.plot(
|
2612
|
-
|
2613
|
-
|
2614
|
-
|
3228
|
+
pos_chain2[:, 0],
|
3229
|
+
pos_chain2[:, 1],
|
3230
|
+
pos_chain2[:, 2],
|
2615
3231
|
"-",
|
2616
3232
|
label=f"root = {nodes2}",
|
2617
3233
|
)
|
@@ -2621,8 +3237,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2621
3237
|
else:
|
2622
3238
|
if projection == "xy" or projection == "yx" or projection is None:
|
2623
3239
|
self.__plot_2d(
|
2624
|
-
|
2625
|
-
|
3240
|
+
pos_chain1,
|
3241
|
+
pos_chain2,
|
2626
3242
|
nodes1,
|
2627
3243
|
nodes2,
|
2628
3244
|
ax,
|
@@ -2633,8 +3249,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2633
3249
|
)
|
2634
3250
|
elif projection == "xz" or projection == "zx":
|
2635
3251
|
self.__plot_2d(
|
2636
|
-
|
2637
|
-
|
3252
|
+
pos_chain1,
|
3253
|
+
pos_chain2,
|
2638
3254
|
nodes1,
|
2639
3255
|
nodes2,
|
2640
3256
|
ax,
|
@@ -2645,8 +3261,8 @@ class lineageTree(lineageTreeLoaders):
|
|
2645
3261
|
)
|
2646
3262
|
elif projection == "yz" or projection == "zy":
|
2647
3263
|
self.__plot_2d(
|
2648
|
-
|
2649
|
-
|
3264
|
+
pos_chain1,
|
3265
|
+
pos_chain2,
|
2650
3266
|
nodes1,
|
2651
3267
|
nodes2,
|
2652
3268
|
ax,
|
@@ -2660,24 +3276,25 @@ class lineageTree(lineageTreeLoaders):
|
|
2660
3276
|
from sklearn.decomposition import PCA
|
2661
3277
|
except ImportError:
|
2662
3278
|
Warning(
|
2663
|
-
"scikit-learn is not installed, the PCA orientation cannot be used.
|
3279
|
+
"scikit-learn is not installed, the PCA orientation cannot be used."
|
3280
|
+
"You can install scikit-learn with pip install"
|
2664
3281
|
)
|
2665
3282
|
|
2666
3283
|
# Apply PCA
|
2667
3284
|
pca = PCA(n_components=2)
|
2668
|
-
pca.fit(np.vstack([
|
2669
|
-
|
2670
|
-
|
3285
|
+
pca.fit(np.vstack([pos_chain1, pos_chain2]))
|
3286
|
+
pos_chain1_2d = pca.transform(pos_chain1)
|
3287
|
+
pos_chain2_2d = pca.transform(pos_chain2)
|
2671
3288
|
|
2672
3289
|
ax.plot(
|
2673
|
-
|
2674
|
-
|
3290
|
+
pos_chain1_2d[:, 0],
|
3291
|
+
pos_chain1_2d[:, 1],
|
2675
3292
|
"-",
|
2676
3293
|
label=f"root = {nodes1}",
|
2677
3294
|
)
|
2678
3295
|
ax.plot(
|
2679
|
-
|
2680
|
-
|
3296
|
+
pos_chain2_2d[:, 0],
|
3297
|
+
pos_chain2_2d[:, 1],
|
2681
3298
|
"-",
|
2682
3299
|
label=f"root = {nodes2}",
|
2683
3300
|
)
|
@@ -2706,7 +3323,7 @@ class lineageTree(lineageTreeLoaders):
|
|
2706
3323
|
'pca' : PCA projection"""
|
2707
3324
|
)
|
2708
3325
|
|
2709
|
-
connections = [[
|
3326
|
+
connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
|
2710
3327
|
|
2711
3328
|
for connection in connections:
|
2712
3329
|
xyz1 = connection[0]
|
@@ -2729,120 +3346,179 @@ class lineageTree(lineageTreeLoaders):
|
|
2729
3346
|
warnings.warn(
|
2730
3347
|
"Error: not possible to show alignment in PCA projection !",
|
2731
3348
|
UserWarning,
|
3349
|
+
stacklevel=2,
|
2732
3350
|
)
|
2733
3351
|
|
2734
3352
|
return distance, fig
|
2735
3353
|
|
2736
|
-
def first_labelling(self):
|
2737
|
-
self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
|
2738
|
-
|
2739
3354
|
def __init__(
|
2740
3355
|
self,
|
2741
|
-
|
2742
|
-
|
2743
|
-
|
2744
|
-
|
2745
|
-
|
2746
|
-
|
2747
|
-
|
2748
|
-
|
2749
|
-
|
2750
|
-
reorder: bool = False,
|
2751
|
-
xml_attributes: tuple = None,
|
2752
|
-
name: str = None,
|
2753
|
-
time_resolution: Union[int, None] = None,
|
3356
|
+
*,
|
3357
|
+
successor: dict[int, Sequence] | None = None,
|
3358
|
+
predecessor: dict[int, int | Sequence] | None = None,
|
3359
|
+
time: dict[int, int] | None = None,
|
3360
|
+
starting_time: int | None = None,
|
3361
|
+
pos: dict[int, Iterable] | None = None,
|
3362
|
+
name: str | None = None,
|
3363
|
+
root_leaf_value: Sequence | None = None,
|
3364
|
+
**kwargs,
|
2754
3365
|
):
|
2755
|
-
"""
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
2770
|
-
|
2771
|
-
|
2772
|
-
|
2773
|
-
|
2774
|
-
|
2775
|
-
|
2776
|
-
|
2777
|
-
|
2778
|
-
|
3366
|
+
"""Create a lineageTree object from minimal information, without reading from a file.
|
3367
|
+
Either `successor` or `predecessor` should be specified.
|
3368
|
+
|
3369
|
+
Parameters
|
3370
|
+
----------
|
3371
|
+
successor : dict mapping int to Iterable
|
3372
|
+
Dictionary assigning nodes to their successors.
|
3373
|
+
predecessor : dict mapping int to int or Iterable
|
3374
|
+
Dictionary assigning nodes to their predecessors.
|
3375
|
+
time : dict mapping int to int, optional
|
3376
|
+
Dictionary assigning nodes to the time point they were recorded to.
|
3377
|
+
Defaults to None, in which case all times are set to `starting_time`.
|
3378
|
+
starting_time : int, optional
|
3379
|
+
Starting time of the lineage tree. Defaults to 0.
|
3380
|
+
pos : dict mapping int to Iterable, optional
|
3381
|
+
Dictionary assigning nodes to their positions. Defaults to None.
|
3382
|
+
name : str, optional
|
3383
|
+
Name of the lineage tree. Defaults to None.
|
3384
|
+
root_leaf_value : Iterable, optional
|
3385
|
+
Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
|
3386
|
+
Defaults are `[None, (), [], set()]`.
|
3387
|
+
**kwargs:
|
3388
|
+
Supported keyword arguments are dictionaries assigning nodes to any custom property.
|
3389
|
+
The property must be specified for every node, and named differently from lineageTree's own attributes.
|
3390
|
+
"""
|
3391
|
+
self.__version__ = importlib.metadata.version("LineageTree")
|
3392
|
+
self.name = str(name) if name is not None else None
|
3393
|
+
if successor is not None and predecessor is not None:
|
3394
|
+
raise ValueError(
|
3395
|
+
"You cannot have both successors and predecessors."
|
3396
|
+
)
|
2779
3397
|
|
2780
|
-
|
2781
|
-
|
2782
|
-
|
2783
|
-
|
2784
|
-
|
2785
|
-
|
2786
|
-
|
2787
|
-
|
2788
|
-
|
2789
|
-
|
2790
|
-
self.
|
2791
|
-
|
2792
|
-
|
2793
|
-
|
2794
|
-
|
2795
|
-
|
2796
|
-
|
2797
|
-
|
2798
|
-
|
2799
|
-
|
2800
|
-
|
2801
|
-
|
2802
|
-
|
2803
|
-
|
2804
|
-
|
2805
|
-
|
2806
|
-
|
2807
|
-
|
2808
|
-
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
2812
|
-
|
2813
|
-
|
2814
|
-
|
3398
|
+
if root_leaf_value is None:
|
3399
|
+
root_leaf_value = [None, (), [], set()]
|
3400
|
+
elif not isinstance(root_leaf_value, Iterable):
|
3401
|
+
raise TypeError(
|
3402
|
+
f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
|
3403
|
+
)
|
3404
|
+
elif len(root_leaf_value) < 1:
|
3405
|
+
raise ValueError(
|
3406
|
+
"root_leaf_value should have at least one element."
|
3407
|
+
)
|
3408
|
+
self._successor = {}
|
3409
|
+
self._predecessor = {}
|
3410
|
+
if successor is not None:
|
3411
|
+
for pred, succs in successor.items():
|
3412
|
+
if succs in root_leaf_value:
|
3413
|
+
self._successor[pred] = ()
|
3414
|
+
else:
|
3415
|
+
if not isinstance(succs, Iterable):
|
3416
|
+
raise TypeError(
|
3417
|
+
f"Successors should be Iterable, got {type(succs)}."
|
3418
|
+
)
|
3419
|
+
if len(succs) == 0:
|
3420
|
+
raise ValueError(
|
3421
|
+
f"{succs} was not declared as a leaf but was found as a successor.\n"
|
3422
|
+
"Please lift the ambiguity."
|
3423
|
+
)
|
3424
|
+
self._successor[pred] = tuple(succs)
|
3425
|
+
for succ in succs:
|
3426
|
+
if succ in self._predecessor:
|
3427
|
+
raise ValueError(
|
3428
|
+
"Node can have at most one predecessor."
|
3429
|
+
)
|
3430
|
+
self._predecessor[succ] = (pred,)
|
3431
|
+
elif predecessor is not None:
|
3432
|
+
for succ, pred in predecessor.items():
|
3433
|
+
if pred in root_leaf_value:
|
3434
|
+
self._predecessor[succ] = ()
|
3435
|
+
else:
|
3436
|
+
if isinstance(pred, Sequence):
|
3437
|
+
if len(pred) == 0:
|
3438
|
+
raise ValueError(
|
3439
|
+
f"{pred} was not declared as a leaf but was found as a successor.\n"
|
3440
|
+
"Please lift the ambiguity."
|
3441
|
+
)
|
3442
|
+
if 1 < len(pred):
|
3443
|
+
raise ValueError(
|
3444
|
+
"Node can have at most one predecessor."
|
3445
|
+
)
|
3446
|
+
pred = pred[0]
|
3447
|
+
self._predecessor[succ] = (pred,)
|
3448
|
+
self._successor.setdefault(pred, ())
|
3449
|
+
self._successor[pred] += (succ,)
|
3450
|
+
for root in set(self._successor).difference(self._predecessor):
|
3451
|
+
self._predecessor[root] = ()
|
3452
|
+
for leaf in set(self._predecessor).difference(self._successor):
|
3453
|
+
self._successor[leaf] = ()
|
3454
|
+
|
3455
|
+
if self.__check_for_cycles():
|
3456
|
+
raise ValueError(
|
3457
|
+
"Cycles were found in the tree, there should not be any."
|
3458
|
+
)
|
3459
|
+
|
3460
|
+
if pos is None:
|
3461
|
+
self.pos = {}
|
3462
|
+
else:
|
3463
|
+
if self.nodes.difference(pos) != set():
|
3464
|
+
raise ValueError("Please provide the position of all nodes.")
|
3465
|
+
self.pos = {
|
3466
|
+
node: np.array(position) for node, position in pos.items()
|
3467
|
+
}
|
3468
|
+
|
3469
|
+
if time is None:
|
3470
|
+
if starting_time is None:
|
3471
|
+
starting_time = 0
|
3472
|
+
if not isinstance(starting_time, int):
|
3473
|
+
warnings.warn(
|
3474
|
+
f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
|
3475
|
+
stacklevel=2,
|
2815
3476
|
)
|
2816
|
-
|
2817
|
-
|
2818
|
-
|
3477
|
+
self._time = {node: starting_time for node in self.roots}
|
3478
|
+
queue = list(self.roots)
|
3479
|
+
for node in queue:
|
3480
|
+
for succ in self._successor[node]:
|
3481
|
+
self._time[succ] = self._time[node] + 1
|
3482
|
+
queue.append(succ)
|
3483
|
+
else:
|
3484
|
+
if starting_time is not None:
|
3485
|
+
warnings.warn(
|
3486
|
+
"Both `time` and `starting_time` were provided, `starting_time` was ignored.",
|
3487
|
+
stacklevel=2,
|
3488
|
+
)
|
3489
|
+
self._time = {n: int(time[n]) for n in self.nodes}
|
3490
|
+
if self._time != time:
|
3491
|
+
if len(self._time) != len(time):
|
3492
|
+
warnings.warn(
|
3493
|
+
"The provided `time` dictionary had keys that were not nodes. "
|
3494
|
+
"They have been removed",
|
3495
|
+
stacklevel=2,
|
3496
|
+
)
|
2819
3497
|
else:
|
2820
|
-
|
2821
|
-
|
2822
|
-
|
2823
|
-
|
2824
|
-
|
2825
|
-
|
2826
|
-
|
2827
|
-
|
2828
|
-
|
2829
|
-
|
2830
|
-
|
2831
|
-
self.
|
2832
|
-
|
2833
|
-
|
2834
|
-
|
2835
|
-
|
2836
|
-
|
2837
|
-
|
2838
|
-
|
2839
|
-
|
2840
|
-
|
2841
|
-
|
2842
|
-
|
2843
|
-
|
2844
|
-
|
2845
|
-
|
2846
|
-
|
2847
|
-
if self[succ] == []:
|
2848
|
-
self.predecessor.pop(succ)
|
3498
|
+
warnings.warn(
|
3499
|
+
"The provided `time` dictionary had values that were not `int`. "
|
3500
|
+
"These values have been truncated and converted to `int`",
|
3501
|
+
stacklevel=2,
|
3502
|
+
)
|
3503
|
+
if self.nodes.symmetric_difference(self._time) != set():
|
3504
|
+
raise ValueError(
|
3505
|
+
"Please provide the time of all nodes and only existing nodes."
|
3506
|
+
)
|
3507
|
+
if not all(
|
3508
|
+
self._time[node] < self._time[s]
|
3509
|
+
for node, succ in self._successor.items()
|
3510
|
+
for s in succ
|
3511
|
+
):
|
3512
|
+
raise ValueError(
|
3513
|
+
"Provided times are not strictly increasing. Setting times to default."
|
3514
|
+
)
|
3515
|
+
# custom properties
|
3516
|
+
for name, d in kwargs.items():
|
3517
|
+
if name in self.__dict__:
|
3518
|
+
warnings.warn(
|
3519
|
+
f"Attribute name {name} is reserved.", stacklevel=2
|
3520
|
+
)
|
3521
|
+
continue
|
3522
|
+
setattr(self, name, d)
|
3523
|
+
if not hasattr(self, "_comparisons"):
|
3524
|
+
self._comparisons = {}
|