LineageTree 1.7.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,625 +2,538 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.metadata
5
9
  import os
6
10
  import pickle as pkl
7
11
  import struct
8
12
  import warnings
9
- from collections.abc import Iterable
10
- from functools import partial
13
+ from collections.abc import Callable, Iterable, Sequence
14
+ from functools import partial, wraps
11
15
  from itertools import combinations
12
16
  from numbers import Number
13
- from pathlib import Path
14
- from typing import TextIO, Union
15
-
16
- from .loaders import lineageTreeLoaders
17
- from .tree_styles import tree_style
18
-
19
- try:
20
- from edist import uted
21
- except ImportError:
22
- warnings.warn(
23
- "No edist installed therefore you will not be able to compute the tree edit distance.",
24
- stacklevel=2,
25
- )
17
+ from types import MappingProxyType
18
+ from typing import TYPE_CHECKING, Literal
19
+
20
+ import matplotlib.colors as mcolors
26
21
  import matplotlib.pyplot as plt
27
22
  import numpy as np
23
+ import svgwrite
24
+ from edist import uted
25
+ from matplotlib import colormaps
26
+ from matplotlib.collections import LineCollection
27
+ from packaging.version import Version
28
28
  from scipy.interpolate import InterpolatedUnivariateSpline
29
- from scipy.spatial import Delaunay, distance
30
- from scipy.spatial import cKDTree as KDTree
29
+ from scipy.sparse import dok_array
30
+ from scipy.spatial import Delaunay, KDTree, distance
31
31
 
32
+ from .tree_approximation import TreeApproximationTemplate, tree_style
32
33
  from .utils import (
33
- create_links_and_cycles,
34
+ convert_style_to_number,
35
+ create_links_and_chains,
34
36
  hierarchical_pos,
35
37
  )
36
38
 
39
+ if TYPE_CHECKING:
40
+ from edist.alignment import Alignment
41
+
42
+
43
+ class dynamic_property(property):
44
+ def __init__(
45
+ self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
46
+ ):
47
+ super().__init__(fget, fset, fdel, doc)
48
+ self.protected_name = protected_name
49
+
50
+ def __set_name__(self, owner, name):
51
+ self.name = name
52
+ if self.protected_name is None:
53
+ self.protected_name = f"_{name}"
54
+ if not hasattr(owner, "_protected_dynamic_properties"):
55
+ owner._protected_dynamic_properties = []
56
+ owner._protected_dynamic_properties.append(self.protected_name)
57
+ if not hasattr(owner, "_dynamic_properties"):
58
+ owner._dynamic_properties = []
59
+ owner._dynamic_properties += [name, self.protected_name]
60
+ setattr(owner, self.protected_name, None)
61
+
62
+ def __get__(self, instance, owner):
63
+ if instance is None:
64
+ return self
65
+ instance._has_been_reset = False
66
+ if getattr(instance, self.protected_name) is None:
67
+ value = super().__get__(instance, owner)
68
+ setattr(instance, self.protected_name, value)
69
+ return value
70
+ else:
71
+ return getattr(instance, self.protected_name)
72
+
73
+
74
+ class lineageTree:
75
+ norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
76
+
77
+ def modifier(wrapped_func):
78
+ @wraps(wrapped_func)
79
+ def raising_flag(self, *args, **kwargs):
80
+ should_reset = (
81
+ not hasattr(self, "_has_been_reset")
82
+ or not self._has_been_reset
83
+ )
84
+ out_func = wrapped_func(self, *args, **kwargs)
85
+ if should_reset:
86
+ for prop in self._protected_dynamic_properties:
87
+ self.__dict__[prop] = None
88
+ self._has_been_reset = True
89
+ return out_func
90
+
91
+ return raising_flag
92
+
93
+ def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
94
+ """Check if the connected component of a given node `n` has a cycle.
95
+
96
+ Returns
97
+ -------
98
+ bool
99
+ True if the tree has cycles, False otherwise.
100
+ set of int
101
+ The set of nodes that have been checked.
102
+ """
103
+ to_do = [n]
104
+ no_cycle = True
105
+ already_done = set()
106
+ while to_do and no_cycle:
107
+ current = to_do.pop(-1)
108
+ if current not in already_done:
109
+ already_done.add(current)
110
+ else:
111
+ no_cycle = False
112
+ to_do.extend(self._successor[current])
113
+ to_do = list(self._predecessor[n])
114
+ while to_do and no_cycle:
115
+ current = to_do.pop(-1)
116
+ if current not in already_done:
117
+ already_done.add(current)
118
+ else:
119
+ no_cycle = False
120
+ to_do.extend(self._predecessor[current])
121
+ return not no_cycle, already_done
122
+
123
+ def __check_for_cycles(self) -> bool:
124
+ """Check if the tree has cycles.
125
+
126
+ Returns
127
+ -------
128
+ bool
129
+ True if the tree has cycles, False otherwise.
130
+ """
131
+ to_do = set(self.nodes)
132
+ found_cycle = False
133
+ while to_do and not found_cycle:
134
+ current = to_do.pop()
135
+ found_cycle, done = self.__check_cc_cycles(current)
136
+ to_do.difference_update(done)
137
+ return found_cycle
37
138
 
38
- class lineageTree(lineageTreeLoaders):
39
- def __eq__(self, other):
139
+ def __eq__(self, other) -> bool:
40
140
  if isinstance(other, lineageTree):
41
- return other.successor == self.successor
42
- return False
141
+ return (
142
+ other._successor == self._successor
143
+ and other._predecessor == self._predecessor
144
+ and other._time == self._time
145
+ )
146
+ else:
147
+ return False
43
148
 
44
- def get_next_id(self):
45
- """Computes the next authorized id.
149
+ def get_next_id(self) -> int:
150
+ """Computes the next authorized id and assign it.
46
151
 
47
- Returns:
48
- int: next authorized id
152
+ Returns
153
+ -------
154
+ int
155
+ next authorized id
49
156
  """
50
- if self.max_id == -1 and self.nodes:
51
- self.max_id = max(self.nodes)
52
- if self.next_id == []:
157
+ if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
158
+ self.max_id = max(self.nodes) if len(self.nodes) else 0
159
+ if not hasattr(self, "next_id") or self.next_id == []:
53
160
  self.max_id += 1
54
161
  return self.max_id
55
162
  else:
56
163
  return self.next_id.pop()
57
164
 
58
- def complete_lineage(self, nodes: Union[int, set] = None):
59
- """Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
60
- for tree edit distance algorithms.
61
-
62
- Args:
63
- nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
64
- """
65
- if nodes is None:
66
- nodes = set(self.roots)
67
- elif isinstance(nodes, int):
68
- nodes = {nodes}
69
- for node in nodes:
70
- sub = set(self.get_sub_tree(node))
71
- specific_leaves = sub.intersection(self.leaves)
72
- for leaf in specific_leaves:
73
- self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
74
-
75
165
  ###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
76
- def add_branch(
166
+ @modifier
167
+ def add_chain(
77
168
  self,
78
- pred: int,
169
+ node: int,
79
170
  length: int,
80
- move_timepoints: bool = True,
81
- pos: Union[callable, None] = None,
82
- reverse: bool = False,
83
- ):
84
- """Adds a branch of specific length to a node either as a successor or as a predecessor.
171
+ downstream: bool,
172
+ pos: Callable | None = None,
173
+ ) -> int:
174
+ """Adds a chain of specific length to a node either as a successor or as a predecessor.
85
175
  If it is placed on top of a tree all the nodes will move timepoints #length down.
86
176
 
87
- Args:
88
- pred (int): Id of the successor (predecessor if reverse is False)
89
- length (int): The length of the new branch.
90
- pos (np.ndarray, optional): The new position of the branch. Defaults to None.
91
- move_timepoints (bool): Moves the ti Important only if reverse= True
92
- reverese (bool): If True will create a branch that goes forwards in time otherwise backwards.
93
- Returns:
94
- (int): Id of the first node of the sublineage.
177
+ Parameters
178
+ ----------
179
+ node : int
180
+ Id of the successor (predecessor if `downstream==False`)
181
+ length : int
182
+ The length of the new chain.
183
+ downstream : bool, default=True
184
+ If `True` will create a chain that goes forwards in time otherwise backwards.
185
+ pos : np.ndarray, optional
186
+ The new position of the chain. Defaults to None.
187
+
188
+ Returns
189
+ -------
190
+ int
191
+ Id of the first node of the sublineage.
95
192
  """
96
193
  if length == 0:
97
- return pred
98
- if self.predecessor.get(pred) and not reverse:
99
- raise Warning("Cannot add 2 predecessors to a node")
100
- time = self.time[pred]
101
- original = pred
102
- if not reverse:
103
- if move_timepoints:
104
- nodes_to_move = set(self.get_sub_tree(pred))
105
- new_times = {
106
- node: self.time[node] + length for node in nodes_to_move
107
- }
108
- for node in nodes_to_move:
109
- old_time = self.time[node]
110
- self.time_nodes[old_time].remove(node)
111
- self.time_nodes.setdefault(old_time + length, set()).add(
112
- node
113
- )
114
- self.time.update(new_times)
115
- for t in range(length - 1, -1, -1):
116
- _next = self.add_node(
117
- time + t,
118
- succ=pred,
119
- pos=self.pos[original],
120
- reverse=True,
121
- )
122
- pred = _next
123
- else:
124
- for t in range(length):
125
- _next = self.add_node(
126
- time - t,
127
- succ=pred,
128
- pos=self.pos[original],
129
- reverse=True,
130
- )
131
- pred = _next
194
+ return node
195
+ if length < 1:
196
+ raise ValueError("Length cannot be <1")
197
+ if downstream:
198
+ for _ in range(int(length)):
199
+ old_node = node
200
+ node = self._add_node(pred=[old_node])
201
+ self._time[node] = self._time[old_node] + 1
132
202
  else:
133
- for _ in range(length):
134
- _next = self.add_node(
135
- self.time[pred] + 1,
136
- succ=pred,
137
- pos=self.pos[original],
138
- reverse=False,
203
+ if self._predecessor[node]:
204
+ raise Warning("The node already has a predecessor.")
205
+ if self._time[node] - length < self.t_b:
206
+ raise Warning(
207
+ "A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
139
208
  )
140
- pred = _next
141
- self.labels[pred] = "New branch"
142
- if self.time[pred] == self.t_b:
143
- self.roots.add(pred)
144
- self.labels[pred] = "New branch"
145
- if original in self.roots and reverse is True:
146
- self.roots.add(pred)
147
- self.labels[pred] = "New branch"
148
- self.roots.remove(original)
149
- self.labels.pop(original, -1)
150
- self.t_e = max(self.time_nodes)
151
- return pred
152
-
153
- def cut_tree(self, root):
154
- """It transforms a lineage that has at least 2 divisions into 2 independent lineages,
155
- that spawn from the time point of the first node. (splits a tree into 2)
156
-
157
- Args:
158
- root (int): The id of the node, which will be cut.
159
-
160
- Returns:
161
- int: The id of the new tree
162
- """
163
- cycle = self.get_successors(root)
164
- last_cell = cycle[-1]
165
- if last_cell in self.successor:
166
- new_lT = self.successor[last_cell].pop()
167
- self.predecessor.pop(new_lT)
168
- label_of_root = self.labels.get(cycle[0], cycle[0])
169
- self.labels[cycle[0]] = f"L-Split {label_of_root}"
170
- new_tr = self.add_branch(
171
- new_lT, len(cycle) + 1, move_timepoints=False
172
- )
173
- self.roots.add(new_tr)
174
- self.labels[new_tr] = f"R-Split {label_of_root}"
175
- return new_tr
176
- else:
177
- raise Warning("No division of the branch")
178
-
179
- def fuse_lineage_tree(
180
- self,
181
- l1_root: int,
182
- l2_root: int,
183
- length_l1: int = 0,
184
- length_l2: int = 0,
185
- length: int = 1,
186
- ):
187
- """Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
188
- first node and the node of the resulting lineage can also be longer.
189
-
190
- Args:
191
- l1_root (int): Id of the first root
192
- l2_root (int): Id of the second root
193
- length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
194
- length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
195
- length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
196
-
197
- Returns:
198
- int: The id of the root of the new lineage.
199
- """
200
- if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
201
- raise ValueError("Please select 2 roots.")
202
- if self.time[l1_root] != self.time[l2_root]:
203
- warnings.warn(
204
- "Using lineagetrees that do not exist in the same timepoint. The operation will continue"
205
- )
206
- new_root1 = self.add_branch(l1_root, length_l1)
207
- new_root2 = self.add_branch(l2_root, length_l2)
208
- next_root1 = self[new_root1][0]
209
- self.remove_nodes(new_root1)
210
- self.successor[new_root2].append(next_root1)
211
- self.predecessor[next_root1] = [new_root2]
212
- new_branch = self.add_branch(new_root2, length)
213
- self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
214
- return new_branch
215
-
216
- def copy_lineage(self, root):
217
- """
218
- Copies the structure of a tree and makes a new with new nodes.
219
- Warning does not take into account the predecessor of the root node.
220
-
221
- Args:
222
- root (int): The root of the tree to be copied
209
+ for _ in range(int(length)):
210
+ old_node = node
211
+ node = self._add_node(succ=[old_node])
212
+ self._time[node] = self._time[old_node] - 1
213
+ return node
214
+
215
+ @modifier
216
+ def add_root(self, t: int, pos: list | None = None) -> int:
217
+ """Adds a root to a specific timepoint.
218
+
219
+ Parameters
220
+ ----------
221
+ t :int
222
+ The timepoint the node is going to be added.
223
+ pos : list
224
+ The position of the new node.
225
+ Returns
226
+ -------
227
+ int
228
+ The id of the new root.
229
+ """
230
+ C_next = self.get_next_id()
231
+ self._successor[C_next] = ()
232
+ self._predecessor[C_next] = ()
233
+ self._time[C_next] = t
234
+ self.pos[C_next] = pos if isinstance(pos, list) else []
235
+ self._changed_roots = True
236
+ return C_next
223
237
 
224
- Returns:
225
- int: The root of the new tree.
226
- """
227
- new_nodes = {
228
- old_node: self.get_next_id()
229
- for old_node in self.get_sub_tree(root)
230
- }
231
- self.nodes.update(new_nodes.values())
232
- for old_node, new_node in new_nodes.items():
233
- self.time[new_node] = self.time[old_node]
234
- succ = self.successor.get(old_node)
235
- if succ:
236
- self.successor[new_node] = [new_nodes[n] for n in succ]
237
- pred = self.predecessor.get(old_node)
238
- if pred:
239
- self.predecessor[new_node] = [new_nodes[n] for n in pred]
240
- self.pos[new_node] = self.pos[old_node] + 0.5
241
- self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
242
- new_root = new_nodes[root]
243
- self.labels[new_root] = f"Copy of {root}"
244
- if self.time[new_root] == 0:
245
- self.roots.add(new_root)
246
- return new_root
247
-
248
- def add_node(
238
+ def _add_node(
249
239
  self,
250
- t: int = None,
251
- succ: int = None,
252
- pos: np.ndarray = None,
253
- nid: int = None,
254
- reverse: bool = False,
240
+ succ: list | None = None,
241
+ pred: list | None = None,
242
+ pos: np.ndarray | None = None,
243
+ nid: int | None = None,
255
244
  ) -> int:
256
- """Adds a node to the lineageTree and update it accordingly.
257
-
258
- Args:
259
- t (int): int, time to which to add the node
260
- succ (int): id of the node the new node is a successor to
261
- pos ([float, ]): list of three floats representing the 3D
262
- spatial position of the node
263
- nid (int): id value of the new node, to be used carefully,
264
- if None is provided the new id is automatically computed.
265
- reverse (bool): True if in this lineageTree the predecessors
266
- are the successors and reciprocally.
267
- This is there for bacward compatibility, should be left at False.
268
- Returns:
269
- int: id of the new node.
270
- """
245
+ """Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
246
+ Does not handle time! You cannot enter both a successor and a predecessor.
247
+
248
+ Parameters
249
+ ----------
250
+ succ : list
251
+ list of ids of the nodes the new node is a successor to
252
+ pred : list
253
+ list of ids of the nodes the new node is a predecessor to
254
+ pos : np.ndarray, optional
255
+ position of the new node
256
+ nid : int, optional
257
+ id value of the new node, to be used carefully,
258
+ if None is provided the new id is automatically computed.
259
+
260
+ Returns
261
+ -------
262
+ int
263
+ id of the new node.
264
+ """
265
+ if not succ and not pred:
266
+ raise Warning(
267
+ "Please enter a successor or a predecessor, otherwise use the add_roots() function."
268
+ )
271
269
  C_next = self.get_next_id() if nid is None else nid
272
- self.time_nodes.setdefault(t, set()).add(C_next)
273
- if succ is not None and not reverse:
274
- self.successor.setdefault(succ, []).append(C_next)
275
- self.predecessor.setdefault(C_next, []).append(succ)
276
- elif succ is not None:
277
- self.predecessor.setdefault(succ, []).append(C_next)
278
- self.successor.setdefault(C_next, []).append(succ)
279
- self.nodes.add(C_next)
280
- self.pos[C_next] = pos
281
- self.time[C_next] = t
270
+ if succ:
271
+ self._successor[C_next] = succ
272
+ for suc in succ:
273
+ self._predecessor[suc] = (C_next,)
274
+ else:
275
+ self._successor[C_next] = ()
276
+ if pred:
277
+ self._predecessor[C_next] = pred
278
+ self._successor[pred[0]] = self._successor.setdefault(
279
+ pred[0], ()
280
+ ) + (C_next,)
281
+ else:
282
+ self._predecessor[C_next] = ()
283
+ if isinstance(pos, list):
284
+ self.pos[C_next] = pos
282
285
  return C_next
283
286
 
284
- def remove_nodes(self, group: Union[int, set, list]):
287
+ @modifier
288
+ def remove_nodes(self, group: int | set | list) -> None:
285
289
  """Removes a group of nodes from the LineageTree
286
290
 
287
- Args:
288
- group (set|list|int): One or more nodes that are to be removed.
291
+ Parameters
292
+ ----------
293
+ group : set of int or list of int or int
294
+ One or more nodes that are to be removed.
289
295
  """
290
- if isinstance(group, int):
296
+ if isinstance(group, int | float):
291
297
  group = {group}
292
298
  if isinstance(group, list):
293
299
  group = set(group)
294
- group = group.intersection(self.nodes)
295
- self.nodes.difference_update(group)
296
- times = {self.time.pop(n) for n in group}
297
- for t in times:
298
- self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
300
+ group = self.nodes.intersection(group)
299
301
  for node in group:
300
- self.pos.pop(node)
301
- if self.predecessor.get(node):
302
- pred = self.predecessor[node][0]
303
- siblings = self.successor.pop(pred, [])
304
- if len(siblings) == 2:
305
- siblings.remove(node)
306
- self.successor[pred] = siblings
307
- self.predecessor.pop(node, [])
308
- for succ in self.successor.get(node, []):
309
- self.predecessor.pop(succ, [])
310
- self.successor.pop(node, [])
311
- self.labels.pop(node, 0)
312
- if node in self.roots:
313
- self.roots.remove(node)
314
-
315
- def modify_branch(self, node, new_length):
316
- """Changes the length of a branch, so it adds or removes nodes
317
- to make the correct length of the cycle.
318
-
319
- Args:
320
- node (int): Any node of the branch to be modified/
321
- new_length (int): The new length of the tree.
322
- """
323
- if new_length <= 1:
324
- warnings.warn("New length should be more than 1")
325
- return None
326
- cycle = self.get_cycle(node)
327
- length = len(cycle)
328
- successors = self.successor.get(cycle[-1])
329
- if length == 1 and new_length != 1:
330
- pred = self.predecessor.pop(node, None)
331
- new_node = self.add_branch(
332
- node,
333
- length=new_length - 1,
334
- move_timepoints=True,
335
- reverse=False,
336
- )
337
- if pred:
338
- self.successor[pred[0]].remove(node)
339
- self.successor[pred[0]].append(new_node)
340
- elif self.leaves.intersection(cycle) and new_length < length:
341
- self.remove_nodes(cycle[new_length:])
342
- elif new_length < length:
343
- to_remove = length - new_length
344
- last_cell = cycle[new_length - 1]
345
- subtree = self.get_sub_tree(cycle[-1])[1:]
346
- self.remove_nodes(cycle[new_length:])
347
- self.successor[last_cell] = successors
348
- if successors:
349
- for succ in successors:
350
- self.predecessor[succ] = [last_cell]
351
- for node in subtree:
352
- if node not in cycle[new_length - 1 :]:
353
- old_time = self.time[node]
354
- self.time[node] = old_time - to_remove
355
- self.time_nodes[old_time].remove(node)
356
- self.time_nodes.setdefault(
357
- old_time - to_remove, set()
358
- ).add(node)
359
- elif length < new_length:
360
- to_add = new_length - length
361
- last_cell = cycle[-1]
362
- self.successor.pop(cycle[-2])
363
- self.predecessor.pop(last_cell)
364
- succ = self.add_branch(
365
- last_cell, length=to_add, move_timepoints=True, reverse=False
366
- )
367
- self.predecessor[succ] = [cycle[-2]]
368
- self.successor[cycle[-2]] = [succ]
369
- self.time[last_cell] = (
370
- self.time[self.predecessor[last_cell][0]] + 1
371
- )
372
- else:
373
- return None
374
-
375
- @property
376
- def time_resolution(self):
377
- if not hasattr(self, "_time_resolution"):
378
- self.time_resolution = 1
379
- return self._time_resolution / 10
380
-
381
- @time_resolution.setter
382
- def time_resolution(self, time_resolution):
383
- if time_resolution is not None:
384
- self._time_resolution = int(time_resolution * 10)
385
- else:
386
- warnings.warn("Time resolution set to default 1", stacklevel=2)
387
- self._time_resolution = 10
302
+ for attr in self.__dict__:
303
+ attr_value = self.__getattribute__(attr)
304
+ if isinstance(attr_value, dict) and attr not in [
305
+ "successor",
306
+ "predecessor",
307
+ "_successor",
308
+ "_predecessor",
309
+ "_time",
310
+ ]:
311
+ attr_value.pop(node, ())
312
+ if self._predecessor.get(node):
313
+ self._successor[self._predecessor[node][0]] = tuple(
314
+ set(
315
+ self._successor[self._predecessor[node][0]]
316
+ ).difference(group)
317
+ )
318
+ for p_node in self._successor.get(node, []):
319
+ self._predecessor[p_node] = ()
320
+ self._predecessor.pop(node, ())
321
+ self._successor.pop(node, ())
388
322
 
389
323
  @property
390
- def roots(self):
391
- if not hasattr(self, "_roots"):
392
- self._roots = set(self.nodes).difference(self.predecessor)
393
- return self._roots
324
+ def successor(self) -> MappingProxyType[int, tuple[int]]:
325
+ """The successor of the tree."""
326
+ if not hasattr(self, "_protected_successor"):
327
+ self._protected_successor = MappingProxyType(self._successor)
328
+ return self._protected_successor
394
329
 
395
330
  @property
396
- def edges(self):
397
- return {(k, vi) for k, v in self.successor.items() for vi in v}
331
+ def predecessor(self) -> MappingProxyType[int, tuple[int]]:
332
+ """The predecessor of the tree."""
333
+ if not hasattr(self, "_protected_predecessor"):
334
+ self._protected_predecessor = MappingProxyType(self._predecessor)
335
+ return self._protected_predecessor
398
336
 
399
337
  @property
400
- def leaves(self):
401
- return set(self.predecessor).difference(self.successor)
338
+ def time(self) -> MappingProxyType[int, int]:
339
+ """The time of the tree."""
340
+ if not hasattr(self, "_protected_time"):
341
+ self._protected_time = MappingProxyType(self._time)
342
+ return self._protected_time
343
+
344
+ @dynamic_property
345
+ def t_b(self) -> int:
346
+ """The first timepoint of the tree."""
347
+ return min(self._time.values())
348
+
349
+ @dynamic_property
350
+ def t_e(self) -> int:
351
+ """The last timepoint of the tree."""
352
+ return max(self._time.values())
353
+
354
+ @dynamic_property
355
+ def nodes(self) -> frozenset[int]:
356
+ """Nodes of the tree"""
357
+ return frozenset(self._successor.keys())
358
+
359
+ @dynamic_property
360
+ def depth(self) -> dict[int, int]:
361
+ """The depth of each node in the tree."""
362
+ _depth = {}
363
+ for leaf in self.leaves:
364
+ _depth[leaf] = 1
365
+ while leaf in self._predecessor and self._predecessor[leaf]:
366
+ parent = self._predecessor[leaf][0]
367
+ current_depth = _depth.get(parent, 0)
368
+ _depth[parent] = max(_depth[leaf] + 1, current_depth)
369
+ leaf = parent
370
+ for root in self.roots - set(_depth):
371
+ _depth[root] = 1
372
+ return _depth
373
+
374
+ @dynamic_property
375
+ def roots(self) -> frozenset[int]:
376
+ """Set of roots of the tree"""
377
+ return frozenset({s for s, p in self._predecessor.items() if p == ()})
378
+
379
+ @dynamic_property
380
+ def leaves(self) -> frozenset[int]:
381
+ """Set of leaves"""
382
+ return frozenset({p for p, s in self._successor.items() if s == ()})
383
+
384
+ @dynamic_property
385
+ def edges(self) -> tuple[tuple[int, int]]:
386
+ """Set of edges"""
387
+ return tuple((p, si) for p, s in self._successor.items() for si in s)
402
388
 
403
389
  @property
404
- def labels(self):
390
+ def labels(self) -> dict[int, str]:
391
+ """The labels of the nodes."""
405
392
  if not hasattr(self, "_labels"):
406
- if hasattr(self, "cell_name"):
393
+ if hasattr(self, "node_name"):
407
394
  self._labels = {
408
- i: self.cell_name.get(i, "Unlabeled") for i in self.roots
395
+ i: self.node_name.get(i, "Unlabeled") for i in self.roots
409
396
  }
410
397
  else:
411
398
  self._labels = {
412
- i: "Unlabeled"
413
- for i in self.roots
414
- for l in self.find_leaves(i)
415
- if abs(self.time[l] - self.time[i])
399
+ root: "Unlabeled"
400
+ for root in self.roots
401
+ for leaf in self.find_leaves(root)
402
+ if abs(self._time[leaf] - self._time[root])
416
403
  >= abs(self.t_e - self.t_b) / 4
417
404
  }
418
405
  return self._labels
419
406
 
420
- def _write_header_am(self, f: TextIO, nb_points: int, length: int):
421
- """Header for Amira .am files"""
422
- f.write("# AmiraMesh 3D ASCII 2.0\n")
423
- f.write("define VERTEX %d\n" % (nb_points * 2))
424
- f.write("define EDGE %d\n" % nb_points)
425
- f.write("define POINT %d\n" % ((length) * nb_points))
426
- f.write("Parameters {\n")
427
- f.write('\tContentType "HxSpatialGraph"\n')
428
- f.write("}\n")
429
-
430
- f.write("VERTEX { float[3] VertexCoordinates } @1\n")
431
- f.write("EDGE { int[2] EdgeConnectivity } @2\n")
432
- f.write("EDGE { int NumEdgePoints } @3\n")
433
- f.write("POINT { float[3] EdgePointCoordinates } @4\n")
434
- f.write("VERTEX { float Vcolor } @5\n")
435
- f.write("VERTEX { int Vbool } @6\n")
436
- f.write("EDGE { float Ecolor } @7\n")
437
- f.write("VERTEX { int Vbool2 } @8\n")
438
-
439
- def write_to_am(
440
- self,
441
- path_format: str,
442
- t_b: int = None,
443
- t_e: int = None,
444
- length: int = 5,
445
- manual_labels: dict = None,
446
- default_label: int = 5,
447
- new_pos: np.ndarray = None,
448
- ):
449
- """Writes a lineageTree into an Amira readable data (.am format).
450
-
451
- Args:
452
- path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
453
- t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
454
- t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
455
- note, if there is no 'to_take_time' attribute, self.time_nodes
456
- is considered instead (historical)
457
- length (int): length of the track to print (how many time before).
458
- manual_labels ({id: label, }): dictionary that maps cell ids to
459
- default_label (int): default value for the manual label
460
- new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
461
- if new_pos == None (default) then self.pos is considered.
462
- """
463
- if not hasattr(self, "to_take_time"):
464
- self.to_take_time = self.time_nodes
465
- if t_b is None:
466
- t_b = min(self.to_take_time.keys())
467
- if t_e is None:
468
- t_e = max(self.to_take_time.keys())
469
- if new_pos is None:
470
- new_pos = self.pos
471
-
472
- if manual_labels is None:
473
- manual_labels = {}
474
- for t in range(t_b, t_e + 1):
475
- with open(path_format % t, "w") as f:
476
- nb_points = len(self.to_take_time[t])
477
- self._write_header_am(f, nb_points, length)
478
- points_v = {}
479
- for C in self.to_take_time[t]:
480
- C_tmp = C
481
- positions = []
482
- for _ in range(length):
483
- C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
484
- positions.append(new_pos[C_tmp])
485
- points_v[C] = positions
486
-
487
- f.write("@1\n")
488
- for C in self.to_take_time[t]:
489
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
490
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
491
-
492
- f.write("@2\n")
493
- for i, _ in enumerate(self.to_take_time[t]):
494
- f.write("%d %d\n" % (2 * i, 2 * i + 1))
495
-
496
- f.write("@3\n")
497
- for _ in self.to_take_time[t]:
498
- f.write("%d\n" % (length))
499
-
500
- f.write("@4\n")
501
- for C in self.to_take_time[t]:
502
- for p in points_v[C]:
503
- f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
504
-
505
- f.write("@5\n")
506
- for C in self.to_take_time[t]:
507
- f.write("%f\n" % (manual_labels.get(C, default_label)))
508
- f.write(f"{0:f}\n")
509
-
510
- f.write("@6\n")
511
- for C in self.to_take_time[t]:
512
- f.write(
513
- "%d\n"
514
- % (
515
- int(
516
- manual_labels.get(C, default_label)
517
- != default_label
518
- )
519
- )
520
- )
521
- f.write("%d\n" % (0))
522
-
523
- f.write("@7\n")
524
- for C in self.to_take_time[t]:
525
- f.write(
526
- "%f\n"
527
- % (np.linalg.norm(points_v[C][0] - points_v[C][-1]))
528
- )
529
-
530
- f.write("@8\n")
531
- for _ in self.to_take_time[t]:
532
- f.write("%d\n" % (1))
533
- f.write("%d\n" % (0))
534
- f.close()
407
+ @property
408
+ def time_resolution(self) -> float:
409
+ if not hasattr(self, "_time_resolution"):
410
+ self._time_resolution = 0
411
+ return self._time_resolution / 10
535
412
 
536
- def _get_height(self, c: int, done: dict):
537
- """Recursively computes the height of a cell within a tree * a space factor.
413
+ @time_resolution.setter
414
+ def time_resolution(self, time_resolution) -> None:
415
+ if time_resolution is not None and time_resolution > 0:
416
+ self._time_resolution = int(time_resolution * 10)
417
+ else:
418
+ warnings.warn("Time resolution set to default 0", stacklevel=2)
419
+ self._time_resolution = 0
420
+
421
+ def __setstate__(self, state):
422
+ if "_successor" not in state:
423
+ state["_successor"] = state["successor"]
424
+ if "_predecessor" not in state:
425
+ state["_predecessor"] = state["predecessor"]
426
+ if "_time" not in state:
427
+ state["_time"] = state["time"]
428
+ self.__dict__.update(state)
429
+
430
+ def _get_height(self, c: int, done: dict) -> float:
431
+ """Recursively computes the height of a node within a tree times a space factor.
538
432
  This function is specific to the function write_to_svg.
539
433
 
540
- Args:
541
- c (int): id of a cell in a lineage tree from which the height will be computed from
542
- done ({int: [int, int]}): a dictionary that maps a cell id to its vertical and horizontal position
543
- Returns:
544
- float:
434
+ Parameters
435
+ ----------
436
+ c : int
437
+ id of a node in a lineage tree from which the height will be computed from
438
+ done : dict mapping int to list of two int
439
+ a dictionary that maps a node id to its vertical and horizontal position
440
+
441
+ Returns
442
+ -------
443
+ float
444
+ the height of the node `c`
545
445
  """
546
446
  if c in done:
547
447
  return done[c][0]
548
448
  else:
549
449
  P = np.mean(
550
- [self._get_height(di, done) for di in self.successor[c]]
450
+ [self._get_height(di, done) for di in self._successor[c]]
551
451
  )
552
- done[c] = [P, self.vert_space_factor * self.time[c]]
452
+ done[c] = [P, self.vert_space_factor * self._time[c]]
553
453
  return P
554
454
 
555
455
  def write_to_svg(
556
456
  self,
557
457
  file_name: str,
558
- roots: list = None,
458
+ roots: list | None = None,
559
459
  draw_nodes: bool = True,
560
460
  draw_edges: bool = True,
561
- order_key: callable = None,
461
+ order_key: Callable | None = None,
562
462
  vert_space_factor: float = 0.5,
563
463
  horizontal_space: float = 1,
564
- node_size: callable = None,
565
- stroke_width: callable = None,
464
+ node_size: Callable | str | None = None,
465
+ stroke_width: Callable | None = None,
566
466
  factor: float = 1.0,
567
- node_color: callable = None,
568
- stroke_color: callable = None,
569
- positions: dict = None,
570
- node_color_map: callable = None,
571
- normalize: bool = True,
572
- ):
573
- ##### remove background? default True background value? default 1
574
-
467
+ node_color: Callable | str | None = None,
468
+ stroke_color: Callable | None = None,
469
+ positions: dict | None = None,
470
+ node_color_map: Callable | str | None = None,
471
+ ) -> None:
575
472
  """Writes the lineage tree to an SVG file.
576
473
  Node and edges coloring and size can be provided.
577
474
 
578
- Args:
579
- file_name (str): filesystem filename valid for `open()`
580
- roots ([int, ...]): list of node ids to be drawn. If `None` all the nodes will be drawn. Default `None`
581
- draw_nodes (bool): wether to print the nodes or not, default `True`
582
- draw_edges (bool): wether to print the edges or not, default `True`
583
- order_key (callable): function that would work for the attribute `key=` for the `sort`/`sorted` function
584
- vert_space_factor (float): the vertical position of a node is its time. `vert_space_factor` is a
585
- multiplier to space more or less nodes in time
586
- horizontal_space (float): space between two consecutive nodes
587
- node_size (callable | str): a function that maps a node id to a `float` value that will determine the
588
- radius of the node. The default function return the constant value `vertical_space_factor/2.1`
589
- If a string is given instead and it is a property of the tree,
590
- the the size will be mapped according to the property
591
- stroke_width (callable): a function that maps a node id to a `float` value that will determine the
592
- width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
593
- factor (float): scaling factor for nodes positions, default 1
594
- node_color (callable | str): a function that maps a node id to a triplet between 0 and 255.
595
- The triplet will determine the color of the node. If a string is given instead and it is a property
596
- of the tree, the the color will be mapped according to the property
597
- node_color_map (callable | str): the name of the colormap to use to color the nodes, or a colormap function
598
- stroke_color (callable): a function that maps a node id to a triplet between 0 and 255.
599
- The triplet will determine the color of the stroke of the inward edge.
600
- positions ({int: [float, float], ...}): dictionary that maps a node id to a 2D position.
601
- Default `None`. If provided it will be used to position the nodes.
475
+ Parameters
476
+ ----------
477
+ file_name : str
478
+ filesystem filename valid for `open()`
479
+ roots : list of int, defaults to `self.roots`
480
+ list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
481
+ draw_nodes : bool, default True
482
+ wether to print the nodes or not
483
+ draw_edges : bool, default True
484
+ wether to print the edges or not
485
+ order_key : Callable, optional
486
+ function that would work for the attribute `key=` for the `sort`/`sorted` function
487
+ vert_space_factor : float, default=0.5
488
+ the vertical position of a node is its time. `vert_space_factor` is a
489
+ multiplier to space more or less nodes in time
490
+ horizontal_space : float, default=1
491
+ space between two consecutive nodes
492
+ node_size : Callable or str, optional
493
+ a function that maps a node id to a `float` value that will determine the
494
+ radius of the node. The default function return the constant value `vertical_space_factor/2.1`
495
+ If a string is given instead and it is a property of the tree,
496
+ the the size will be mapped according to the property
497
+ stroke_width : Callable, optional
498
+ a function that maps a node id to a `float` value that will determine the
499
+ width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
500
+ factor : float, default=1.0
501
+ scaling factor for nodes positions, default 1
502
+ node_color : Callable or str, optional
503
+ a function that maps a node id to a triplet between 0 and 255.
504
+ The triplet will determine the color of the node. If a string is given instead and it is a property
505
+ of the tree, the the color will be mapped according to the property
506
+ node_color_map : Callable or str, optional
507
+ the name of the colormap to use to color the nodes, or a colormap function
508
+ stroke_color : Callable, optional
509
+ a function that maps a node id to a triplet between 0 and 255.
510
+ The triplet will determine the color of the stroke of the inward edge.
511
+ positions : dict mapping int to list of two float, optional
512
+ dictionary that maps a node id to a 2D position.
513
+ Default `None`. If provided it will be used to position the nodes.
602
514
  """
603
- import svgwrite
604
515
 
605
516
  def normalize_values(v, nodes, _range, shift, mult):
606
517
  min_ = np.percentile(v, 1)
607
518
  max_ = np.percentile(v, 99)
608
519
  values = _range * ((v - min_) / (max_ - min_)) + shift
609
- values_dict_nodes = dict(zip(nodes, values))
520
+ values_dict_nodes = dict(zip(nodes, values, strict=True))
610
521
  return lambda x: values_dict_nodes[x] * mult
611
522
 
612
523
  if roots is None:
613
524
  roots = self.roots
614
525
  if hasattr(self, "image_label"):
615
- roots = [cell for cell in roots if self.image_label[cell] != 1]
526
+ roots = [node for node in roots if self.image_label[node] != 1]
616
527
 
617
528
  if node_size is None:
618
529
 
619
530
  def node_size(x):
620
531
  return vert_space_factor / 2.1
621
532
 
622
- elif isinstance(node_size, str) and node_size in self.__dict__:
623
- values = np.array([self[node_size][c] for c in self.nodes])
533
+ else:
534
+ values = np.array(
535
+ [self._successor[node_size][c] for c in self.nodes]
536
+ )
624
537
  node_size = normalize_values(
625
538
  values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
626
539
  )
@@ -635,18 +548,19 @@ class lineageTree(lineageTreeLoaders):
635
548
  return 0, 0, 0
636
549
 
637
550
  elif isinstance(node_color, str) and node_color in self.__dict__:
638
- if isinstance(node_color_map, str):
639
- from matplotlib import colormaps
551
+ from matplotlib import colormaps
640
552
 
641
- if node_color_map in colormaps:
642
- node_color_map = colormaps[node_color_map]
643
- else:
644
- node_color_map = colormaps["viridis"]
645
- values = np.array([self[node_color][c] for c in self.nodes])
553
+ if node_color_map in colormaps:
554
+ cmap = colormaps[node_color_map]
555
+ else:
556
+ cmap = colormaps["viridis"]
557
+ values = np.array(
558
+ [self._successor[node_color][c] for c in self.nodes]
559
+ )
646
560
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
647
561
 
648
562
  def node_color(x):
649
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
563
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
650
564
 
651
565
  coloring_edges = stroke_color is not None
652
566
  if not coloring_edges:
@@ -655,24 +569,25 @@ class lineageTree(lineageTreeLoaders):
655
569
  return 0, 0, 0
656
570
 
657
571
  elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
658
- if isinstance(node_color_map, str):
659
- from matplotlib import colormaps
572
+ from matplotlib import colormaps
660
573
 
661
- if node_color_map in colormaps:
662
- node_color_map = colormaps[node_color_map]
663
- else:
664
- node_color_map = colormaps["viridis"]
665
- values = np.array([self[stroke_color][c] for c in self.nodes])
574
+ if node_color_map in colormaps:
575
+ cmap = colormaps[node_color_map]
576
+ else:
577
+ cmap = colormaps["viridis"]
578
+ values = np.array(
579
+ [self._successor[stroke_color][c] for c in self.nodes]
580
+ )
666
581
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
667
582
 
668
583
  def stroke_color(x):
669
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
584
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
670
585
 
671
586
  prev_x = 0
672
587
  self.vert_space_factor = vert_space_factor
673
588
  if order_key is not None:
674
589
  roots.sort(key=order_key)
675
- treated_cells = []
590
+ treated_nodes = []
676
591
 
677
592
  pos_given = positions is not None
678
593
  if not pos_given:
@@ -683,25 +598,26 @@ class lineageTree(lineageTreeLoaders):
683
598
  [0.0, 0.0],
684
599
  ]
685
600
  * len(self.nodes),
686
- )
601
+ strict=True,
602
+ ),
687
603
  )
688
604
  for _i, r in enumerate(roots):
689
605
  r_leaves = []
690
606
  to_do = [r]
691
607
  while len(to_do) != 0:
692
608
  curr = to_do.pop(0)
693
- treated_cells += [curr]
694
- if curr in self.successor:
609
+ treated_nodes += [curr]
610
+ if not self._successor[curr]:
695
611
  if order_key is not None:
696
- to_do += sorted(self.successor[curr], key=order_key)
612
+ to_do += sorted(self._successor[curr], key=order_key)
697
613
  else:
698
- to_do += self.successor[curr]
614
+ to_do += self._successor[curr]
699
615
  else:
700
616
  r_leaves += [curr]
701
617
  r_pos = {
702
618
  leave: [
703
619
  prev_x + horizontal_space * (1 + j),
704
- self.vert_space_factor * self.time[leave],
620
+ self.vert_space_factor * self._time[leave],
705
621
  ]
706
622
  for j, leave in enumerate(r_leaves)
707
623
  }
@@ -716,12 +632,12 @@ class lineageTree(lineageTreeLoaders):
716
632
  size=factor * np.max(list(positions.values()), axis=0),
717
633
  )
718
634
  if draw_edges and not draw_nodes and not coloring_edges:
719
- to_do = set(treated_cells)
635
+ to_do = set(treated_nodes)
720
636
  while len(to_do) > 0:
721
637
  curr = to_do.pop()
722
- c_cycle = self.get_cycle(curr)
723
- x1, y1 = positions[c_cycle[0]]
724
- x2, y2 = positions[c_cycle[-1]]
638
+ c_chain = self.get_chain_of_node(curr)
639
+ x1, y1 = positions[c_chain[0]]
640
+ x2, y2 = positions[c_chain[-1]]
725
641
  dwg.add(
726
642
  dwg.line(
727
643
  (factor * x1, factor * y1),
@@ -729,7 +645,7 @@ class lineageTree(lineageTreeLoaders):
729
645
  stroke=svgwrite.rgb(0, 0, 0),
730
646
  )
731
647
  )
732
- for si in self[c_cycle[-1]]:
648
+ for si in self._successor[c_chain[-1]]:
733
649
  x3, y3 = positions[si]
734
650
  dwg.add(
735
651
  dwg.line(
@@ -738,11 +654,11 @@ class lineageTree(lineageTreeLoaders):
738
654
  stroke=svgwrite.rgb(0, 0, 0),
739
655
  )
740
656
  )
741
- to_do.difference_update(c_cycle)
657
+ to_do.difference_update(c_chain)
742
658
  else:
743
- for c in treated_cells:
659
+ for c in treated_nodes:
744
660
  x1, y1 = positions[c]
745
- for si in self[c]:
661
+ for si in self._successor[c]:
746
662
  x2, y2 = positions[si]
747
663
  if draw_edges:
748
664
  dwg.add(
@@ -753,7 +669,7 @@ class lineageTree(lineageTreeLoaders):
753
669
  stroke_width=svgwrite.pt(stroke_width(si)),
754
670
  )
755
671
  )
756
- for c in treated_cells:
672
+ for c in treated_nodes:
757
673
  x1, y1 = positions[c]
758
674
  if draw_nodes:
759
675
  dwg.add(
@@ -770,49 +686,58 @@ class lineageTree(lineageTreeLoaders):
770
686
  fname: str,
771
687
  t_min: int = -1,
772
688
  t_max: int = np.inf,
773
- nodes_to_use: list = None,
689
+ nodes_to_use: list[int] | None = None,
774
690
  temporal: bool = True,
775
- spatial: str = None,
691
+ spatial: str | None = None,
776
692
  write_layout: bool = True,
777
- node_properties: dict = None,
693
+ node_properties: dict | None = None,
778
694
  Names: bool = False,
779
- ):
695
+ ) -> None:
780
696
  """Write a lineage tree into an understable tulip file.
781
697
 
782
- Args:
783
- fname (str): path to the tulip file to create
784
- t_min (int): minimum time to consider, default -1
785
- t_max (int): maximum time to consider, default np.inf
786
- nodes_to_use ([int, ]): list of nodes to show in the graph,
787
- default *None*, then self.nodes is used
788
- (taking into account *t_min* and *t_max*)
789
- temporal (bool): True if the temporal links should be printed, default True
790
- spatial (str): Build spatial edges from a spatial neighbourhood graph.
791
- The graph has to be computed before running this function
792
- 'ball': neighbours at a given distance,
793
- 'kn': k-nearest neighbours,
794
- 'GG': gabriel graph,
795
- None: no spatial edges are writen.
796
- Default None
797
- write_layout (bool): True, write the spatial position as layout,
798
- False, do not write spatial positionm
799
- default True
800
- node_properties ({`p_name`, [{id, p_value}, default]}): a dictionary of properties to write
801
- To a key representing the name of the property is
802
- paired a dictionary that maps a cell id to a property
803
- and a default value for this property
804
- Names (bool): Only works with ASTEC outputs, True to sort the cells by their names
698
+ Parameters
699
+ ----------
700
+ fname : str
701
+ path to the tulip file to create
702
+ t_min : int, default=-1
703
+ minimum time to consider
704
+ t_max : int, default=np.inf
705
+ maximum time to consider
706
+ nodes_to_use : list of int, optional
707
+ list of nodes to show in the graph,
708
+ if `None` then self.nodes is used
709
+ (taking into account `t_min` and `t_max`)
710
+ temporal : bool, default=True
711
+ True if the temporal links should be printed
712
+ spatial : str, optional
713
+ Build spatial edges from a spatial neighbourhood graph.
714
+ The graph has to be computed before running this function
715
+ 'ball': neighbours at a given distance,
716
+ 'kn': k-nearest neighbours,
717
+ 'GG': gabriel graph,
718
+ None: no spatial edges are writen.
719
+ Default None
720
+ write_layout : bool, default=True
721
+ write the spatial position as layout if True
722
+ do not write spatial position otherwise
723
+ node_properties : dict mapping str to list of dict of properties and its default value, optional
724
+ a dictionary of properties to write
725
+ To a key representing the name of the property is
726
+ paired a dictionary that maps a node id to a property
727
+ and a default value for this property
728
+ Names : bool, default=True
729
+ Only works with ASTEC outputs, True to sort the nodes by their names
805
730
  """
806
731
 
807
732
  def format_names(names_which_matter):
808
- """Return an ensured formated cell names"""
733
+ """Return an ensured formated node names"""
809
734
  tmp = {}
810
735
  for k, v in names_which_matter.items():
811
736
  tmp[k] = (
812
737
  v.split(".")[0][0]
813
- + "%02d" % int(v.split(".")[0][1:])
738
+ + "{:02d}".format(int(v.split(".")[0][1:]))
814
739
  + "."
815
- + "%04d" % int(v.split(".")[1][:-1])
740
+ + "{:04d}".format(int(v.split(".")[1][:-1]))
816
741
  + v.split(".")[1][-1]
817
742
  )
818
743
  return tmp
@@ -839,20 +764,20 @@ class lineageTree(lineageTreeLoaders):
839
764
  if not nodes_to_use:
840
765
  if t_max != np.inf or t_min > -1:
841
766
  nodes_to_use = [
842
- n for n in self.nodes if t_min < self.time[n] <= t_max
767
+ n for n in self.nodes if t_min < self._time[n] <= t_max
843
768
  ]
844
769
  edges_to_use = []
845
770
  if temporal:
846
771
  edges_to_use += [
847
772
  e
848
773
  for e in self.edges
849
- if t_min < self.time[e[0]] < t_max
774
+ if t_min < self._time[e[0]] < t_max
850
775
  ]
851
776
  if spatial:
852
777
  edges_to_use += [
853
778
  e
854
779
  for e in s_edges
855
- if t_min < self.time[e[0]] < t_max
780
+ if t_min < self._time[e[0]] < t_max
856
781
  ]
857
782
  else:
858
783
  nodes_to_use = list(self.nodes)
@@ -866,12 +791,12 @@ class lineageTree(lineageTreeLoaders):
866
791
  nodes_to_use = set(nodes_to_use)
867
792
  if temporal:
868
793
  for n in nodes_to_use:
869
- for d in self.successor.get(n, []):
794
+ for d in self._successor[n]:
870
795
  if d in nodes_to_use:
871
796
  edges_to_use.append((n, d))
872
797
  if spatial:
873
798
  edges_to_use += [
874
- e for e in s_edges if t_min < self.time[e[0]] < t_max
799
+ e for e in s_edges if t_min < self._time[e[0]] < t_max
875
800
  ]
876
801
  nodes_to_use = set(nodes_to_use)
877
802
  if Names:
@@ -889,12 +814,12 @@ class lineageTree(lineageTreeLoaders):
889
814
  for k, v in node_properties[Names][0].items():
890
815
  if (
891
816
  len(
892
- self.successor.get(
893
- self.predecessor.get(k, [-1])[0], []
817
+ self._successor.get(
818
+ self._predecessor.get(k, [-1])[0], ()
894
819
  )
895
820
  )
896
821
  != 1
897
- or self.time[k] == t_min + 1
822
+ or self._time[k] == t_min + 1
898
823
  ):
899
824
  tmp_names[k] = v
900
825
  node_properties[Names][0] = tmp_names
@@ -924,7 +849,7 @@ class lineageTree(lineageTreeLoaders):
924
849
  f.write('\t(default "0" "0")\n')
925
850
  for n in nodes_to_use:
926
851
  f.write(
927
- "\t(node " + str(n) + ' "' + str(self.time[n]) + '")\n'
852
+ "\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
928
853
  )
929
854
  f.write(")\n")
930
855
 
@@ -953,10 +878,10 @@ class lineageTree(lineageTreeLoaders):
953
878
  if node_properties:
954
879
  for p_name, (p_dict, default) in node_properties.items():
955
880
  if isinstance(list(p_dict.values())[0], str):
956
- f.write('(property 0 string "%s"\n' % p_name)
881
+ f.write(f'(property 0 string "{p_name}"\n')
957
882
  f.write(f"\t(default {default} {default})\n")
958
883
  elif isinstance(list(p_dict.values())[0], Number):
959
- f.write('(property 0 double "%s"\n' % p_name)
884
+ f.write(f'(property 0 double "{p_name}"\n')
960
885
  f.write('\t(default "0" "0")\n')
961
886
  for n in nodes_to_use:
962
887
  f.write(
@@ -971,7 +896,9 @@ class lineageTree(lineageTreeLoaders):
971
896
  f.write(")")
972
897
  f.close()
973
898
 
974
- def to_binary(self, fname: str, starting_points: list = None):
899
+ def to_binary(
900
+ self, fname: str, starting_points: list[int] | None = None
901
+ ) -> None:
975
902
  """Writes the lineage tree (a forest) as a binary structure
976
903
  (assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
977
904
  The binary file is composed of 3 sequences of numbers and
@@ -986,36 +913,37 @@ class lineageTree(lineageTreeLoaders):
986
913
  The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
987
914
  The *pos_sequence* is stored as a list of double.
988
915
 
989
- Args:
990
- fname (str): name of the binary file
991
- starting_points ([int, ]): list of the roots to be written.
992
- If None, all roots are written, default value, None
916
+ Parameters
917
+ ----------
918
+ fname : str
919
+ name of the binary file
920
+ starting_points : list of int, optional
921
+ list of the roots to be written.
922
+ If `None`, all roots are written, default value, None
993
923
  """
994
924
  if starting_points is None:
995
- starting_points = [
996
- c for c in self.successor if self.predecessor.get(c, []) == []
997
- ]
925
+ starting_points = list(self.roots)
998
926
  number_sequence = [-1]
999
927
  pos_sequence = []
1000
928
  time_sequence = []
1001
929
  for c in starting_points:
1002
- time_sequence.append(self.time.get(c, 0))
930
+ time_sequence.append(self._time.get(c, 0))
1003
931
  to_treat = [c]
1004
- while to_treat != []:
932
+ while to_treat:
1005
933
  curr_c = to_treat.pop()
1006
934
  number_sequence.append(curr_c)
1007
935
  pos_sequence += list(self.pos[curr_c])
1008
- if self[curr_c] == []:
936
+ if self._successor[curr_c] == ():
1009
937
  number_sequence.append(-1)
1010
- elif len(self.successor[curr_c]) == 1:
1011
- to_treat += self.successor[curr_c]
938
+ elif len(self._successor[curr_c]) == 1:
939
+ to_treat += self._successor[curr_c]
1012
940
  else:
1013
941
  number_sequence.append(-2)
1014
- to_treat += self.successor[curr_c]
942
+ to_treat += self._successor[curr_c]
1015
943
  remaining_nodes = set(self.nodes) - set(number_sequence)
1016
944
 
1017
945
  for c in remaining_nodes:
1018
- time_sequence.append(self.time.get(c, 0))
946
+ time_sequence.append(self._time.get(c, 0))
1019
947
  number_sequence.append(c)
1020
948
  pos_sequence += list(self.pos[c])
1021
949
  number_sequence.append(-1)
@@ -1030,204 +958,270 @@ class lineageTree(lineageTreeLoaders):
1030
958
 
1031
959
  f.close()
1032
960
 
1033
- def read_from_binary(self, fname: str):
1034
- """
1035
- Reads a binary lineageTree file name.
1036
- Format description: see self.to_binary
961
+ def write(self, fname: str) -> None:
962
+ """Write a lineage tree on disk as an .lT file.
1037
963
 
1038
- Args:
1039
- fname: string, path to the binary file
1040
- reverse_time: bool, not used
1041
- """
1042
- q_size = struct.calcsize("q")
1043
- H_size = struct.calcsize("H")
1044
- d_size = struct.calcsize("d")
1045
-
1046
- with open(fname, "rb") as f:
1047
- len_tree = struct.unpack("q", f.read(q_size))[0]
1048
- len_time = struct.unpack("q", f.read(q_size))[0]
1049
- len_pos = struct.unpack("q", f.read(q_size))[0]
1050
- number_sequence = list(
1051
- struct.unpack("q" * len_tree, f.read(q_size * len_tree))
1052
- )
1053
- time_sequence = list(
1054
- struct.unpack("H" * len_time, f.read(H_size * len_time))
1055
- )
1056
- pos_sequence = np.array(
1057
- struct.unpack("d" * len_pos, f.read(d_size * len_pos))
1058
- )
1059
-
1060
- f.close()
1061
-
1062
- successor = {}
1063
- predecessor = {}
1064
- time = {}
1065
- time_nodes = {}
1066
- time_edges = {}
1067
- pos = {}
1068
- is_root = {}
1069
- nodes = []
1070
- edges = []
1071
- waiting_list = []
1072
- print(number_sequence[0])
1073
- i = 0
1074
- done = False
1075
- if max(number_sequence[::2]) == -1:
1076
- tmp = number_sequence[1::2]
1077
- if len(tmp) * 3 == len(pos_sequence) == len(time_sequence) * 3:
1078
- time = dict(list(zip(tmp, time_sequence)))
1079
- for c, t in time.items():
1080
- time_nodes.setdefault(t, set()).add(c)
1081
- pos = dict(
1082
- list(zip(tmp, np.reshape(pos_sequence, (len_time, 3))))
1083
- )
1084
- is_root = {c: True for c in tmp}
1085
- nodes = tmp
1086
- done = True
1087
- while (
1088
- i < len(number_sequence) and not done
1089
- ): # , c in enumerate(number_sequence[:-1]):
1090
- c = number_sequence[i]
1091
- if c == -1:
1092
- if waiting_list != []:
1093
- prev_mother = waiting_list.pop()
1094
- successor[prev_mother].insert(0, number_sequence[i + 1])
1095
- edges.append((prev_mother, number_sequence[i + 1]))
1096
- time_edges.setdefault(t, set()).add(
1097
- (prev_mother, number_sequence[i + 1])
1098
- )
1099
- is_root[number_sequence[i + 1]] = False
1100
- t = time[prev_mother] + 1
1101
- else:
1102
- t = time_sequence.pop(0)
1103
- is_root[number_sequence[i + 1]] = True
1104
-
1105
- elif c == -2:
1106
- successor[waiting_list[-1]] = [number_sequence[i + 1]]
1107
- edges.append((waiting_list[-1], number_sequence[i + 1]))
1108
- time_edges.setdefault(t, set()).add(
1109
- (waiting_list[-1], number_sequence[i + 1])
1110
- )
1111
- is_root[number_sequence[i + 1]] = False
1112
- pos[waiting_list[-1]] = pos_sequence[:3]
1113
- pos_sequence = pos_sequence[3:]
1114
- nodes.append(waiting_list[-1])
1115
- time[waiting_list[-1]] = t
1116
- time_nodes.setdefault(t, set()).add(waiting_list[-1])
1117
- t += 1
1118
-
1119
- elif number_sequence[i + 1] >= 0:
1120
- successor[c] = [number_sequence[i + 1]]
1121
- edges.append((c, number_sequence[i + 1]))
1122
- time_edges.setdefault(t, set()).add(
1123
- (c, number_sequence[i + 1])
1124
- )
1125
- is_root[number_sequence[i + 1]] = False
1126
- pos[c] = pos_sequence[:3]
1127
- pos_sequence = pos_sequence[3:]
1128
- nodes.append(c)
1129
- time[c] = t
1130
- time_nodes.setdefault(t, set()).add(c)
1131
- t += 1
1132
-
1133
- elif number_sequence[i + 1] == -2:
1134
- waiting_list += [c]
1135
-
1136
- elif number_sequence[i + 1] == -1:
1137
- pos[c] = pos_sequence[:3]
1138
- pos_sequence = pos_sequence[3:]
1139
- nodes.append(c)
1140
- time[c] = t
1141
- time_nodes.setdefault(t, set()).add(c)
1142
- t += 1
1143
- i += 1
1144
- if waiting_list != []:
1145
- prev_mother = waiting_list.pop()
1146
- successor[prev_mother].insert(0, number_sequence[i + 1])
1147
- edges.append((prev_mother, number_sequence[i + 1]))
1148
- time_edges.setdefault(t, set()).add(
1149
- (prev_mother, number_sequence[i + 1])
1150
- )
1151
- if i + 1 < len(number_sequence):
1152
- is_root[number_sequence[i + 1]] = False
1153
- t = time[prev_mother] + 1
1154
- else:
1155
- if len(time_sequence) > 0:
1156
- t = time_sequence.pop(0)
1157
- if i + 1 < len(number_sequence):
1158
- is_root[number_sequence[i + 1]] = True
1159
- i += 1
1160
-
1161
- predecessor = {vi: [k] for k, v in successor.items() for vi in v}
1162
-
1163
- self.successor = successor
1164
- self.predecessor = predecessor
1165
- self.time = time
1166
- self.time_nodes = time_nodes
1167
- self.time_edges = time_edges
1168
- self.pos = pos
1169
- self.nodes = set(nodes)
1170
- self.t_b = min(time_nodes.keys())
1171
- self.t_e = max(time_nodes.keys())
1172
- self.is_root = is_root
1173
- self.max_id = max(self.nodes)
1174
-
1175
- def write(self, fname: str):
1176
- """
1177
- Write a lineage tree on disk as an .lT file.
1178
-
1179
- Args:
1180
- fname (str): path to and name of the file to save
964
+ Parameters
965
+ ----------
966
+ fname : str
967
+ path to and name of the file to save
1181
968
  """
1182
969
  if os.path.splitext(fname)[-1] != ".lT":
1183
970
  fname = os.path.extsep.join((fname, "lT"))
971
+ if hasattr(self, "_protected_predecessor"):
972
+ del self._protected_predecessor
973
+ if hasattr(self, "_protected_successor"):
974
+ del self._protected_successor
975
+ if hasattr(self, "_protected_time"):
976
+ del self._protected_time
1184
977
  with open(fname, "bw") as f:
1185
978
  pkl.dump(self, f)
1186
979
  f.close()
1187
980
 
1188
981
  @classmethod
1189
- def load(clf, fname: str, rm_empty_lists=True):
1190
- """
1191
- Loading a lineage tree from a ".lT" file.
982
+ def load(clf, fname: str):
983
+ """Loading a lineage tree from a '.lT' file.
1192
984
 
1193
- Args:
1194
- fname (str): path to and name of the file to read
985
+ Parameters
986
+ ----------
987
+ fname : str
988
+ path to and name of the file to read
1195
989
 
1196
- Returns:
1197
- (lineageTree): loaded file
990
+ Returns
991
+ -------
992
+ LineageTree
993
+ loaded file
1198
994
  """
1199
995
  with open(fname, "br") as f:
1200
996
  lT = pkl.load(f)
1201
997
  f.close()
998
+ if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
999
+ "2.0.0"
1000
+ ):
1001
+ properties = {
1002
+ prop_name: prop
1003
+ for prop_name, prop in lT.__dict__.items()
1004
+ if isinstance(prop, dict)
1005
+ and prop_name
1006
+ not in [
1007
+ "successor",
1008
+ "predecessor",
1009
+ "time",
1010
+ "_successor",
1011
+ "_predecessor",
1012
+ "_time",
1013
+ "pos",
1014
+ "labels",
1015
+ ]
1016
+ + lineageTree._dynamic_properties
1017
+ + lineageTree._protected_dynamic_properties
1018
+ }
1019
+ print("_comparisons" in properties)
1020
+ lT = lineageTree(
1021
+ successor=lT._successor,
1022
+ time=lT._time,
1023
+ pos=lT.pos,
1024
+ name=lT.name if hasattr(lT, "name") else None,
1025
+ **properties,
1026
+ )
1202
1027
  if not hasattr(lT, "time_resolution"):
1203
- lT.time_resolution = None
1204
- if rm_empty_lists:
1205
- if [] in lT.successor.values():
1206
- for node, succ in lT.successor.items():
1207
- if succ == []:
1208
- lT.successor.pop(node)
1209
- if [] in lT.predecessor.values():
1210
- for node, succ in lT.predecessor.items():
1211
- if succ == []:
1212
- lT.predecessor.pop(node)
1213
- lT.t_e = max(lT.time_nodes)
1214
- lT.t_b = min(lT.time_nodes)
1215
- warnings.warn("Empty lists have been removed")
1028
+ lT.time_resolution = 1
1029
+
1216
1030
  return lT
1217
1031
 
1218
- def get_idx3d(self, t: int) -> tuple:
1219
- """Get a 3d kdtree for the dataset at time *t* .
1220
- The kdtree is stored in *self.kdtrees[t]*
1221
-
1222
- Args:
1223
- t (int): time
1224
- Returns:
1225
- (kdtree, [int, ]): the built kdtree and
1226
- the correspondancy list,
1227
- If the query in the kdtree gives you the value i,
1228
- then it corresponds to the id in the tree to_check_self[i]
1032
+ def get_predecessors(
1033
+ self,
1034
+ x: int,
1035
+ depth: int | None = None,
1036
+ start_time: int | None = None,
1037
+ end_time: int | None = None,
1038
+ ) -> list[int]:
1039
+ """Computes the predecessors of the node `x` up to
1040
+ `depth` predecessors or the begining of the life of `x`.
1041
+ The ordered list of ids is returned.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ x : int
1046
+ id of the node to compute
1047
+ depth : int
1048
+ maximum number of predecessors to return
1049
+
1050
+ Returns
1051
+ -------
1052
+ list of int
1053
+ list of ids, the last id is `x`
1054
+ """
1055
+ if not start_time:
1056
+ start_time = self.t_b
1057
+ if not end_time:
1058
+ end_time = self.t_e
1059
+ unconstrained_chain = [x]
1060
+ chain = [x] if start_time <= self._time[x] <= end_time else []
1061
+ acc = 0
1062
+ while (
1063
+ acc != depth
1064
+ and start_time < self._time[unconstrained_chain[0]]
1065
+ and (
1066
+ self._predecessor[unconstrained_chain[0]] != ()
1067
+ and ( # Please dont change very important even if it looks weird.
1068
+ len(
1069
+ self._successor[
1070
+ self._predecessor[unconstrained_chain[0]][0]
1071
+ ]
1072
+ )
1073
+ == 1
1074
+ )
1075
+ )
1076
+ ):
1077
+ unconstrained_chain.insert(
1078
+ 0, self._predecessor[unconstrained_chain[0]][0]
1079
+ )
1080
+ acc += 1
1081
+ if start_time <= self._time[unconstrained_chain[0]] <= end_time:
1082
+ chain.insert(0, unconstrained_chain[0])
1083
+
1084
+ return chain
1085
+
1086
+ def get_successors(
1087
+ self, x: int, depth: int | None = None, end_time: int | None = None
1088
+ ) -> list[int]:
1089
+ """Computes the successors of the node `x` up to
1090
+ `depth` successors or the end of the life of `x`.
1091
+ The ordered list of ids is returned.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ x : int
1096
+ id of the node to compute
1097
+ depth : int, optional
1098
+ maximum number of predecessors to return
1099
+ end_time : int, optional
1100
+ maximum time to consider
1101
+
1102
+ Returns
1103
+ -------
1104
+ list of int
1105
+ list of ids, the first id is `x`
1106
+ """
1107
+ if end_time is None:
1108
+ end_time = self.t_e
1109
+ chain = [x]
1110
+ acc = 0
1111
+ while (
1112
+ len(self._successor[chain[-1]]) == 1
1113
+ and acc != depth
1114
+ and self._time[chain[-1]] < end_time
1115
+ ):
1116
+ chain += self._successor[chain[-1]]
1117
+ acc += 1
1118
+
1119
+ return chain
1120
+
1121
+ def get_chain_of_node(
1122
+ self,
1123
+ x: int,
1124
+ depth: int | None = None,
1125
+ depth_pred: int | None = None,
1126
+ depth_succ: int | None = None,
1127
+ end_time: int | None = None,
1128
+ ) -> list[int]:
1129
+ """Computes the predecessors and successors of the node `x` up to
1130
+ `depth_pred` predecessors plus `depth_succ` successors.
1131
+ If the value `depth` is provided and not None,
1132
+ `depth_pred` and `depth_succ` are overwriten by `depth`.
1133
+ The ordered list of ids is returned.
1134
+ If all `depth` are None, the full chain is returned.
1135
+
1136
+ Parameters
1137
+ ----------
1138
+ x : int
1139
+ id of the node to compute
1140
+ depth : int, optional
1141
+ maximum number of predecessors and successor to return
1142
+ depth_pred : int, optional
1143
+ maximum number of predecessors to return
1144
+ depth_succ : int, optional
1145
+ maximum number of successors to return
1146
+
1147
+ Returns
1148
+ -------
1149
+ list of int
1150
+ list of node ids
1229
1151
  """
1230
- to_check_self = list(self.time_nodes[t])
1152
+ if end_time is None:
1153
+ end_time = self.t_e
1154
+ if depth is not None:
1155
+ depth_pred = depth_succ = depth
1156
+ return self.get_predecessors(x, depth_pred, end_time=end_time)[
1157
+ :-1
1158
+ ] + self.get_successors(x, depth_succ, end_time=end_time)
1159
+
1160
+ @dynamic_property
1161
+ def all_chains(self) -> list[list[int]]:
1162
+ """List of all chains in the tree, ordered in depth-first search."""
1163
+ return self._compute_all_chains()
1164
+
1165
+ @dynamic_property
1166
+ def time_nodes(self):
1167
+ _time_nodes = {}
1168
+ for c, t in self._time.items():
1169
+ _time_nodes.setdefault(t, set()).add(c)
1170
+ return _time_nodes
1171
+
1172
+ def m(self, i, j):
1173
+ if (i, j) not in self._tmp_parenting:
1174
+ if i == j: # the distance to the node itself is 0
1175
+ self._tmp_parenting[(i, j)] = 0
1176
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1177
+ elif not self._predecessor[
1178
+ j
1179
+ ]: # j and i are note connected so the distance if inf
1180
+ self._tmp_parenting[(i, j)] = np.inf
1181
+ else: # the distance between i and j is the distance between i and pred(j) + 1
1182
+ self._tmp_parenting[(i, j)] = (
1183
+ self.m(i, self._predecessor[j][0]) + 1
1184
+ )
1185
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1186
+ self._parenting[j, i] = -self._tmp_parenting[(i, j)]
1187
+ return self._tmp_parenting[(i, j)]
1188
+
1189
+ @property
1190
+ def parenting(self):
1191
+ if not hasattr(self, "_parenting"):
1192
+ self._parenting = dok_array((max(self.nodes) + 1,) * 2)
1193
+ self._tmp_parenting = {}
1194
+ for i, j in combinations(self.nodes, 2):
1195
+ if self._time[j] < self.time[i]:
1196
+ i, j = j, i
1197
+ self._tmp_parenting[(i, j)] = self.m(i, j)
1198
+ del self._tmp_parenting
1199
+ return self._parenting
1200
+
1201
+ def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
1202
+ """Get a 3d kdtree for the dataset at time `t`.
1203
+ The kdtree is stored in `self.kdtrees[t]` and returned.
1204
+ The correspondancy list is also returned.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ t : int
1209
+ time
1210
+
1211
+ Returns
1212
+ -------
1213
+ KDTree
1214
+ The KDTree corresponding to the lineage tree at time `t`
1215
+ np.ndarray
1216
+ The correspondancy list in the KDTree.
1217
+ If the query in the kdtree gives you the value `i`,
1218
+ then it corresponds to the id in the tree `to_check_self[i]`
1219
+ """
1220
+ to_check_self = list(self.nodes_at_t(t=t))
1221
+
1222
+ if not hasattr(self, "kdtrees"):
1223
+ self.kdtrees = {}
1224
+
1231
1225
  if t not in self.kdtrees:
1232
1226
  data_corres = {}
1233
1227
  data = []
@@ -1240,16 +1234,21 @@ class lineageTree(lineageTreeLoaders):
1240
1234
  idx3d = self.kdtrees[t]
1241
1235
  return idx3d, np.array(to_check_self)
1242
1236
 
1243
- def get_gabriel_graph(self, t: int) -> dict:
1244
- """Build the Gabriel graph of the given graph for time point `t`
1245
- The Garbiel graph is then stored in self.Gabriel_graph and returned
1246
- *WARNING: the graph is not recomputed if already computed. even if nodes were added*.
1237
+ def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
1238
+ """Build the Gabriel graph of the given graph for time point `t`.
1239
+ The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
1240
+
1241
+ .. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
1247
1242
 
1248
- Args:
1249
- t (int): time
1250
- Returns:
1251
- {int, set([int, ])}: a dictionary that maps a node to
1252
- the set of its neighbors
1243
+ Parameters
1244
+ ----------
1245
+ t : int
1246
+ time
1247
+
1248
+ Returns
1249
+ -------
1250
+ dict of int to set of int
1251
+ A dictionary that maps a node to the set of its neighbors
1253
1252
  """
1254
1253
  if not hasattr(self, "Gabriel_graph"):
1255
1254
  self.Gabriel_graph = {}
@@ -1294,178 +1293,110 @@ class lineageTree(lineageTreeLoaders):
1294
1293
 
1295
1294
  return self.Gabriel_graph[t]
1296
1295
 
1297
- def get_predecessors(
1298
- self, x: int, depth: int = None, start_time: int = None, end_time=None
1299
- ) -> list:
1300
- """Computes the predecessors of the node `x` up to
1301
- `depth` predecessors or the begining of the life of `x`.
1302
- The ordered list of ids is returned.
1303
-
1304
- Args:
1305
- x (int): id of the node to compute
1306
- depth (int): maximum number of predecessors to return
1307
- Returns:
1308
- [int, ]: list of ids, the last id is `x`
1309
- """
1310
- if not start_time:
1311
- start_time = self.t_b
1312
- if not end_time:
1313
- end_time = self.t_e
1314
- unconstrained_cycle = [x]
1315
- cycle = [x] if start_time <= self.time[x] <= end_time else []
1316
- acc = 0
1317
- while (
1318
- len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
1319
- == 1
1320
- and acc != depth
1321
- and start_time
1322
- <= self.time.get(
1323
- self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
1324
- )
1325
- ):
1326
- unconstrained_cycle.insert(
1327
- 0, self.predecessor[unconstrained_cycle[0]][0]
1328
- )
1329
- acc += 1
1330
- if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
1331
- cycle.insert(0, unconstrained_cycle[0])
1332
-
1333
- return cycle
1334
-
1335
- def get_successors(
1336
- self, x: int, depth: int = None, end_time: int = None
1337
- ) -> list:
1338
- """Computes the successors of the node `x` up to
1339
- `depth` successors or the end of the life of `x`.
1340
- The ordered list of ids is returned.
1341
-
1342
- Args:
1343
- x (int): id of the node to compute
1344
- depth (int): maximum number of predecessors to return
1345
- Returns:
1346
- [int, ]: list of ids, the first id is `x`
1347
- """
1348
- if end_time is None:
1349
- end_time = self.t_e
1350
- cycle = [x]
1351
- acc = 0
1352
- while (
1353
- len(self[cycle[-1]]) == 1
1354
- and acc != depth
1355
- and self.time[cycle[-1]] < end_time
1356
- ):
1357
- cycle += self.successor[cycle[-1]]
1358
- acc += 1
1359
-
1360
- return cycle
1296
+ def get_all_chains_of_subtree(
1297
+ self, node: int, end_time: int | None = None
1298
+ ) -> list[list[int]]:
1299
+ """Computes all the chains of the subtree spawn by a given node.
1300
+ Similar to get_all_chains().
1361
1301
 
1362
- def get_cycle(
1363
- self,
1364
- x: int,
1365
- depth: int = None,
1366
- depth_pred: int = None,
1367
- depth_succ: int = None,
1368
- end_time: int = None,
1369
- ) -> list:
1370
- """Computes the predecessors and successors of the node `x` up to
1371
- `depth_pred` predecessors plus `depth_succ` successors.
1372
- If the value `depth` is provided and not None,
1373
- `depth_pred` and `depth_succ` are overwriten by `depth`.
1374
- The ordered list of ids is returned.
1375
- If all `depth` are None, the full cycle is returned.
1376
-
1377
- Args:
1378
- x (int): id of the node to compute
1379
- depth (int): maximum number of predecessors and successor to return
1380
- depth_pred (int): maximum number of predecessors to return
1381
- depth_succ (int): maximum number of successors to return
1382
- Returns:
1383
- [int, ]: list of ids
1384
- """
1385
- if end_time is None:
1386
- end_time = self.t_e
1387
- if depth is not None:
1388
- depth_pred = depth_succ = depth
1389
- return self.get_predecessors(x, depth_pred, end_time=end_time)[
1390
- :-1
1391
- ] + self.get_successors(x, depth_succ, end_time=end_time)
1392
-
1393
- @property
1394
- def all_tracks(self):
1395
- if not hasattr(self, "_all_tracks"):
1396
- self._all_tracks = self.get_all_tracks()
1397
- return self._all_tracks
1302
+ Parameters
1303
+ ----------
1304
+ node : int
1305
+ The node from which we want to get its chains.
1306
+ end_time : int, optional
1307
+ The time at which we want to stop the chains.
1398
1308
 
1399
- def get_all_branches_of_node(
1400
- self, node: int, end_time: int = None
1401
- ) -> list:
1402
- """Computes all the tracks of the subtree spawn by a given node.
1403
- Similar to get_all_tracks().
1404
-
1405
- Args:
1406
- node (int, optional): The node that we want to get its branches.
1407
-
1408
- Returns:
1409
- ([[int, ...], ...]): list of lists containing track cell ids
1309
+ Returns
1310
+ -------
1311
+ list of list of int
1312
+ list of chains
1410
1313
  """
1411
1314
  if not end_time:
1412
1315
  end_time = self.t_e
1413
- branches = [self.get_successors(node)]
1414
- to_do = list(self[branches[0][-1]])
1316
+ chains = [self.get_successors(node)]
1317
+ to_do = list(self._successor[chains[0][-1]])
1415
1318
  while to_do:
1416
1319
  current = to_do.pop()
1417
- track = self.get_successors(current, end_time=end_time)
1418
- # if len(track) != 1 or self.time[current] <= end_time:
1419
- if self.time[track[-1]] <= end_time:
1420
- branches += [track]
1421
- to_do += self[track[-1]]
1422
- return branches
1423
-
1424
- def get_all_tracks(self, force_recompute: bool = False) -> list:
1425
- """Computes all the tracks of a given lineage tree,
1426
- stores it in `self.all_tracks` and returns it.
1427
-
1428
- Returns:
1429
- ([[int, ...], ...]): list of lists containing track cell ids
1430
- """
1431
- if not hasattr(self, "_all_tracks") or force_recompute:
1432
- self._all_tracks = []
1433
- to_do = list(self.roots)
1434
- while len(to_do) != 0:
1435
- current = to_do.pop()
1436
- track = self.get_cycle(current)
1437
- self._all_tracks += [track]
1438
- to_do.extend(self[track[-1]])
1439
- return self._all_tracks
1440
-
1441
- def get_tracks(self, roots: list = None) -> list:
1442
- """Computes the tracks given by the list of nodes `roots` and returns it.
1443
-
1444
- Args:
1445
- roots (list): list of ids of the roots to be computed
1446
- Returns:
1447
- ([[int, ...], ...]): list of lists containing track cell ids
1448
- """
1449
- if roots is None:
1450
- return self.get_all_tracks(force_recompute=True)
1451
- else:
1452
- tracks = []
1453
- to_do = list(roots)
1454
- while len(to_do) != 0:
1455
- current = to_do.pop()
1456
- track = self.get_cycle(current)
1457
- tracks.append(track)
1458
- to_do.extend(self[track[-1]])
1459
- return tracks
1320
+ chain = self.get_successors(current, end_time=end_time)
1321
+ if self._time[chain[-1]] <= end_time:
1322
+ chains += [chain]
1323
+ to_do += self._successor[chain[-1]]
1324
+ return chains
1325
+
1326
+ def _compute_all_chains(self) -> list[list[int]]:
1327
+ """Computes all the chains of a given lineage tree,
1328
+ stores it in `self.all_chains` and returns it.
1329
+
1330
+ Returns
1331
+ -------
1332
+ list of list of int
1333
+ list of chains
1334
+ """
1335
+ all_chains = []
1336
+ to_do = sorted(self.roots, key=self.time.get, reverse=True)
1337
+ while len(to_do) != 0:
1338
+ current = to_do.pop()
1339
+ chain = self.get_chain_of_node(current)
1340
+ all_chains += [chain]
1341
+ to_do.extend(self._successor[chain[-1]])
1342
+ return all_chains
1343
+
1344
+ def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
1345
+ self, nodes: Iterable | int | None = None
1346
+ ) -> dict[int, list[list[int]]]:
1347
+ """Returns all the chains in the subtrees spawned by each of the given nodes.
1348
+
1349
+ Parameters
1350
+ ----------
1351
+ nodes : Iterable or int, optional
1352
+ id or Iterable of ids of the nodes to be computed, if `None` all roots are used
1353
+
1354
+ Returns
1355
+ -------
1356
+ dict mapping int to list of Chain
1357
+ dictionary mapping the node ids to a list of chains
1358
+ """
1359
+ all_chains = self.all_chains
1360
+ if nodes is None:
1361
+ nodes = self.roots
1362
+ if not isinstance(nodes, Iterable):
1363
+ nodes = [nodes]
1364
+ output_chains = {}
1365
+ for n in nodes:
1366
+ starting_node = self.get_predecessors(n)[0]
1367
+ found = False
1368
+ done = False
1369
+ starting_time = self.time[n]
1370
+ i = 0
1371
+ current_chain = []
1372
+ while not done and i < len(all_chains):
1373
+ curr_found = all_chains[i][0] == starting_node
1374
+ found = found or curr_found
1375
+ if found:
1376
+ done = (
1377
+ self.time[all_chains[i][0]] <= starting_time
1378
+ ) and not curr_found
1379
+ if not done:
1380
+ if curr_found:
1381
+ current_chain.append(self.get_successors(n))
1382
+ else:
1383
+ current_chain.append(all_chains[i])
1384
+ i += 1
1385
+ output_chains[n] = current_chain
1386
+ return output_chains
1460
1387
 
1461
- def find_leaves(self, roots: Union[int, set, list, tuple]):
1388
+ def find_leaves(self, roots: int | Iterable) -> set[int]:
1462
1389
  """Finds the leaves of a tree spawned by one or more nodes.
1463
1390
 
1464
- Args:
1465
- roots (Union[int,set,list,tuple]): The roots of the trees.
1391
+ Parameters
1392
+ ----------
1393
+ roots : int or Iterable
1394
+ The roots of the trees spawning the leaves
1466
1395
 
1467
- Returns:
1468
- set: The leaves of one or more trees.
1396
+ Returns
1397
+ -------
1398
+ set
1399
+ The leaves of one or more trees.
1469
1400
  """
1470
1401
  if not isinstance(roots, Iterable):
1471
1402
  to_do = [roots]
@@ -1474,29 +1405,34 @@ class lineageTree(lineageTreeLoaders):
1474
1405
  leaves = set()
1475
1406
  while to_do:
1476
1407
  curr = to_do.pop()
1477
- succ = self.successor.get(curr)
1478
- if succ is not None:
1408
+ succ = self._successor[curr]
1409
+ if not succ:
1479
1410
  leaves.add(curr)
1480
- else:
1481
- to_do += succ
1411
+ to_do += succ
1482
1412
  return leaves
1483
1413
 
1484
- def get_sub_tree(
1414
+ def get_subtree_nodes(
1485
1415
  self,
1486
- x: Union[int, Iterable],
1487
- end_time: Union[int, None] = None,
1416
+ x: int | Iterable,
1417
+ end_time: int | None = None,
1488
1418
  preorder: bool = False,
1489
- ) -> list:
1490
- """Computes the list of cells from the subtree spawned by *x*
1491
- The default output order is breadth first traversal.
1419
+ ) -> list[int]:
1420
+ """Computes the list of nodes from the subtree spawned by *x*
1421
+ The default output order is Breadth First Traversal.
1492
1422
  Unless preorder is `True` in that case the order is
1493
- Depth first traversal preordered.
1423
+ Depth First Traversal (DFT) preordered.
1424
+
1425
+ Parameters
1426
+ ----------
1427
+ x : int
1428
+ id of root node
1429
+ preorder : bool, default=False
1430
+ if True the output preorder is DFT
1494
1431
 
1495
- Args:
1496
- x (int): id of root node
1497
- preorder (bool): if True the output is preorder DFT
1498
- Returns:
1499
- ([int, ...]): the ordered list of node ids
1432
+ Returns
1433
+ -------
1434
+ list of int
1435
+ the ordered list of node ids
1500
1436
  """
1501
1437
  if not end_time:
1502
1438
  end_time = self.t_e
@@ -1504,233 +1440,258 @@ class lineageTree(lineageTreeLoaders):
1504
1440
  to_do = [x]
1505
1441
  elif isinstance(x, Iterable):
1506
1442
  to_do = list(x)
1507
- sub_tree = []
1443
+ subtree = []
1508
1444
  while to_do:
1509
1445
  curr = to_do.pop()
1510
- succ = self.successor.get(curr, [])
1511
- if succ and end_time < self.time.get(curr, end_time):
1446
+ succ = self._successor[curr]
1447
+ if succ and end_time < self._time.get(curr, end_time):
1512
1448
  succ = []
1513
1449
  continue
1514
1450
  if preorder:
1515
1451
  to_do = succ + to_do
1516
1452
  else:
1517
1453
  to_do += succ
1518
- sub_tree += [curr]
1519
- return sub_tree
1454
+ subtree += [curr]
1455
+ return subtree
1520
1456
 
1521
1457
  def compute_spatial_density(
1522
- self, t_b: int = None, t_e: int = None, th: float = 50
1523
- ) -> dict:
1524
- """Computes the spatial density of cells between `t_b` and `t_e`.
1525
- The spatial density is computed as follow:
1526
- #cell/(4/3*pi*th^3)
1527
- The results is stored in self.spatial_density is returned.
1528
-
1529
- Args:
1530
- t_b (int): starting time to look at, default first time point
1531
- t_e (int): ending time to look at, default last time point
1532
- th (float): size of the neighbourhood
1533
- Returns:
1534
- {int, float}: dictionary that maps a cell id to its spatial density
1535
- """
1458
+ self, t_b: int | None = None, t_e: int | None = None, th: float = 50
1459
+ ) -> dict[int, float]:
1460
+ """Computes the spatial density of nodes between `t_b` and `t_e`.
1461
+ The results is stored in `self.spatial_density` and returned.
1462
+
1463
+ Parameters
1464
+ ----------
1465
+ t_b : int, optional
1466
+ starting time to look at, default first time point
1467
+ t_e : int, optional
1468
+ ending time to look at, default last time point
1469
+ th : float, default=50
1470
+ size of the neighbourhood
1471
+
1472
+ Returns
1473
+ -------
1474
+ dict mapping int to float
1475
+ dictionary that maps a node id to its spatial density
1476
+ """
1477
+ if not hasattr(self, "spatial_density"):
1478
+ self.spatial_density = {}
1536
1479
  s_vol = 4 / 3.0 * np.pi * th**3
1537
- time_range = set(range(t_b, t_e + 1)).intersection(self.time_nodes)
1480
+ if t_b is None:
1481
+ t_b = self.t_b
1482
+ if t_e is None:
1483
+ t_e = self.t_e
1484
+ time_range = set(range(t_b, t_e)).intersection(self._time.values())
1538
1485
  for t in time_range:
1539
1486
  idx3d, nodes = self.get_idx3d(t)
1540
1487
  nb_ni = [
1541
1488
  (len(ni) - 1) / s_vol
1542
1489
  for ni in idx3d.query_ball_tree(idx3d, th)
1543
1490
  ]
1544
- self.spatial_density.update(dict(zip(nodes, nb_ni)))
1491
+ self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
1545
1492
  return self.spatial_density
1546
1493
 
1547
- def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
1494
+ def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
1548
1495
  """Computes the k-nearest neighbors
1549
1496
  Writes the output in the attribute `kn_graph`
1550
1497
  and returns it.
1551
1498
 
1552
- Args:
1553
- k (float): number of nearest neighours
1554
- Returns:
1555
- {int, set([int, ...])}: dictionary that maps
1556
- a cell id to its `k` nearest neighbors
1499
+ Parameters
1500
+ ----------
1501
+ k : float
1502
+ number of nearest neighours
1503
+
1504
+ Returns
1505
+ -------
1506
+ dict mapping int to set of int
1507
+ dictionary that maps
1508
+ a node id to its `k` nearest neighbors
1557
1509
  """
1558
1510
  self.kn_graph = {}
1559
- for t, nodes in self.time_nodes.items():
1560
- use_k = k if k < len(nodes) else len(nodes)
1561
- idx3d, nodes = self.get_idx3d(t)
1562
- pos = [self.pos[c] for c in nodes]
1563
- _, neighbs = idx3d.query(pos, use_k)
1564
- out = dict(zip(nodes, [set(nodes[ni[1:]]) for ni in neighbs]))
1565
- self.kn_graph.update(out)
1511
+ for t in set(self._time.values()):
1512
+ nodes = self.nodes_at_t(t)
1513
+ if 1 < len(nodes):
1514
+ use_k = k if k < len(nodes) else len(nodes)
1515
+ idx3d, nodes = self.get_idx3d(t)
1516
+ pos = [self.pos[c] for c in nodes]
1517
+ _, neighbs = idx3d.query(pos, use_k)
1518
+ out = dict(
1519
+ zip(
1520
+ nodes,
1521
+ map(set, nodes[neighbs]),
1522
+ strict=True,
1523
+ )
1524
+ )
1525
+ self.kn_graph.update(out)
1526
+ else:
1527
+ n = nodes.pop
1528
+ self.kn_graph.update({n: {n}})
1566
1529
  return self.kn_graph
1567
1530
 
1568
- def compute_spatial_edges(self, th: int = 50) -> dict:
1531
+ def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
1569
1532
  """Computes the neighbors at a distance `th`
1570
1533
  Writes the output in the attribute `th_edge`
1571
1534
  and returns it.
1572
1535
 
1573
- Args:
1574
- th (float): distance to consider neighbors
1575
- Returns:
1576
- {int, set([int, ...])}: dictionary that maps
1577
- a cell id to its neighbors at a distance `th`
1536
+ Parameters
1537
+ ----------
1538
+ th : float, default=50
1539
+ distance to consider neighbors
1540
+
1541
+ Returns
1542
+ -------
1543
+ dict mapping int to set of int
1544
+ dictionary that maps a node id to its neighbors at a distance `th`
1578
1545
  """
1579
1546
  self.th_edges = {}
1580
- for t, _ in self.time_nodes.items():
1547
+ for t in set(self._time.values()):
1548
+ nodes = self.nodes_at_t(t)
1581
1549
  idx3d, nodes = self.get_idx3d(t)
1582
1550
  neighbs = idx3d.query_ball_tree(idx3d, th)
1583
- out = dict(zip(nodes, [set(nodes[ni]) for ni in neighbs]))
1551
+ out = dict(
1552
+ zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
1553
+ )
1584
1554
  self.th_edges.update(
1585
1555
  {k: v.difference([k]) for k, v in out.items()}
1586
1556
  )
1587
1557
  return self.th_edges
1588
1558
 
1589
- def main_axes(self, time: int = None):
1590
- """Finds the main axes for a timepoint.
1591
- If none will select the timepoint with the highest amound of cells.
1592
-
1593
- Args:
1594
- time (int, optional): The timepoint to find the main axes.
1595
- If None will find the timepoint
1596
- with the largest number of cells.
1597
-
1598
- Returns:
1599
- list: A list that contains the array of eigenvalues and eigenvectors.
1600
- """
1601
- if time is None:
1602
- time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
1603
- pos = np.array([self.pos[node] for node in self.time_nodes[time]])
1604
- pos = pos - np.mean(pos, axis=0)
1605
- cov = np.cov(np.array(pos).T)
1606
- eig_val, eig_vec = np.linalg.eig(cov)
1607
- srt = np.argsort(eig_val)[::-1]
1608
- self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
1609
- return eig_val[srt], eig_vec[:, srt]
1610
-
1611
- def scale_embryo(self, scale=1000):
1612
- """Scale the embryo using their eigenvalues.
1613
-
1614
- Args:
1615
- scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
1616
-
1617
- Returns:
1618
- float: The scale factor.
1619
- """
1620
- eig = self.main_axes()[0]
1621
- return scale / (np.sqrt(eig[0]))
1622
-
1623
- @staticmethod
1624
- def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
1625
- """Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
1626
- Uses the Rodrigues rotation formula.
1627
-
1628
- Args:
1629
- vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
1630
- vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
1631
-
1632
- Returns:
1633
- np.array: The rotation matrix.
1634
- """
1635
- vector1 = vector1 / np.linalg.norm(vector1)
1636
- vector2 = vector2 / np.linalg.norm(vector2)
1637
- if vector1 @ vector2 == 1:
1638
- return np.eye(3)
1639
- angle = np.arccos(vector1 @ vector2)
1640
- axis = np.cross(vector1, vector2)
1641
- axis = axis / np.linalg.norm(axis)
1642
- K = np.array(
1643
- [
1644
- [0, -axis[2], axis[1]],
1645
- [axis[2], 0, -axis[0]],
1646
- [-axis[1], axis[0], 0],
1647
- ]
1648
- )
1649
- return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
1650
-
1651
- def get_ancestor_at_t(self, n: int, time: int = None):
1652
- """
1653
- Find the id of the ancestor of a give node `n`
1559
+ def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
1560
+ """Find the id of the ancestor of a give node `n`
1654
1561
  at a given time `time`.
1655
1562
 
1656
- If there is no ancestor, returns `-1`
1657
- If time is None return the root of the sub tree that spawns
1563
+ If there is no ancestor, returns `None`
1564
+ If time is None return the root of the subtree that spawns
1658
1565
  the node n.
1659
1566
 
1660
- Args:
1661
- n (int): node for which to look the ancestor
1662
- time (int): time at which the ancestor has to be found.
1663
- If `None` the ancestor at the first time point
1664
- will be found (default `None`)
1665
-
1666
- Returns:
1667
- (int): the id of the ancestor at time `time`,
1668
- `-1` if it does not exist
1567
+ Parameters
1568
+ ----------
1569
+ n : int
1570
+ node for which to look the ancestor
1571
+ time : int, optional
1572
+ time at which the ancestor has to be found.
1573
+ If `None` the ancestor at the first time point
1574
+ will be found.
1575
+
1576
+ Returns
1577
+ -------
1578
+ int
1579
+ the id of the ancestor at time `time`,
1580
+ `-1` if there is no ancestor.
1669
1581
  """
1670
1582
  if n not in self.nodes:
1671
- return
1583
+ return -1
1672
1584
  if time is None:
1673
1585
  time = self.t_b
1674
1586
  ancestor = n
1675
1587
  while (
1676
- time < self.time.get(ancestor, -1) and ancestor in self.predecessor
1588
+ time < self._time.get(ancestor, self.t_b - 1)
1589
+ and self._predecessor[ancestor]
1677
1590
  ):
1678
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1679
- return ancestor
1591
+ ancestor = self._predecessor[ancestor][0]
1592
+ if self._time.get(ancestor, self.t_b - 1) == time:
1593
+ return ancestor
1594
+ else:
1595
+ return -1
1680
1596
 
1681
- def get_labelled_ancestor(self, node: int):
1682
- """Finds the first labelled ancestor and returns its ID otherwise returns None
1597
+ def get_labelled_ancestor(self, node: int) -> int:
1598
+ """Finds the first labelled ancestor and returns its ID otherwise returns -1
1683
1599
 
1684
- Args:
1685
- node (int): The id of the node
1600
+ Parameters
1601
+ ----------
1602
+ node : int
1603
+ The id of the node
1686
1604
 
1687
- Returns:
1688
- [None,int]: Returns the first ancestor found that has a label otherwise
1689
- None.
1605
+ Returns
1606
+ -------
1607
+ int
1608
+ Returns the first ancestor found that has a label otherwise `-1`.
1690
1609
  """
1691
1610
  if node not in self.nodes:
1692
- return None
1611
+ return -1
1693
1612
  ancestor = node
1694
1613
  while (
1695
- self.t_b <= self.time.get(ancestor, self.t_b - 1)
1614
+ self.t_b <= self._time.get(ancestor, self.t_b - 1)
1696
1615
  and ancestor != -1
1697
1616
  ):
1698
1617
  if ancestor in self.labels:
1699
1618
  return ancestor
1700
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1701
- return
1619
+ ancestor = self._predecessor.get(ancestor, [-1])[0]
1620
+ return -1
1621
+
1622
+ def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
1623
+ """General purpose function to help with searching the first ancestor that has an attribute.
1624
+ Similar to get_labeled_ancestor and may make it redundant.
1625
+
1626
+ Parameters
1627
+ ----------
1628
+ node : int
1629
+ The id of the node
1630
+
1631
+ Returns
1632
+ -------
1633
+ int
1634
+ Returns the first ancestor found that has an attribute otherwise `-1`.
1635
+ """
1636
+ attr_dict = self.__getattribute__(attribute)
1637
+ if not isinstance(attr_dict, dict):
1638
+ raise ValueError("Please select a dict attribute")
1639
+ if node not in self.nodes:
1640
+ return -1
1641
+ if node in attr_dict:
1642
+ return node
1643
+ if node in self.roots:
1644
+ return -1
1645
+ ancestor = (node,)
1646
+ while ancestor and ancestor != [-1]:
1647
+ ancestor = ancestor[0]
1648
+ if ancestor in attr_dict:
1649
+ return ancestor
1650
+ ancestor = self._predecessor.get(ancestor, [-1])
1651
+ return -1
1702
1652
 
1703
1653
  def unordered_tree_edit_distances_at_time_t(
1704
1654
  self,
1705
1655
  t: int,
1706
- end_time: int = None,
1707
- style="simple",
1656
+ end_time: int | None = None,
1657
+ style: (
1658
+ Literal["simple", "full", "downsampled", "normalized_simple"]
1659
+ | type[TreeApproximationTemplate]
1660
+ ) = "simple",
1708
1661
  downsample: int = 2,
1709
- normalize: bool = True,
1662
+ norm: Literal["max", "sum", None] = "max",
1710
1663
  recompute: bool = False,
1711
- ) -> dict:
1712
- """
1713
- Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1714
-
1715
- Args:
1716
- t (int): time to look at
1717
- delta (callable): comparison function (see edist doc for more information)
1718
- norm (callable): norming function that takes the number of nodes
1719
- of the tree spawned by `n1` and the number of nodes
1720
- of the tree spawned by `n2` as arguments.
1721
- recompute (bool): if True, forces to recompute the distances (default: False)
1722
- end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
1723
- will be taken into account.
1724
-
1725
- Returns:
1726
- (dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
1664
+ ) -> dict[tuple[int, int], float]:
1665
+ """Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1666
+
1667
+ Parameters
1668
+ ----------
1669
+ t : int
1670
+ time to look at
1671
+ end_time : int
1672
+ The final time point the comparison algorithm will take into account.
1673
+ If None all nodes will be taken into account.
1674
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1675
+ Which tree approximation is going to be used for the comparisons.
1676
+ downsample : int, default=2
1677
+ The downsample factor for the downsampled tree approximation.
1678
+ Used only when `style="downsampled"`.
1679
+ norm : {"max", "sum"}, default="max"
1680
+ The normalization method to use.
1681
+ recompute : bool, default=False
1682
+ If True, forces to recompute the distances
1683
+
1684
+ Returns
1685
+ -------
1686
+ dict mapping a tuple of tuple that contains 2 ints to float
1687
+ a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
1727
1688
  """
1728
1689
  if not hasattr(self, "uted"):
1729
1690
  self.uted = {}
1730
1691
  elif t in self.uted and not recompute:
1731
1692
  return self.uted[t]
1732
1693
  self.uted[t] = {}
1733
- roots = self.time_nodes[t]
1694
+ roots = self.nodes_at_t(t=t)
1734
1695
  for n1, n2 in combinations(roots, 2):
1735
1696
  key = tuple(sorted((n1, n2)))
1736
1697
  self.uted[t][key] = self.unordered_tree_edit_distance(
@@ -1739,37 +1700,132 @@ class lineageTree(lineageTreeLoaders):
1739
1700
  end_time=end_time,
1740
1701
  style=style,
1741
1702
  downsample=downsample,
1742
- normalize=normalize,
1703
+ norm=norm,
1743
1704
  )
1744
1705
  return self.uted[t]
1745
1706
 
1746
- def unordered_tree_edit_distance(
1707
+ def __calculate_distance_of_sub_tree(
1708
+ self,
1709
+ node1: int,
1710
+ node2: int,
1711
+ alignment: Alignment,
1712
+ corres1: dict[int, int],
1713
+ corres2: dict[int, int],
1714
+ delta_tmp: Callable,
1715
+ norm: Callable,
1716
+ norm1: int | float,
1717
+ norm2: int | float,
1718
+ ) -> float:
1719
+ """Calculates the distance of the subtree of each node matched in a comparison.
1720
+ DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
1721
+ TODO ITS BOUND TO CHANGE
1722
+ Parameters
1723
+ ----------
1724
+ node1 : int
1725
+ The root of the first subtree
1726
+ node2 : int
1727
+ The root of the second subtree
1728
+ alignment : Alignment
1729
+ The alignment of the subtree
1730
+ corres1 : dict
1731
+ The correspndance dictionary of the first lineage
1732
+ corres2 : dict
1733
+ The correspondance dictionary of the second lineage
1734
+ delta_tmp : Callable
1735
+ The delta function for the comparisons
1736
+ norm : Callable
1737
+ How should the lineages be normalized
1738
+ norm1 : int or float
1739
+ The result of the normalization of the first tree
1740
+ norm2 : int or float
1741
+ The result of the normalization of the second tree
1742
+
1743
+ Returns
1744
+ -------
1745
+ float
1746
+ The result of the comparison of the subtree
1747
+ """
1748
+ sub_tree_1 = set(self.get_subtree_nodes(node1))
1749
+ sub_tree_2 = set(self.get_subtree_nodes(node2))
1750
+ res = 0
1751
+ for m in alignment:
1752
+ if (
1753
+ corres1.get(m._left, -1) in sub_tree_1
1754
+ or corres2.get(m._right, -1) in sub_tree_2
1755
+ ):
1756
+ res += delta_tmp(
1757
+ m._left if m._left != -1 else None,
1758
+ m._right if m._right != -1 else None,
1759
+ )
1760
+ return res / norm([norm1, norm2])
1761
+
1762
+ def clear_comparisons(self):
1763
+ self._comparisons.clear()
1764
+
1765
+ def __unordereded_backtrace(
1747
1766
  self,
1748
1767
  n1: int,
1749
1768
  n2: int,
1750
- end_time: int = None,
1751
- norm: Union["max", "sum", None] = "max",
1752
- style="simple",
1769
+ end_time: int | None = None,
1770
+ norm: Literal["max", "sum", None] = "max",
1771
+ style: (
1772
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1773
+ | type[TreeApproximationTemplate]
1774
+ ) = "simple",
1753
1775
  downsample: int = 2,
1754
- ) -> float:
1776
+ ) -> dict[
1777
+ str,
1778
+ Alignment
1779
+ | tuple[TreeApproximationTemplate, TreeApproximationTemplate],
1780
+ ]:
1755
1781
  """
1756
- Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
1782
+ Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
1757
1783
  by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
1758
1784
  cost is given by the function delta (see edist doc for more information).
1759
- The distance is normed by the function norm that takes the two list of nodes
1760
- spawned by the trees `n1` and `n2`.
1761
1785
 
1762
- Args:
1763
- n1 (int): id of the first node to compare
1764
- n2 (int): id of the second node to compare
1765
- tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
1766
- Defaults to "fragmented".
1767
-
1768
- Returns:
1769
- (float) The normed unordered tree edit distance
1770
- """
1771
-
1772
- tree = tree_style[style].value
1786
+ Parameters
1787
+ ----------
1788
+ n1 : int
1789
+ id of the first node to compare
1790
+ n2 : int
1791
+ id of the second node to compare
1792
+ end_time : int
1793
+ The final time point the comparison algorithm will take into account.
1794
+ If None all nodes will be taken into account.
1795
+ norm : {"max", "sum"}, default="max"
1796
+ The normalization method to use.
1797
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1798
+ Which tree approximation is going to be used for the comparisons.
1799
+ downsample : int, default=2
1800
+ The downsample factor for the downsampled tree approximation.
1801
+ Used only when `style="downsampled"`.
1802
+
1803
+ Returns
1804
+ -------
1805
+ dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
1806
+ - 'alignment'
1807
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
1808
+ - 'trees'
1809
+ A list of the two trees that have been mapped to each other.
1810
+ """
1811
+
1812
+ parameters = (
1813
+ end_time,
1814
+ convert_style_to_number(style=style, downsample=downsample),
1815
+ )
1816
+ n1, n2 = sorted([n1, n2])
1817
+ self._comparisons.setdefault(parameters, {})
1818
+ if len(self._comparisons) > 100:
1819
+ warnings.warn(
1820
+ "More than 100 comparisons are saved, use clear_comparisons() to delete them.",
1821
+ stacklevel=2,
1822
+ )
1823
+ if isinstance(style, str):
1824
+ tree = tree_style[style].value
1825
+ elif issubclass(style, TreeApproximationTemplate):
1826
+ tree = style
1827
+ else:
1828
+ raise ValueError("Please use a valid approximation.")
1773
1829
  tree1 = tree(
1774
1830
  lT=self,
1775
1831
  downsample=downsample,
@@ -1798,7 +1854,11 @@ class lineageTree(lineageTreeLoaders):
1798
1854
  corres2,
1799
1855
  ) = tree2.edist
1800
1856
  if len(nodes1) == len(nodes2) == 0:
1801
- return 0
1857
+ self._comparisons[parameters][(n1, n2)] = {
1858
+ "alignment": (),
1859
+ "trees": (),
1860
+ }
1861
+ return self._comparisons[parameters][(n1, n2)]
1802
1862
  delta_tmp = partial(
1803
1863
  delta,
1804
1864
  corres1=corres1,
@@ -1806,127 +1866,538 @@ class lineageTree(lineageTreeLoaders):
1806
1866
  times1=times1,
1807
1867
  times2=times2,
1808
1868
  )
1809
- norm1 = tree1.get_norm()
1810
- norm2 = tree2.get_norm()
1811
- norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
1812
- if norm is None:
1813
- norm = "None"
1814
- if norm not in norm_dict:
1869
+ btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
1870
+
1871
+ self._comparisons[parameters][(n1, n2)] = {
1872
+ "alignment": btrc,
1873
+ "trees": (tree1, tree2),
1874
+ }
1875
+ return self._comparisons[parameters][(n1, n2)]
1876
+
1877
+ def plot_tree_distance_graphs(
1878
+ self,
1879
+ n1: int,
1880
+ n2: int,
1881
+ end_time: int | None = None,
1882
+ norm: Literal["max", "sum", None] = "max",
1883
+ style: (
1884
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1885
+ | type[TreeApproximationTemplate]
1886
+ ) = "simple",
1887
+ downsample: int = 2,
1888
+ colormap: str = "cool",
1889
+ default_color: str = "black",
1890
+ size: float = 10,
1891
+ lw: float = 0.3,
1892
+ ax: list[plt.Axes] | None = None,
1893
+ ) -> tuple[plt.figure, plt.Axes]:
1894
+ """
1895
+ Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
1896
+
1897
+ Parameters
1898
+ ----------
1899
+ n1 : int
1900
+ id of the first node to compare
1901
+ n2 : int
1902
+ id of the second node to compare
1903
+ end_time : int
1904
+ The final time point the comparison algorithm will take into account.
1905
+ If None all nodes will be taken into account.
1906
+ norm : {"max", "sum"}, default="max"
1907
+ The normalization method to use.
1908
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
1909
+ Which tree approximation is going to be used for the comparisons.
1910
+ downsample : int, default=2
1911
+ The downsample factor for the downsampled tree approximation.
1912
+ Used only when `style="downsampled"`.
1913
+ colormap : str, default="cool"
1914
+ The colormap used for matched nodes, defaults to "cool"
1915
+ default_color : str
1916
+ The color of the unmatched nodes, defaults to "black"
1917
+ size : float
1918
+ The size of the nodes, defaults to 10
1919
+ lw : float
1920
+ The width of the edges, defaults to 0.3
1921
+ ax : np.ndarray, optional
1922
+ The axes used, if not provided another set of axes is produced, defaults to None
1923
+
1924
+ Returns
1925
+ -------
1926
+ plt.Figure
1927
+ The figure of the plot
1928
+ plt.Axes
1929
+ The axes of the plot
1930
+ """
1931
+ parameters = (
1932
+ end_time,
1933
+ convert_style_to_number(style=style, downsample=downsample),
1934
+ )
1935
+ n1, n2 = sorted([n1, n2])
1936
+ self._comparisons.setdefault(parameters, {})
1937
+ if self._comparisons[parameters].get((n1, n2)):
1938
+ tmp = self._comparisons[parameters][(n1, n2)]
1939
+ else:
1940
+ tmp = self.__unordereded_backtrace(
1941
+ n1, n2, end_time, norm, style, downsample
1942
+ )
1943
+ btrc: Alignment = tmp["alignment"]
1944
+ tree1, tree2 = tmp["trees"]
1945
+ _, times1 = tree1.tree
1946
+ _, times2 = tree2.tree
1947
+ (
1948
+ *_,
1949
+ corres1,
1950
+ ) = tree1.edist
1951
+ (
1952
+ *_,
1953
+ corres2,
1954
+ ) = tree2.edist
1955
+ delta_tmp = partial(
1956
+ tree1.delta,
1957
+ corres1=corres1,
1958
+ corres2=corres2,
1959
+ times1=times1,
1960
+ times2=times2,
1961
+ )
1962
+
1963
+ if norm not in self.norm_dict:
1964
+ raise Warning(
1965
+ "Select a viable normalization method (max, sum, None)"
1966
+ )
1967
+ matched_right = []
1968
+ matched_left = []
1969
+ colors = {}
1970
+ if style not in ("full", "downsampled"):
1971
+ for m in btrc:
1972
+ if m._left != -1 and m._right != -1:
1973
+ cyc1 = self.get_chain_of_node(corres1[m._left])
1974
+ if len(cyc1) > 1:
1975
+ node_1, *_, l_node_1 = cyc1
1976
+ matched_left.append(node_1)
1977
+ matched_left.append(l_node_1)
1978
+ elif len(cyc1) == 1:
1979
+ node_1 = l_node_1 = cyc1.pop()
1980
+ matched_left.append(node_1)
1981
+
1982
+ cyc2 = self.get_chain_of_node(corres2[m._right])
1983
+ if len(cyc2) > 1:
1984
+ node_2, *_, l_node_2 = cyc2
1985
+ matched_right.append(node_2)
1986
+ matched_right.append(l_node_2)
1987
+
1988
+ elif len(cyc2) == 1:
1989
+ node_2 = l_node_2 = cyc2.pop()
1990
+ matched_right.append(node_2)
1991
+
1992
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
1993
+ node_1,
1994
+ node_2,
1995
+ btrc,
1996
+ corres1,
1997
+ corres2,
1998
+ delta_tmp,
1999
+ self.norm_dict[norm],
2000
+ tree1.get_norm(node_1),
2001
+ tree2.get_norm(node_2),
2002
+ )
2003
+ colors[node_2] = colors[node_1]
2004
+ colors[l_node_1] = colors[node_1]
2005
+ colors[l_node_2] = colors[node_2]
2006
+ else:
2007
+ for m in btrc:
2008
+ if m._left != -1 and m._right != -1:
2009
+ node_1 = corres1[m._left]
2010
+ node_2 = corres2[m._right]
2011
+
2012
+ if (
2013
+ self.get_chain_of_node(node_1)[0] == node_1
2014
+ or self.get_chain_of_node(node_2)[0] == node_2
2015
+ and (node_1 not in colors or node_2 not in colors)
2016
+ ):
2017
+ matched_left.append(node_1)
2018
+ l_node_1 = self.get_chain_of_node(node_1)[-1]
2019
+ matched_left.append(l_node_1)
2020
+ matched_right.append(node_2)
2021
+ l_node_2 = self.get_chain_of_node(node_2)[-1]
2022
+ matched_right.append(l_node_2)
2023
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
2024
+ node_1,
2025
+ node_2,
2026
+ btrc,
2027
+ corres1,
2028
+ corres2,
2029
+ delta_tmp,
2030
+ self.norm_dict[norm],
2031
+ tree1.get_norm(node_1),
2032
+ tree2.get_norm(node_2),
2033
+ )
2034
+ colors[l_node_1] = colors[node_1]
2035
+ colors[node_2] = colors[node_1]
2036
+ colors[l_node_2] = colors[node_1]
2037
+ if ax is None:
2038
+ fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
2039
+ cmap = colormaps[colormap]
2040
+ c_norm = mcolors.Normalize(0, 1)
2041
+ colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
2042
+ self.plot_subtree(
2043
+ self.get_ancestor_at_t(n1),
2044
+ end_time=end_time,
2045
+ size=size,
2046
+ selected_nodes=matched_left,
2047
+ color_of_nodes=colors,
2048
+ selected_edges=matched_left,
2049
+ color_of_edges=colors,
2050
+ default_color=default_color,
2051
+ lw=lw,
2052
+ ax=ax[0],
2053
+ )
2054
+ self.plot_subtree(
2055
+ self.get_ancestor_at_t(n2),
2056
+ end_time=end_time,
2057
+ size=size,
2058
+ selected_nodes=matched_right,
2059
+ color_of_nodes=colors,
2060
+ selected_edges=matched_right,
2061
+ color_of_edges=colors,
2062
+ default_color=default_color,
2063
+ lw=lw,
2064
+ ax=ax[1],
2065
+ )
2066
+ return ax[0].get_figure(), ax
2067
+
2068
+ def labelled_mappings(
2069
+ self,
2070
+ n1: int,
2071
+ n2: int,
2072
+ end_time: int | None = None,
2073
+ norm: Literal["max", "sum", None] = "max",
2074
+ style: (
2075
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2076
+ | type[TreeApproximationTemplate]
2077
+ ) = "simple",
2078
+ downsample: int = 2,
2079
+ ) -> dict[str, list[str]]:
2080
+ """
2081
+ Returns the labels or IDs of all the nodes in the subtrees compared.
2082
+
2083
+
2084
+ Parameters
2085
+ ----------
2086
+ n1 : int
2087
+ id of the first node to compare
2088
+ n2 : int
2089
+ id of the second node to compare
2090
+ end_time : int, optional
2091
+ The final time point the comparison algorithm will take into account.
2092
+ If None or not provided all nodes will be taken into account.
2093
+ norm : {"max", "sum"}, default="max"
2094
+ The normalization method to use, defaults to 'max'.
2095
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
2096
+ Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
2097
+ downsample : int, default=2
2098
+ The downsample factor for the downsampled tree approximation.
2099
+ Used only when `style="downsampled"`.
2100
+
2101
+ Returns
2102
+ -------
2103
+ dict mapping str to list of str
2104
+ - 'matched' The labels of the matched nodes of the alignment.
2105
+ - 'unmatched' The labels of the unmatched nodes of the alginment.
2106
+ """
2107
+ parameters = (
2108
+ end_time,
2109
+ convert_style_to_number(style=style, downsample=downsample),
2110
+ )
2111
+ n1, n2 = sorted([n1, n2])
2112
+ self._comparisons.setdefault(parameters, {})
2113
+ if self._comparisons[parameters].get((n1, n2)):
2114
+ tmp = self._comparisons[parameters][(n1, n2)]
2115
+ else:
2116
+ tmp = self.__unordereded_backtrace(
2117
+ n1, n2, end_time, norm, style, downsample
2118
+ )
2119
+ btrc = tmp["alignment"]
2120
+ tree1, tree2 = tmp["trees"]
2121
+
2122
+ (
2123
+ *_,
2124
+ corres1,
2125
+ ) = tree1.edist
2126
+ (
2127
+ *_,
2128
+ corres2,
2129
+ ) = tree2.edist
2130
+
2131
+ if norm not in self.norm_dict:
1815
2132
  raise Warning(
1816
2133
  "Select a viable normalization method (max, sum, None)"
1817
2134
  )
1818
- return uted.uted(
1819
- nodes1, adj1, nodes2, adj2, delta=delta_tmp
1820
- ) / norm_dict[norm]([norm1, norm2])
2135
+ matched = []
2136
+ unmatched = []
2137
+ if style not in ("full", "downsampled"):
2138
+ for m in btrc:
2139
+ if m._left != -1 and m._right != -1:
2140
+ cyc1 = self.get_chain_of_node(corres1[m._left])
2141
+ if len(cyc1) > 1:
2142
+ node_1, *_ = cyc1
2143
+ elif len(cyc1) == 1:
2144
+ node_1 = cyc1.pop()
2145
+ cyc2 = self.get_chain_of_node(corres2[m._right])
2146
+ if len(cyc2) > 1:
2147
+ node_2, *_ = cyc2
2148
+ elif len(cyc2) == 1:
2149
+ node_2 = cyc2.pop()
2150
+ matched.append(
2151
+ (
2152
+ self.labels.get(node_1, node_1),
2153
+ self.labels.get(node_2, node_2),
2154
+ )
2155
+ )
2156
+
2157
+ else:
2158
+ if m._left != -1:
2159
+ node_1 = self.get_chain_of_node(
2160
+ corres1.get(m._left, "-")
2161
+ )[0]
2162
+ else:
2163
+ node_1 = self.get_chain_of_node(
2164
+ corres2.get(m._right, "-")
2165
+ )[0]
2166
+ unmatched.append(self.labels.get(node_1, node_1))
2167
+ else:
2168
+ for m in btrc:
2169
+ if m._left != -1 and m._right != -1:
2170
+ node_1 = corres1[m._left]
2171
+ node_2 = corres2[m._right]
2172
+ matched.append(
2173
+ (
2174
+ self.labels.get(node_1, node_1),
2175
+ self.labels.get(node_2, node_2),
2176
+ )
2177
+ )
2178
+ else:
2179
+ if m._left != -1:
2180
+ node_1 = corres1[m._left]
2181
+ else:
2182
+ node_1 = corres2[m._right]
2183
+ unmatched.append(self.labels.get(node_1, node_1))
2184
+ return {"matched": matched, "unmatched": unmatched}
2185
+
2186
+ def unordered_tree_edit_distance(
2187
+ self,
2188
+ n1: int,
2189
+ n2: int,
2190
+ end_time: int | None = None,
2191
+ norm: Literal["max", "sum", None] = "max",
2192
+ style: (
2193
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2194
+ | type[TreeApproximationTemplate]
2195
+ ) = "simple",
2196
+ downsample: int = 2,
2197
+ return_norms: bool = False,
2198
+ ) -> float | tuple[float, tuple[float, float]]:
2199
+ """
2200
+ Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
2201
+ by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
2202
+ cost is given by the function delta (see edist doc for more information).
2203
+ The distance is normed by the function norm that takes the two list of nodes
2204
+ spawned by the trees `n1` and `n2`.
2205
+
2206
+ Parameters
2207
+ ----------
2208
+ n1 : int
2209
+ id of the first node to compare
2210
+ n2 : int
2211
+ id of the second node to compare
2212
+ end_time : int, optional
2213
+ The final time point the comparison algorithm will take into account.
2214
+ If None or not provided all nodes will be taken into account.
2215
+ norm : {"max", "sum"}, default="max"
2216
+ The normalization method to use, defaults to 'max'.
2217
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
2218
+ Which tree approximation is going to be used for the comparisons.
2219
+ downsample : int, default=2
2220
+ The downsample factor for the downsampled tree approximation.
2221
+ Used only when `style="downsampled"`.
2222
+
2223
+ Returns
2224
+ -------
2225
+ float
2226
+ The normalized unordered tree edit distance between `n1` and `n2`
2227
+ """
2228
+ parameters = (
2229
+ end_time,
2230
+ convert_style_to_number(style=style, downsample=downsample),
2231
+ )
2232
+ n1, n2 = sorted([n1, n2])
2233
+ self._comparisons.setdefault(parameters, {})
2234
+ if self._comparisons[parameters].get((n1, n2)):
2235
+ tmp = self._comparisons[parameters][(n1, n2)]
2236
+ else:
2237
+ tmp = self.__unordereded_backtrace(
2238
+ n1, n2, end_time, norm, style, downsample
2239
+ )
2240
+ btrc = tmp["alignment"]
2241
+ tree1, tree2 = tmp["trees"]
2242
+ _, times1 = tree1.tree
2243
+ _, times2 = tree2.tree
2244
+ (
2245
+ nodes1,
2246
+ adj1,
2247
+ corres1,
2248
+ ) = tree1.edist
2249
+ (
2250
+ nodes2,
2251
+ adj2,
2252
+ corres2,
2253
+ ) = tree2.edist
2254
+ delta_tmp = partial(
2255
+ tree1.delta,
2256
+ corres1=corres1,
2257
+ corres2=corres2,
2258
+ times1=times1,
2259
+ times2=times2,
2260
+ )
2261
+
2262
+ if norm not in self.norm_dict:
2263
+ raise ValueError(
2264
+ "Select a viable normalization method (max, sum, None)"
2265
+ )
2266
+ cost = btrc.cost(nodes1, nodes2, delta_tmp)
2267
+ norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
2268
+ if return_norms:
2269
+ return cost, norm_values
2270
+ return cost / self.norm_dict[norm](norm_values)
1821
2271
 
1822
2272
  @staticmethod
1823
2273
  def __plot_nodes(
1824
- hier, selected_nodes, color, size, ax, default_color="black", **kwargs
1825
- ):
2274
+ hier: dict,
2275
+ selected_nodes: set,
2276
+ color: str | dict | list,
2277
+ size: int | float,
2278
+ ax: plt.Axes,
2279
+ default_color: str = "black",
2280
+ **kwargs,
2281
+ ) -> None:
1826
2282
  """
1827
2283
  Private method that plots the nodes of the tree.
1828
2284
  """
1829
- hier_unselected = np.array(
1830
- [v for k, v in hier.items() if k not in selected_nodes]
1831
- )
1832
- if hier_unselected.any():
1833
- ax.scatter(
1834
- *hier_unselected.T,
1835
- s=size,
1836
- zorder=10,
1837
- color=default_color,
1838
- **kwargs,
1839
- )
1840
- if selected_nodes.intersection(hier.keys()):
1841
- hier_selected = np.array(
1842
- [v for k, v in hier.items() if k in selected_nodes]
1843
- )
1844
- ax.scatter(
1845
- *hier_selected.T, s=size, zorder=10, color=color, **kwargs
1846
- )
2285
+
2286
+ if isinstance(color, dict):
2287
+ color = [color.get(k, default_color) for k in hier]
2288
+ elif isinstance(color, str | list):
2289
+ color = [
2290
+ color if node in selected_nodes else default_color
2291
+ for node in hier
2292
+ ]
2293
+ hier_pos = np.array(list(hier.values()))
2294
+ ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
1847
2295
 
1848
2296
  @staticmethod
1849
2297
  def __plot_edges(
1850
- hier,
1851
- lnks_tms,
1852
- selected_edges,
1853
- color,
1854
- ax,
1855
- default_color="black",
2298
+ hier: dict,
2299
+ lnks_tms: dict,
2300
+ selected_edges: Iterable,
2301
+ color: str | dict | list,
2302
+ lw: float,
2303
+ ax: plt.Axes,
2304
+ default_color: str = "black",
1856
2305
  **kwargs,
1857
- ):
2306
+ ) -> None:
1858
2307
  """
1859
2308
  Private method that plots the edges of the tree.
1860
2309
  """
1861
- x, y = [], []
2310
+ if isinstance(color, dict):
2311
+ selected_edges = color.keys()
2312
+ lines = []
2313
+ c = []
1862
2314
  for pred, succs in lnks_tms["links"].items():
1863
- for succ in succs:
1864
- if pred not in selected_edges or succ not in selected_edges:
1865
- x.extend((hier[succ][0], hier[pred][0], None))
1866
- y.extend((hier[succ][1], hier[pred][1], None))
1867
- ax.plot(x, y, linewidth=0.3, zorder=0.1, c=default_color, **kwargs)
1868
- x, y = [], []
1869
- for pred, succs in lnks_tms["links"].items():
1870
- for succ in succs:
1871
- if pred in selected_edges and succ in selected_edges:
1872
- x.extend((hier[succ][0], hier[pred][0], None))
1873
- y.extend((hier[succ][1], hier[pred][1], None))
1874
- ax.plot(x, y, linewidth=0.3, zorder=0.2, c=color, **kwargs)
2315
+ for suc in succs:
2316
+ lines.append(
2317
+ [
2318
+ [hier[suc][0], hier[suc][1]],
2319
+ [hier[pred][0], hier[pred][1]],
2320
+ ]
2321
+ )
2322
+ if pred in selected_edges:
2323
+ if isinstance(color, str | list):
2324
+ c.append(color)
2325
+ elif isinstance(color, dict):
2326
+ c.append(color[pred])
2327
+ else:
2328
+ c.append(default_color)
2329
+ lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
2330
+ ax.add_collection(lc)
1875
2331
 
1876
2332
  def draw_tree_graph(
1877
2333
  self,
1878
- hier,
1879
- lnks_tms,
1880
- selected_nodes=None,
1881
- selected_edges=None,
1882
- color_of_nodes="magenta",
1883
- color_of_edges=None,
1884
- size=10,
1885
- ax=None,
1886
- figure=None,
1887
- default_color="black",
2334
+ hier: dict[int, tuple[int, int]],
2335
+ lnks_tms: dict[str, dict[int, list | int]],
2336
+ selected_nodes: list | set | None = None,
2337
+ selected_edges: list | set | None = None,
2338
+ color_of_nodes: str | dict = "magenta",
2339
+ color_of_edges: str | dict = "magenta",
2340
+ size: int | float = 10,
2341
+ lw: float = 0.3,
2342
+ ax: plt.Axes | None = None,
2343
+ default_color: str = "black",
1888
2344
  **kwargs,
1889
- ):
2345
+ ) -> tuple[plt.Figure, plt.Axes]:
1890
2346
  """Function to plot the tree graph.
1891
2347
 
1892
- Args:
1893
- hier (dict): Dictinary that contains the positions of all nodes.
1894
- lnks_tms (dict): 2 dictionaries: 1 contains all links from start of life cycle to end of life cycle and
1895
- the succesors of each cell.
1896
- 1 contains the length of each life cycle.
1897
- selected_nodes (list|set, optional): Which cells are to be selected (Painted with a different color). Defaults to None.
1898
- selected_edges (list|set, optional): Which edges are to be selected (Painted with a different color). Defaults to None.
1899
- color_of_nodes (str, optional): Color of selected nodes. Defaults to "magenta".
1900
- color_of_edges (_type_, optional): Color of selected edges. Defaults to None.
1901
- size (int, optional): Size of the nodes. Defaults to 10.
1902
- ax (_type_, optional): Plot the graph on existing ax. Defaults to None.
1903
- figure (_type_, optional): _description_. Defaults to None.
1904
- default_color (str, optional): Default color of nodes. Defaults to "black".
1905
-
1906
- Returns:
1907
- figure, ax: The matplotlib figure and ax object.
2348
+ Parameters
2349
+ ----------
2350
+ hier : dict mapping int to tuple of int
2351
+ Dictionary that contains the positions of all nodes.
2352
+ lnks_tms : dict mapping string to dictionaries mapping int to list or int
2353
+ - 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
2354
+ - 'times' : contains the distance between the start and the end of each chain.
2355
+ selected_nodes : list or set, optional
2356
+ Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
2357
+ selected_edges : list or set, optional
2358
+ Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
2359
+ color_of_nodes : str, default="magenta"
2360
+ Color of selected nodes
2361
+ color_of_edges : str, default="magenta"
2362
+ Color of selected edges
2363
+ size : int, default=10
2364
+ Size of the nodes, defaults to 10
2365
+ lw : float, default=0.3
2366
+ The width of the edges of the tree graph, defaults to 0.3
2367
+ ax : plt.Axes, optional
2368
+ Plot the graph on existing ax. If not provided or None a new ax is going to be created.
2369
+ default_color : str, default="black"
2370
+ Default color of nodes
2371
+
2372
+ Returns
2373
+ -------
2374
+ plt.Figure
2375
+ The matplotlib figure
2376
+ plt.Axes
2377
+ The matplotlib ax
1908
2378
  """
1909
2379
  if selected_nodes is None:
1910
2380
  selected_nodes = []
1911
2381
  if selected_edges is None:
1912
2382
  selected_edges = []
1913
2383
  if ax is None:
1914
- figure, ax = plt.subplots()
2384
+ _, ax = plt.subplots()
1915
2385
  else:
1916
2386
  ax.clear()
1917
2387
  if not isinstance(selected_nodes, set):
1918
2388
  selected_nodes = set(selected_nodes)
1919
2389
  if not isinstance(selected_edges, set):
1920
2390
  selected_edges = set(selected_edges)
1921
- self.__plot_nodes(
1922
- hier,
1923
- selected_nodes,
1924
- color_of_nodes,
1925
- size=size,
1926
- ax=ax,
1927
- default_color=default_color,
1928
- **kwargs,
1929
- )
2391
+ if 0 < size:
2392
+ self.__plot_nodes(
2393
+ hier,
2394
+ selected_nodes,
2395
+ color_of_nodes,
2396
+ size=size,
2397
+ ax=ax,
2398
+ default_color=default_color,
2399
+ **kwargs,
2400
+ )
1930
2401
  if not color_of_edges:
1931
2402
  color_of_edges = color_of_nodes
1932
2403
  self.__plot_edges(
@@ -1934,62 +2405,107 @@ class lineageTree(lineageTreeLoaders):
1934
2405
  lnks_tms,
1935
2406
  selected_edges,
1936
2407
  color_of_edges,
2408
+ lw,
1937
2409
  ax,
1938
2410
  default_color=default_color,
1939
2411
  **kwargs,
1940
2412
  )
2413
+ ax.autoscale()
2414
+ plt.draw()
1941
2415
  ax.get_yaxis().set_visible(False)
1942
2416
  ax.get_xaxis().set_visible(False)
1943
- return figure, ax
2417
+ return ax.get_figure(), ax
1944
2418
 
1945
- def to_simple_graph(self, node=None, start_time: int = None):
2419
+ def _create_dict_of_plots(
2420
+ self,
2421
+ node: int | Iterable[int] | None = None,
2422
+ start_time: int | None = None,
2423
+ end_time: int | None = None,
2424
+ ) -> dict[int, dict]:
1946
2425
  """Generates a dictionary of graphs where the keys are the index of the graph and
1947
- the values are the graphs themselves which are produced by create_links_and _cycles
1948
-
1949
- Args:
1950
- node (_type_, optional): The id of the node/nodes to produce the simple graphs. Defaults to None.
1951
- start_time (int, optional): Important only if there are no nodes it will produce the graph of every
1952
- root that starts before or at start time. Defaults to None.
1953
-
1954
- Returns:
1955
- (dict): The keys are just index values 0-n and the values are the graphs produced.
2426
+ the values are the graphs themselves which are produced by `create_links_and_chains`
2427
+
2428
+ Parameters
2429
+ ----------
2430
+ node : int or Iterable of int, optional
2431
+ The id of the node/nodes to produce the simple graphs, if not provided or None will
2432
+ calculate the dicts for every root that starts before 'start_time'
2433
+ start_time : int, optional
2434
+ Important only if there are no nodes it will produce the graph of every
2435
+ root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
2436
+ end_time : int, optional
2437
+ The last timepoint to be considered, if not provided or None the last timepoint of the
2438
+ dataset (t_e) is considered.
2439
+
2440
+ Returns
2441
+ -------
2442
+ dict mapping int to dict
2443
+ The keys are just index values 0-n and the values are the graphs produced.
1956
2444
  """
1957
2445
  if start_time is None:
1958
2446
  start_time = self.t_b
2447
+ if end_time is None:
2448
+ end_time = self.t_e
1959
2449
  if node is None:
1960
2450
  mothers = [
1961
- root for root in self.roots if self.time[root] <= start_time
2451
+ root for root in self.roots if self._time[root] <= start_time
1962
2452
  ]
2453
+ elif isinstance(node, Iterable):
2454
+ mothers = node
1963
2455
  else:
1964
- mothers = node if isinstance(node, (list, set)) else [node]
2456
+ mothers = [node]
1965
2457
  return {
1966
- i: create_links_and_cycles(self, mother)
2458
+ i: create_links_and_chains(self, mother, end_time=end_time)
1967
2459
  for i, mother in enumerate(mothers)
1968
2460
  }
1969
2461
 
1970
2462
  def plot_all_lineages(
1971
2463
  self,
1972
- nodes: list = None,
1973
- last_time_point_to_consider: int = None,
1974
- nrows=2,
1975
- figsize=(10, 15),
1976
- dpi=100,
1977
- fontsize=15,
1978
- figure=None,
1979
- axes=None,
2464
+ nodes: list | None = None,
2465
+ last_time_point_to_consider: int | None = None,
2466
+ nrows: int = 2,
2467
+ figsize: tuple[int, int] = (10, 15),
2468
+ dpi: int = 100,
2469
+ fontsize: int = 15,
2470
+ axes: plt.Axes | None = None,
2471
+ vert_gap: int = 1,
1980
2472
  **kwargs,
1981
- ):
2473
+ ) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
1982
2474
  """Plots all lineages.
1983
2475
 
1984
- Args:
1985
- last_time_point_to_consider (int, optional): Which timepoints and upwards are the graphs to be plotted.
1986
- For example if start_time is 10, then all trees that begin
1987
- on tp 10 or before are calculated. Defaults to None, where
1988
- it will plot all the roots that exist on self.t_b.
1989
- nrows (int): How many rows of plots should be printed.
1990
- kwargs: args accepted by matplotlib
2476
+ Parameters
2477
+ ----------
2478
+ nodes : list, optional
2479
+ The nodes spawning the graphs to be plotted.
2480
+ last_time_point_to_consider : int, optional
2481
+ Which timepoints and upwards are the graphs to be plotted.
2482
+ For example if start_time is 10, then all trees that begin
2483
+ on tp 10 or before are calculated. Defaults to None, where
2484
+ it will plot all the roots that exist on `self.t_b`.
2485
+ nrows : int, default=2
2486
+ How many rows of plots should be printed.
2487
+ figsize : tuple, default=(10, 15)
2488
+ The size of the figure.
2489
+ dpi : int, default=100
2490
+ The dpi of the figure.
2491
+ fontsize : int, default=15
2492
+ The fontsize of the labels.
2493
+ axes : plt.Axes, optional
2494
+ The axes to plot the graphs on. If None or not provided new axes are going to be created.
2495
+ vert_gap : int, default=1
2496
+ space between the nodes, defaults to 1
2497
+ **kwargs:
2498
+ kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
2499
+
2500
+ Returns
2501
+ -------
2502
+ plt.Figure
2503
+ The figure
2504
+ plt.Axes
2505
+ The axes
2506
+ dict of plt.Axes to int
2507
+ A dictionary that maps the axes to the root of the tree.
1991
2508
  """
1992
-
1993
2509
  nrows = int(nrows)
1994
2510
  if last_time_point_to_consider is None:
1995
2511
  last_time_point_to_consider = self.t_b
@@ -1998,22 +2514,33 @@ class lineageTree(lineageTreeLoaders):
1998
2514
  raise Warning("Number of rows has to be at least 1")
1999
2515
  if nodes:
2000
2516
  graphs = {
2001
- i: self.to_simple_graph(node) for i, node in enumerate(nodes)
2517
+ i: self._create_dict_of_plots(node)
2518
+ for i, node in enumerate(nodes)
2002
2519
  }
2003
2520
  else:
2004
- graphs = self.to_simple_graph(
2521
+ graphs = self._create_dict_of_plots(
2005
2522
  start_time=last_time_point_to_consider
2006
2523
  )
2007
2524
  pos = {
2008
2525
  i: hierarchical_pos(
2009
- g, g["root"], ycenter=-int(self.time[g["root"]])
2526
+ g,
2527
+ g["root"],
2528
+ ycenter=-int(self._time[g["root"]]),
2529
+ vert_gap=vert_gap,
2010
2530
  )
2011
2531
  for i, g in graphs.items()
2012
2532
  }
2013
- ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2014
- figure, axes = plt.subplots(
2015
- figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2016
- )
2533
+ if axes is None:
2534
+ ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2535
+ figure, axes = plt.subplots(
2536
+ figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2537
+ )
2538
+ else:
2539
+ figure, axes = axes.flatten()[0].get_figure(), axes
2540
+ if len(axes.flatten()) < len(graphs):
2541
+ raise Exception(
2542
+ f"Not enough axes, they should be at least {len(graphs)}."
2543
+ )
2017
2544
  flat_axes = axes.flatten()
2018
2545
  ax2root = {}
2019
2546
  min_width, min_height = float("inf"), float("inf")
@@ -2051,127 +2578,149 @@ class lineageTree(lineageTreeLoaders):
2051
2578
  },
2052
2579
  )
2053
2580
  [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
2054
- return figure, axes, ax2root
2581
+ return axes.flatten()[0].get_figure(), axes, ax2root
2055
2582
 
2056
- def plot_node(self, node, figsize=(4, 7), dpi=150, vert_gap=2, **kwargs):
2583
+ def plot_subtree(
2584
+ self,
2585
+ node: int,
2586
+ end_time: int | None = None,
2587
+ figsize: tuple[int, int] = (4, 7),
2588
+ dpi: int = 150,
2589
+ vert_gap: int = 2,
2590
+ selected_nodes: list | None = None,
2591
+ selected_edges: list | None = None,
2592
+ color_of_nodes: str | dict = "magenta",
2593
+ color_of_edges: str | dict = "magenta",
2594
+ size: int | float = 10,
2595
+ lw: float = 0.1,
2596
+ default_color: str = "black",
2597
+ ax: plt.Axes | None = None,
2598
+ ) -> tuple[plt.Figure, plt.Axes]:
2057
2599
  """Plots the subtree spawn by a node.
2058
2600
 
2059
- Args:
2060
- node (int): The id of the node that is going to be plotted.
2061
- kwargs: args accepted by matplotlib
2062
- """
2063
- graph = self.to_simple_graph(node)
2601
+ Parameters
2602
+ ----------
2603
+ node : int
2604
+ The id of the node that is going to be plotted.
2605
+ end_time : int, optional
2606
+ The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
2607
+ figsize : tuple of 2 ints, default=(4,7)
2608
+ The size of the figure, deafults to (4,7)
2609
+ vert_gap : int, default=2
2610
+ The verical gap of a node when it divides, defaults to 2.
2611
+ dpi : int, default=150
2612
+ The dpi of the figure, defaults to 150
2613
+ selected_nodes : list, optional
2614
+ The nodes that are selected by the user to be colored in a different color, defaults to None
2615
+ selected_edges : list, optional
2616
+ The edges that are selected by the user to be colored in a different color, defaults to None
2617
+ color_of_nodes : str, default="magenta"
2618
+ The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
2619
+ color_of_edges : str, default="magenta"
2620
+ The color of the edges to be colored, except the default colored ones, defaults to "magenta"
2621
+ size : int, default=10
2622
+ The size of the nodes, defaults to 10
2623
+ lw : float, default=0.1
2624
+ The widthe of the edges of the tree graph, defaults to 0.1
2625
+ default_color : str, default="black"
2626
+ The default color of nodes and edges, defaults to "black"
2627
+ ax : plt.Axes, optional
2628
+ The ax where the plot is going to be applied, if not provided or None new axes will be created.
2629
+
2630
+ Returns
2631
+ -------
2632
+ plt.Figure
2633
+ The matplotlib figure
2634
+ plt.Axes
2635
+ The matplotlib axes
2636
+
2637
+ Raises
2638
+ ------
2639
+ Warning
2640
+ If more than one nodes are received
2641
+ """
2642
+ graph = self._create_dict_of_plots(node, end_time=end_time)
2064
2643
  if len(graph) > 1:
2065
- raise Warning("Please enter only one node")
2644
+ raise Warning(
2645
+ "Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
2646
+ )
2066
2647
  graph = graph[0]
2067
- figure, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
2648
+ if not ax:
2649
+ _, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
2068
2650
  self.draw_tree_graph(
2069
2651
  hier=hierarchical_pos(
2070
2652
  graph,
2071
2653
  graph["root"],
2072
2654
  vert_gap=vert_gap,
2073
- ycenter=-int(self.time[node]),
2655
+ ycenter=-int(self._time[node]),
2074
2656
  ),
2657
+ selected_edges=selected_edges,
2658
+ selected_nodes=selected_nodes,
2659
+ color_of_edges=color_of_edges,
2660
+ color_of_nodes=color_of_nodes,
2661
+ default_color=default_color,
2662
+ size=size,
2663
+ lw=lw,
2075
2664
  lnks_tms=graph,
2076
2665
  ax=ax,
2077
- **kwargs,
2078
2666
  )
2079
- return figure, ax
2080
-
2081
- # def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
2082
- # metric='euclidian', **kwargs):
2083
- # """ Computes the dynamic time warping distance between the tracks t1 and t2
2084
-
2085
- # Args:
2086
- # t1 ([int, ]): list of node ids for the first track
2087
- # t2 ([int, ]): list of node ids for the second track
2088
- # w (int): maximum wapring allowed (default infinite),
2089
- # if w=1 then the DTW is the distance between t1 and t2
2090
- # start_delay (int): maximum number of time points that can be
2091
- # skipped at the beginning of the track
2092
- # end_delay (int): minimum number of time points that can be
2093
- # skipped at the beginning of the track
2094
- # metric (str): str or callable, optional The distance metric to use.
2095
- # Default='euclidean'. Refer to the documentation for
2096
- # scipy.spatial.distance.cdist. Some examples:
2097
- # 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
2098
- # 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
2099
- # 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
2100
- # 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
2101
- # 'wminkowski', 'yule'
2102
- # **kwargs (dict): Extra arguments to `metric`: refer to each metric
2103
- # documentation in scipy.spatial.distance (optional)
2104
-
2105
- # Returns:
2106
- # float: the dynamic time warping distance between the two tracks
2107
- # """
2108
- # from scipy.sparse import
2109
- # pos_t1 = [self.pos[ti] for ti in t1]
2110
- # pos_t2 = [self.pos[ti] for ti in t2]
2111
- # distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
2112
-
2113
- # c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
2114
-
2115
- # d = np.zeros(c.shape)
2116
- # d[0, 0] = c[0, 0]
2117
- # n, m = c.shape
2118
- # for i in range(1, n):
2119
- # d[i, 0] = d[i-1, 0] + c[i, 0]
2120
- # for j in range(1, m):
2121
- # d[0, j] = d[0, j-1] + c[0, j]
2122
- # for i in range(1, n):
2123
- # for j in range(1, m):
2124
- # d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
2125
- # return d[-1, -1], d
2126
-
2127
- def __getitem__(self, item):
2128
- if isinstance(item, str):
2129
- return self.__dict__[item]
2130
- elif np.issubdtype(type(item), np.integer):
2131
- return self.successor.get(item, [])
2132
- else:
2133
- raise KeyError(
2134
- "Only integer or string are valid key for lineageTree"
2135
- )
2667
+ return ax.get_figure(), ax
2136
2668
 
2137
- def get_cells_at_t_from_root(self, r: [int, list], t: int = None) -> list:
2669
+ def nodes_at_t(
2670
+ self,
2671
+ t: int,
2672
+ r: int | Iterable[int] | None = None,
2673
+ ) -> list:
2138
2674
  """
2139
- Returns the list of cells at time `t` that are spawn by the node(s) `r`.
2675
+ Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
2140
2676
 
2141
- Args:
2142
- r (int | list): id or list of ids of the spawning node
2143
- t (int): target time, if None goes as far as possible
2144
- (default None)
2677
+ Parameters
2678
+ ----------
2679
+ t : int
2680
+ target time, if `None` goes as far as possible
2681
+ r : int or Iterable of int, optional
2682
+ id or list of ids of the spawning node
2145
2683
 
2146
- Returns:
2147
- (list) list of nodes at time `t` spawned by `r`
2684
+ Returns
2685
+ -------
2686
+ list
2687
+ list of nodes at time `t` spawned by `r`
2148
2688
  """
2149
- if not isinstance(r, list):
2689
+ if not r and r != 0:
2690
+ r = {root for root in self.roots if self.time[root] <= t}
2691
+ if isinstance(r, int):
2150
2692
  r = [r]
2693
+ if t is None:
2694
+ t = self.t_e
2151
2695
  to_do = list(r)
2152
2696
  final_nodes = []
2153
2697
  while len(to_do) > 0:
2154
2698
  curr = to_do.pop()
2155
- for _next in self[curr]:
2156
- if self.time[_next] < t:
2699
+ for _next in self._successor[curr]:
2700
+ if self._time[_next] < t:
2157
2701
  to_do.append(_next)
2158
- elif self.time[_next] == t:
2702
+ elif self._time[_next] == t:
2159
2703
  final_nodes.append(_next)
2160
2704
  if not final_nodes:
2161
2705
  return list(r)
2162
2706
  return final_nodes
2163
2707
 
2164
2708
  @staticmethod
2165
- def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
2709
+ def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
2166
2710
  """
2167
2711
  Calculate the line that centers the band w.
2168
2712
 
2169
- Args:
2170
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2713
+ Parameters
2714
+ ----------
2715
+ dist_mat : np.ndarray
2716
+ distance matrix obtained by the function calculate_dtw
2171
2717
 
2172
- Returns:
2173
- (float) Slope
2174
- (float) intercept of the line
2718
+ Returns
2719
+ -------
2720
+ float
2721
+ The slope of the curve
2722
+ float
2723
+ The intercept of the curve
2175
2724
  """
2176
2725
  i, j = dist_mat.shape
2177
2726
  x1 = max(0, i - j) / 2
@@ -2191,22 +2740,33 @@ class lineageTree(lineageTreeLoaders):
2191
2740
  fast: bool = False,
2192
2741
  w: int = 0,
2193
2742
  centered_band: bool = True,
2194
- ) -> (((int, int), ...), np.ndarray):
2743
+ ) -> tuple[list[int], np.ndarray, float]:
2195
2744
  """
2196
2745
  Find DTW minimum cost between two series using dynamic programming.
2197
2746
 
2198
- Args:
2199
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2200
- start_d (int): start delay
2201
- back_d (int): end delay
2202
- w (int): window constrain
2203
- slope (float): to calculate window - givem by the function __calculate_diag_line
2204
- intercept (flost): to calculate window - givem by the function __calculate_diag_line
2205
- use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
2206
-
2207
- Returns:
2208
- (tuple of tuples) Aligment path
2209
- (matrix) Cost matrix
2747
+ Parameters
2748
+ ----------
2749
+ dist_mat : np.ndarray
2750
+ distance matrix obtained by the function calculate_dtw
2751
+ start_d : int, default=0
2752
+ start delay
2753
+ back_d : int, default=0
2754
+ end delay
2755
+ fast : bool, default=False
2756
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2757
+ w : int, default=0
2758
+ window constrain
2759
+ centered_band : bool, default=True
2760
+ if `True`, the band will be centered around the diagonal
2761
+
2762
+ Returns
2763
+ -------
2764
+ tuple of tuples of int
2765
+ Aligment path
2766
+ np.ndarray
2767
+ cost matrix
2768
+ float
2769
+ optimal cost
2210
2770
  """
2211
2771
  N, M = dist_mat.shape
2212
2772
  w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
@@ -2327,7 +2887,6 @@ class lineageTree(lineageTreeLoaders):
2327
2887
 
2328
2888
  # special reflection case
2329
2889
  if np.linalg.det(R) < 0:
2330
- # print("det(R) < R, reflection detected!, correcting for it ...")
2331
2890
  Vt[2, :] *= -1
2332
2891
  R = Vt.T @ U.T
2333
2892
 
@@ -2336,52 +2895,59 @@ class lineageTree(lineageTreeLoaders):
2336
2895
  return R, t
2337
2896
 
2338
2897
  def __interpolate(
2339
- self, track1: list, track2: list, threshold: int
2340
- ) -> (np.ndarray, np.ndarray):
2898
+ self, chain1: list, chain2: list, threshold: int
2899
+ ) -> tuple[np.ndarray, np.ndarray]:
2341
2900
  """
2342
2901
  Interpolate two series that have different lengths
2343
2902
 
2344
- Args:
2345
- track1 (list): list of nodes of the first cell cycle to compare
2346
- track2 (list): list of nodes of the second cell cycle to compare
2347
- threshold (int): set a maximum number of points a track can have
2348
-
2349
- Returns:
2350
- (list of list) x, y, z postions for track1
2351
- (list of list) x, y, z postions for track2
2903
+ Parameters
2904
+ ----------
2905
+ chain1 : list of int
2906
+ list of nodes of the first chain to compare
2907
+ chain2 : list of int
2908
+ list of nodes of the second chain to compare
2909
+ threshold : int
2910
+ set a maximum number of points a chain can have
2911
+
2912
+ Returns
2913
+ -------
2914
+ list of np.ndarray
2915
+ `x`, `y`, `z` postions for `chain1`
2916
+ list of np.ndarray
2917
+ `x`, `y`, `z` postions for `chain2`
2352
2918
  """
2353
2919
  inter1_pos = []
2354
2920
  inter2_pos = []
2355
2921
 
2356
- track1_pos = np.array([self.pos[c_id] for c_id in track1])
2357
- track2_pos = np.array([self.pos[c_id] for c_id in track2])
2922
+ chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
2923
+ chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
2358
2924
 
2359
- # Both tracks have the same length and size below the threshold - nothing is done
2360
- if len(track1) == len(track2) and (
2361
- len(track1) <= threshold or len(track2) <= threshold
2925
+ # Both chains have the same length and size below the threshold - nothing is done
2926
+ if len(chain1) == len(chain2) and (
2927
+ len(chain1) <= threshold or len(chain2) <= threshold
2362
2928
  ):
2363
- return track1_pos, track2_pos
2364
- # Both tracks have the same length but one or more sizes are above the threshold
2365
- elif len(track1) > threshold or len(track2) > threshold:
2929
+ return chain1_pos, chain2_pos
2930
+ # Both chains have the same length but one or more sizes are above the threshold
2931
+ elif len(chain1) > threshold or len(chain2) > threshold:
2366
2932
  sampling = threshold
2367
- # Tracks have different lengths and the sizes are below the threshold
2933
+ # chains have different lengths and the sizes are below the threshold
2368
2934
  else:
2369
- sampling = max(len(track1), len(track2))
2935
+ sampling = max(len(chain1), len(chain2))
2370
2936
 
2371
2937
  for pos in range(3):
2372
- track1_interp = InterpolatedUnivariateSpline(
2373
- np.linspace(0, 1, len(track1_pos[:, pos])),
2374
- track1_pos[:, pos],
2938
+ chain1_interp = InterpolatedUnivariateSpline(
2939
+ np.linspace(0, 1, len(chain1_pos[:, pos])),
2940
+ chain1_pos[:, pos],
2375
2941
  k=1,
2376
2942
  )
2377
- inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
2943
+ inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
2378
2944
 
2379
- track2_interp = InterpolatedUnivariateSpline(
2380
- np.linspace(0, 1, len(track2_pos[:, pos])),
2381
- track2_pos[:, pos],
2945
+ chain2_interp = InterpolatedUnivariateSpline(
2946
+ np.linspace(0, 1, len(chain2_pos[:, pos])),
2947
+ chain2_pos[:, pos],
2382
2948
  k=1,
2383
2949
  )
2384
- inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
2950
+ inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
2385
2951
 
2386
2952
  return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
2387
2953
 
@@ -2397,46 +2963,66 @@ class lineageTree(lineageTreeLoaders):
2397
2963
  w: int = 0,
2398
2964
  centered_band: bool = True,
2399
2965
  cost_mat_p: bool = False,
2400
- ) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
2401
- """
2402
- Calculate DTW distance between two cell cycles
2403
-
2404
- Args:
2405
- nodes1 (int): node to compare distance
2406
- nodes2 (int): node to compare distance
2407
- threshold: set a maximum number of points a track can have
2408
- regist (boolean): Rotate and translate trajectories
2409
- start_d (int): start delay
2410
- back_d (int): end delay
2411
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2412
- w (int): window size
2413
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2414
- cost_mat_p (boolean): True if print the not normalized cost matrix
2415
-
2416
- Returns:
2417
- (float) DTW distance
2418
- (tuple of tuples) Aligment path
2419
- (matrix) Cost matrix
2420
- (list of lists) pos_cycle1: rotated and translated trajectories positions
2421
- (list of lists) pos_cycle2: rotated and translated trajectories positions
2966
+ ) -> (
2967
+ tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
2968
+ | tuple[float, tuple]
2969
+ ):
2422
2970
  """
2423
- nodes1_cycle = self.get_cycle(nodes1)
2424
- nodes2_cycle = self.get_cycle(nodes2)
2425
-
2426
- interp_cycle1, interp_cycle2 = self.__interpolate(
2427
- nodes1_cycle, nodes2_cycle, threshold
2971
+ Calculate DTW distance between two chains
2972
+
2973
+ Parameters
2974
+ ----------
2975
+ nodes1 : int
2976
+ node to compare distance
2977
+ nodes2 : int
2978
+ node to compare distance
2979
+ threshold : int, default=1000
2980
+ set a maximum number of points a chain can have
2981
+ regist : bool, default=True
2982
+ Rotate and translate trajectories
2983
+ start_d : int, default=0
2984
+ start delay
2985
+ back_d : int, default=0
2986
+ end delay
2987
+ fast : bool, default=False
2988
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2989
+ w : int, default=0
2990
+ window size
2991
+ centered_band : bool, default=True
2992
+ when running the fast algorithm, `True` if the windown is centered
2993
+ cost_mat_p : bool, default=False
2994
+ True if print the not normalized cost matrix
2995
+
2996
+ Returns
2997
+ -------
2998
+ float
2999
+ DTW distance
3000
+ tuple of tuples
3001
+ Aligment path
3002
+ matrix
3003
+ Cost matrix
3004
+ list of lists
3005
+ rotated and translated trajectories positions
3006
+ list of lists
3007
+ rotated and translated trajectories positions
3008
+ """
3009
+ nodes1_chain = self.get_chain_of_node(nodes1)
3010
+ nodes2_chain = self.get_chain_of_node(nodes2)
3011
+
3012
+ interp_chain1, interp_chain2 = self.__interpolate(
3013
+ nodes1_chain, nodes2_chain, threshold
2428
3014
  )
2429
3015
 
2430
- pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
2431
- pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
3016
+ pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
3017
+ pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
2432
3018
 
2433
3019
  if regist:
2434
3020
  R, t = self.__rigid_transform_3D(
2435
- np.transpose(interp_cycle1), np.transpose(interp_cycle2)
3021
+ np.transpose(interp_chain1), np.transpose(interp_chain2)
2436
3022
  )
2437
- pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
3023
+ pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
2438
3024
 
2439
- dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
3025
+ dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
2440
3026
 
2441
3027
  path, cost_mat, final_cost = self.__dp(
2442
3028
  dist_mat,
@@ -2449,7 +3035,7 @@ class lineageTree(lineageTreeLoaders):
2449
3035
  cost = final_cost / len(path)
2450
3036
 
2451
3037
  if cost_mat_p:
2452
- return cost, path, cost_mat, pos_cycle1, pos_cycle2
3038
+ return cost, path, cost_mat, pos_chain1, pos_chain2
2453
3039
  else:
2454
3040
  return cost, path
2455
3041
 
@@ -2464,24 +3050,39 @@ class lineageTree(lineageTreeLoaders):
2464
3050
  fast: bool = False,
2465
3051
  w: int = 0,
2466
3052
  centered_band: bool = True,
2467
- ) -> (float, plt.figure):
2468
- """
2469
- Plot DTW cost matrix between two cell cycles in heatmap format
2470
-
2471
- Args:
2472
- nodes1 (int): node to compare distance
2473
- nodes2 (int): node to compare distance
2474
- start_d (int): start delay
2475
- back_d (int): end delay
2476
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2477
- w (int): window size
2478
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2479
-
2480
- Returns:
2481
- (float) DTW distance
2482
- (figure) Heatmap of cost matrix with opitimal path
2483
- """
2484
- cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
3053
+ ) -> tuple[float, plt.Figure]:
3054
+ """
3055
+ Plot DTW cost matrix between two chains in heatmap format
3056
+
3057
+ Parameters
3058
+ ----------
3059
+ nodes1 : int
3060
+ node to compare distance
3061
+ nodes2 : int
3062
+ node to compare distance
3063
+ threshold : int, default=1000
3064
+ set a maximum number of points a chain can have
3065
+ regist : bool, default=True
3066
+ Rotate and translate trajectories
3067
+ start_d : int, default=0
3068
+ start delay
3069
+ back_d : int, default=0
3070
+ end delay
3071
+ fast : bool, default=False
3072
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
3073
+ w : int, default=0
3074
+ window size
3075
+ centered_band : bool, default=True
3076
+ when running the fast algorithm, `True` if the windown is centered
3077
+
3078
+ Returns
3079
+ -------
3080
+ float
3081
+ DTW distance
3082
+ plt.Figure
3083
+ Heatmap of cost matrix with opitimal path
3084
+ """
3085
+ cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
2485
3086
  nodes1,
2486
3087
  nodes2,
2487
3088
  threshold,
@@ -2503,32 +3104,32 @@ class lineageTree(lineageTreeLoaders):
2503
3104
  ax.set_title("Heatmap of DTW Cost Matrix")
2504
3105
  ax.set_xlabel("Tree 1")
2505
3106
  ax.set_ylabel("tree 2")
2506
- x_path, y_path = zip(*path)
3107
+ x_path, y_path = zip(*path, strict=True)
2507
3108
  ax.plot(y_path, x_path, color="black")
2508
3109
 
2509
3110
  return cost, fig
2510
3111
 
2511
3112
  @staticmethod
2512
3113
  def __plot_2d(
2513
- pos_cycle1,
2514
- pos_cycle2,
2515
- nodes1,
2516
- nodes2,
2517
- ax,
2518
- x_idx,
2519
- y_idx,
2520
- x_label,
2521
- y_label,
2522
- ):
3114
+ pos_chain1: np.ndarray,
3115
+ pos_chain2: np.ndarray,
3116
+ nodes1: list[int],
3117
+ nodes2: list[int],
3118
+ ax: plt.Axes,
3119
+ x_idx: list[int],
3120
+ y_idx: list[int],
3121
+ x_label: str,
3122
+ y_label: str,
3123
+ ) -> None:
2523
3124
  ax.plot(
2524
- pos_cycle1[:, x_idx],
2525
- pos_cycle1[:, y_idx],
3125
+ pos_chain1[:, x_idx],
3126
+ pos_chain1[:, y_idx],
2526
3127
  "-",
2527
3128
  label=f"root = {nodes1}",
2528
3129
  )
2529
3130
  ax.plot(
2530
- pos_cycle2[:, x_idx],
2531
- pos_cycle2[:, y_idx],
3131
+ pos_chain2[:, x_idx],
3132
+ pos_chain2[:, y_idx],
2532
3133
  "-",
2533
3134
  label=f"root = {nodes2}",
2534
3135
  )
@@ -2546,40 +3147,55 @@ class lineageTree(lineageTreeLoaders):
2546
3147
  fast: bool = False,
2547
3148
  w: int = 0,
2548
3149
  centered_band: bool = True,
2549
- projection: str = None,
3150
+ projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
2550
3151
  alig: bool = False,
2551
- ) -> (float, plt.figure):
2552
- """
2553
- Plots DTW trajectories aligment between two cell cycles in 2D or 3D
2554
-
2555
- Args:
2556
- nodes1 (int): node to compare distance
2557
- nodes2 (int): node to compare distance
2558
- threshold (int): set a maximum number of points a track can have
2559
- regist (boolean): Rotate and translate trajectories
2560
- start_d (int): start delay
2561
- back_d (int): end delay
2562
- w (int): window size
2563
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2564
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2565
- projection (string): specify which 2D to plot ->
2566
- '3d' : for the 3d visualization
2567
- 'xy' or None (default) : 2D projection of axis x and y
2568
- 'xz' : 2D projection of axis x and z
2569
- 'yz' : 2D projection of axis y and z
2570
- 'pca' : PCA projection
2571
- alig (boolean): True to show alignment on plot
2572
-
2573
- Returns:
2574
- (float) DTW distance
2575
- (figue) Trajectories Plot
3152
+ ) -> tuple[float, plt.Figure]:
3153
+ """
3154
+ Plots DTW trajectories aligment between two chains in 2D or 3D
3155
+
3156
+ Parameters
3157
+ ----------
3158
+ nodes1 : int
3159
+ node to compare distance
3160
+ nodes2 : int
3161
+ node to compare distance
3162
+ threshold : int, default=1000
3163
+ set a maximum number of points a chain can have
3164
+ regist : bool, default=True
3165
+ Rotate and translate trajectories
3166
+ start_d : int, default=0
3167
+ start delay
3168
+ back_d : int, default=0
3169
+ end delay
3170
+ w : int, default=0
3171
+ window size
3172
+ fast : bool, default=False
3173
+ True if the user wants to run the fast algorithm with window restrains
3174
+ centered_band : bool, default=True
3175
+ if running the fast algorithm, True if the windown is centered
3176
+ projection : {"3d", "xy", "xz", "yz", "pca"}, optional
3177
+ specify which 2D to plot ->
3178
+ "3d" : for the 3d visualization
3179
+ "xy" or None (default) : 2D projection of axis x and y
3180
+ "xz" : 2D projection of axis x and z
3181
+ "yz" : 2D projection of axis y and z
3182
+ "pca" : PCA projection
3183
+ alig : bool
3184
+ True to show alignment on plot
3185
+
3186
+ Returns
3187
+ -------
3188
+ float
3189
+ DTW distance
3190
+ figure
3191
+ Trajectories Plot
2576
3192
  """
2577
3193
  (
2578
3194
  distance,
2579
3195
  alignment,
2580
3196
  cost_mat,
2581
- pos_cycle1,
2582
- pos_cycle2,
3197
+ pos_chain1,
3198
+ pos_chain2,
2583
3199
  ) = self.calculate_dtw(
2584
3200
  nodes1,
2585
3201
  nodes2,
@@ -2602,16 +3218,16 @@ class lineageTree(lineageTreeLoaders):
2602
3218
 
2603
3219
  if projection == "3d":
2604
3220
  ax.plot(
2605
- pos_cycle1[:, 0],
2606
- pos_cycle1[:, 1],
2607
- pos_cycle1[:, 2],
3221
+ pos_chain1[:, 0],
3222
+ pos_chain1[:, 1],
3223
+ pos_chain1[:, 2],
2608
3224
  "-",
2609
3225
  label=f"root = {nodes1}",
2610
3226
  )
2611
3227
  ax.plot(
2612
- pos_cycle2[:, 0],
2613
- pos_cycle2[:, 1],
2614
- pos_cycle2[:, 2],
3228
+ pos_chain2[:, 0],
3229
+ pos_chain2[:, 1],
3230
+ pos_chain2[:, 2],
2615
3231
  "-",
2616
3232
  label=f"root = {nodes2}",
2617
3233
  )
@@ -2621,8 +3237,8 @@ class lineageTree(lineageTreeLoaders):
2621
3237
  else:
2622
3238
  if projection == "xy" or projection == "yx" or projection is None:
2623
3239
  self.__plot_2d(
2624
- pos_cycle1,
2625
- pos_cycle2,
3240
+ pos_chain1,
3241
+ pos_chain2,
2626
3242
  nodes1,
2627
3243
  nodes2,
2628
3244
  ax,
@@ -2633,8 +3249,8 @@ class lineageTree(lineageTreeLoaders):
2633
3249
  )
2634
3250
  elif projection == "xz" or projection == "zx":
2635
3251
  self.__plot_2d(
2636
- pos_cycle1,
2637
- pos_cycle2,
3252
+ pos_chain1,
3253
+ pos_chain2,
2638
3254
  nodes1,
2639
3255
  nodes2,
2640
3256
  ax,
@@ -2645,8 +3261,8 @@ class lineageTree(lineageTreeLoaders):
2645
3261
  )
2646
3262
  elif projection == "yz" or projection == "zy":
2647
3263
  self.__plot_2d(
2648
- pos_cycle1,
2649
- pos_cycle2,
3264
+ pos_chain1,
3265
+ pos_chain2,
2650
3266
  nodes1,
2651
3267
  nodes2,
2652
3268
  ax,
@@ -2660,24 +3276,25 @@ class lineageTree(lineageTreeLoaders):
2660
3276
  from sklearn.decomposition import PCA
2661
3277
  except ImportError:
2662
3278
  Warning(
2663
- "scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
3279
+ "scikit-learn is not installed, the PCA orientation cannot be used."
3280
+ "You can install scikit-learn with pip install"
2664
3281
  )
2665
3282
 
2666
3283
  # Apply PCA
2667
3284
  pca = PCA(n_components=2)
2668
- pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
2669
- pos_cycle1_2d = pca.transform(pos_cycle1)
2670
- pos_cycle2_2d = pca.transform(pos_cycle2)
3285
+ pca.fit(np.vstack([pos_chain1, pos_chain2]))
3286
+ pos_chain1_2d = pca.transform(pos_chain1)
3287
+ pos_chain2_2d = pca.transform(pos_chain2)
2671
3288
 
2672
3289
  ax.plot(
2673
- pos_cycle1_2d[:, 0],
2674
- pos_cycle1_2d[:, 1],
3290
+ pos_chain1_2d[:, 0],
3291
+ pos_chain1_2d[:, 1],
2675
3292
  "-",
2676
3293
  label=f"root = {nodes1}",
2677
3294
  )
2678
3295
  ax.plot(
2679
- pos_cycle2_2d[:, 0],
2680
- pos_cycle2_2d[:, 1],
3296
+ pos_chain2_2d[:, 0],
3297
+ pos_chain2_2d[:, 1],
2681
3298
  "-",
2682
3299
  label=f"root = {nodes2}",
2683
3300
  )
@@ -2706,7 +3323,7 @@ class lineageTree(lineageTreeLoaders):
2706
3323
  'pca' : PCA projection"""
2707
3324
  )
2708
3325
 
2709
- connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
3326
+ connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
2710
3327
 
2711
3328
  for connection in connections:
2712
3329
  xyz1 = connection[0]
@@ -2729,120 +3346,179 @@ class lineageTree(lineageTreeLoaders):
2729
3346
  warnings.warn(
2730
3347
  "Error: not possible to show alignment in PCA projection !",
2731
3348
  UserWarning,
3349
+ stacklevel=2,
2732
3350
  )
2733
3351
 
2734
3352
  return distance, fig
2735
3353
 
2736
- def first_labelling(self):
2737
- self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
2738
-
2739
3354
  def __init__(
2740
3355
  self,
2741
- file_format: str = None,
2742
- tb: int = None,
2743
- te: int = None,
2744
- z_mult: float = 1.0,
2745
- file_type: str = "",
2746
- delim: str = ",",
2747
- eigen: bool = False,
2748
- shape: tuple = None,
2749
- raw_size: tuple = None,
2750
- reorder: bool = False,
2751
- xml_attributes: tuple = None,
2752
- name: str = None,
2753
- time_resolution: Union[int, None] = None,
3356
+ *,
3357
+ successor: dict[int, Sequence] | None = None,
3358
+ predecessor: dict[int, int | Sequence] | None = None,
3359
+ time: dict[int, int] | None = None,
3360
+ starting_time: int | None = None,
3361
+ pos: dict[int, Iterable] | None = None,
3362
+ name: str | None = None,
3363
+ root_leaf_value: Sequence | None = None,
3364
+ **kwargs,
2754
3365
  ):
2755
- """
2756
- TODO: complete the doc
2757
- Main library to build tree graph representation of lineage tree data
2758
- It can read TGMM, ASTEC, SVF, MaMuT and TrackMate outputs.
2759
-
2760
- Args:
2761
- file_format (str): either - path format to TGMM xmls
2762
- - path to the MaMuT xml
2763
- - path to the binary file
2764
- tb (int, optional):first time point (necessary for TGMM xmls only)
2765
- te (int, optional): last time point (necessary for TGMM xmls only)
2766
- z_mult (float, optional):z aspect ratio if necessary (usually only for TGMM xmls)
2767
- file_type (str, optional):type of input file. Accepts:
2768
- 'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
2769
- default is 'binary'
2770
- delim (str, optional): _description_. Defaults to ",".
2771
- eigen (bool, optional): _description_. Defaults to False.
2772
- shape (tuple, optional): _description_. Defaults to None.
2773
- raw_size (tuple, optional): _description_. Defaults to None.
2774
- reorder (bool, optional): _description_. Defaults to False.
2775
- xml_attributes (tuple, optional): _description_. Defaults to None.
2776
- name (str, optional): The name of the dataset. Defaults to None.
2777
- time_resolution (Union[int, None], optional): Time resolution in mins (If time resolution is smaller than one minute input the time in ms). Defaults to None.
2778
- """
3366
+ """Create a lineageTree object from minimal information, without reading from a file.
3367
+ Either `successor` or `predecessor` should be specified.
3368
+
3369
+ Parameters
3370
+ ----------
3371
+ successor : dict mapping int to Iterable
3372
+ Dictionary assigning nodes to their successors.
3373
+ predecessor : dict mapping int to int or Iterable
3374
+ Dictionary assigning nodes to their predecessors.
3375
+ time : dict mapping int to int, optional
3376
+ Dictionary assigning nodes to the time point they were recorded to.
3377
+ Defaults to None, in which case all times are set to `starting_time`.
3378
+ starting_time : int, optional
3379
+ Starting time of the lineage tree. Defaults to 0.
3380
+ pos : dict mapping int to Iterable, optional
3381
+ Dictionary assigning nodes to their positions. Defaults to None.
3382
+ name : str, optional
3383
+ Name of the lineage tree. Defaults to None.
3384
+ root_leaf_value : Iterable, optional
3385
+ Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
3386
+ Defaults are `[None, (), [], set()]`.
3387
+ **kwargs:
3388
+ Supported keyword arguments are dictionaries assigning nodes to any custom property.
3389
+ The property must be specified for every node, and named differently from lineageTree's own attributes.
3390
+ """
3391
+ self.__version__ = importlib.metadata.version("LineageTree")
3392
+ self.name = str(name) if name is not None else None
3393
+ if successor is not None and predecessor is not None:
3394
+ raise ValueError(
3395
+ "You cannot have both successors and predecessors."
3396
+ )
2779
3397
 
2780
- self.name = name
2781
- self.time_nodes = {}
2782
- self.time_edges = {}
2783
- self.max_id = -1
2784
- self.next_id = []
2785
- self.nodes = set()
2786
- self.successor = {}
2787
- self.predecessor = {}
2788
- self.pos = {}
2789
- self.time_id = {}
2790
- self.time = {}
2791
- if time_resolution is not None:
2792
- self._time_resolution = time_resolution
2793
- self.kdtrees = {}
2794
- self.spatial_density = {}
2795
- if file_type and file_format:
2796
- if xml_attributes is None:
2797
- self.xml_attributes = []
2798
- else:
2799
- self.xml_attributes = xml_attributes
2800
- file_type = file_type.lower()
2801
- if file_type == "tgmm":
2802
- self.read_tgmm_xml(file_format, tb, te, z_mult)
2803
- self.t_b = tb
2804
- self.t_e = te
2805
- elif file_type == "mamut" or file_type == "trackmate":
2806
- self.read_from_mamut_xml(file_format)
2807
- elif file_type == "celegans":
2808
- self.read_from_txt_for_celegans(file_format)
2809
- elif file_type == "celegans_cao":
2810
- self.read_from_txt_for_celegans_CAO(
2811
- file_format,
2812
- reorder=reorder,
2813
- shape=shape,
2814
- raw_size=raw_size,
3398
+ if root_leaf_value is None:
3399
+ root_leaf_value = [None, (), [], set()]
3400
+ elif not isinstance(root_leaf_value, Iterable):
3401
+ raise TypeError(
3402
+ f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
3403
+ )
3404
+ elif len(root_leaf_value) < 1:
3405
+ raise ValueError(
3406
+ "root_leaf_value should have at least one element."
3407
+ )
3408
+ self._successor = {}
3409
+ self._predecessor = {}
3410
+ if successor is not None:
3411
+ for pred, succs in successor.items():
3412
+ if succs in root_leaf_value:
3413
+ self._successor[pred] = ()
3414
+ else:
3415
+ if not isinstance(succs, Iterable):
3416
+ raise TypeError(
3417
+ f"Successors should be Iterable, got {type(succs)}."
3418
+ )
3419
+ if len(succs) == 0:
3420
+ raise ValueError(
3421
+ f"{succs} was not declared as a leaf but was found as a successor.\n"
3422
+ "Please lift the ambiguity."
3423
+ )
3424
+ self._successor[pred] = tuple(succs)
3425
+ for succ in succs:
3426
+ if succ in self._predecessor:
3427
+ raise ValueError(
3428
+ "Node can have at most one predecessor."
3429
+ )
3430
+ self._predecessor[succ] = (pred,)
3431
+ elif predecessor is not None:
3432
+ for succ, pred in predecessor.items():
3433
+ if pred in root_leaf_value:
3434
+ self._predecessor[succ] = ()
3435
+ else:
3436
+ if isinstance(pred, Sequence):
3437
+ if len(pred) == 0:
3438
+ raise ValueError(
3439
+ f"{pred} was not declared as a leaf but was found as a successor.\n"
3440
+ "Please lift the ambiguity."
3441
+ )
3442
+ if 1 < len(pred):
3443
+ raise ValueError(
3444
+ "Node can have at most one predecessor."
3445
+ )
3446
+ pred = pred[0]
3447
+ self._predecessor[succ] = (pred,)
3448
+ self._successor.setdefault(pred, ())
3449
+ self._successor[pred] += (succ,)
3450
+ for root in set(self._successor).difference(self._predecessor):
3451
+ self._predecessor[root] = ()
3452
+ for leaf in set(self._predecessor).difference(self._successor):
3453
+ self._successor[leaf] = ()
3454
+
3455
+ if self.__check_for_cycles():
3456
+ raise ValueError(
3457
+ "Cycles were found in the tree, there should not be any."
3458
+ )
3459
+
3460
+ if pos is None:
3461
+ self.pos = {}
3462
+ else:
3463
+ if self.nodes.difference(pos) != set():
3464
+ raise ValueError("Please provide the position of all nodes.")
3465
+ self.pos = {
3466
+ node: np.array(position) for node, position in pos.items()
3467
+ }
3468
+
3469
+ if time is None:
3470
+ if starting_time is None:
3471
+ starting_time = 0
3472
+ if not isinstance(starting_time, int):
3473
+ warnings.warn(
3474
+ f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
3475
+ stacklevel=2,
2815
3476
  )
2816
- elif file_type == "mastodon":
2817
- if isinstance(file_format, list) and len(file_format) == 2:
2818
- self.read_from_mastodon_csv(file_format)
3477
+ self._time = {node: starting_time for node in self.roots}
3478
+ queue = list(self.roots)
3479
+ for node in queue:
3480
+ for succ in self._successor[node]:
3481
+ self._time[succ] = self._time[node] + 1
3482
+ queue.append(succ)
3483
+ else:
3484
+ if starting_time is not None:
3485
+ warnings.warn(
3486
+ "Both `time` and `starting_time` were provided, `starting_time` was ignored.",
3487
+ stacklevel=2,
3488
+ )
3489
+ self._time = {n: int(time[n]) for n in self.nodes}
3490
+ if self._time != time:
3491
+ if len(self._time) != len(time):
3492
+ warnings.warn(
3493
+ "The provided `time` dictionary had keys that were not nodes. "
3494
+ "They have been removed",
3495
+ stacklevel=2,
3496
+ )
2819
3497
  else:
2820
- if isinstance(file_format, list):
2821
- file_format = file_format[0]
2822
- self.read_from_mastodon(file_format, name)
2823
- elif file_type == "astec":
2824
- self.read_from_ASTEC(file_format, eigen)
2825
- elif file_type == "csv":
2826
- self.read_from_csv(file_format, z_mult, link=1, delim=delim)
2827
- elif file_format and file_format.endswith(".lT"):
2828
- with open(file_format, "br") as f:
2829
- tmp = pkl.load(f)
2830
- f.close()
2831
- self.__dict__.update(tmp.__dict__)
2832
- elif file_format is not None:
2833
- self.read_from_binary(file_format)
2834
- if self.name is None:
2835
- try:
2836
- self.name = Path(file_format).stem
2837
- except:
2838
- self.name = Path(file_format[0]).stem
2839
- if [] in self.successor.values():
2840
- successors = list(self.successor.keys())
2841
- for succ in successors:
2842
- if self[succ] == []:
2843
- self.successor.pop(succ)
2844
- if [] in self.predecessor.values():
2845
- predecessors = list(self.predecessor.keys())
2846
- for succ in predecessors:
2847
- if self[succ] == []:
2848
- self.predecessor.pop(succ)
3498
+ warnings.warn(
3499
+ "The provided `time` dictionary had values that were not `int`. "
3500
+ "These values have been truncated and converted to `int`",
3501
+ stacklevel=2,
3502
+ )
3503
+ if self.nodes.symmetric_difference(self._time) != set():
3504
+ raise ValueError(
3505
+ "Please provide the time of all nodes and only existing nodes."
3506
+ )
3507
+ if not all(
3508
+ self._time[node] < self._time[s]
3509
+ for node, succ in self._successor.items()
3510
+ for s in succ
3511
+ ):
3512
+ raise ValueError(
3513
+ "Provided times are not strictly increasing. Setting times to default."
3514
+ )
3515
+ # custom properties
3516
+ for name, d in kwargs.items():
3517
+ if name in self.__dict__:
3518
+ warnings.warn(
3519
+ f"Attribute name {name} is reserved.", stacklevel=2
3520
+ )
3521
+ continue
3522
+ setattr(self, name, d)
3523
+ if not hasattr(self, "_comparisons"):
3524
+ self._comparisons = {}