LineageTree 1.8.0__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,643 +2,542 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.metadata
5
9
  import os
6
10
  import pickle as pkl
7
11
  import struct
8
12
  import warnings
9
- from collections.abc import Iterable
10
- from functools import partial
13
+ from collections.abc import Callable, Iterable, Sequence
14
+ from functools import partial, wraps
11
15
  from itertools import combinations
12
16
  from numbers import Number
13
- from pathlib import Path
14
- from typing import TextIO, Union
15
-
16
- from .loaders import lineageTreeLoaders
17
- from .tree_styles import tree_style
18
-
19
- try:
20
- from edist import uted
21
- except ImportError:
22
- warnings.warn(
23
- "No edist installed therefore you will not be able to compute the tree edit distance.",
24
- stacklevel=2,
25
- )
17
+ from types import MappingProxyType
18
+ from typing import TYPE_CHECKING, Literal
19
+
20
+ import matplotlib.colors as mcolors
26
21
  import matplotlib.pyplot as plt
27
22
  import numpy as np
23
+ import svgwrite
24
+ from edist import uted
25
+ from matplotlib import colormaps
26
+ from matplotlib.collections import LineCollection
27
+ from packaging.version import Version
28
28
  from scipy.interpolate import InterpolatedUnivariateSpline
29
- from scipy.spatial import Delaunay, distance
30
- from scipy.spatial import cKDTree as KDTree
29
+ from scipy.sparse import dok_array
30
+ from scipy.spatial import Delaunay, KDTree, distance
31
31
 
32
+ from . import __version__
33
+ from .tree_approximation import TreeApproximationTemplate, tree_style
32
34
  from .utils import (
33
- create_links_and_cycles,
35
+ convert_style_to_number,
36
+ create_links_and_chains,
34
37
  hierarchical_pos,
35
38
  )
36
39
 
40
+ if TYPE_CHECKING:
41
+ from edist.alignment import Alignment
42
+
43
+
44
+ class dynamic_property(property):
45
+ def __init__(
46
+ self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
47
+ ):
48
+ super().__init__(fget, fset, fdel, doc)
49
+ self.protected_name = protected_name
50
+
51
+ def __set_name__(self, owner, name):
52
+ self.name = name
53
+ if self.protected_name is None:
54
+ self.protected_name = f"_{name}"
55
+ if not hasattr(owner, "_protected_dynamic_properties"):
56
+ owner._protected_dynamic_properties = []
57
+ owner._protected_dynamic_properties.append(self.protected_name)
58
+ if not hasattr(owner, "_dynamic_properties"):
59
+ owner._dynamic_properties = []
60
+ owner._dynamic_properties += [name, self.protected_name]
61
+ setattr(owner, self.protected_name, None)
62
+
63
+ def __get__(self, instance, owner):
64
+ if instance is None:
65
+ return self
66
+ instance._has_been_reset = False
67
+ if getattr(instance, self.protected_name) is None:
68
+ value = super().__get__(instance, owner)
69
+ setattr(instance, self.protected_name, value)
70
+ return value
71
+ else:
72
+ return getattr(instance, self.protected_name)
73
+
74
+
75
+ class lineageTree:
76
+ norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
77
+
78
+ def modifier(wrapped_func):
79
+ @wraps(wrapped_func)
80
+ def raising_flag(self, *args, **kwargs):
81
+ should_reset = (
82
+ not hasattr(self, "_has_been_reset")
83
+ or not self._has_been_reset
84
+ )
85
+ out_func = wrapped_func(self, *args, **kwargs)
86
+ if should_reset:
87
+ for prop in self._protected_dynamic_properties:
88
+ self.__dict__[prop] = None
89
+ self._has_been_reset = True
90
+ return out_func
91
+
92
+ return raising_flag
93
+
94
+ def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
95
+ """Check if the connected component of a given node `n` has a cycle.
96
+
97
+ Returns
98
+ -------
99
+ bool
100
+ True if the tree has cycles, False otherwise.
101
+ set of int
102
+ The set of nodes that have been checked.
103
+ """
104
+ to_do = [n]
105
+ no_cycle = True
106
+ already_done = set()
107
+ while to_do and no_cycle:
108
+ current = to_do.pop(-1)
109
+ if current not in already_done:
110
+ already_done.add(current)
111
+ else:
112
+ no_cycle = False
113
+ to_do.extend(self._successor[current])
114
+ to_do = list(self._predecessor[n])
115
+ while to_do and no_cycle:
116
+ current = to_do.pop(-1)
117
+ if current not in already_done:
118
+ already_done.add(current)
119
+ else:
120
+ no_cycle = False
121
+ to_do.extend(self._predecessor[current])
122
+ return not no_cycle, already_done
123
+
124
+ def __check_for_cycles(self) -> bool:
125
+ """Check if the tree has cycles.
126
+
127
+ Returns
128
+ -------
129
+ bool
130
+ True if the tree has cycles, False otherwise.
131
+ """
132
+ to_do = set(self.nodes)
133
+ found_cycle = False
134
+ while to_do and not found_cycle:
135
+ current = to_do.pop()
136
+ found_cycle, done = self.__check_cc_cycles(current)
137
+ to_do.difference_update(done)
138
+ return found_cycle
37
139
 
38
- class lineageTree(lineageTreeLoaders):
39
- def __eq__(self, other):
140
+ def __eq__(self, other) -> bool:
40
141
  if isinstance(other, lineageTree):
41
- return other.successor == self.successor
42
- return False
142
+ return (
143
+ other._successor == self._successor
144
+ and other._predecessor == self._predecessor
145
+ and other._time == self._time
146
+ )
147
+ else:
148
+ return False
43
149
 
44
- def get_next_id(self):
45
- """Computes the next authorized id.
150
+ def get_next_id(self) -> int:
151
+ """Computes the next authorized id and assign it.
46
152
 
47
- Returns:
48
- int: next authorized id
153
+ Returns
154
+ -------
155
+ int
156
+ next authorized id
49
157
  """
50
- if self.max_id == -1 and self.nodes:
51
- self.max_id = max(self.nodes)
52
- if self.next_id == []:
158
+ if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
159
+ self.max_id = max(self.nodes) if len(self.nodes) else 0
160
+ if not hasattr(self, "next_id") or self.next_id == []:
53
161
  self.max_id += 1
54
162
  return self.max_id
55
163
  else:
56
164
  return self.next_id.pop()
57
165
 
58
- def complete_lineage(self, nodes: Union[int, set] = None):
59
- """Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
60
- for tree edit distance algorithms.
61
-
62
- Args:
63
- nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
64
- """
65
- if nodes is None:
66
- nodes = set(self.roots)
67
- elif isinstance(nodes, int):
68
- nodes = {nodes}
69
- for node in nodes:
70
- sub = set(self.get_sub_tree(node))
71
- specific_leaves = sub.intersection(self.leaves)
72
- for leaf in specific_leaves:
73
- self.add_branch(
74
- leaf,
75
- (self.t_e - self.time[leaf]),
76
- reverse=True,
77
- move_timepoints=True,
78
- )
79
-
80
166
  ###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
81
- def add_branch(
167
+ @modifier
168
+ def add_chain(
82
169
  self,
83
- pred: int,
170
+ node: int,
84
171
  length: int,
85
- move_timepoints: bool = True,
86
- pos: Union[callable, None] = None,
87
- reverse: bool = False,
88
- ):
89
- """Adds a branch of specific length to a node either as a successor or as a predecessor.
172
+ downstream: bool,
173
+ pos: Callable | None = None,
174
+ ) -> int:
175
+ """Adds a chain of specific length to a node either as a successor or as a predecessor.
90
176
  If it is placed on top of a tree all the nodes will move timepoints #length down.
91
177
 
92
- Args:
93
- pred (int): Id of the successor (predecessor if reverse is False)
94
- length (int): The length of the new branch.
95
- pos (np.ndarray, optional): The new position of the branch. Defaults to None.
96
- move_timepoints (bool): Moves the time, important only if reverse= True
97
- reverese (bool): If True will create a branch that goes forwards in time otherwise backwards.
98
- Returns:
99
- (int): Id of the first node of the sublineage.
178
+ Parameters
179
+ ----------
180
+ node : int
181
+ Id of the successor (predecessor if `downstream==False`)
182
+ length : int
183
+ The length of the new chain.
184
+ downstream : bool, default=True
185
+ If `True` will create a chain that goes forwards in time otherwise backwards.
186
+ pos : np.ndarray, optional
187
+ The new position of the chain. Defaults to None.
188
+
189
+ Returns
190
+ -------
191
+ int
192
+ Id of the first node of the sublineage.
100
193
  """
101
194
  if length == 0:
102
- return pred
103
- if self.predecessor.get(pred) and not reverse:
104
- raise Warning("Cannot add 2 predecessors to a node")
105
- time = self.time[pred]
106
- original = pred
107
- if not reverse:
108
- if move_timepoints:
109
- nodes_to_move = set(self.get_sub_tree(pred))
110
- new_times = {
111
- node: self.time[node] + length for node in nodes_to_move
112
- }
113
- for node in nodes_to_move:
114
- old_time = self.time[node]
115
- self.time_nodes[old_time].remove(node)
116
- self.time_nodes.setdefault(old_time + length, set()).add(
117
- node
118
- )
119
- self.time.update(new_times)
120
- for t in range(length - 1, -1, -1):
121
- _next = self.add_node(
122
- time + t,
123
- succ=pred,
124
- pos=self.pos[original],
125
- reverse=True,
126
- )
127
- pred = _next
128
- else:
129
- for t in range(length):
130
- _next = self.add_node(
131
- time - t,
132
- succ=pred,
133
- pos=self.pos[original],
134
- reverse=True,
135
- )
136
- pred = _next
195
+ return node
196
+ if length < 1:
197
+ raise ValueError("Length cannot be <1")
198
+ if downstream:
199
+ for _ in range(int(length)):
200
+ old_node = node
201
+ node = self._add_node(pred=[old_node])
202
+ self._time[node] = self._time[old_node] + 1
137
203
  else:
138
- for _ in range(length):
139
- _next = self.add_node(
140
- self.time[pred] + 1,
141
- succ=pred,
142
- pos=self.pos[original],
143
- reverse=False,
204
+ if self._predecessor[node]:
205
+ raise Warning("The node already has a predecessor.")
206
+ if self._time[node] - length < self.t_b:
207
+ raise Warning(
208
+ "A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
144
209
  )
145
- pred = _next
146
- self.successor[self.get_cycle(pred)[-1]] = []
147
- self.labels[pred] = "New branch"
148
- if self.time[pred] == self.t_b:
149
- self.labels[pred] = "New branch"
150
- if original in self.roots and reverse is True:
151
- self.labels[pred] = "New branch"
152
- self.labels.pop(original, -1)
153
- self.t_e = max(self.time_nodes)
154
- return pred
155
-
156
- def cut_tree(self, root):
157
- """It transforms a lineage that has at least 2 divisions into 2 independent lineages,
158
- that spawn from the time point of the first node. (splits a tree into 2)
159
-
160
- Args:
161
- root (int): The id of the node, which will be cut.
162
-
163
- Returns:
164
- int: The id of the new tree
165
- """
166
- cycle = self.get_successors(root)
167
- last_cell = cycle[-1]
168
- if last_cell in self.successor:
169
- new_lT = self.successor[last_cell].pop()
170
- self.predecessor.pop(new_lT)
171
- label_of_root = self.labels.get(cycle[0], cycle[0])
172
- self.labels[cycle[0]] = f"L-Split {label_of_root}"
173
- new_tr = self.add_branch(new_lT, len(cycle), move_timepoints=False)
174
- self.roots.add(new_tr)
175
- self.labels[new_tr] = f"R-Split {label_of_root}"
176
- return new_tr
177
- else:
178
- raise Warning("No division of the branch")
179
-
180
- def fuse_lineage_tree(
181
- self,
182
- l1_root: int,
183
- l2_root: int,
184
- length_l1: int = 0,
185
- length_l2: int = 0,
186
- length: int = 1,
187
- ):
188
- """Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
189
- first node and the node of the resulting lineage can also be longer.
190
-
191
- Args:
192
- l1_root (int): Id of the first root
193
- l2_root (int): Id of the second root
194
- length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
195
- length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
196
- length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
197
-
198
- Returns:
199
- int: The id of the root of the new lineage.
200
- """
201
- if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
202
- raise ValueError("Please select 2 roots.")
203
- if self.time[l1_root] != self.time[l2_root]:
204
- warnings.warn(
205
- "Using lineagetrees that do not exist in the same timepoint. The operation will continue"
206
- )
207
- new_root1 = self.add_branch(l1_root, length_l1)
208
- new_root2 = self.add_branch(l2_root, length_l2)
209
- next_root1 = self[new_root1][0]
210
- self.remove_nodes(new_root1)
211
- self.successor[new_root2].append(next_root1)
212
- self.predecessor[next_root1] = [new_root2]
213
- new_branch = self.add_branch(
214
- new_root2,
215
- length - 1,
216
- )
217
- self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
218
- return new_branch
219
-
220
- def copy_lineage(self, root):
221
- """
222
- Copies the structure of a tree and makes a new with new nodes.
223
- Warning does not take into account the predecessor of the root node.
224
-
225
- Args:
226
- root (int): The root of the tree to be copied
210
+ for _ in range(int(length)):
211
+ old_node = node
212
+ node = self._add_node(succ=[old_node])
213
+ self._time[node] = self._time[old_node] - 1
214
+ return node
215
+
216
+ @modifier
217
+ def add_root(self, t: int, pos: list | None = None) -> int:
218
+ """Adds a root to a specific timepoint.
219
+
220
+ Parameters
221
+ ----------
222
+ t :int
223
+ The timepoint the node is going to be added.
224
+ pos : list
225
+ The position of the new node.
226
+ Returns
227
+ -------
228
+ int
229
+ The id of the new root.
230
+ """
231
+ C_next = self.get_next_id()
232
+ self._successor[C_next] = ()
233
+ self._predecessor[C_next] = ()
234
+ self._time[C_next] = t
235
+ self.pos[C_next] = pos if isinstance(pos, list) else []
236
+ self._changed_roots = True
237
+ return C_next
227
238
 
228
- Returns:
229
- int: The root of the new tree.
230
- """
231
- new_nodes = {
232
- old_node: self.get_next_id()
233
- for old_node in self.get_sub_tree(root)
234
- }
235
- self.nodes.update(new_nodes.values())
236
- for old_node, new_node in new_nodes.items():
237
- self.time[new_node] = self.time[old_node]
238
- succ = self.successor.get(old_node)
239
- if succ:
240
- self.successor[new_node] = [new_nodes[n] for n in succ]
241
- pred = self.predecessor.get(old_node)
242
- if pred:
243
- self.predecessor[new_node] = [new_nodes[n] for n in pred]
244
- self.pos[new_node] = self.pos[old_node] + 0.5
245
- self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
246
- new_root = new_nodes[root]
247
- self.labels[new_root] = f"Copy of {root}"
248
- if self.time[new_root] == 0:
249
- self.roots.add(new_root)
250
- return new_root
251
-
252
- def add_node(
239
+ def _add_node(
253
240
  self,
254
- t: int = None,
255
- succ: int = None,
256
- pos: np.ndarray = None,
257
- nid: int = None,
258
- reverse: bool = False,
241
+ succ: list | None = None,
242
+ pred: list | None = None,
243
+ pos: np.ndarray | None = None,
244
+ nid: int | None = None,
259
245
  ) -> int:
260
- """Adds a node to the lineageTree and update it accordingly.
261
-
262
- Args:
263
- t (int): int, time to which to add the node
264
- succ (int): id of the node the new node is a successor to
265
- pos ([float, ]): list of three floats representing the 3D
266
- spatial position of the node
267
- nid (int): id value of the new node, to be used carefully,
268
- if None is provided the new id is automatically computed.
269
- reverse (bool): True if in this lineageTree the predecessors
270
- are the successors and reciprocally.
271
- This is there for bacward compatibility, should be left at False.
272
- Returns:
273
- int: id of the new node.
274
- """
246
+ """Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
247
+ Does not handle time! You cannot enter both a successor and a predecessor.
248
+
249
+ Parameters
250
+ ----------
251
+ succ : list
252
+ list of ids of the nodes the new node is a successor to
253
+ pred : list
254
+ list of ids of the nodes the new node is a predecessor to
255
+ pos : np.ndarray, optional
256
+ position of the new node
257
+ nid : int, optional
258
+ id value of the new node, to be used carefully,
259
+ if None is provided the new id is automatically computed.
260
+
261
+ Returns
262
+ -------
263
+ int
264
+ id of the new node.
265
+ """
266
+ if not succ and not pred:
267
+ raise Warning(
268
+ "Please enter a successor or a predecessor, otherwise use the add_roots() function."
269
+ )
275
270
  C_next = self.get_next_id() if nid is None else nid
276
- self.time_nodes.setdefault(t, set()).add(C_next)
277
- if succ is not None and not reverse:
278
- self.successor.setdefault(succ, []).append(C_next)
279
- self.predecessor.setdefault(C_next, []).append(succ)
280
- elif succ is not None:
281
- self.predecessor.setdefault(succ, []).append(C_next)
282
- self.successor.setdefault(C_next, []).append(succ)
283
- self.nodes.add(C_next)
284
- self.pos[C_next] = pos
285
- self.time[C_next] = t
271
+ if succ:
272
+ self._successor[C_next] = succ
273
+ for suc in succ:
274
+ self._predecessor[suc] = (C_next,)
275
+ else:
276
+ self._successor[C_next] = ()
277
+ if pred:
278
+ self._predecessor[C_next] = pred
279
+ self._successor[pred[0]] = self._successor.setdefault(
280
+ pred[0], ()
281
+ ) + (C_next,)
282
+ else:
283
+ self._predecessor[C_next] = ()
284
+ if isinstance(pos, list):
285
+ self.pos[C_next] = pos
286
286
  return C_next
287
287
 
288
- def remove_nodes(self, group: Union[int, set, list]):
288
+ @modifier
289
+ def remove_nodes(self, group: int | set | list) -> None:
289
290
  """Removes a group of nodes from the LineageTree
290
291
 
291
- Args:
292
- group (set|list|int): One or more nodes that are to be removed.
292
+ Parameters
293
+ ----------
294
+ group : set of int or list of int or int
295
+ One or more nodes that are to be removed.
293
296
  """
294
- if isinstance(group, int):
297
+ if isinstance(group, int | float):
295
298
  group = {group}
296
299
  if isinstance(group, list):
297
300
  group = set(group)
298
- group = group.intersection(self.nodes)
299
- self.nodes.difference_update(group)
300
- times = {self.time.pop(n) for n in group}
301
- for t in times:
302
- self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
301
+ group = self.nodes.intersection(group)
303
302
  for node in group:
304
- self.pos.pop(node)
305
- if self.predecessor.get(node):
306
- pred = self.predecessor[node][0]
307
- siblings = self.successor.pop(pred, [])
308
- if len(siblings) == 2:
309
- siblings.remove(node)
310
- self.successor[pred] = siblings
311
- self.predecessor.pop(node, [])
312
- for succ in self.successor.get(node, []):
313
- self.predecessor.pop(succ, [])
314
- self.successor.pop(node, [])
315
- self.labels.pop(node, 0)
316
- if node in self.roots:
317
- self.roots.remove(node)
318
-
319
- def modify_branch(self, node, new_length):
320
- """Changes the length of a branch, so it adds or removes nodes
321
- to make the correct length of the cycle.
322
-
323
- Args:
324
- node (int): Any node of the branch to be modified/
325
- new_length (int): The new length of the tree.
326
- """
327
- if new_length <= 1:
328
- warnings.warn("New length should be more than 1", stacklevel=2)
329
- return None
330
- cycle = self.get_cycle(node)
331
- length = len(cycle)
332
- successors = self.successor.get(cycle[-1])
333
- if length == 1 and new_length != 1:
334
- pred = self.predecessor.pop(node, None)
335
- new_node = self.add_branch(
336
- node,
337
- length=new_length - 1,
338
- move_timepoints=True,
339
- reverse=False,
340
- )
341
- if pred:
342
- self.successor[pred[0]].remove(node)
343
- self.successor[pred[0]].append(new_node)
344
- elif self.leaves.intersection(cycle) and new_length < length:
345
- self.remove_nodes(cycle[new_length:])
346
- elif new_length < length:
347
- to_remove = length - new_length
348
- last_cell = cycle[new_length - 1]
349
- subtree = self.get_sub_tree(cycle[-1])[1:]
350
- self.remove_nodes(cycle[new_length:])
351
- self.successor[last_cell] = successors
352
- if successors:
353
- for succ in successors:
354
- self.predecessor[succ] = [last_cell]
355
- for node in subtree:
356
- if node not in cycle[new_length - 1 :]:
357
- old_time = self.time[node]
358
- self.time[node] = old_time - to_remove
359
- self.time_nodes[old_time].remove(node)
360
- self.time_nodes.setdefault(
361
- old_time - to_remove, set()
362
- ).add(node)
363
- elif length < new_length:
364
- to_add = new_length - length
365
- last_cell = cycle[-1]
366
- self.successor.pop(cycle[-2])
367
- self.predecessor.pop(last_cell)
368
- succ = self.add_branch(
369
- last_cell, length=to_add, move_timepoints=True, reverse=False
370
- )
371
- self.predecessor[succ] = [cycle[-2]]
372
- self.successor[cycle[-2]] = [succ]
373
- self.time[last_cell] = (
374
- self.time[self.predecessor[last_cell][0]] + 1
375
- )
376
- else:
377
- return None
378
-
379
- @property
380
- def time_resolution(self):
381
- if not hasattr(self, "_time_resolution"):
382
- self.time_resolution = 1
383
- return self._time_resolution / 10
384
-
385
- @time_resolution.setter
386
- def time_resolution(self, time_resolution):
387
- if time_resolution is not None:
388
- self._time_resolution = int(time_resolution * 10)
389
- else:
390
- warnings.warn("Time resolution set to default 0", stacklevel=2)
391
- self._time_resolution = 10
392
-
393
- @property
394
- def depth(self):
395
- if not hasattr(self, "_depth"):
396
- self._depth = {}
397
- for leaf in self.leaves:
398
- self._depth[leaf] = 1
399
- while leaf in self.predecessor:
400
- parent = self.predecessor[leaf][0]
401
- current_depth = self._depth.get(parent, 0)
402
- self._depth[parent] = max(
403
- self._depth[leaf] + 1, current_depth
404
- )
405
- leaf = parent
406
- for root in self.roots - set(self._depth):
407
- self._depth[root] = 1
408
- return self._depth
303
+ for attr in self.__dict__:
304
+ attr_value = self.__getattribute__(attr)
305
+ if isinstance(attr_value, dict) and attr not in [
306
+ "successor",
307
+ "predecessor",
308
+ "_successor",
309
+ "_predecessor",
310
+ ]:
311
+ attr_value.pop(node, ())
312
+ if self._predecessor.get(node):
313
+ self._successor[self._predecessor[node][0]] = tuple(
314
+ set(
315
+ self._successor[self._predecessor[node][0]]
316
+ ).difference(group)
317
+ )
318
+ for p_node in self._successor.get(node, []):
319
+ self._predecessor[p_node] = ()
320
+ self._predecessor.pop(node, ())
321
+ self._successor.pop(node, ())
409
322
 
410
323
  @property
411
- def roots(self):
412
- return set(self.nodes).difference(self.predecessor)
324
+ def successor(self) -> MappingProxyType[int, tuple[int]]:
325
+ """The successor of the tree."""
326
+ if not hasattr(self, "_protected_successor"):
327
+ self._protected_successor = MappingProxyType(self._successor)
328
+ return self._protected_successor
413
329
 
414
330
  @property
415
- def edges(self):
416
- return {(k, vi) for k, v in self.successor.items() for vi in v}
331
+ def predecessor(self) -> MappingProxyType[int, tuple[int]]:
332
+ """The predecessor of the tree."""
333
+ if not hasattr(self, "_protected_predecessor"):
334
+ self._protected_predecessor = MappingProxyType(self._predecessor)
335
+ return self._protected_predecessor
417
336
 
418
337
  @property
419
- def leaves(self):
420
- return {p for p, s in self.successor.items() if s == []}
338
+ def time(self) -> MappingProxyType[int, int]:
339
+ """The time of the tree."""
340
+ if not hasattr(self, "_protected_time"):
341
+ self._protected_time = MappingProxyType(self._time)
342
+ return self._protected_time
343
+
344
+ @dynamic_property
345
+ def t_b(self) -> int:
346
+ """The first timepoint of the tree."""
347
+ return min(self._time.values())
348
+
349
+ @dynamic_property
350
+ def t_e(self) -> int:
351
+ """The last timepoint of the tree."""
352
+ return max(self._time.values())
353
+
354
+ @dynamic_property
355
+ def nodes(self) -> frozenset[int]:
356
+ """Nodes of the tree"""
357
+ return frozenset(self._successor.keys())
358
+
359
+ @dynamic_property
360
+ def number_of_nodes(self) -> int:
361
+ return len(self.nodes)
362
+
363
+ @dynamic_property
364
+ def depth(self) -> dict[int, int]:
365
+ """The depth of each node in the tree."""
366
+ _depth = {}
367
+ for leaf in self.leaves:
368
+ _depth[leaf] = 1
369
+ while leaf in self._predecessor and self._predecessor[leaf]:
370
+ parent = self._predecessor[leaf][0]
371
+ current_depth = _depth.get(parent, 0)
372
+ _depth[parent] = max(_depth[leaf] + 1, current_depth)
373
+ leaf = parent
374
+ for root in self.roots - set(_depth):
375
+ _depth[root] = 1
376
+ return _depth
377
+
378
+ @dynamic_property
379
+ def roots(self) -> frozenset[int]:
380
+ """Set of roots of the tree"""
381
+ return frozenset({s for s, p in self._predecessor.items() if p == ()})
382
+
383
+ @dynamic_property
384
+ def leaves(self) -> frozenset[int]:
385
+ """Set of leaves"""
386
+ return frozenset({p for p, s in self._successor.items() if s == ()})
387
+
388
+ @dynamic_property
389
+ def edges(self) -> tuple[tuple[int, int]]:
390
+ """Set of edges"""
391
+ return tuple((p, si) for p, s in self._successor.items() for si in s)
421
392
 
422
393
  @property
423
- def labels(self):
394
+ def labels(self) -> dict[int, str]:
395
+ """The labels of the nodes."""
424
396
  if not hasattr(self, "_labels"):
425
- if hasattr(self, "cell_name"):
397
+ if hasattr(self, "node_name"):
426
398
  self._labels = {
427
- i: self.cell_name.get(i, "Unlabeled") for i in self.roots
399
+ i: self.node_name.get(i, "Unlabeled") for i in self.roots
428
400
  }
429
401
  else:
430
402
  self._labels = {
431
403
  root: "Unlabeled"
432
404
  for root in self.roots
433
405
  for leaf in self.find_leaves(root)
434
- if abs(self.time[leaf] - self.time[root])
406
+ if abs(self._time[leaf] - self._time[root])
435
407
  >= abs(self.t_e - self.t_b) / 4
436
408
  }
437
409
  return self._labels
438
410
 
439
- def _write_header_am(self, f: TextIO, nb_points: int, length: int):
440
- """Header for Amira .am files"""
441
- f.write("# AmiraMesh 3D ASCII 2.0\n")
442
- f.write("define VERTEX %d\n" % (nb_points * 2))
443
- f.write("define EDGE %d\n" % nb_points)
444
- f.write("define POINT %d\n" % ((length) * nb_points))
445
- f.write("Parameters {\n")
446
- f.write('\tContentType "HxSpatialGraph"\n')
447
- f.write("}\n")
448
-
449
- f.write("VERTEX { float[3] VertexCoordinates } @1\n")
450
- f.write("EDGE { int[2] EdgeConnectivity } @2\n")
451
- f.write("EDGE { int NumEdgePoints } @3\n")
452
- f.write("POINT { float[3] EdgePointCoordinates } @4\n")
453
- f.write("VERTEX { float Vcolor } @5\n")
454
- f.write("VERTEX { int Vbool } @6\n")
455
- f.write("EDGE { float Ecolor } @7\n")
456
- f.write("VERTEX { int Vbool2 } @8\n")
457
-
458
- def write_to_am(
459
- self,
460
- path_format: str,
461
- t_b: int = None,
462
- t_e: int = None,
463
- length: int = 5,
464
- manual_labels: dict = None,
465
- default_label: int = 5,
466
- new_pos: np.ndarray = None,
467
- ):
468
- """Writes a lineageTree into an Amira readable data (.am format).
469
-
470
- Args:
471
- path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
472
- t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
473
- t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
474
- note, if there is no 'to_take_time' attribute, self.time_nodes
475
- is considered instead (historical)
476
- length (int): length of the track to print (how many time before).
477
- manual_labels ({id: label, }): dictionary that maps cell ids to
478
- default_label (int): default value for the manual label
479
- new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
480
- if new_pos == None (default) then self.pos is considered.
481
- """
482
- if not hasattr(self, "to_take_time"):
483
- self.to_take_time = self.time_nodes
484
- if t_b is None:
485
- t_b = min(self.to_take_time.keys())
486
- if t_e is None:
487
- t_e = max(self.to_take_time.keys())
488
- if new_pos is None:
489
- new_pos = self.pos
490
-
491
- if manual_labels is None:
492
- manual_labels = {}
493
- for t in range(t_b, t_e + 1):
494
- with open(path_format % t, "w") as f:
495
- nb_points = len(self.to_take_time[t])
496
- self._write_header_am(f, nb_points, length)
497
- points_v = {}
498
- for C in self.to_take_time[t]:
499
- C_tmp = C
500
- positions = []
501
- for _ in range(length):
502
- C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
503
- positions.append(new_pos[C_tmp])
504
- points_v[C] = positions
505
-
506
- f.write("@1\n")
507
- for C in self.to_take_time[t]:
508
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
509
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
510
-
511
- f.write("@2\n")
512
- for i, _ in enumerate(self.to_take_time[t]):
513
- f.write("%d %d\n" % (2 * i, 2 * i + 1))
514
-
515
- f.write("@3\n")
516
- for _ in self.to_take_time[t]:
517
- f.write("%d\n" % (length))
518
-
519
- f.write("@4\n")
520
- for C in self.to_take_time[t]:
521
- for p in points_v[C]:
522
- f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
523
-
524
- f.write("@5\n")
525
- for C in self.to_take_time[t]:
526
- f.write(f"{manual_labels.get(C, default_label):f}\n")
527
- f.write(f"{0:f}\n")
528
-
529
- f.write("@6\n")
530
- for C in self.to_take_time[t]:
531
- f.write(
532
- "%d\n"
533
- % (
534
- int(
535
- manual_labels.get(C, default_label)
536
- != default_label
537
- )
538
- )
539
- )
540
- f.write("%d\n" % (0))
541
-
542
- f.write("@7\n")
543
- for C in self.to_take_time[t]:
544
- f.write(
545
- f"{np.linalg.norm(points_v[C][0] - points_v[C][-1]):f}\n"
546
- )
547
-
548
- f.write("@8\n")
549
- for _ in self.to_take_time[t]:
550
- f.write("%d\n" % (1))
551
- f.write("%d\n" % (0))
552
- f.close()
411
+ @property
412
+ def time_resolution(self) -> float:
413
+ if not hasattr(self, "_time_resolution"):
414
+ self._time_resolution = 0
415
+ return self._time_resolution / 10
553
416
 
554
- def _get_height(self, c: int, done: dict):
555
- """Recursively computes the height of a cell within a tree * a space factor.
417
+ @time_resolution.setter
418
+ def time_resolution(self, time_resolution) -> None:
419
+ if time_resolution is not None and time_resolution > 0:
420
+ self._time_resolution = int(time_resolution * 10)
421
+ else:
422
+ warnings.warn("Time resolution set to default 0", stacklevel=2)
423
+ self._time_resolution = 0
424
+
425
+ def __setstate__(self, state):
426
+ if "_successor" not in state:
427
+ state["_successor"] = state["successor"]
428
+ if "_predecessor" not in state:
429
+ state["_predecessor"] = state["predecessor"]
430
+ if "_time" not in state:
431
+ state["_time"] = state["time"]
432
+ self.__dict__.update(state)
433
+
434
+ def _get_height(self, c: int, done: dict) -> float:
435
+ """Recursively computes the height of a node within a tree times a space factor.
556
436
  This function is specific to the function write_to_svg.
557
437
 
558
- Args:
559
- c (int): id of a cell in a lineage tree from which the height will be computed from
560
- done ({int: [int, int]}): a dictionary that maps a cell id to its vertical and horizontal position
561
- Returns:
562
- float:
438
+ Parameters
439
+ ----------
440
+ c : int
441
+ id of a node in a lineage tree from which the height will be computed from
442
+ done : dict mapping int to list of two int
443
+ a dictionary that maps a node id to its vertical and horizontal position
444
+
445
+ Returns
446
+ -------
447
+ float
448
+ the height of the node `c`
563
449
  """
564
450
  if c in done:
565
451
  return done[c][0]
566
452
  else:
567
453
  P = np.mean(
568
- [self._get_height(di, done) for di in self.successor[c]]
454
+ [self._get_height(di, done) for di in self._successor[c]]
569
455
  )
570
- done[c] = [P, self.vert_space_factor * self.time[c]]
456
+ done[c] = [P, self.vert_space_factor * self._time[c]]
571
457
  return P
572
458
 
573
459
  def write_to_svg(
574
460
  self,
575
461
  file_name: str,
576
- roots: list = None,
462
+ roots: list | None = None,
577
463
  draw_nodes: bool = True,
578
464
  draw_edges: bool = True,
579
- order_key: callable = None,
465
+ order_key: Callable | None = None,
580
466
  vert_space_factor: float = 0.5,
581
467
  horizontal_space: float = 1,
582
- node_size: callable = None,
583
- stroke_width: callable = None,
468
+ node_size: Callable | str | None = None,
469
+ stroke_width: Callable | None = None,
584
470
  factor: float = 1.0,
585
- node_color: callable = None,
586
- stroke_color: callable = None,
587
- positions: dict = None,
588
- node_color_map: callable = None,
589
- normalize: bool = True,
590
- ):
591
- ##### remove background? default True background value? default 1
592
-
471
+ node_color: Callable | str | None = None,
472
+ stroke_color: Callable | None = None,
473
+ positions: dict | None = None,
474
+ node_color_map: Callable | str | None = None,
475
+ ) -> None:
593
476
  """Writes the lineage tree to an SVG file.
594
477
  Node and edges coloring and size can be provided.
595
478
 
596
- Args:
597
- file_name (str): filesystem filename valid for `open()`
598
- roots ([int, ...]): list of node ids to be drawn. If `None` all the nodes will be drawn. Default `None`
599
- draw_nodes (bool): wether to print the nodes or not, default `True`
600
- draw_edges (bool): wether to print the edges or not, default `True`
601
- order_key (callable): function that would work for the attribute `key=` for the `sort`/`sorted` function
602
- vert_space_factor (float): the vertical position of a node is its time. `vert_space_factor` is a
603
- multiplier to space more or less nodes in time
604
- horizontal_space (float): space between two consecutive nodes
605
- node_size (callable | str): a function that maps a node id to a `float` value that will determine the
606
- radius of the node. The default function return the constant value `vertical_space_factor/2.1`
607
- If a string is given instead and it is a property of the tree,
608
- the the size will be mapped according to the property
609
- stroke_width (callable): a function that maps a node id to a `float` value that will determine the
610
- width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
611
- factor (float): scaling factor for nodes positions, default 1
612
- node_color (callable | str): a function that maps a node id to a triplet between 0 and 255.
613
- The triplet will determine the color of the node. If a string is given instead and it is a property
614
- of the tree, the the color will be mapped according to the property
615
- node_color_map (callable | str): the name of the colormap to use to color the nodes, or a colormap function
616
- stroke_color (callable): a function that maps a node id to a triplet between 0 and 255.
617
- The triplet will determine the color of the stroke of the inward edge.
618
- positions ({int: [float, float], ...}): dictionary that maps a node id to a 2D position.
619
- Default `None`. If provided it will be used to position the nodes.
479
+ Parameters
480
+ ----------
481
+ file_name : str
482
+ filesystem filename valid for `open()`
483
+ roots : list of int, defaults to `self.roots`
484
+ list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
485
+ draw_nodes : bool, default True
486
+ wether to print the nodes or not
487
+ draw_edges : bool, default True
488
+ wether to print the edges or not
489
+ order_key : Callable, optional
490
+ function that would work for the attribute `key=` for the `sort`/`sorted` function
491
+ vert_space_factor : float, default=0.5
492
+ the vertical position of a node is its time. `vert_space_factor` is a
493
+ multiplier to space more or less nodes in time
494
+ horizontal_space : float, default=1
495
+ space between two consecutive nodes
496
+ node_size : Callable or str, optional
497
+ a function that maps a node id to a `float` value that will determine the
498
+ radius of the node. The default function return the constant value `vertical_space_factor/2.1`
499
+ If a string is given instead and it is a property of the tree,
500
+ the the size will be mapped according to the property
501
+ stroke_width : Callable, optional
502
+ a function that maps a node id to a `float` value that will determine the
503
+ width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
504
+ factor : float, default=1.0
505
+ scaling factor for nodes positions, default 1
506
+ node_color : Callable or str, optional
507
+ a function that maps a node id to a triplet between 0 and 255.
508
+ The triplet will determine the color of the node. If a string is given instead and it is a property
509
+ of the tree, the the color will be mapped according to the property
510
+ node_color_map : Callable or str, optional
511
+ the name of the colormap to use to color the nodes, or a colormap function
512
+ stroke_color : Callable, optional
513
+ a function that maps a node id to a triplet between 0 and 255.
514
+ The triplet will determine the color of the stroke of the inward edge.
515
+ positions : dict mapping int to list of two float, optional
516
+ dictionary that maps a node id to a 2D position.
517
+ Default `None`. If provided it will be used to position the nodes.
620
518
  """
621
- import svgwrite
622
519
 
623
520
  def normalize_values(v, nodes, _range, shift, mult):
624
521
  min_ = np.percentile(v, 1)
625
522
  max_ = np.percentile(v, 99)
626
523
  values = _range * ((v - min_) / (max_ - min_)) + shift
627
- values_dict_nodes = dict(zip(nodes, values))
524
+ values_dict_nodes = dict(zip(nodes, values, strict=True))
628
525
  return lambda x: values_dict_nodes[x] * mult
629
526
 
630
527
  if roots is None:
631
528
  roots = self.roots
632
529
  if hasattr(self, "image_label"):
633
- roots = [cell for cell in roots if self.image_label[cell] != 1]
530
+ roots = [node for node in roots if self.image_label[node] != 1]
634
531
 
635
532
  if node_size is None:
636
533
 
637
534
  def node_size(x):
638
535
  return vert_space_factor / 2.1
639
536
 
640
- elif isinstance(node_size, str) and node_size in self.__dict__:
641
- values = np.array([self[node_size][c] for c in self.nodes])
537
+ else:
538
+ values = np.array(
539
+ [self._successor[node_size][c] for c in self.nodes]
540
+ )
642
541
  node_size = normalize_values(
643
542
  values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
644
543
  )
@@ -653,18 +552,19 @@ class lineageTree(lineageTreeLoaders):
653
552
  return 0, 0, 0
654
553
 
655
554
  elif isinstance(node_color, str) and node_color in self.__dict__:
656
- if isinstance(node_color_map, str):
657
- from matplotlib import colormaps
555
+ from matplotlib import colormaps
658
556
 
659
- if node_color_map in colormaps:
660
- node_color_map = colormaps[node_color_map]
661
- else:
662
- node_color_map = colormaps["viridis"]
663
- values = np.array([self[node_color][c] for c in self.nodes])
557
+ if node_color_map in colormaps:
558
+ cmap = colormaps[node_color_map]
559
+ else:
560
+ cmap = colormaps["viridis"]
561
+ values = np.array(
562
+ [self._successor[node_color][c] for c in self.nodes]
563
+ )
664
564
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
665
565
 
666
566
  def node_color(x):
667
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
567
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
668
568
 
669
569
  coloring_edges = stroke_color is not None
670
570
  if not coloring_edges:
@@ -673,24 +573,25 @@ class lineageTree(lineageTreeLoaders):
673
573
  return 0, 0, 0
674
574
 
675
575
  elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
676
- if isinstance(node_color_map, str):
677
- from matplotlib import colormaps
576
+ from matplotlib import colormaps
678
577
 
679
- if node_color_map in colormaps:
680
- node_color_map = colormaps[node_color_map]
681
- else:
682
- node_color_map = colormaps["viridis"]
683
- values = np.array([self[stroke_color][c] for c in self.nodes])
578
+ if node_color_map in colormaps:
579
+ cmap = colormaps[node_color_map]
580
+ else:
581
+ cmap = colormaps["viridis"]
582
+ values = np.array(
583
+ [self._successor[stroke_color][c] for c in self.nodes]
584
+ )
684
585
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
685
586
 
686
587
  def stroke_color(x):
687
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
588
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
688
589
 
689
590
  prev_x = 0
690
591
  self.vert_space_factor = vert_space_factor
691
592
  if order_key is not None:
692
593
  roots.sort(key=order_key)
693
- treated_cells = []
594
+ treated_nodes = []
694
595
 
695
596
  pos_given = positions is not None
696
597
  if not pos_given:
@@ -701,25 +602,26 @@ class lineageTree(lineageTreeLoaders):
701
602
  [0.0, 0.0],
702
603
  ]
703
604
  * len(self.nodes),
704
- )
605
+ strict=True,
606
+ ),
705
607
  )
706
608
  for _i, r in enumerate(roots):
707
609
  r_leaves = []
708
610
  to_do = [r]
709
611
  while len(to_do) != 0:
710
612
  curr = to_do.pop(0)
711
- treated_cells += [curr]
712
- if curr in self.successor:
613
+ treated_nodes += [curr]
614
+ if not self._successor[curr]:
713
615
  if order_key is not None:
714
- to_do += sorted(self.successor[curr], key=order_key)
616
+ to_do += sorted(self._successor[curr], key=order_key)
715
617
  else:
716
- to_do += self.successor[curr]
618
+ to_do += self._successor[curr]
717
619
  else:
718
620
  r_leaves += [curr]
719
621
  r_pos = {
720
622
  leave: [
721
623
  prev_x + horizontal_space * (1 + j),
722
- self.vert_space_factor * self.time[leave],
624
+ self.vert_space_factor * self._time[leave],
723
625
  ]
724
626
  for j, leave in enumerate(r_leaves)
725
627
  }
@@ -734,12 +636,12 @@ class lineageTree(lineageTreeLoaders):
734
636
  size=factor * np.max(list(positions.values()), axis=0),
735
637
  )
736
638
  if draw_edges and not draw_nodes and not coloring_edges:
737
- to_do = set(treated_cells)
639
+ to_do = set(treated_nodes)
738
640
  while len(to_do) > 0:
739
641
  curr = to_do.pop()
740
- c_cycle = self.get_cycle(curr)
741
- x1, y1 = positions[c_cycle[0]]
742
- x2, y2 = positions[c_cycle[-1]]
642
+ c_chain = self.get_chain_of_node(curr)
643
+ x1, y1 = positions[c_chain[0]]
644
+ x2, y2 = positions[c_chain[-1]]
743
645
  dwg.add(
744
646
  dwg.line(
745
647
  (factor * x1, factor * y1),
@@ -747,7 +649,7 @@ class lineageTree(lineageTreeLoaders):
747
649
  stroke=svgwrite.rgb(0, 0, 0),
748
650
  )
749
651
  )
750
- for si in self[c_cycle[-1]]:
652
+ for si in self._successor[c_chain[-1]]:
751
653
  x3, y3 = positions[si]
752
654
  dwg.add(
753
655
  dwg.line(
@@ -756,11 +658,11 @@ class lineageTree(lineageTreeLoaders):
756
658
  stroke=svgwrite.rgb(0, 0, 0),
757
659
  )
758
660
  )
759
- to_do.difference_update(c_cycle)
661
+ to_do.difference_update(c_chain)
760
662
  else:
761
- for c in treated_cells:
663
+ for c in treated_nodes:
762
664
  x1, y1 = positions[c]
763
- for si in self[c]:
665
+ for si in self._successor[c]:
764
666
  x2, y2 = positions[si]
765
667
  if draw_edges:
766
668
  dwg.add(
@@ -771,7 +673,7 @@ class lineageTree(lineageTreeLoaders):
771
673
  stroke_width=svgwrite.pt(stroke_width(si)),
772
674
  )
773
675
  )
774
- for c in treated_cells:
676
+ for c in treated_nodes:
775
677
  x1, y1 = positions[c]
776
678
  if draw_nodes:
777
679
  dwg.add(
@@ -788,49 +690,58 @@ class lineageTree(lineageTreeLoaders):
788
690
  fname: str,
789
691
  t_min: int = -1,
790
692
  t_max: int = np.inf,
791
- nodes_to_use: list = None,
693
+ nodes_to_use: list[int] | None = None,
792
694
  temporal: bool = True,
793
- spatial: str = None,
695
+ spatial: str | None = None,
794
696
  write_layout: bool = True,
795
- node_properties: dict = None,
697
+ node_properties: dict | None = None,
796
698
  Names: bool = False,
797
- ):
699
+ ) -> None:
798
700
  """Write a lineage tree into an understable tulip file.
799
701
 
800
- Args:
801
- fname (str): path to the tulip file to create
802
- t_min (int): minimum time to consider, default -1
803
- t_max (int): maximum time to consider, default np.inf
804
- nodes_to_use ([int, ]): list of nodes to show in the graph,
805
- default *None*, then self.nodes is used
806
- (taking into account *t_min* and *t_max*)
807
- temporal (bool): True if the temporal links should be printed, default True
808
- spatial (str): Build spatial edges from a spatial neighbourhood graph.
809
- The graph has to be computed before running this function
810
- 'ball': neighbours at a given distance,
811
- 'kn': k-nearest neighbours,
812
- 'GG': gabriel graph,
813
- None: no spatial edges are writen.
814
- Default None
815
- write_layout (bool): True, write the spatial position as layout,
816
- False, do not write spatial positionm
817
- default True
818
- node_properties ({`p_name`, [{id, p_value}, default]}): a dictionary of properties to write
819
- To a key representing the name of the property is
820
- paired a dictionary that maps a cell id to a property
821
- and a default value for this property
822
- Names (bool): Only works with ASTEC outputs, True to sort the cells by their names
702
+ Parameters
703
+ ----------
704
+ fname : str
705
+ path to the tulip file to create
706
+ t_min : int, default=-1
707
+ minimum time to consider
708
+ t_max : int, default=np.inf
709
+ maximum time to consider
710
+ nodes_to_use : list of int, optional
711
+ list of nodes to show in the graph,
712
+ if `None` then self.nodes is used
713
+ (taking into account `t_min` and `t_max`)
714
+ temporal : bool, default=True
715
+ True if the temporal links should be printed
716
+ spatial : str, optional
717
+ Build spatial edges from a spatial neighbourhood graph.
718
+ The graph has to be computed before running this function
719
+ 'ball': neighbours at a given distance,
720
+ 'kn': k-nearest neighbours,
721
+ 'GG': gabriel graph,
722
+ None: no spatial edges are writen.
723
+ Default None
724
+ write_layout : bool, default=True
725
+ write the spatial position as layout if True
726
+ do not write spatial position otherwise
727
+ node_properties : dict mapping str to list of dict of properties and its default value, optional
728
+ a dictionary of properties to write
729
+ To a key representing the name of the property is
730
+ paired a dictionary that maps a node id to a property
731
+ and a default value for this property
732
+ Names : bool, default=True
733
+ Only works with ASTEC outputs, True to sort the nodes by their names
823
734
  """
824
735
 
825
736
  def format_names(names_which_matter):
826
- """Return an ensured formated cell names"""
737
+ """Return an ensured formated node names"""
827
738
  tmp = {}
828
739
  for k, v in names_which_matter.items():
829
740
  tmp[k] = (
830
741
  v.split(".")[0][0]
831
- + "%02d" % int(v.split(".")[0][1:])
742
+ + "{:02d}".format(int(v.split(".")[0][1:]))
832
743
  + "."
833
- + "%04d" % int(v.split(".")[1][:-1])
744
+ + "{:04d}".format(int(v.split(".")[1][:-1]))
834
745
  + v.split(".")[1][-1]
835
746
  )
836
747
  return tmp
@@ -857,20 +768,20 @@ class lineageTree(lineageTreeLoaders):
857
768
  if not nodes_to_use:
858
769
  if t_max != np.inf or t_min > -1:
859
770
  nodes_to_use = [
860
- n for n in self.nodes if t_min < self.time[n] <= t_max
771
+ n for n in self.nodes if t_min < self._time[n] <= t_max
861
772
  ]
862
773
  edges_to_use = []
863
774
  if temporal:
864
775
  edges_to_use += [
865
776
  e
866
777
  for e in self.edges
867
- if t_min < self.time[e[0]] < t_max
778
+ if t_min < self._time[e[0]] < t_max
868
779
  ]
869
780
  if spatial:
870
781
  edges_to_use += [
871
782
  e
872
783
  for e in s_edges
873
- if t_min < self.time[e[0]] < t_max
784
+ if t_min < self._time[e[0]] < t_max
874
785
  ]
875
786
  else:
876
787
  nodes_to_use = list(self.nodes)
@@ -884,12 +795,12 @@ class lineageTree(lineageTreeLoaders):
884
795
  nodes_to_use = set(nodes_to_use)
885
796
  if temporal:
886
797
  for n in nodes_to_use:
887
- for d in self.successor.get(n, []):
798
+ for d in self._successor[n]:
888
799
  if d in nodes_to_use:
889
800
  edges_to_use.append((n, d))
890
801
  if spatial:
891
802
  edges_to_use += [
892
- e for e in s_edges if t_min < self.time[e[0]] < t_max
803
+ e for e in s_edges if t_min < self._time[e[0]] < t_max
893
804
  ]
894
805
  nodes_to_use = set(nodes_to_use)
895
806
  if Names:
@@ -907,12 +818,12 @@ class lineageTree(lineageTreeLoaders):
907
818
  for k, v in node_properties[Names][0].items():
908
819
  if (
909
820
  len(
910
- self.successor.get(
911
- self.predecessor.get(k, [-1])[0], []
821
+ self._successor.get(
822
+ self._predecessor.get(k, [-1])[0], ()
912
823
  )
913
824
  )
914
825
  != 1
915
- or self.time[k] == t_min + 1
826
+ or self._time[k] == t_min + 1
916
827
  ):
917
828
  tmp_names[k] = v
918
829
  node_properties[Names][0] = tmp_names
@@ -942,7 +853,7 @@ class lineageTree(lineageTreeLoaders):
942
853
  f.write('\t(default "0" "0")\n')
943
854
  for n in nodes_to_use:
944
855
  f.write(
945
- "\t(node " + str(n) + ' "' + str(self.time[n]) + '")\n'
856
+ "\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
946
857
  )
947
858
  f.write(")\n")
948
859
 
@@ -989,7 +900,9 @@ class lineageTree(lineageTreeLoaders):
989
900
  f.write(")")
990
901
  f.close()
991
902
 
992
- def to_binary(self, fname: str, starting_points: list = None):
903
+ def to_binary(
904
+ self, fname: str, starting_points: list[int] | None = None
905
+ ) -> None:
993
906
  """Writes the lineage tree (a forest) as a binary structure
994
907
  (assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
995
908
  The binary file is composed of 3 sequences of numbers and
@@ -1004,36 +917,37 @@ class lineageTree(lineageTreeLoaders):
1004
917
  The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
1005
918
  The *pos_sequence* is stored as a list of double.
1006
919
 
1007
- Args:
1008
- fname (str): name of the binary file
1009
- starting_points ([int, ]): list of the roots to be written.
1010
- If None, all roots are written, default value, None
920
+ Parameters
921
+ ----------
922
+ fname : str
923
+ name of the binary file
924
+ starting_points : list of int, optional
925
+ list of the roots to be written.
926
+ If `None`, all roots are written, default value, None
1011
927
  """
1012
928
  if starting_points is None:
1013
- starting_points = [
1014
- c for c in self.successor if self.predecessor.get(c, []) == []
1015
- ]
929
+ starting_points = list(self.roots)
1016
930
  number_sequence = [-1]
1017
931
  pos_sequence = []
1018
932
  time_sequence = []
1019
933
  for c in starting_points:
1020
- time_sequence.append(self.time.get(c, 0))
934
+ time_sequence.append(self._time.get(c, 0))
1021
935
  to_treat = [c]
1022
- while to_treat != []:
936
+ while to_treat:
1023
937
  curr_c = to_treat.pop()
1024
938
  number_sequence.append(curr_c)
1025
939
  pos_sequence += list(self.pos[curr_c])
1026
- if self[curr_c] == []:
940
+ if self._successor[curr_c] == ():
1027
941
  number_sequence.append(-1)
1028
- elif len(self.successor[curr_c]) == 1:
1029
- to_treat += self.successor[curr_c]
942
+ elif len(self._successor[curr_c]) == 1:
943
+ to_treat += self._successor[curr_c]
1030
944
  else:
1031
945
  number_sequence.append(-2)
1032
- to_treat += self.successor[curr_c]
946
+ to_treat += self._successor[curr_c]
1033
947
  remaining_nodes = set(self.nodes) - set(number_sequence)
1034
948
 
1035
949
  for c in remaining_nodes:
1036
- time_sequence.append(self.time.get(c, 0))
950
+ time_sequence.append(self._time.get(c, 0))
1037
951
  number_sequence.append(c)
1038
952
  pos_sequence += list(self.pos[c])
1039
953
  number_sequence.append(-1)
@@ -1048,50 +962,269 @@ class lineageTree(lineageTreeLoaders):
1048
962
 
1049
963
  f.close()
1050
964
 
1051
- def write(self, fname: str):
1052
- """
1053
- Write a lineage tree on disk as an .lT file.
965
+ def write(self, fname: str) -> None:
966
+ """Write a lineage tree on disk as an .lT file.
1054
967
 
1055
- Args:
1056
- fname (str): path to and name of the file to save
968
+ Parameters
969
+ ----------
970
+ fname : str
971
+ path to and name of the file to save
1057
972
  """
1058
- if os.path.splitext(fname)[-1] != ".lT":
973
+ if os.path.splitext(fname)[-1].upper() != ".LT":
1059
974
  fname = os.path.extsep.join((fname, "lT"))
975
+ if hasattr(self, "_protected_predecessor"):
976
+ del self._protected_predecessor
977
+ if hasattr(self, "_protected_successor"):
978
+ del self._protected_successor
979
+ if hasattr(self, "_protected_time"):
980
+ del self._protected_time
1060
981
  with open(fname, "bw") as f:
1061
982
  pkl.dump(self, f)
1062
983
  f.close()
1063
984
 
1064
985
  @classmethod
1065
- def load(clf, fname: str, rm_empty_lists=False):
1066
- """
1067
- Loading a lineage tree from a ".lT" file.
986
+ def load(clf, fname: str):
987
+ """Loading a lineage tree from a '.lT' file.
1068
988
 
1069
- Args:
1070
- fname (str): path to and name of the file to read
989
+ Parameters
990
+ ----------
991
+ fname : str
992
+ path to and name of the file to read
1071
993
 
1072
- Returns:
1073
- (lineageTree): loaded file
994
+ Returns
995
+ -------
996
+ LineageTree
997
+ loaded file
1074
998
  """
1075
999
  with open(fname, "br") as f:
1076
1000
  lT = pkl.load(f)
1077
1001
  f.close()
1002
+ if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
1003
+ "2.0.0"
1004
+ ):
1005
+ properties = {
1006
+ prop_name: prop
1007
+ for prop_name, prop in lT.__dict__.items()
1008
+ if (isinstance(prop, dict) or prop_name == "_time_resolution")
1009
+ and prop_name
1010
+ not in [
1011
+ "successor",
1012
+ "predecessor",
1013
+ "time",
1014
+ "_successor",
1015
+ "_predecessor",
1016
+ "_time",
1017
+ "pos",
1018
+ "labels",
1019
+ ]
1020
+ + lineageTree._dynamic_properties
1021
+ + lineageTree._protected_dynamic_properties
1022
+ }
1023
+ lT = lineageTree(
1024
+ successor=lT._successor,
1025
+ time=lT._time,
1026
+ pos=lT.pos,
1027
+ name=lT.name if hasattr(lT, "name") else None,
1028
+ **properties,
1029
+ )
1078
1030
  if not hasattr(lT, "time_resolution"):
1079
- lT.time_resolution = None
1031
+ lT.time_resolution = 1
1032
+
1080
1033
  return lT
1081
1034
 
1082
- def get_idx3d(self, t: int) -> tuple:
1083
- """Get a 3d kdtree for the dataset at time *t* .
1084
- The kdtree is stored in *self.kdtrees[t]*
1085
-
1086
- Args:
1087
- t (int): time
1088
- Returns:
1089
- (kdtree, [int, ]): the built kdtree and
1090
- the correspondancy list,
1091
- If the query in the kdtree gives you the value i,
1092
- then it corresponds to the id in the tree to_check_self[i]
1035
+ def get_predecessors(
1036
+ self,
1037
+ x: int,
1038
+ depth: int | None = None,
1039
+ start_time: int | None = None,
1040
+ end_time: int | None = None,
1041
+ ) -> list[int]:
1042
+ """Computes the predecessors of the node `x` up to
1043
+ `depth` predecessors or the begining of the life of `x`.
1044
+ The ordered list of ids is returned.
1045
+
1046
+ Parameters
1047
+ ----------
1048
+ x : int
1049
+ id of the node to compute
1050
+ depth : int
1051
+ maximum number of predecessors to return
1052
+
1053
+ Returns
1054
+ -------
1055
+ list of int
1056
+ list of ids, the last id is `x`
1057
+ """
1058
+ if not start_time:
1059
+ start_time = self.t_b
1060
+ if not end_time:
1061
+ end_time = self.t_e
1062
+ unconstrained_chain = [x]
1063
+ chain = [x] if start_time <= self._time[x] <= end_time else []
1064
+ acc = 0
1065
+ while (
1066
+ acc != depth
1067
+ and start_time < self._time[unconstrained_chain[0]]
1068
+ and (
1069
+ self._predecessor[unconstrained_chain[0]] != ()
1070
+ and ( # Please dont change very important even if it looks weird.
1071
+ len(
1072
+ self._successor[
1073
+ self._predecessor[unconstrained_chain[0]][0]
1074
+ ]
1075
+ )
1076
+ == 1
1077
+ )
1078
+ )
1079
+ ):
1080
+ unconstrained_chain.insert(
1081
+ 0, self._predecessor[unconstrained_chain[0]][0]
1082
+ )
1083
+ acc += 1
1084
+ if start_time <= self._time[unconstrained_chain[0]] <= end_time:
1085
+ chain.insert(0, unconstrained_chain[0])
1086
+
1087
+ return chain
1088
+
1089
+ def get_successors(
1090
+ self, x: int, depth: int | None = None, end_time: int | None = None
1091
+ ) -> list[int]:
1092
+ """Computes the successors of the node `x` up to
1093
+ `depth` successors or the end of the life of `x`.
1094
+ The ordered list of ids is returned.
1095
+
1096
+ Parameters
1097
+ ----------
1098
+ x : int
1099
+ id of the node to compute
1100
+ depth : int, optional
1101
+ maximum number of predecessors to return
1102
+ end_time : int, optional
1103
+ maximum time to consider
1104
+
1105
+ Returns
1106
+ -------
1107
+ list of int
1108
+ list of ids, the first id is `x`
1109
+ """
1110
+ if end_time is None:
1111
+ end_time = self.t_e
1112
+ chain = [x]
1113
+ acc = 0
1114
+ while (
1115
+ len(self._successor[chain[-1]]) == 1
1116
+ and acc != depth
1117
+ and self._time[chain[-1]] < end_time
1118
+ ):
1119
+ chain += self._successor[chain[-1]]
1120
+ acc += 1
1121
+
1122
+ return chain
1123
+
1124
+ def get_chain_of_node(
1125
+ self,
1126
+ x: int,
1127
+ depth: int | None = None,
1128
+ depth_pred: int | None = None,
1129
+ depth_succ: int | None = None,
1130
+ end_time: int | None = None,
1131
+ ) -> list[int]:
1132
+ """Computes the predecessors and successors of the node `x` up to
1133
+ `depth_pred` predecessors plus `depth_succ` successors.
1134
+ If the value `depth` is provided and not None,
1135
+ `depth_pred` and `depth_succ` are overwriten by `depth`.
1136
+ The ordered list of ids is returned.
1137
+ If all `depth` are None, the full chain is returned.
1138
+
1139
+ Parameters
1140
+ ----------
1141
+ x : int
1142
+ id of the node to compute
1143
+ depth : int, optional
1144
+ maximum number of predecessors and successor to return
1145
+ depth_pred : int, optional
1146
+ maximum number of predecessors to return
1147
+ depth_succ : int, optional
1148
+ maximum number of successors to return
1149
+
1150
+ Returns
1151
+ -------
1152
+ list of int
1153
+ list of node ids
1093
1154
  """
1094
- to_check_self = list(self.time_nodes[t])
1155
+ if end_time is None:
1156
+ end_time = self.t_e
1157
+ if depth is not None:
1158
+ depth_pred = depth_succ = depth
1159
+ return self.get_predecessors(x, depth_pred, end_time=end_time)[
1160
+ :-1
1161
+ ] + self.get_successors(x, depth_succ, end_time=end_time)
1162
+
1163
+ @dynamic_property
1164
+ def all_chains(self) -> list[list[int]]:
1165
+ """List of all chains in the tree, ordered in depth-first search."""
1166
+ return self._compute_all_chains()
1167
+
1168
+ @dynamic_property
1169
+ def time_nodes(self):
1170
+ _time_nodes = {}
1171
+ for c, t in self._time.items():
1172
+ _time_nodes.setdefault(t, set()).add(c)
1173
+ return _time_nodes
1174
+
1175
+ def m(self, i, j):
1176
+ if (i, j) not in self._tmp_parenting:
1177
+ if i == j: # the distance to the node itself is 0
1178
+ self._tmp_parenting[(i, j)] = 0
1179
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1180
+ elif not self._predecessor[
1181
+ j
1182
+ ]: # j and i are note connected so the distance if inf
1183
+ self._tmp_parenting[(i, j)] = np.inf
1184
+ else: # the distance between i and j is the distance between i and pred(j) + 1
1185
+ self._tmp_parenting[(i, j)] = (
1186
+ self.m(i, self._predecessor[j][0]) + 1
1187
+ )
1188
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1189
+ self._parenting[j, i] = -self._tmp_parenting[(i, j)]
1190
+ return self._tmp_parenting[(i, j)]
1191
+
1192
+ @property
1193
+ def parenting(self):
1194
+ if not hasattr(self, "_parenting"):
1195
+ self._parenting = dok_array((max(self.nodes) + 1,) * 2)
1196
+ self._tmp_parenting = {}
1197
+ for i, j in combinations(self.nodes, 2):
1198
+ if self._time[j] < self.time[i]:
1199
+ i, j = j, i
1200
+ self._tmp_parenting[(i, j)] = self.m(i, j)
1201
+ del self._tmp_parenting
1202
+ return self._parenting
1203
+
1204
+ def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
1205
+ """Get a 3d kdtree for the dataset at time `t`.
1206
+ The kdtree is stored in `self.kdtrees[t]` and returned.
1207
+ The correspondancy list is also returned.
1208
+
1209
+ Parameters
1210
+ ----------
1211
+ t : int
1212
+ time
1213
+
1214
+ Returns
1215
+ -------
1216
+ KDTree
1217
+ The KDTree corresponding to the lineage tree at time `t`
1218
+ np.ndarray
1219
+ The correspondancy list in the KDTree.
1220
+ If the query in the kdtree gives you the value `i`,
1221
+ then it corresponds to the id in the tree `to_check_self[i]`
1222
+ """
1223
+ to_check_self = list(self.nodes_at_t(t=t))
1224
+
1225
+ if not hasattr(self, "kdtrees"):
1226
+ self.kdtrees = {}
1227
+
1095
1228
  if t not in self.kdtrees:
1096
1229
  data_corres = {}
1097
1230
  data = []
@@ -1104,16 +1237,21 @@ class lineageTree(lineageTreeLoaders):
1104
1237
  idx3d = self.kdtrees[t]
1105
1238
  return idx3d, np.array(to_check_self)
1106
1239
 
1107
- def get_gabriel_graph(self, t: int) -> dict:
1108
- """Build the Gabriel graph of the given graph for time point `t`
1109
- The Garbiel graph is then stored in self.Gabriel_graph and returned
1110
- *WARNING: the graph is not recomputed if already computed. even if nodes were added*.
1240
+ def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
1241
+ """Build the Gabriel graph of the given graph for time point `t`.
1242
+ The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
1243
+
1244
+ .. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
1111
1245
 
1112
- Args:
1113
- t (int): time
1114
- Returns:
1115
- {int, set([int, ])}: a dictionary that maps a node to
1116
- the set of its neighbors
1246
+ Parameters
1247
+ ----------
1248
+ t : int
1249
+ time
1250
+
1251
+ Returns
1252
+ -------
1253
+ dict of int to set of int
1254
+ A dictionary that maps a node to the set of its neighbors
1117
1255
  """
1118
1256
  if not hasattr(self, "Gabriel_graph"):
1119
1257
  self.Gabriel_graph = {}
@@ -1158,178 +1296,110 @@ class lineageTree(lineageTreeLoaders):
1158
1296
 
1159
1297
  return self.Gabriel_graph[t]
1160
1298
 
1161
- def get_predecessors(
1162
- self, x: int, depth: int = None, start_time: int = None, end_time=None
1163
- ) -> list:
1164
- """Computes the predecessors of the node `x` up to
1165
- `depth` predecessors or the begining of the life of `x`.
1166
- The ordered list of ids is returned.
1299
+ def get_all_chains_of_subtree(
1300
+ self, node: int, end_time: int | None = None
1301
+ ) -> list[list[int]]:
1302
+ """Computes all the chains of the subtree spawn by a given node.
1303
+ Similar to get_all_chains().
1167
1304
 
1168
- Args:
1169
- x (int): id of the node to compute
1170
- depth (int): maximum number of predecessors to return
1171
- Returns:
1172
- [int, ]: list of ids, the last id is `x`
1173
- """
1174
- if not start_time:
1175
- start_time = self.t_b
1176
- if not end_time:
1177
- end_time = self.t_e
1178
- unconstrained_cycle = [x]
1179
- cycle = [x] if start_time <= self.time[x] <= end_time else []
1180
- acc = 0
1181
- while (
1182
- len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
1183
- == 1
1184
- and acc != depth
1185
- and start_time
1186
- <= self.time.get(
1187
- self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
1188
- )
1189
- ):
1190
- unconstrained_cycle.insert(
1191
- 0, self.predecessor[unconstrained_cycle[0]][0]
1192
- )
1193
- acc += 1
1194
- if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
1195
- cycle.insert(0, unconstrained_cycle[0])
1196
-
1197
- return cycle
1198
-
1199
- def get_successors(
1200
- self, x: int, depth: int = None, end_time: int = None
1201
- ) -> list:
1202
- """Computes the successors of the node `x` up to
1203
- `depth` successors or the end of the life of `x`.
1204
- The ordered list of ids is returned.
1305
+ Parameters
1306
+ ----------
1307
+ node : int
1308
+ The node from which we want to get its chains.
1309
+ end_time : int, optional
1310
+ The time at which we want to stop the chains.
1205
1311
 
1206
- Args:
1207
- x (int): id of the node to compute
1208
- depth (int): maximum number of predecessors to return
1209
- Returns:
1210
- [int, ]: list of ids, the first id is `x`
1211
- """
1212
- if end_time is None:
1213
- end_time = self.t_e
1214
- cycle = [x]
1215
- acc = 0
1216
- while (
1217
- len(self[cycle[-1]]) == 1
1218
- and acc != depth
1219
- and self.time[cycle[-1]] < end_time
1220
- ):
1221
- cycle += self.successor[cycle[-1]]
1222
- acc += 1
1223
-
1224
- return cycle
1225
-
1226
- def get_cycle(
1227
- self,
1228
- x: int,
1229
- depth: int = None,
1230
- depth_pred: int = None,
1231
- depth_succ: int = None,
1232
- end_time: int = None,
1233
- ) -> list:
1234
- """Computes the predecessors and successors of the node `x` up to
1235
- `depth_pred` predecessors plus `depth_succ` successors.
1236
- If the value `depth` is provided and not None,
1237
- `depth_pred` and `depth_succ` are overwriten by `depth`.
1238
- The ordered list of ids is returned.
1239
- If all `depth` are None, the full cycle is returned.
1240
-
1241
- Args:
1242
- x (int): id of the node to compute
1243
- depth (int): maximum number of predecessors and successor to return
1244
- depth_pred (int): maximum number of predecessors to return
1245
- depth_succ (int): maximum number of successors to return
1246
- Returns:
1247
- [int, ]: list of ids
1248
- """
1249
- if end_time is None:
1250
- end_time = self.t_e
1251
- if depth is not None:
1252
- depth_pred = depth_succ = depth
1253
- return self.get_predecessors(x, depth_pred, end_time=end_time)[
1254
- :-1
1255
- ] + self.get_successors(x, depth_succ, end_time=end_time)
1256
-
1257
- @property
1258
- def all_tracks(self):
1259
- if not hasattr(self, "_all_tracks"):
1260
- self._all_tracks = self.get_all_tracks()
1261
- return self._all_tracks
1262
-
1263
- def get_all_branches_of_node(
1264
- self, node: int, end_time: int = None
1265
- ) -> list:
1266
- """Computes all the tracks of the subtree spawn by a given node.
1267
- Similar to get_all_tracks().
1268
-
1269
- Args:
1270
- node (int, optional): The node that we want to get its branches.
1271
-
1272
- Returns:
1273
- ([[int, ...], ...]): list of lists containing track cell ids
1312
+ Returns
1313
+ -------
1314
+ list of list of int
1315
+ list of chains
1274
1316
  """
1275
1317
  if not end_time:
1276
1318
  end_time = self.t_e
1277
- branches = [self.get_successors(node)]
1278
- to_do = list(self[branches[0][-1]])
1319
+ chains = [self.get_successors(node)]
1320
+ to_do = list(self._successor[chains[0][-1]])
1279
1321
  while to_do:
1280
1322
  current = to_do.pop()
1281
- track = self.get_successors(current, end_time=end_time)
1282
- # if len(track) != 1 or self.time[current] <= end_time:
1283
- if self.time[track[-1]] <= end_time:
1284
- branches += [track]
1285
- to_do += self[track[-1]]
1286
- return branches
1287
-
1288
- def get_all_tracks(self, force_recompute: bool = False) -> list:
1289
- """Computes all the tracks of a given lineage tree,
1290
- stores it in `self.all_tracks` and returns it.
1291
-
1292
- Returns:
1293
- ([[int, ...], ...]): list of lists containing track cell ids
1294
- """
1295
- if not hasattr(self, "_all_tracks") or force_recompute:
1296
- self._all_tracks = []
1297
- to_do = list(self.roots)
1298
- while len(to_do) != 0:
1299
- current = to_do.pop()
1300
- track = self.get_cycle(current)
1301
- self._all_tracks += [track]
1302
- to_do.extend(self[track[-1]])
1303
- return self._all_tracks
1304
-
1305
- def get_tracks(self, roots: list = None) -> list:
1306
- """Computes the tracks given by the list of nodes `roots` and returns it.
1307
-
1308
- Args:
1309
- roots (list): list of ids of the roots to be computed
1310
- Returns:
1311
- ([[int, ...], ...]): list of lists containing track cell ids
1312
- """
1313
- if roots is None:
1314
- return self.get_all_tracks(force_recompute=True)
1315
- else:
1316
- tracks = []
1317
- to_do = list(roots)
1318
- while len(to_do) != 0:
1319
- current = to_do.pop()
1320
- track = self.get_cycle(current)
1321
- tracks.append(track)
1322
- to_do.extend(self[track[-1]])
1323
- return tracks
1324
-
1325
- def find_leaves(self, roots: Union[int, set, list, tuple]):
1323
+ chain = self.get_successors(current, end_time=end_time)
1324
+ if self._time[chain[-1]] <= end_time:
1325
+ chains += [chain]
1326
+ to_do += self._successor[chain[-1]]
1327
+ return chains
1328
+
1329
+ def _compute_all_chains(self) -> list[list[int]]:
1330
+ """Computes all the chains of a given lineage tree,
1331
+ stores it in `self.all_chains` and returns it.
1332
+
1333
+ Returns
1334
+ -------
1335
+ list of list of int
1336
+ list of chains
1337
+ """
1338
+ all_chains = []
1339
+ to_do = sorted(self.roots, key=self.time.get, reverse=True)
1340
+ while len(to_do) != 0:
1341
+ current = to_do.pop()
1342
+ chain = self.get_chain_of_node(current)
1343
+ all_chains += [chain]
1344
+ to_do.extend(self._successor[chain[-1]])
1345
+ return all_chains
1346
+
1347
+ def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
1348
+ self, nodes: Iterable | int | None = None
1349
+ ) -> dict[int, list[list[int]]]:
1350
+ """Returns all the chains in the subtrees spawned by each of the given nodes.
1351
+
1352
+ Parameters
1353
+ ----------
1354
+ nodes : Iterable or int, optional
1355
+ id or Iterable of ids of the nodes to be computed, if `None` all roots are used
1356
+
1357
+ Returns
1358
+ -------
1359
+ dict mapping int to list of Chain
1360
+ dictionary mapping the node ids to a list of chains
1361
+ """
1362
+ all_chains = self.all_chains
1363
+ if nodes is None:
1364
+ nodes = self.roots
1365
+ if not isinstance(nodes, Iterable):
1366
+ nodes = [nodes]
1367
+ output_chains = {}
1368
+ for n in nodes:
1369
+ starting_node = self.get_predecessors(n)[0]
1370
+ found = False
1371
+ done = False
1372
+ starting_time = self.time[n]
1373
+ i = 0
1374
+ current_chain = []
1375
+ while not done and i < len(all_chains):
1376
+ curr_found = all_chains[i][0] == starting_node
1377
+ found = found or curr_found
1378
+ if found:
1379
+ done = (
1380
+ self.time[all_chains[i][0]] <= starting_time
1381
+ ) and not curr_found
1382
+ if not done:
1383
+ if curr_found:
1384
+ current_chain.append(self.get_successors(n))
1385
+ else:
1386
+ current_chain.append(all_chains[i])
1387
+ i += 1
1388
+ output_chains[n] = current_chain
1389
+ return output_chains
1390
+
1391
+ def find_leaves(self, roots: int | Iterable) -> set[int]:
1326
1392
  """Finds the leaves of a tree spawned by one or more nodes.
1327
1393
 
1328
- Args:
1329
- roots (Union[int,set,list,tuple]): The roots of the trees.
1394
+ Parameters
1395
+ ----------
1396
+ roots : int or Iterable
1397
+ The roots of the trees spawning the leaves
1330
1398
 
1331
- Returns:
1332
- set: The leaves of one or more trees.
1399
+ Returns
1400
+ -------
1401
+ set
1402
+ The leaves of one or more trees.
1333
1403
  """
1334
1404
  if not isinstance(roots, Iterable):
1335
1405
  to_do = [roots]
@@ -1338,28 +1408,34 @@ class lineageTree(lineageTreeLoaders):
1338
1408
  leaves = set()
1339
1409
  while to_do:
1340
1410
  curr = to_do.pop()
1341
- succ = self.successor.get(curr, [])
1411
+ succ = self._successor[curr]
1342
1412
  if not succ:
1343
1413
  leaves.add(curr)
1344
1414
  to_do += succ
1345
1415
  return leaves
1346
1416
 
1347
- def get_sub_tree(
1417
+ def get_subtree_nodes(
1348
1418
  self,
1349
- x: Union[int, Iterable],
1350
- end_time: Union[int, None] = None,
1419
+ x: int | Iterable,
1420
+ end_time: int | None = None,
1351
1421
  preorder: bool = False,
1352
- ) -> list:
1353
- """Computes the list of cells from the subtree spawned by *x*
1354
- The default output order is breadth first traversal.
1422
+ ) -> list[int]:
1423
+ """Computes the list of nodes from the subtree spawned by *x*
1424
+ The default output order is Breadth First Traversal.
1355
1425
  Unless preorder is `True` in that case the order is
1356
- Depth first traversal preordered.
1426
+ Depth First Traversal (DFT) preordered.
1427
+
1428
+ Parameters
1429
+ ----------
1430
+ x : int
1431
+ id of root node
1432
+ preorder : bool, default=False
1433
+ if True the output preorder is DFT
1357
1434
 
1358
- Args:
1359
- x (int): id of root node
1360
- preorder (bool): if True the output is preorder DFT
1361
- Returns:
1362
- ([int, ...]): the ordered list of node ids
1435
+ Returns
1436
+ -------
1437
+ list of int
1438
+ the ordered list of node ids
1363
1439
  """
1364
1440
  if not end_time:
1365
1441
  end_time = self.t_e
@@ -1367,233 +1443,258 @@ class lineageTree(lineageTreeLoaders):
1367
1443
  to_do = [x]
1368
1444
  elif isinstance(x, Iterable):
1369
1445
  to_do = list(x)
1370
- sub_tree = []
1446
+ subtree = []
1371
1447
  while to_do:
1372
1448
  curr = to_do.pop()
1373
- succ = self.successor.get(curr, [])
1374
- if succ and end_time < self.time.get(curr, end_time):
1449
+ succ = self._successor[curr]
1450
+ if succ and end_time < self._time.get(curr, end_time):
1375
1451
  succ = []
1376
1452
  continue
1377
1453
  if preorder:
1378
1454
  to_do = succ + to_do
1379
1455
  else:
1380
1456
  to_do += succ
1381
- sub_tree += [curr]
1382
- return sub_tree
1457
+ subtree += [curr]
1458
+ return subtree
1383
1459
 
1384
1460
  def compute_spatial_density(
1385
- self, t_b: int = None, t_e: int = None, th: float = 50
1386
- ) -> dict:
1387
- """Computes the spatial density of cells between `t_b` and `t_e`.
1388
- The spatial density is computed as follow:
1389
- #cell/(4/3*pi*th^3)
1390
- The results is stored in self.spatial_density is returned.
1391
-
1392
- Args:
1393
- t_b (int): starting time to look at, default first time point
1394
- t_e (int): ending time to look at, default last time point
1395
- th (float): size of the neighbourhood
1396
- Returns:
1397
- {int, float}: dictionary that maps a cell id to its spatial density
1398
- """
1461
+ self, t_b: int | None = None, t_e: int | None = None, th: float = 50
1462
+ ) -> dict[int, float]:
1463
+ """Computes the spatial density of nodes between `t_b` and `t_e`.
1464
+ The results is stored in `self.spatial_density` and returned.
1465
+
1466
+ Parameters
1467
+ ----------
1468
+ t_b : int, optional
1469
+ starting time to look at, default first time point
1470
+ t_e : int, optional
1471
+ ending time to look at, default last time point
1472
+ th : float, default=50
1473
+ size of the neighbourhood
1474
+
1475
+ Returns
1476
+ -------
1477
+ dict mapping int to float
1478
+ dictionary that maps a node id to its spatial density
1479
+ """
1480
+ if not hasattr(self, "spatial_density"):
1481
+ self.spatial_density = {}
1399
1482
  s_vol = 4 / 3.0 * np.pi * th**3
1400
- time_range = set(range(t_b, t_e + 1)).intersection(self.time_nodes)
1483
+ if t_b is None:
1484
+ t_b = self.t_b
1485
+ if t_e is None:
1486
+ t_e = self.t_e
1487
+ time_range = set(range(t_b, t_e)).intersection(self._time.values())
1401
1488
  for t in time_range:
1402
1489
  idx3d, nodes = self.get_idx3d(t)
1403
1490
  nb_ni = [
1404
1491
  (len(ni) - 1) / s_vol
1405
1492
  for ni in idx3d.query_ball_tree(idx3d, th)
1406
1493
  ]
1407
- self.spatial_density.update(dict(zip(nodes, nb_ni)))
1494
+ self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
1408
1495
  return self.spatial_density
1409
1496
 
1410
- def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
1497
+ def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
1411
1498
  """Computes the k-nearest neighbors
1412
1499
  Writes the output in the attribute `kn_graph`
1413
1500
  and returns it.
1414
1501
 
1415
- Args:
1416
- k (float): number of nearest neighours
1417
- Returns:
1418
- {int, set([int, ...])}: dictionary that maps
1419
- a cell id to its `k` nearest neighbors
1502
+ Parameters
1503
+ ----------
1504
+ k : float
1505
+ number of nearest neighours
1506
+
1507
+ Returns
1508
+ -------
1509
+ dict mapping int to set of int
1510
+ dictionary that maps
1511
+ a node id to its `k` nearest neighbors
1420
1512
  """
1421
1513
  self.kn_graph = {}
1422
- for t, nodes in self.time_nodes.items():
1423
- use_k = k if k < len(nodes) else len(nodes)
1424
- idx3d, nodes = self.get_idx3d(t)
1425
- pos = [self.pos[c] for c in nodes]
1426
- _, neighbs = idx3d.query(pos, use_k)
1427
- out = dict(zip(nodes, [set(nodes[ni[1:]]) for ni in neighbs]))
1428
- self.kn_graph.update(out)
1514
+ for t in set(self._time.values()):
1515
+ nodes = self.nodes_at_t(t)
1516
+ if 1 < len(nodes):
1517
+ use_k = k if k < len(nodes) else len(nodes)
1518
+ idx3d, nodes = self.get_idx3d(t)
1519
+ pos = [self.pos[c] for c in nodes]
1520
+ _, neighbs = idx3d.query(pos, use_k)
1521
+ out = dict(
1522
+ zip(
1523
+ nodes,
1524
+ map(set, nodes[neighbs]),
1525
+ strict=True,
1526
+ )
1527
+ )
1528
+ self.kn_graph.update(out)
1529
+ else:
1530
+ n = nodes.pop
1531
+ self.kn_graph.update({n: {n}})
1429
1532
  return self.kn_graph
1430
1533
 
1431
- def compute_spatial_edges(self, th: int = 50) -> dict:
1534
+ def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
1432
1535
  """Computes the neighbors at a distance `th`
1433
1536
  Writes the output in the attribute `th_edge`
1434
1537
  and returns it.
1435
1538
 
1436
- Args:
1437
- th (float): distance to consider neighbors
1438
- Returns:
1439
- {int, set([int, ...])}: dictionary that maps
1440
- a cell id to its neighbors at a distance `th`
1539
+ Parameters
1540
+ ----------
1541
+ th : float, default=50
1542
+ distance to consider neighbors
1543
+
1544
+ Returns
1545
+ -------
1546
+ dict mapping int to set of int
1547
+ dictionary that maps a node id to its neighbors at a distance `th`
1441
1548
  """
1442
1549
  self.th_edges = {}
1443
- for t, _ in self.time_nodes.items():
1550
+ for t in set(self._time.values()):
1551
+ nodes = self.nodes_at_t(t)
1444
1552
  idx3d, nodes = self.get_idx3d(t)
1445
1553
  neighbs = idx3d.query_ball_tree(idx3d, th)
1446
- out = dict(zip(nodes, [set(nodes[ni]) for ni in neighbs]))
1554
+ out = dict(
1555
+ zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
1556
+ )
1447
1557
  self.th_edges.update(
1448
1558
  {k: v.difference([k]) for k, v in out.items()}
1449
1559
  )
1450
1560
  return self.th_edges
1451
1561
 
1452
- def main_axes(self, time: int = None):
1453
- """Finds the main axes for a timepoint.
1454
- If none will select the timepoint with the highest amound of cells.
1455
-
1456
- Args:
1457
- time (int, optional): The timepoint to find the main axes.
1458
- If None will find the timepoint
1459
- with the largest number of cells.
1460
-
1461
- Returns:
1462
- list: A list that contains the array of eigenvalues and eigenvectors.
1463
- """
1464
- if time is None:
1465
- time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
1466
- pos = np.array([self.pos[node] for node in self.time_nodes[time]])
1467
- pos = pos - np.mean(pos, axis=0)
1468
- cov = np.cov(np.array(pos).T)
1469
- eig_val, eig_vec = np.linalg.eig(cov)
1470
- srt = np.argsort(eig_val)[::-1]
1471
- self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
1472
- return eig_val[srt], eig_vec[:, srt]
1473
-
1474
- def scale_embryo(self, scale=1000):
1475
- """Scale the embryo using their eigenvalues.
1476
-
1477
- Args:
1478
- scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
1479
-
1480
- Returns:
1481
- float: The scale factor.
1482
- """
1483
- eig = self.main_axes()[0]
1484
- return scale / (np.sqrt(eig[0]))
1485
-
1486
- @staticmethod
1487
- def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
1488
- """Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
1489
- Uses the Rodrigues rotation formula.
1490
-
1491
- Args:
1492
- vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
1493
- vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
1494
-
1495
- Returns:
1496
- np.array: The rotation matrix.
1497
- """
1498
- vector1 = vector1 / np.linalg.norm(vector1)
1499
- vector2 = vector2 / np.linalg.norm(vector2)
1500
- if vector1 @ vector2 == 1:
1501
- return np.eye(3)
1502
- angle = np.arccos(vector1 @ vector2)
1503
- axis = np.cross(vector1, vector2)
1504
- axis = axis / np.linalg.norm(axis)
1505
- K = np.array(
1506
- [
1507
- [0, -axis[2], axis[1]],
1508
- [axis[2], 0, -axis[0]],
1509
- [-axis[1], axis[0], 0],
1510
- ]
1511
- )
1512
- return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
1513
-
1514
- def get_ancestor_at_t(self, n: int, time: int = None):
1515
- """
1516
- Find the id of the ancestor of a give node `n`
1562
+ def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
1563
+ """Find the id of the ancestor of a give node `n`
1517
1564
  at a given time `time`.
1518
1565
 
1519
- If there is no ancestor, returns `-1`
1520
- If time is None return the root of the sub tree that spawns
1566
+ If there is no ancestor, returns `None`
1567
+ If time is None return the root of the subtree that spawns
1521
1568
  the node n.
1522
1569
 
1523
- Args:
1524
- n (int): node for which to look the ancestor
1525
- time (int): time at which the ancestor has to be found.
1526
- If `None` the ancestor at the first time point
1527
- will be found (default `None`)
1528
-
1529
- Returns:
1530
- (int): the id of the ancestor at time `time`,
1531
- `-1` if it does not exist
1570
+ Parameters
1571
+ ----------
1572
+ n : int
1573
+ node for which to look the ancestor
1574
+ time : int, optional
1575
+ time at which the ancestor has to be found.
1576
+ If `None` the ancestor at the first time point
1577
+ will be found.
1578
+
1579
+ Returns
1580
+ -------
1581
+ int
1582
+ the id of the ancestor at time `time`,
1583
+ `-1` if there is no ancestor.
1532
1584
  """
1533
1585
  if n not in self.nodes:
1534
- return
1586
+ return -1
1535
1587
  if time is None:
1536
1588
  time = self.t_b
1537
1589
  ancestor = n
1538
1590
  while (
1539
- time < self.time.get(ancestor, -1) and ancestor in self.predecessor
1591
+ time < self._time.get(ancestor, self.t_b - 1)
1592
+ and self._predecessor[ancestor]
1540
1593
  ):
1541
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1542
- return ancestor
1594
+ ancestor = self._predecessor[ancestor][0]
1595
+ if self._time.get(ancestor, self.t_b - 1) == time:
1596
+ return ancestor
1597
+ else:
1598
+ return -1
1543
1599
 
1544
- def get_labelled_ancestor(self, node: int):
1545
- """Finds the first labelled ancestor and returns its ID otherwise returns None
1600
+ def get_labelled_ancestor(self, node: int) -> int:
1601
+ """Finds the first labelled ancestor and returns its ID otherwise returns -1
1546
1602
 
1547
- Args:
1548
- node (int): The id of the node
1603
+ Parameters
1604
+ ----------
1605
+ node : int
1606
+ The id of the node
1549
1607
 
1550
- Returns:
1551
- [None,int]: Returns the first ancestor found that has a label otherwise
1552
- None.
1608
+ Returns
1609
+ -------
1610
+ int
1611
+ Returns the first ancestor found that has a label otherwise `-1`.
1553
1612
  """
1554
1613
  if node not in self.nodes:
1555
- return None
1614
+ return -1
1556
1615
  ancestor = node
1557
1616
  while (
1558
- self.t_b <= self.time.get(ancestor, self.t_b - 1)
1617
+ self.t_b <= self._time.get(ancestor, self.t_b - 1)
1559
1618
  and ancestor != -1
1560
1619
  ):
1561
1620
  if ancestor in self.labels:
1562
1621
  return ancestor
1563
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1564
- return
1622
+ ancestor = self._predecessor.get(ancestor, [-1])[0]
1623
+ return -1
1624
+
1625
+ def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
1626
+ """General purpose function to help with searching the first ancestor that has an attribute.
1627
+ Similar to get_labeled_ancestor and may make it redundant.
1628
+
1629
+ Parameters
1630
+ ----------
1631
+ node : int
1632
+ The id of the node
1633
+
1634
+ Returns
1635
+ -------
1636
+ int
1637
+ Returns the first ancestor found that has an attribute otherwise `-1`.
1638
+ """
1639
+ attr_dict = self.__getattribute__(attribute)
1640
+ if not isinstance(attr_dict, dict):
1641
+ raise ValueError("Please select a dict attribute")
1642
+ if node not in self.nodes:
1643
+ return -1
1644
+ if node in attr_dict:
1645
+ return node
1646
+ if node in self.roots:
1647
+ return -1
1648
+ ancestor = (node,)
1649
+ while ancestor and ancestor != [-1]:
1650
+ ancestor = ancestor[0]
1651
+ if ancestor in attr_dict:
1652
+ return ancestor
1653
+ ancestor = self._predecessor.get(ancestor, [-1])
1654
+ return -1
1565
1655
 
1566
1656
  def unordered_tree_edit_distances_at_time_t(
1567
1657
  self,
1568
1658
  t: int,
1569
- end_time: int = None,
1570
- style="simple",
1659
+ end_time: int | None = None,
1660
+ style: (
1661
+ Literal["simple", "full", "downsampled", "normalized_simple"]
1662
+ | type[TreeApproximationTemplate]
1663
+ ) = "simple",
1571
1664
  downsample: int = 2,
1572
- normalize: bool = True,
1665
+ norm: Literal["max", "sum", None] = "max",
1573
1666
  recompute: bool = False,
1574
- ) -> dict:
1575
- """
1576
- Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1577
-
1578
- Args:
1579
- t (int): time to look at
1580
- delta (callable): comparison function (see edist doc for more information)
1581
- norm (callable): norming function that takes the number of nodes
1582
- of the tree spawned by `n1` and the number of nodes
1583
- of the tree spawned by `n2` as arguments.
1584
- recompute (bool): if True, forces to recompute the distances (default: False)
1585
- end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
1586
- will be taken into account.
1587
-
1588
- Returns:
1589
- (dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
1667
+ ) -> dict[tuple[int, int], float]:
1668
+ """Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1669
+
1670
+ Parameters
1671
+ ----------
1672
+ t : int
1673
+ time to look at
1674
+ end_time : int
1675
+ The final time point the comparison algorithm will take into account.
1676
+ If None all nodes will be taken into account.
1677
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1678
+ Which tree approximation is going to be used for the comparisons.
1679
+ downsample : int, default=2
1680
+ The downsample factor for the downsampled tree approximation.
1681
+ Used only when `style="downsampled"`.
1682
+ norm : {"max", "sum"}, default="max"
1683
+ The normalization method to use.
1684
+ recompute : bool, default=False
1685
+ If True, forces to recompute the distances
1686
+
1687
+ Returns
1688
+ -------
1689
+ dict mapping a tuple of tuple that contains 2 ints to float
1690
+ a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
1590
1691
  """
1591
1692
  if not hasattr(self, "uted"):
1592
1693
  self.uted = {}
1593
1694
  elif t in self.uted and not recompute:
1594
1695
  return self.uted[t]
1595
1696
  self.uted[t] = {}
1596
- roots = self.time_nodes[t]
1697
+ roots = self.nodes_at_t(t=t)
1597
1698
  for n1, n2 in combinations(roots, 2):
1598
1699
  key = tuple(sorted((n1, n2)))
1599
1700
  self.uted[t][key] = self.unordered_tree_edit_distance(
@@ -1602,37 +1703,132 @@ class lineageTree(lineageTreeLoaders):
1602
1703
  end_time=end_time,
1603
1704
  style=style,
1604
1705
  downsample=downsample,
1605
- normalize=normalize,
1706
+ norm=norm,
1606
1707
  )
1607
1708
  return self.uted[t]
1608
1709
 
1609
- def unordered_tree_edit_distance(
1710
+ def __calculate_distance_of_sub_tree(
1711
+ self,
1712
+ node1: int,
1713
+ node2: int,
1714
+ alignment: Alignment,
1715
+ corres1: dict[int, int],
1716
+ corres2: dict[int, int],
1717
+ delta_tmp: Callable,
1718
+ norm: Callable,
1719
+ norm1: int | float,
1720
+ norm2: int | float,
1721
+ ) -> float:
1722
+ """Calculates the distance of the subtree of each node matched in a comparison.
1723
+ DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
1724
+ TODO ITS BOUND TO CHANGE
1725
+ Parameters
1726
+ ----------
1727
+ node1 : int
1728
+ The root of the first subtree
1729
+ node2 : int
1730
+ The root of the second subtree
1731
+ alignment : Alignment
1732
+ The alignment of the subtree
1733
+ corres1 : dict
1734
+ The correspndance dictionary of the first lineage
1735
+ corres2 : dict
1736
+ The correspondance dictionary of the second lineage
1737
+ delta_tmp : Callable
1738
+ The delta function for the comparisons
1739
+ norm : Callable
1740
+ How should the lineages be normalized
1741
+ norm1 : int or float
1742
+ The result of the normalization of the first tree
1743
+ norm2 : int or float
1744
+ The result of the normalization of the second tree
1745
+
1746
+ Returns
1747
+ -------
1748
+ float
1749
+ The result of the comparison of the subtree
1750
+ """
1751
+ sub_tree_1 = set(self.get_subtree_nodes(node1))
1752
+ sub_tree_2 = set(self.get_subtree_nodes(node2))
1753
+ res = 0
1754
+ for m in alignment:
1755
+ if (
1756
+ corres1.get(m._left, -1) in sub_tree_1
1757
+ or corres2.get(m._right, -1) in sub_tree_2
1758
+ ):
1759
+ res += delta_tmp(
1760
+ m._left if m._left != -1 else None,
1761
+ m._right if m._right != -1 else None,
1762
+ )
1763
+ return res / norm([norm1, norm2])
1764
+
1765
+ def clear_comparisons(self):
1766
+ self._comparisons.clear()
1767
+
1768
+ def __unordereded_backtrace(
1610
1769
  self,
1611
1770
  n1: int,
1612
1771
  n2: int,
1613
- end_time: int = None,
1614
- norm: Union["max", "sum", None] = "max",
1615
- style="simple",
1772
+ end_time: int | None = None,
1773
+ norm: Literal["max", "sum", None] = "max",
1774
+ style: (
1775
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1776
+ | type[TreeApproximationTemplate]
1777
+ ) = "simple",
1616
1778
  downsample: int = 2,
1617
- ) -> float:
1779
+ ) -> dict[
1780
+ str,
1781
+ Alignment
1782
+ | tuple[TreeApproximationTemplate, TreeApproximationTemplate],
1783
+ ]:
1618
1784
  """
1619
- Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
1785
+ Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
1620
1786
  by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
1621
1787
  cost is given by the function delta (see edist doc for more information).
1622
- The distance is normed by the function norm that takes the two list of nodes
1623
- spawned by the trees `n1` and `n2`.
1624
1788
 
1625
- Args:
1626
- n1 (int): id of the first node to compare
1627
- n2 (int): id of the second node to compare
1628
- tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
1629
- Defaults to "fragmented".
1630
-
1631
- Returns:
1632
- (float) The normed unordered tree edit distance
1633
- """
1634
-
1635
- tree = tree_style[style].value
1789
+ Parameters
1790
+ ----------
1791
+ n1 : int
1792
+ id of the first node to compare
1793
+ n2 : int
1794
+ id of the second node to compare
1795
+ end_time : int
1796
+ The final time point the comparison algorithm will take into account.
1797
+ If None all nodes will be taken into account.
1798
+ norm : {"max", "sum"}, default="max"
1799
+ The normalization method to use.
1800
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1801
+ Which tree approximation is going to be used for the comparisons.
1802
+ downsample : int, default=2
1803
+ The downsample factor for the downsampled tree approximation.
1804
+ Used only when `style="downsampled"`.
1805
+
1806
+ Returns
1807
+ -------
1808
+ dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
1809
+ - 'alignment'
1810
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
1811
+ - 'trees'
1812
+ A list of the two trees that have been mapped to each other.
1813
+ """
1814
+
1815
+ parameters = (
1816
+ end_time,
1817
+ convert_style_to_number(style=style, downsample=downsample),
1818
+ )
1819
+ n1, n2 = sorted([n1, n2])
1820
+ self._comparisons.setdefault(parameters, {})
1821
+ if len(self._comparisons) > 100:
1822
+ warnings.warn(
1823
+ "More than 100 comparisons are saved, use clear_comparisons() to delete them.",
1824
+ stacklevel=2,
1825
+ )
1826
+ if isinstance(style, str):
1827
+ tree = tree_style[style].value
1828
+ elif issubclass(style, TreeApproximationTemplate):
1829
+ tree = style
1830
+ else:
1831
+ raise ValueError("Please use a valid approximation.")
1636
1832
  tree1 = tree(
1637
1833
  lT=self,
1638
1834
  downsample=downsample,
@@ -1661,7 +1857,11 @@ class lineageTree(lineageTreeLoaders):
1661
1857
  corres2,
1662
1858
  ) = tree2.edist
1663
1859
  if len(nodes1) == len(nodes2) == 0:
1664
- return 0
1860
+ self._comparisons[parameters][(n1, n2)] = {
1861
+ "alignment": (),
1862
+ "trees": (),
1863
+ }
1864
+ return self._comparisons[parameters][(n1, n2)]
1665
1865
  delta_tmp = partial(
1666
1866
  delta,
1667
1867
  corres1=corres1,
@@ -1669,126 +1869,538 @@ class lineageTree(lineageTreeLoaders):
1669
1869
  times1=times1,
1670
1870
  times2=times2,
1671
1871
  )
1672
- norm1 = tree1.get_norm()
1673
- norm2 = tree2.get_norm()
1674
- norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
1675
- if norm is None:
1676
- norm = "None"
1677
- if norm not in norm_dict:
1872
+ btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
1873
+
1874
+ self._comparisons[parameters][(n1, n2)] = {
1875
+ "alignment": btrc,
1876
+ "trees": (tree1, tree2),
1877
+ }
1878
+ return self._comparisons[parameters][(n1, n2)]
1879
+
1880
+ def plot_tree_distance_graphs(
1881
+ self,
1882
+ n1: int,
1883
+ n2: int,
1884
+ end_time: int | None = None,
1885
+ norm: Literal["max", "sum", None] = "max",
1886
+ style: (
1887
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1888
+ | type[TreeApproximationTemplate]
1889
+ ) = "simple",
1890
+ downsample: int = 2,
1891
+ colormap: str = "cool",
1892
+ default_color: str = "black",
1893
+ size: float = 10,
1894
+ lw: float = 0.3,
1895
+ ax: list[plt.Axes] | None = None,
1896
+ ) -> tuple[plt.figure, plt.Axes]:
1897
+ """
1898
+ Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
1899
+
1900
+ Parameters
1901
+ ----------
1902
+ n1 : int
1903
+ id of the first node to compare
1904
+ n2 : int
1905
+ id of the second node to compare
1906
+ end_time : int
1907
+ The final time point the comparison algorithm will take into account.
1908
+ If None all nodes will be taken into account.
1909
+ norm : {"max", "sum"}, default="max"
1910
+ The normalization method to use.
1911
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
1912
+ Which tree approximation is going to be used for the comparisons.
1913
+ downsample : int, default=2
1914
+ The downsample factor for the downsampled tree approximation.
1915
+ Used only when `style="downsampled"`.
1916
+ colormap : str, default="cool"
1917
+ The colormap used for matched nodes, defaults to "cool"
1918
+ default_color : str
1919
+ The color of the unmatched nodes, defaults to "black"
1920
+ size : float
1921
+ The size of the nodes, defaults to 10
1922
+ lw : float
1923
+ The width of the edges, defaults to 0.3
1924
+ ax : np.ndarray, optional
1925
+ The axes used, if not provided another set of axes is produced, defaults to None
1926
+
1927
+ Returns
1928
+ -------
1929
+ plt.Figure
1930
+ The figure of the plot
1931
+ plt.Axes
1932
+ The axes of the plot
1933
+ """
1934
+ parameters = (
1935
+ end_time,
1936
+ convert_style_to_number(style=style, downsample=downsample),
1937
+ )
1938
+ n1, n2 = sorted([n1, n2])
1939
+ self._comparisons.setdefault(parameters, {})
1940
+ if self._comparisons[parameters].get((n1, n2)):
1941
+ tmp = self._comparisons[parameters][(n1, n2)]
1942
+ else:
1943
+ tmp = self.__unordereded_backtrace(
1944
+ n1, n2, end_time, norm, style, downsample
1945
+ )
1946
+ btrc: Alignment = tmp["alignment"]
1947
+ tree1, tree2 = tmp["trees"]
1948
+ _, times1 = tree1.tree
1949
+ _, times2 = tree2.tree
1950
+ (
1951
+ *_,
1952
+ corres1,
1953
+ ) = tree1.edist
1954
+ (
1955
+ *_,
1956
+ corres2,
1957
+ ) = tree2.edist
1958
+ delta_tmp = partial(
1959
+ tree1.delta,
1960
+ corres1=corres1,
1961
+ corres2=corres2,
1962
+ times1=times1,
1963
+ times2=times2,
1964
+ )
1965
+
1966
+ if norm not in self.norm_dict:
1678
1967
  raise Warning(
1679
1968
  "Select a viable normalization method (max, sum, None)"
1680
1969
  )
1681
- return uted.uted(
1682
- nodes1, adj1, nodes2, adj2, delta=delta_tmp
1683
- ) / norm_dict[norm]([norm1, norm2])
1970
+ matched_right = []
1971
+ matched_left = []
1972
+ colors = {}
1973
+ if style not in ("full", "downsampled"):
1974
+ for m in btrc:
1975
+ if m._left != -1 and m._right != -1:
1976
+ cyc1 = self.get_chain_of_node(corres1[m._left])
1977
+ if len(cyc1) > 1:
1978
+ node_1, *_, l_node_1 = cyc1
1979
+ matched_left.append(node_1)
1980
+ matched_left.append(l_node_1)
1981
+ elif len(cyc1) == 1:
1982
+ node_1 = l_node_1 = cyc1.pop()
1983
+ matched_left.append(node_1)
1984
+
1985
+ cyc2 = self.get_chain_of_node(corres2[m._right])
1986
+ if len(cyc2) > 1:
1987
+ node_2, *_, l_node_2 = cyc2
1988
+ matched_right.append(node_2)
1989
+ matched_right.append(l_node_2)
1990
+
1991
+ elif len(cyc2) == 1:
1992
+ node_2 = l_node_2 = cyc2.pop()
1993
+ matched_right.append(node_2)
1994
+
1995
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
1996
+ node_1,
1997
+ node_2,
1998
+ btrc,
1999
+ corres1,
2000
+ corres2,
2001
+ delta_tmp,
2002
+ self.norm_dict[norm],
2003
+ tree1.get_norm(node_1),
2004
+ tree2.get_norm(node_2),
2005
+ )
2006
+ colors[node_2] = colors[node_1]
2007
+ colors[l_node_1] = colors[node_1]
2008
+ colors[l_node_2] = colors[node_2]
2009
+ else:
2010
+ for m in btrc:
2011
+ if m._left != -1 and m._right != -1:
2012
+ node_1 = corres1[m._left]
2013
+ node_2 = corres2[m._right]
2014
+
2015
+ if (
2016
+ self.get_chain_of_node(node_1)[0] == node_1
2017
+ or self.get_chain_of_node(node_2)[0] == node_2
2018
+ and (node_1 not in colors or node_2 not in colors)
2019
+ ):
2020
+ matched_left.append(node_1)
2021
+ l_node_1 = self.get_chain_of_node(node_1)[-1]
2022
+ matched_left.append(l_node_1)
2023
+ matched_right.append(node_2)
2024
+ l_node_2 = self.get_chain_of_node(node_2)[-1]
2025
+ matched_right.append(l_node_2)
2026
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
2027
+ node_1,
2028
+ node_2,
2029
+ btrc,
2030
+ corres1,
2031
+ corres2,
2032
+ delta_tmp,
2033
+ self.norm_dict[norm],
2034
+ tree1.get_norm(node_1),
2035
+ tree2.get_norm(node_2),
2036
+ )
2037
+ colors[l_node_1] = colors[node_1]
2038
+ colors[node_2] = colors[node_1]
2039
+ colors[l_node_2] = colors[node_1]
2040
+ if ax is None:
2041
+ fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
2042
+ cmap = colormaps[colormap]
2043
+ c_norm = mcolors.Normalize(0, 1)
2044
+ colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
2045
+ self.plot_subtree(
2046
+ self.get_ancestor_at_t(n1),
2047
+ end_time=end_time,
2048
+ size=size,
2049
+ selected_nodes=matched_left,
2050
+ color_of_nodes=colors,
2051
+ selected_edges=matched_left,
2052
+ color_of_edges=colors,
2053
+ default_color=default_color,
2054
+ lw=lw,
2055
+ ax=ax[0],
2056
+ )
2057
+ self.plot_subtree(
2058
+ self.get_ancestor_at_t(n2),
2059
+ end_time=end_time,
2060
+ size=size,
2061
+ selected_nodes=matched_right,
2062
+ color_of_nodes=colors,
2063
+ selected_edges=matched_right,
2064
+ color_of_edges=colors,
2065
+ default_color=default_color,
2066
+ lw=lw,
2067
+ ax=ax[1],
2068
+ )
2069
+ return ax[0].get_figure(), ax
2070
+
2071
+ def labelled_mappings(
2072
+ self,
2073
+ n1: int,
2074
+ n2: int,
2075
+ end_time: int | None = None,
2076
+ norm: Literal["max", "sum", None] = "max",
2077
+ style: (
2078
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2079
+ | type[TreeApproximationTemplate]
2080
+ ) = "simple",
2081
+ downsample: int = 2,
2082
+ ) -> dict[str, list[str]]:
2083
+ """
2084
+ Returns the labels or IDs of all the nodes in the subtrees compared.
2085
+
2086
+
2087
+ Parameters
2088
+ ----------
2089
+ n1 : int
2090
+ id of the first node to compare
2091
+ n2 : int
2092
+ id of the second node to compare
2093
+ end_time : int, optional
2094
+ The final time point the comparison algorithm will take into account.
2095
+ If None or not provided all nodes will be taken into account.
2096
+ norm : {"max", "sum"}, default="max"
2097
+ The normalization method to use, defaults to 'max'.
2098
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
2099
+ Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
2100
+ downsample : int, default=2
2101
+ The downsample factor for the downsampled tree approximation.
2102
+ Used only when `style="downsampled"`.
2103
+
2104
+ Returns
2105
+ -------
2106
+ dict mapping str to list of str
2107
+ - 'matched' The labels of the matched nodes of the alignment.
2108
+ - 'unmatched' The labels of the unmatched nodes of the alginment.
2109
+ """
2110
+ parameters = (
2111
+ end_time,
2112
+ convert_style_to_number(style=style, downsample=downsample),
2113
+ )
2114
+ n1, n2 = sorted([n1, n2])
2115
+ self._comparisons.setdefault(parameters, {})
2116
+ if self._comparisons[parameters].get((n1, n2)):
2117
+ tmp = self._comparisons[parameters][(n1, n2)]
2118
+ else:
2119
+ tmp = self.__unordereded_backtrace(
2120
+ n1, n2, end_time, norm, style, downsample
2121
+ )
2122
+ btrc = tmp["alignment"]
2123
+ tree1, tree2 = tmp["trees"]
2124
+
2125
+ (
2126
+ *_,
2127
+ corres1,
2128
+ ) = tree1.edist
2129
+ (
2130
+ *_,
2131
+ corres2,
2132
+ ) = tree2.edist
2133
+
2134
+ if norm not in self.norm_dict:
2135
+ raise Warning(
2136
+ "Select a viable normalization method (max, sum, None)"
2137
+ )
2138
+ matched = []
2139
+ unmatched = []
2140
+ if style not in ("full", "downsampled"):
2141
+ for m in btrc:
2142
+ if m._left != -1 and m._right != -1:
2143
+ cyc1 = self.get_chain_of_node(corres1[m._left])
2144
+ if len(cyc1) > 1:
2145
+ node_1, *_ = cyc1
2146
+ elif len(cyc1) == 1:
2147
+ node_1 = cyc1.pop()
2148
+ cyc2 = self.get_chain_of_node(corres2[m._right])
2149
+ if len(cyc2) > 1:
2150
+ node_2, *_ = cyc2
2151
+ elif len(cyc2) == 1:
2152
+ node_2 = cyc2.pop()
2153
+ matched.append(
2154
+ (
2155
+ self.labels.get(node_1, node_1),
2156
+ self.labels.get(node_2, node_2),
2157
+ )
2158
+ )
2159
+
2160
+ else:
2161
+ if m._left != -1:
2162
+ node_1 = self.get_chain_of_node(
2163
+ corres1.get(m._left, "-")
2164
+ )[0]
2165
+ else:
2166
+ node_1 = self.get_chain_of_node(
2167
+ corres2.get(m._right, "-")
2168
+ )[0]
2169
+ unmatched.append(self.labels.get(node_1, node_1))
2170
+ else:
2171
+ for m in btrc:
2172
+ if m._left != -1 and m._right != -1:
2173
+ node_1 = corres1[m._left]
2174
+ node_2 = corres2[m._right]
2175
+ matched.append(
2176
+ (
2177
+ self.labels.get(node_1, node_1),
2178
+ self.labels.get(node_2, node_2),
2179
+ )
2180
+ )
2181
+ else:
2182
+ if m._left != -1:
2183
+ node_1 = corres1[m._left]
2184
+ else:
2185
+ node_1 = corres2[m._right]
2186
+ unmatched.append(self.labels.get(node_1, node_1))
2187
+ return {"matched": matched, "unmatched": unmatched}
2188
+
2189
+ def unordered_tree_edit_distance(
2190
+ self,
2191
+ n1: int,
2192
+ n2: int,
2193
+ end_time: int | None = None,
2194
+ norm: Literal["max", "sum", None] = "max",
2195
+ style: (
2196
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2197
+ | type[TreeApproximationTemplate]
2198
+ ) = "simple",
2199
+ downsample: int = 2,
2200
+ return_norms: bool = False,
2201
+ ) -> float | tuple[float, tuple[float, float]]:
2202
+ """
2203
+ Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
2204
+ by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
2205
+ cost is given by the function delta (see edist doc for more information).
2206
+ The distance is normed by the function norm that takes the two list of nodes
2207
+ spawned by the trees `n1` and `n2`.
2208
+
2209
+ Parameters
2210
+ ----------
2211
+ n1 : int
2212
+ id of the first node to compare
2213
+ n2 : int
2214
+ id of the second node to compare
2215
+ end_time : int, optional
2216
+ The final time point the comparison algorithm will take into account.
2217
+ If None or not provided all nodes will be taken into account.
2218
+ norm : {"max", "sum"}, default="max"
2219
+ The normalization method to use, defaults to 'max'.
2220
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
2221
+ Which tree approximation is going to be used for the comparisons.
2222
+ downsample : int, default=2
2223
+ The downsample factor for the downsampled tree approximation.
2224
+ Used only when `style="downsampled"`.
2225
+
2226
+ Returns
2227
+ -------
2228
+ float
2229
+ The normalized unordered tree edit distance between `n1` and `n2`
2230
+ """
2231
+ parameters = (
2232
+ end_time,
2233
+ convert_style_to_number(style=style, downsample=downsample),
2234
+ )
2235
+ n1, n2 = sorted([n1, n2])
2236
+ self._comparisons.setdefault(parameters, {})
2237
+ if self._comparisons[parameters].get((n1, n2)):
2238
+ tmp = self._comparisons[parameters][(n1, n2)]
2239
+ else:
2240
+ tmp = self.__unordereded_backtrace(
2241
+ n1, n2, end_time, norm, style, downsample
2242
+ )
2243
+ btrc = tmp["alignment"]
2244
+ tree1, tree2 = tmp["trees"]
2245
+ _, times1 = tree1.tree
2246
+ _, times2 = tree2.tree
2247
+ (
2248
+ nodes1,
2249
+ adj1,
2250
+ corres1,
2251
+ ) = tree1.edist
2252
+ (
2253
+ nodes2,
2254
+ adj2,
2255
+ corres2,
2256
+ ) = tree2.edist
2257
+ delta_tmp = partial(
2258
+ tree1.delta,
2259
+ corres1=corres1,
2260
+ corres2=corres2,
2261
+ times1=times1,
2262
+ times2=times2,
2263
+ )
2264
+
2265
+ if norm not in self.norm_dict:
2266
+ raise ValueError(
2267
+ "Select a viable normalization method (max, sum, None)"
2268
+ )
2269
+ cost = btrc.cost(nodes1, nodes2, delta_tmp)
2270
+ norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
2271
+ if return_norms:
2272
+ return cost, norm_values
2273
+ return cost / self.norm_dict[norm](norm_values)
1684
2274
 
1685
2275
  @staticmethod
1686
2276
  def __plot_nodes(
1687
- hier, selected_nodes, color, size, ax, default_color="black", **kwargs
1688
- ):
2277
+ hier: dict,
2278
+ selected_nodes: set,
2279
+ color: str | dict | list,
2280
+ size: int | float,
2281
+ ax: plt.Axes,
2282
+ default_color: str = "black",
2283
+ **kwargs,
2284
+ ) -> None:
1689
2285
  """
1690
2286
  Private method that plots the nodes of the tree.
1691
2287
  """
1692
- hier_unselected = np.array(
1693
- [v for k, v in hier.items() if k not in selected_nodes]
1694
- )
1695
- if hier_unselected.any():
1696
- ax.scatter(
1697
- *hier_unselected.T,
1698
- s=size,
1699
- zorder=10,
1700
- color=default_color,
1701
- **kwargs,
1702
- )
1703
- if selected_nodes.intersection(hier.keys()):
1704
- hier_selected = np.array(
1705
- [v for k, v in hier.items() if k in selected_nodes]
1706
- )
1707
- ax.scatter(
1708
- *hier_selected.T, s=size, zorder=10, color=color, **kwargs
1709
- )
2288
+
2289
+ if isinstance(color, dict):
2290
+ color = [color.get(k, default_color) for k in hier]
2291
+ elif isinstance(color, str | list):
2292
+ color = [
2293
+ color if node in selected_nodes else default_color
2294
+ for node in hier
2295
+ ]
2296
+ hier_pos = np.array(list(hier.values()))
2297
+ ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
1710
2298
 
1711
2299
  @staticmethod
1712
2300
  def __plot_edges(
1713
- hier,
1714
- lnks_tms,
1715
- selected_edges,
1716
- color,
1717
- ax,
1718
- default_color="black",
2301
+ hier: dict,
2302
+ lnks_tms: dict,
2303
+ selected_edges: Iterable,
2304
+ color: str | dict | list,
2305
+ lw: float,
2306
+ ax: plt.Axes,
2307
+ default_color: str = "black",
1719
2308
  **kwargs,
1720
- ):
2309
+ ) -> None:
1721
2310
  """
1722
2311
  Private method that plots the edges of the tree.
1723
2312
  """
1724
- x, y = [], []
2313
+ if isinstance(color, dict):
2314
+ selected_edges = color.keys()
2315
+ lines = []
2316
+ c = []
1725
2317
  for pred, succs in lnks_tms["links"].items():
1726
- for succ in succs:
1727
- if pred not in selected_edges or succ not in selected_edges:
1728
- x.extend((hier[succ][0], hier[pred][0], None))
1729
- y.extend((hier[succ][1], hier[pred][1], None))
1730
- ax.plot(x, y, linewidth=0.3, zorder=0.1, c=default_color, **kwargs)
1731
- x, y = [], []
1732
- for pred, succs in lnks_tms["links"].items():
1733
- for succ in succs:
1734
- if pred in selected_edges and succ in selected_edges:
1735
- x.extend((hier[succ][0], hier[pred][0], None))
1736
- y.extend((hier[succ][1], hier[pred][1], None))
1737
- ax.plot(x, y, linewidth=0.3, zorder=0.2, c=color, **kwargs)
2318
+ for suc in succs:
2319
+ lines.append(
2320
+ [
2321
+ [hier[suc][0], hier[suc][1]],
2322
+ [hier[pred][0], hier[pred][1]],
2323
+ ]
2324
+ )
2325
+ if pred in selected_edges:
2326
+ if isinstance(color, str | list):
2327
+ c.append(color)
2328
+ elif isinstance(color, dict):
2329
+ c.append(color[pred])
2330
+ else:
2331
+ c.append(default_color)
2332
+ lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
2333
+ ax.add_collection(lc)
1738
2334
 
1739
2335
  def draw_tree_graph(
1740
2336
  self,
1741
- hier,
1742
- lnks_tms,
1743
- selected_nodes=None,
1744
- selected_edges=None,
1745
- color_of_nodes="magenta",
1746
- color_of_edges=None,
1747
- size=10,
1748
- ax=None,
1749
- default_color="black",
2337
+ hier: dict[int, tuple[int, int]],
2338
+ lnks_tms: dict[str, dict[int, list | int]],
2339
+ selected_nodes: list | set | None = None,
2340
+ selected_edges: list | set | None = None,
2341
+ color_of_nodes: str | dict = "magenta",
2342
+ color_of_edges: str | dict = "magenta",
2343
+ size: int | float = 10,
2344
+ lw: float = 0.3,
2345
+ ax: plt.Axes | None = None,
2346
+ default_color: str = "black",
1750
2347
  **kwargs,
1751
- ):
2348
+ ) -> tuple[plt.Figure, plt.Axes]:
1752
2349
  """Function to plot the tree graph.
1753
2350
 
1754
- Args:
1755
- hier (dict): Dictinary that contains the positions of all nodes.
1756
- lnks_tms (dict): 2 dictionaries: 1 contains all links from start of life cycle to end of life cycle and
1757
- the succesors of each cell.
1758
- 1 contains the length of each life cycle.
1759
- selected_nodes (list|set, optional): Which cells are to be selected (Painted with a different color). Defaults to None.
1760
- selected_edges (list|set, optional): Which edges are to be selected (Painted with a different color). Defaults to None.
1761
- color_of_nodes (str, optional): Color of selected nodes. Defaults to "magenta".
1762
- color_of_edges (_type_, optional): Color of selected edges. Defaults to None.
1763
- size (int, optional): Size of the nodes. Defaults to 10.
1764
- ax (_type_, optional): Plot the graph on existing ax. Defaults to None.
1765
- figure (_type_, optional): _description_. Defaults to None.
1766
- default_color (str, optional): Default color of nodes. Defaults to "black".
1767
-
1768
- Returns:
1769
- figure, ax: The matplotlib figure and ax object.
2351
+ Parameters
2352
+ ----------
2353
+ hier : dict mapping int to tuple of int
2354
+ Dictionary that contains the positions of all nodes.
2355
+ lnks_tms : dict mapping string to dictionaries mapping int to list or int
2356
+ - 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
2357
+ - 'times' : contains the distance between the start and the end of each chain.
2358
+ selected_nodes : list or set, optional
2359
+ Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
2360
+ selected_edges : list or set, optional
2361
+ Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
2362
+ color_of_nodes : str, default="magenta"
2363
+ Color of selected nodes
2364
+ color_of_edges : str, default="magenta"
2365
+ Color of selected edges
2366
+ size : int, default=10
2367
+ Size of the nodes, defaults to 10
2368
+ lw : float, default=0.3
2369
+ The width of the edges of the tree graph, defaults to 0.3
2370
+ ax : plt.Axes, optional
2371
+ Plot the graph on existing ax. If not provided or None a new ax is going to be created.
2372
+ default_color : str, default="black"
2373
+ Default color of nodes
2374
+
2375
+ Returns
2376
+ -------
2377
+ plt.Figure
2378
+ The matplotlib figure
2379
+ plt.Axes
2380
+ The matplotlib ax
1770
2381
  """
1771
2382
  if selected_nodes is None:
1772
2383
  selected_nodes = []
1773
2384
  if selected_edges is None:
1774
2385
  selected_edges = []
1775
2386
  if ax is None:
1776
- figure, ax = plt.subplots()
2387
+ _, ax = plt.subplots()
1777
2388
  else:
1778
2389
  ax.clear()
1779
2390
  if not isinstance(selected_nodes, set):
1780
2391
  selected_nodes = set(selected_nodes)
1781
2392
  if not isinstance(selected_edges, set):
1782
2393
  selected_edges = set(selected_edges)
1783
- self.__plot_nodes(
1784
- hier,
1785
- selected_nodes,
1786
- color_of_nodes,
1787
- size=size,
1788
- ax=ax,
1789
- default_color=default_color,
1790
- **kwargs,
1791
- )
2394
+ if 0 < size:
2395
+ self.__plot_nodes(
2396
+ hier,
2397
+ selected_nodes,
2398
+ color_of_nodes,
2399
+ size=size,
2400
+ ax=ax,
2401
+ default_color=default_color,
2402
+ **kwargs,
2403
+ )
1792
2404
  if not color_of_edges:
1793
2405
  color_of_edges = color_of_nodes
1794
2406
  self.__plot_edges(
@@ -1796,62 +2408,107 @@ class lineageTree(lineageTreeLoaders):
1796
2408
  lnks_tms,
1797
2409
  selected_edges,
1798
2410
  color_of_edges,
2411
+ lw,
1799
2412
  ax,
1800
2413
  default_color=default_color,
1801
2414
  **kwargs,
1802
2415
  )
2416
+ ax.autoscale()
2417
+ plt.draw()
1803
2418
  ax.get_yaxis().set_visible(False)
1804
2419
  ax.get_xaxis().set_visible(False)
1805
2420
  return ax.get_figure(), ax
1806
2421
 
1807
- def to_simple_graph(self, node=None, start_time: int = None):
2422
+ def _create_dict_of_plots(
2423
+ self,
2424
+ node: int | Iterable[int] | None = None,
2425
+ start_time: int | None = None,
2426
+ end_time: int | None = None,
2427
+ ) -> dict[int, dict]:
1808
2428
  """Generates a dictionary of graphs where the keys are the index of the graph and
1809
- the values are the graphs themselves which are produced by create_links_and _cycles
1810
-
1811
- Args:
1812
- node (_type_, optional): The id of the node/nodes to produce the simple graphs. Defaults to None.
1813
- start_time (int, optional): Important only if there are no nodes it will produce the graph of every
1814
- root that starts before or at start time. Defaults to None.
1815
-
1816
- Returns:
1817
- (dict): The keys are just index values 0-n and the values are the graphs produced.
2429
+ the values are the graphs themselves which are produced by `create_links_and_chains`
2430
+
2431
+ Parameters
2432
+ ----------
2433
+ node : int or Iterable of int, optional
2434
+ The id of the node/nodes to produce the simple graphs, if not provided or None will
2435
+ calculate the dicts for every root that starts before 'start_time'
2436
+ start_time : int, optional
2437
+ Important only if there are no nodes it will produce the graph of every
2438
+ root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
2439
+ end_time : int, optional
2440
+ The last timepoint to be considered, if not provided or None the last timepoint of the
2441
+ dataset (t_e) is considered.
2442
+
2443
+ Returns
2444
+ -------
2445
+ dict mapping int to dict
2446
+ The keys are just index values 0-n and the values are the graphs produced.
1818
2447
  """
1819
2448
  if start_time is None:
1820
2449
  start_time = self.t_b
2450
+ if end_time is None:
2451
+ end_time = self.t_e
1821
2452
  if node is None:
1822
2453
  mothers = [
1823
- root for root in self.roots if self.time[root] <= start_time
2454
+ root for root in self.roots if self._time[root] <= start_time
1824
2455
  ]
2456
+ elif isinstance(node, Iterable):
2457
+ mothers = node
1825
2458
  else:
1826
- mothers = node if isinstance(node, (list, set)) else [node]
2459
+ mothers = [node]
1827
2460
  return {
1828
- i: create_links_and_cycles(self, mother)
2461
+ i: create_links_and_chains(self, mother, end_time=end_time)
1829
2462
  for i, mother in enumerate(mothers)
1830
2463
  }
1831
2464
 
1832
2465
  def plot_all_lineages(
1833
2466
  self,
1834
- nodes: list = None,
1835
- last_time_point_to_consider: int = None,
1836
- nrows=2,
1837
- figsize=(10, 15),
1838
- dpi=100,
1839
- fontsize=15,
1840
- axes=None,
1841
- vert_gap=1,
2467
+ nodes: list | None = None,
2468
+ last_time_point_to_consider: int | None = None,
2469
+ nrows: int = 2,
2470
+ figsize: tuple[int, int] = (10, 15),
2471
+ dpi: int = 100,
2472
+ fontsize: int = 15,
2473
+ axes: plt.Axes | None = None,
2474
+ vert_gap: int = 1,
1842
2475
  **kwargs,
1843
- ):
2476
+ ) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
1844
2477
  """Plots all lineages.
1845
2478
 
1846
- Args:
1847
- last_time_point_to_consider (int, optional): Which timepoints and upwards are the graphs to be plotted.
1848
- For example if start_time is 10, then all trees that begin
1849
- on tp 10 or before are calculated. Defaults to None, where
1850
- it will plot all the roots that exist on self.t_b.
1851
- nrows (int): How many rows of plots should be printed.
1852
- kwargs: args accepted by matplotlib
2479
+ Parameters
2480
+ ----------
2481
+ nodes : list, optional
2482
+ The nodes spawning the graphs to be plotted.
2483
+ last_time_point_to_consider : int, optional
2484
+ Which timepoints and upwards are the graphs to be plotted.
2485
+ For example if start_time is 10, then all trees that begin
2486
+ on tp 10 or before are calculated. Defaults to None, where
2487
+ it will plot all the roots that exist on `self.t_b`.
2488
+ nrows : int, default=2
2489
+ How many rows of plots should be printed.
2490
+ figsize : tuple, default=(10, 15)
2491
+ The size of the figure.
2492
+ dpi : int, default=100
2493
+ The dpi of the figure.
2494
+ fontsize : int, default=15
2495
+ The fontsize of the labels.
2496
+ axes : plt.Axes, optional
2497
+ The axes to plot the graphs on. If None or not provided new axes are going to be created.
2498
+ vert_gap : int, default=1
2499
+ space between the nodes, defaults to 1
2500
+ **kwargs:
2501
+ kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
2502
+
2503
+ Returns
2504
+ -------
2505
+ plt.Figure
2506
+ The figure
2507
+ plt.Axes
2508
+ The axes
2509
+ dict of plt.Axes to int
2510
+ A dictionary that maps the axes to the root of the tree.
1853
2511
  """
1854
-
1855
2512
  nrows = int(nrows)
1856
2513
  if last_time_point_to_consider is None:
1857
2514
  last_time_point_to_consider = self.t_b
@@ -1860,17 +2517,18 @@ class lineageTree(lineageTreeLoaders):
1860
2517
  raise Warning("Number of rows has to be at least 1")
1861
2518
  if nodes:
1862
2519
  graphs = {
1863
- i: self.to_simple_graph(node) for i, node in enumerate(nodes)
2520
+ i: self._create_dict_of_plots(node)
2521
+ for i, node in enumerate(nodes)
1864
2522
  }
1865
2523
  else:
1866
- graphs = self.to_simple_graph(
2524
+ graphs = self._create_dict_of_plots(
1867
2525
  start_time=last_time_point_to_consider
1868
2526
  )
1869
2527
  pos = {
1870
2528
  i: hierarchical_pos(
1871
2529
  g,
1872
2530
  g["root"],
1873
- ycenter=-int(self.time[g["root"]]),
2531
+ ycenter=-int(self._time[g["root"]]),
1874
2532
  vert_gap=vert_gap,
1875
2533
  )
1876
2534
  for i, g in graphs.items()
@@ -1925,135 +2583,147 @@ class lineageTree(lineageTreeLoaders):
1925
2583
  [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
1926
2584
  return axes.flatten()[0].get_figure(), axes, ax2root
1927
2585
 
1928
- def plot_node(
2586
+ def plot_subtree(
1929
2587
  self,
1930
- node,
1931
- figsize=(4, 7),
1932
- dpi=150,
1933
- vert_gap=2,
1934
- ax=None,
1935
- **kwargs,
1936
- ):
2588
+ node: int,
2589
+ end_time: int | None = None,
2590
+ figsize: tuple[int, int] = (4, 7),
2591
+ dpi: int = 150,
2592
+ vert_gap: int = 2,
2593
+ selected_nodes: list | None = None,
2594
+ selected_edges: list | None = None,
2595
+ color_of_nodes: str | dict = "magenta",
2596
+ color_of_edges: str | dict = "magenta",
2597
+ size: int | float = 10,
2598
+ lw: float = 0.1,
2599
+ default_color: str = "black",
2600
+ ax: plt.Axes | None = None,
2601
+ ) -> tuple[plt.Figure, plt.Axes]:
1937
2602
  """Plots the subtree spawn by a node.
1938
2603
 
1939
- Args:
1940
- node (int): The id of the node that is going to be plotted.
1941
- kwargs: args accepted by matplotlib
1942
- """
1943
- graph = self.to_simple_graph(node)
2604
+ Parameters
2605
+ ----------
2606
+ node : int
2607
+ The id of the node that is going to be plotted.
2608
+ end_time : int, optional
2609
+ The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
2610
+ figsize : tuple of 2 ints, default=(4,7)
2611
+ The size of the figure, deafults to (4,7)
2612
+ vert_gap : int, default=2
2613
+ The verical gap of a node when it divides, defaults to 2.
2614
+ dpi : int, default=150
2615
+ The dpi of the figure, defaults to 150
2616
+ selected_nodes : list, optional
2617
+ The nodes that are selected by the user to be colored in a different color, defaults to None
2618
+ selected_edges : list, optional
2619
+ The edges that are selected by the user to be colored in a different color, defaults to None
2620
+ color_of_nodes : str, default="magenta"
2621
+ The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
2622
+ color_of_edges : str, default="magenta"
2623
+ The color of the edges to be colored, except the default colored ones, defaults to "magenta"
2624
+ size : int, default=10
2625
+ The size of the nodes, defaults to 10
2626
+ lw : float, default=0.1
2627
+ The widthe of the edges of the tree graph, defaults to 0.1
2628
+ default_color : str, default="black"
2629
+ The default color of nodes and edges, defaults to "black"
2630
+ ax : plt.Axes, optional
2631
+ The ax where the plot is going to be applied, if not provided or None new axes will be created.
2632
+
2633
+ Returns
2634
+ -------
2635
+ plt.Figure
2636
+ The matplotlib figure
2637
+ plt.Axes
2638
+ The matplotlib axes
2639
+
2640
+ Raises
2641
+ ------
2642
+ Warning
2643
+ If more than one nodes are received
2644
+ """
2645
+ graph = self._create_dict_of_plots(node, end_time=end_time)
1944
2646
  if len(graph) > 1:
1945
- raise Warning("Please enter only one node")
2647
+ raise Warning(
2648
+ "Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
2649
+ )
1946
2650
  graph = graph[0]
1947
2651
  if not ax:
1948
- figure, ax = plt.subplots(
1949
- nrows=1, ncols=1, figsize=figsize, dpi=dpi
1950
- )
2652
+ _, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
1951
2653
  self.draw_tree_graph(
1952
2654
  hier=hierarchical_pos(
1953
2655
  graph,
1954
2656
  graph["root"],
1955
2657
  vert_gap=vert_gap,
1956
- ycenter=-int(self.time[node]),
2658
+ ycenter=-int(self._time[node]),
1957
2659
  ),
2660
+ selected_edges=selected_edges,
2661
+ selected_nodes=selected_nodes,
2662
+ color_of_edges=color_of_edges,
2663
+ color_of_nodes=color_of_nodes,
2664
+ default_color=default_color,
2665
+ size=size,
2666
+ lw=lw,
1958
2667
  lnks_tms=graph,
1959
2668
  ax=ax,
1960
2669
  )
1961
2670
  return ax.get_figure(), ax
1962
2671
 
1963
- # def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
1964
- # metric='euclidian', **kwargs):
1965
- # """ Computes the dynamic time warping distance between the tracks t1 and t2
1966
-
1967
- # Args:
1968
- # t1 ([int, ]): list of node ids for the first track
1969
- # t2 ([int, ]): list of node ids for the second track
1970
- # w (int): maximum wapring allowed (default infinite),
1971
- # if w=1 then the DTW is the distance between t1 and t2
1972
- # start_delay (int): maximum number of time points that can be
1973
- # skipped at the beginning of the track
1974
- # end_delay (int): minimum number of time points that can be
1975
- # skipped at the beginning of the track
1976
- # metric (str): str or callable, optional The distance metric to use.
1977
- # Default='euclidean'. Refer to the documentation for
1978
- # scipy.spatial.distance.cdist. Some examples:
1979
- # 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
1980
- # 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
1981
- # 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
1982
- # 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
1983
- # 'wminkowski', 'yule'
1984
- # **kwargs (dict): Extra arguments to `metric`: refer to each metric
1985
- # documentation in scipy.spatial.distance (optional)
1986
-
1987
- # Returns:
1988
- # float: the dynamic time warping distance between the two tracks
1989
- # """
1990
- # from scipy.sparse import
1991
- # pos_t1 = [self.pos[ti] for ti in t1]
1992
- # pos_t2 = [self.pos[ti] for ti in t2]
1993
- # distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
1994
-
1995
- # c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
1996
-
1997
- # d = np.zeros(c.shape)
1998
- # d[0, 0] = c[0, 0]
1999
- # n, m = c.shape
2000
- # for i in range(1, n):
2001
- # d[i, 0] = d[i-1, 0] + c[i, 0]
2002
- # for j in range(1, m):
2003
- # d[0, j] = d[0, j-1] + c[0, j]
2004
- # for i in range(1, n):
2005
- # for j in range(1, m):
2006
- # d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
2007
- # return d[-1, -1], d
2008
-
2009
- def __getitem__(self, item):
2010
- if isinstance(item, str):
2011
- return self.__dict__[item]
2012
- elif np.issubdtype(type(item), np.integer):
2013
- return self.successor.get(item, [])
2014
- else:
2015
- raise KeyError(
2016
- "Only integer or string are valid key for lineageTree"
2017
- )
2018
-
2019
- def get_cells_at_t_from_root(self, r: [int, list], t: int = None) -> list:
2020
- """
2021
- Returns the list of cells at time `t` that are spawn by the node(s) `r`.
2022
-
2023
- Args:
2024
- r (int | list): id or list of ids of the spawning node
2025
- t (int): target time, if None goes as far as possible
2026
- (default None)
2027
-
2028
- Returns:
2029
- (list) list of nodes at time `t` spawned by `r`
2030
- """
2031
- if not isinstance(r, list):
2672
+ def nodes_at_t(
2673
+ self,
2674
+ t: int,
2675
+ r: int | Iterable[int] | None = None,
2676
+ ) -> list[int]:
2677
+ """
2678
+ Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
2679
+
2680
+ Parameters
2681
+ ----------
2682
+ t : int
2683
+ target time, if `None` goes as far as possible
2684
+ r : int or Iterable of int, optional
2685
+ id or list of ids of the spawning node
2686
+
2687
+ Returns
2688
+ -------
2689
+ list of int
2690
+ list of ids of the nodes at time `t` spawned by `r`
2691
+ """
2692
+ if not r and r != 0:
2693
+ r = {root for root in self.roots if self.time[root] <= t}
2694
+ if isinstance(r, int):
2032
2695
  r = [r]
2696
+ if t is None:
2697
+ t = self.t_e
2033
2698
  to_do = list(r)
2034
2699
  final_nodes = []
2035
2700
  while len(to_do) > 0:
2036
2701
  curr = to_do.pop()
2037
- for _next in self[curr]:
2038
- if self.time[_next] < t:
2702
+ for _next in self._successor[curr]:
2703
+ if self._time[_next] < t:
2039
2704
  to_do.append(_next)
2040
- elif self.time[_next] == t:
2705
+ elif self._time[_next] == t:
2041
2706
  final_nodes.append(_next)
2042
2707
  if not final_nodes:
2043
2708
  return list(r)
2044
2709
  return final_nodes
2045
2710
 
2046
2711
  @staticmethod
2047
- def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
2712
+ def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
2048
2713
  """
2049
2714
  Calculate the line that centers the band w.
2050
2715
 
2051
- Args:
2052
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2716
+ Parameters
2717
+ ----------
2718
+ dist_mat : np.ndarray
2719
+ distance matrix obtained by the function calculate_dtw
2053
2720
 
2054
- Returns:
2055
- (float) Slope
2056
- (float) intercept of the line
2721
+ Returns
2722
+ -------
2723
+ float
2724
+ The slope of the curve
2725
+ float
2726
+ The intercept of the curve
2057
2727
  """
2058
2728
  i, j = dist_mat.shape
2059
2729
  x1 = max(0, i - j) / 2
@@ -2073,22 +2743,33 @@ class lineageTree(lineageTreeLoaders):
2073
2743
  fast: bool = False,
2074
2744
  w: int = 0,
2075
2745
  centered_band: bool = True,
2076
- ) -> (((int, int), ...), np.ndarray):
2746
+ ) -> tuple[list[int], np.ndarray, float]:
2077
2747
  """
2078
2748
  Find DTW minimum cost between two series using dynamic programming.
2079
2749
 
2080
- Args:
2081
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2082
- start_d (int): start delay
2083
- back_d (int): end delay
2084
- w (int): window constrain
2085
- slope (float): to calculate window - givem by the function __calculate_diag_line
2086
- intercept (flost): to calculate window - givem by the function __calculate_diag_line
2087
- use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
2088
-
2089
- Returns:
2090
- (tuple of tuples) Aligment path
2091
- (matrix) Cost matrix
2750
+ Parameters
2751
+ ----------
2752
+ dist_mat : np.ndarray
2753
+ distance matrix obtained by the function calculate_dtw
2754
+ start_d : int, default=0
2755
+ start delay
2756
+ back_d : int, default=0
2757
+ end delay
2758
+ fast : bool, default=False
2759
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2760
+ w : int, default=0
2761
+ window constrain
2762
+ centered_band : bool, default=True
2763
+ if `True`, the band will be centered around the diagonal
2764
+
2765
+ Returns
2766
+ -------
2767
+ tuple of tuples of int
2768
+ Aligment path
2769
+ np.ndarray
2770
+ cost matrix
2771
+ float
2772
+ optimal cost
2092
2773
  """
2093
2774
  N, M = dist_mat.shape
2094
2775
  w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
@@ -2209,7 +2890,6 @@ class lineageTree(lineageTreeLoaders):
2209
2890
 
2210
2891
  # special reflection case
2211
2892
  if np.linalg.det(R) < 0:
2212
- # print("det(R) < R, reflection detected!, correcting for it ...")
2213
2893
  Vt[2, :] *= -1
2214
2894
  R = Vt.T @ U.T
2215
2895
 
@@ -2218,52 +2898,59 @@ class lineageTree(lineageTreeLoaders):
2218
2898
  return R, t
2219
2899
 
2220
2900
  def __interpolate(
2221
- self, track1: list, track2: list, threshold: int
2222
- ) -> (np.ndarray, np.ndarray):
2901
+ self, chain1: list, chain2: list, threshold: int
2902
+ ) -> tuple[np.ndarray, np.ndarray]:
2223
2903
  """
2224
2904
  Interpolate two series that have different lengths
2225
2905
 
2226
- Args:
2227
- track1 (list): list of nodes of the first cell cycle to compare
2228
- track2 (list): list of nodes of the second cell cycle to compare
2229
- threshold (int): set a maximum number of points a track can have
2230
-
2231
- Returns:
2232
- (list of list) x, y, z postions for track1
2233
- (list of list) x, y, z postions for track2
2906
+ Parameters
2907
+ ----------
2908
+ chain1 : list of int
2909
+ list of nodes of the first chain to compare
2910
+ chain2 : list of int
2911
+ list of nodes of the second chain to compare
2912
+ threshold : int
2913
+ set a maximum number of points a chain can have
2914
+
2915
+ Returns
2916
+ -------
2917
+ list of np.ndarray
2918
+ `x`, `y`, `z` postions for `chain1`
2919
+ list of np.ndarray
2920
+ `x`, `y`, `z` postions for `chain2`
2234
2921
  """
2235
2922
  inter1_pos = []
2236
2923
  inter2_pos = []
2237
2924
 
2238
- track1_pos = np.array([self.pos[c_id] for c_id in track1])
2239
- track2_pos = np.array([self.pos[c_id] for c_id in track2])
2925
+ chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
2926
+ chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
2240
2927
 
2241
- # Both tracks have the same length and size below the threshold - nothing is done
2242
- if len(track1) == len(track2) and (
2243
- len(track1) <= threshold or len(track2) <= threshold
2928
+ # Both chains have the same length and size below the threshold - nothing is done
2929
+ if len(chain1) == len(chain2) and (
2930
+ len(chain1) <= threshold or len(chain2) <= threshold
2244
2931
  ):
2245
- return track1_pos, track2_pos
2246
- # Both tracks have the same length but one or more sizes are above the threshold
2247
- elif len(track1) > threshold or len(track2) > threshold:
2932
+ return chain1_pos, chain2_pos
2933
+ # Both chains have the same length but one or more sizes are above the threshold
2934
+ elif len(chain1) > threshold or len(chain2) > threshold:
2248
2935
  sampling = threshold
2249
- # Tracks have different lengths and the sizes are below the threshold
2936
+ # chains have different lengths and the sizes are below the threshold
2250
2937
  else:
2251
- sampling = max(len(track1), len(track2))
2938
+ sampling = max(len(chain1), len(chain2))
2252
2939
 
2253
2940
  for pos in range(3):
2254
- track1_interp = InterpolatedUnivariateSpline(
2255
- np.linspace(0, 1, len(track1_pos[:, pos])),
2256
- track1_pos[:, pos],
2941
+ chain1_interp = InterpolatedUnivariateSpline(
2942
+ np.linspace(0, 1, len(chain1_pos[:, pos])),
2943
+ chain1_pos[:, pos],
2257
2944
  k=1,
2258
2945
  )
2259
- inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
2946
+ inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
2260
2947
 
2261
- track2_interp = InterpolatedUnivariateSpline(
2262
- np.linspace(0, 1, len(track2_pos[:, pos])),
2263
- track2_pos[:, pos],
2948
+ chain2_interp = InterpolatedUnivariateSpline(
2949
+ np.linspace(0, 1, len(chain2_pos[:, pos])),
2950
+ chain2_pos[:, pos],
2264
2951
  k=1,
2265
2952
  )
2266
- inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
2953
+ inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
2267
2954
 
2268
2955
  return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
2269
2956
 
@@ -2279,46 +2966,66 @@ class lineageTree(lineageTreeLoaders):
2279
2966
  w: int = 0,
2280
2967
  centered_band: bool = True,
2281
2968
  cost_mat_p: bool = False,
2282
- ) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
2283
- """
2284
- Calculate DTW distance between two cell cycles
2285
-
2286
- Args:
2287
- nodes1 (int): node to compare distance
2288
- nodes2 (int): node to compare distance
2289
- threshold: set a maximum number of points a track can have
2290
- regist (boolean): Rotate and translate trajectories
2291
- start_d (int): start delay
2292
- back_d (int): end delay
2293
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2294
- w (int): window size
2295
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2296
- cost_mat_p (boolean): True if print the not normalized cost matrix
2297
-
2298
- Returns:
2299
- (float) DTW distance
2300
- (tuple of tuples) Aligment path
2301
- (matrix) Cost matrix
2302
- (list of lists) pos_cycle1: rotated and translated trajectories positions
2303
- (list of lists) pos_cycle2: rotated and translated trajectories positions
2969
+ ) -> (
2970
+ tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
2971
+ | tuple[float, tuple]
2972
+ ):
2304
2973
  """
2305
- nodes1_cycle = self.get_cycle(nodes1)
2306
- nodes2_cycle = self.get_cycle(nodes2)
2307
-
2308
- interp_cycle1, interp_cycle2 = self.__interpolate(
2309
- nodes1_cycle, nodes2_cycle, threshold
2974
+ Calculate DTW distance between two chains
2975
+
2976
+ Parameters
2977
+ ----------
2978
+ nodes1 : int
2979
+ node to compare distance
2980
+ nodes2 : int
2981
+ node to compare distance
2982
+ threshold : int, default=1000
2983
+ set a maximum number of points a chain can have
2984
+ regist : bool, default=True
2985
+ Rotate and translate trajectories
2986
+ start_d : int, default=0
2987
+ start delay
2988
+ back_d : int, default=0
2989
+ end delay
2990
+ fast : bool, default=False
2991
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2992
+ w : int, default=0
2993
+ window size
2994
+ centered_band : bool, default=True
2995
+ when running the fast algorithm, `True` if the windown is centered
2996
+ cost_mat_p : bool, default=False
2997
+ True if print the not normalized cost matrix
2998
+
2999
+ Returns
3000
+ -------
3001
+ float
3002
+ DTW distance
3003
+ tuple of tuples
3004
+ Aligment path
3005
+ matrix
3006
+ Cost matrix
3007
+ list of lists
3008
+ rotated and translated trajectories positions
3009
+ list of lists
3010
+ rotated and translated trajectories positions
3011
+ """
3012
+ nodes1_chain = self.get_chain_of_node(nodes1)
3013
+ nodes2_chain = self.get_chain_of_node(nodes2)
3014
+
3015
+ interp_chain1, interp_chain2 = self.__interpolate(
3016
+ nodes1_chain, nodes2_chain, threshold
2310
3017
  )
2311
3018
 
2312
- pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
2313
- pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
3019
+ pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
3020
+ pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
2314
3021
 
2315
3022
  if regist:
2316
3023
  R, t = self.__rigid_transform_3D(
2317
- np.transpose(interp_cycle1), np.transpose(interp_cycle2)
3024
+ np.transpose(interp_chain1), np.transpose(interp_chain2)
2318
3025
  )
2319
- pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
3026
+ pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
2320
3027
 
2321
- dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
3028
+ dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
2322
3029
 
2323
3030
  path, cost_mat, final_cost = self.__dp(
2324
3031
  dist_mat,
@@ -2331,7 +3038,7 @@ class lineageTree(lineageTreeLoaders):
2331
3038
  cost = final_cost / len(path)
2332
3039
 
2333
3040
  if cost_mat_p:
2334
- return cost, path, cost_mat, pos_cycle1, pos_cycle2
3041
+ return cost, path, cost_mat, pos_chain1, pos_chain2
2335
3042
  else:
2336
3043
  return cost, path
2337
3044
 
@@ -2346,24 +3053,39 @@ class lineageTree(lineageTreeLoaders):
2346
3053
  fast: bool = False,
2347
3054
  w: int = 0,
2348
3055
  centered_band: bool = True,
2349
- ) -> (float, plt.figure):
2350
- """
2351
- Plot DTW cost matrix between two cell cycles in heatmap format
2352
-
2353
- Args:
2354
- nodes1 (int): node to compare distance
2355
- nodes2 (int): node to compare distance
2356
- start_d (int): start delay
2357
- back_d (int): end delay
2358
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2359
- w (int): window size
2360
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2361
-
2362
- Returns:
2363
- (float) DTW distance
2364
- (figure) Heatmap of cost matrix with opitimal path
2365
- """
2366
- cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
3056
+ ) -> tuple[float, plt.Figure]:
3057
+ """
3058
+ Plot DTW cost matrix between two chains in heatmap format
3059
+
3060
+ Parameters
3061
+ ----------
3062
+ nodes1 : int
3063
+ node to compare distance
3064
+ nodes2 : int
3065
+ node to compare distance
3066
+ threshold : int, default=1000
3067
+ set a maximum number of points a chain can have
3068
+ regist : bool, default=True
3069
+ Rotate and translate trajectories
3070
+ start_d : int, default=0
3071
+ start delay
3072
+ back_d : int, default=0
3073
+ end delay
3074
+ fast : bool, default=False
3075
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
3076
+ w : int, default=0
3077
+ window size
3078
+ centered_band : bool, default=True
3079
+ when running the fast algorithm, `True` if the windown is centered
3080
+
3081
+ Returns
3082
+ -------
3083
+ float
3084
+ DTW distance
3085
+ plt.Figure
3086
+ Heatmap of cost matrix with opitimal path
3087
+ """
3088
+ cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
2367
3089
  nodes1,
2368
3090
  nodes2,
2369
3091
  threshold,
@@ -2385,32 +3107,32 @@ class lineageTree(lineageTreeLoaders):
2385
3107
  ax.set_title("Heatmap of DTW Cost Matrix")
2386
3108
  ax.set_xlabel("Tree 1")
2387
3109
  ax.set_ylabel("tree 2")
2388
- x_path, y_path = zip(*path)
3110
+ x_path, y_path = zip(*path, strict=True)
2389
3111
  ax.plot(y_path, x_path, color="black")
2390
3112
 
2391
3113
  return cost, fig
2392
3114
 
2393
3115
  @staticmethod
2394
3116
  def __plot_2d(
2395
- pos_cycle1,
2396
- pos_cycle2,
2397
- nodes1,
2398
- nodes2,
2399
- ax,
2400
- x_idx,
2401
- y_idx,
2402
- x_label,
2403
- y_label,
2404
- ):
3117
+ pos_chain1: np.ndarray,
3118
+ pos_chain2: np.ndarray,
3119
+ nodes1: list[int],
3120
+ nodes2: list[int],
3121
+ ax: plt.Axes,
3122
+ x_idx: list[int],
3123
+ y_idx: list[int],
3124
+ x_label: str,
3125
+ y_label: str,
3126
+ ) -> None:
2405
3127
  ax.plot(
2406
- pos_cycle1[:, x_idx],
2407
- pos_cycle1[:, y_idx],
3128
+ pos_chain1[:, x_idx],
3129
+ pos_chain1[:, y_idx],
2408
3130
  "-",
2409
3131
  label=f"root = {nodes1}",
2410
3132
  )
2411
3133
  ax.plot(
2412
- pos_cycle2[:, x_idx],
2413
- pos_cycle2[:, y_idx],
3134
+ pos_chain2[:, x_idx],
3135
+ pos_chain2[:, y_idx],
2414
3136
  "-",
2415
3137
  label=f"root = {nodes2}",
2416
3138
  )
@@ -2428,40 +3150,55 @@ class lineageTree(lineageTreeLoaders):
2428
3150
  fast: bool = False,
2429
3151
  w: int = 0,
2430
3152
  centered_band: bool = True,
2431
- projection: str = None,
3153
+ projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
2432
3154
  alig: bool = False,
2433
- ) -> (float, plt.figure):
2434
- """
2435
- Plots DTW trajectories aligment between two cell cycles in 2D or 3D
2436
-
2437
- Args:
2438
- nodes1 (int): node to compare distance
2439
- nodes2 (int): node to compare distance
2440
- threshold (int): set a maximum number of points a track can have
2441
- regist (boolean): Rotate and translate trajectories
2442
- start_d (int): start delay
2443
- back_d (int): end delay
2444
- w (int): window size
2445
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2446
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2447
- projection (string): specify which 2D to plot ->
2448
- '3d' : for the 3d visualization
2449
- 'xy' or None (default) : 2D projection of axis x and y
2450
- 'xz' : 2D projection of axis x and z
2451
- 'yz' : 2D projection of axis y and z
2452
- 'pca' : PCA projection
2453
- alig (boolean): True to show alignment on plot
2454
-
2455
- Returns:
2456
- (float) DTW distance
2457
- (figue) Trajectories Plot
3155
+ ) -> tuple[float, plt.Figure]:
3156
+ """
3157
+ Plots DTW trajectories aligment between two chains in 2D or 3D
3158
+
3159
+ Parameters
3160
+ ----------
3161
+ nodes1 : int
3162
+ node to compare distance
3163
+ nodes2 : int
3164
+ node to compare distance
3165
+ threshold : int, default=1000
3166
+ set a maximum number of points a chain can have
3167
+ regist : bool, default=True
3168
+ Rotate and translate trajectories
3169
+ start_d : int, default=0
3170
+ start delay
3171
+ back_d : int, default=0
3172
+ end delay
3173
+ w : int, default=0
3174
+ window size
3175
+ fast : bool, default=False
3176
+ True if the user wants to run the fast algorithm with window restrains
3177
+ centered_band : bool, default=True
3178
+ if running the fast algorithm, True if the windown is centered
3179
+ projection : {"3d", "xy", "xz", "yz", "pca"}, optional
3180
+ specify which 2D to plot ->
3181
+ "3d" : for the 3d visualization
3182
+ "xy" or None (default) : 2D projection of axis x and y
3183
+ "xz" : 2D projection of axis x and z
3184
+ "yz" : 2D projection of axis y and z
3185
+ "pca" : PCA projection
3186
+ alig : bool
3187
+ True to show alignment on plot
3188
+
3189
+ Returns
3190
+ -------
3191
+ float
3192
+ DTW distance
3193
+ figure
3194
+ Trajectories Plot
2458
3195
  """
2459
3196
  (
2460
3197
  distance,
2461
3198
  alignment,
2462
3199
  cost_mat,
2463
- pos_cycle1,
2464
- pos_cycle2,
3200
+ pos_chain1,
3201
+ pos_chain2,
2465
3202
  ) = self.calculate_dtw(
2466
3203
  nodes1,
2467
3204
  nodes2,
@@ -2484,16 +3221,16 @@ class lineageTree(lineageTreeLoaders):
2484
3221
 
2485
3222
  if projection == "3d":
2486
3223
  ax.plot(
2487
- pos_cycle1[:, 0],
2488
- pos_cycle1[:, 1],
2489
- pos_cycle1[:, 2],
3224
+ pos_chain1[:, 0],
3225
+ pos_chain1[:, 1],
3226
+ pos_chain1[:, 2],
2490
3227
  "-",
2491
3228
  label=f"root = {nodes1}",
2492
3229
  )
2493
3230
  ax.plot(
2494
- pos_cycle2[:, 0],
2495
- pos_cycle2[:, 1],
2496
- pos_cycle2[:, 2],
3231
+ pos_chain2[:, 0],
3232
+ pos_chain2[:, 1],
3233
+ pos_chain2[:, 2],
2497
3234
  "-",
2498
3235
  label=f"root = {nodes2}",
2499
3236
  )
@@ -2503,8 +3240,8 @@ class lineageTree(lineageTreeLoaders):
2503
3240
  else:
2504
3241
  if projection == "xy" or projection == "yx" or projection is None:
2505
3242
  self.__plot_2d(
2506
- pos_cycle1,
2507
- pos_cycle2,
3243
+ pos_chain1,
3244
+ pos_chain2,
2508
3245
  nodes1,
2509
3246
  nodes2,
2510
3247
  ax,
@@ -2515,8 +3252,8 @@ class lineageTree(lineageTreeLoaders):
2515
3252
  )
2516
3253
  elif projection == "xz" or projection == "zx":
2517
3254
  self.__plot_2d(
2518
- pos_cycle1,
2519
- pos_cycle2,
3255
+ pos_chain1,
3256
+ pos_chain2,
2520
3257
  nodes1,
2521
3258
  nodes2,
2522
3259
  ax,
@@ -2527,8 +3264,8 @@ class lineageTree(lineageTreeLoaders):
2527
3264
  )
2528
3265
  elif projection == "yz" or projection == "zy":
2529
3266
  self.__plot_2d(
2530
- pos_cycle1,
2531
- pos_cycle2,
3267
+ pos_chain1,
3268
+ pos_chain2,
2532
3269
  nodes1,
2533
3270
  nodes2,
2534
3271
  ax,
@@ -2542,24 +3279,25 @@ class lineageTree(lineageTreeLoaders):
2542
3279
  from sklearn.decomposition import PCA
2543
3280
  except ImportError:
2544
3281
  Warning(
2545
- "scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
3282
+ "scikit-learn is not installed, the PCA orientation cannot be used."
3283
+ "You can install scikit-learn with pip install"
2546
3284
  )
2547
3285
 
2548
3286
  # Apply PCA
2549
3287
  pca = PCA(n_components=2)
2550
- pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
2551
- pos_cycle1_2d = pca.transform(pos_cycle1)
2552
- pos_cycle2_2d = pca.transform(pos_cycle2)
3288
+ pca.fit(np.vstack([pos_chain1, pos_chain2]))
3289
+ pos_chain1_2d = pca.transform(pos_chain1)
3290
+ pos_chain2_2d = pca.transform(pos_chain2)
2553
3291
 
2554
3292
  ax.plot(
2555
- pos_cycle1_2d[:, 0],
2556
- pos_cycle1_2d[:, 1],
3293
+ pos_chain1_2d[:, 0],
3294
+ pos_chain1_2d[:, 1],
2557
3295
  "-",
2558
3296
  label=f"root = {nodes1}",
2559
3297
  )
2560
3298
  ax.plot(
2561
- pos_cycle2_2d[:, 0],
2562
- pos_cycle2_2d[:, 1],
3299
+ pos_chain2_2d[:, 0],
3300
+ pos_chain2_2d[:, 1],
2563
3301
  "-",
2564
3302
  label=f"root = {nodes2}",
2565
3303
  )
@@ -2588,7 +3326,7 @@ class lineageTree(lineageTreeLoaders):
2588
3326
  'pca' : PCA projection"""
2589
3327
  )
2590
3328
 
2591
- connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
3329
+ connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
2592
3330
 
2593
3331
  for connection in connections:
2594
3332
  xyz1 = connection[0]
@@ -2616,108 +3354,197 @@ class lineageTree(lineageTreeLoaders):
2616
3354
 
2617
3355
  return distance, fig
2618
3356
 
2619
- def first_labelling(self):
2620
- self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
3357
+ def get_subtree(self, node_list: set[int]) -> lineageTree:
3358
+ new_successors = {
3359
+ n: tuple(vi for vi in self.successor[n] if vi in node_list)
3360
+ for n in node_list
3361
+ }
3362
+ return lineageTree(
3363
+ successor=new_successors,
3364
+ time=self._time,
3365
+ pos=self.pos,
3366
+ name=self.name,
3367
+ root_leaf_value=[
3368
+ (),
3369
+ ],
3370
+ **{
3371
+ name: self.__dict__[name]
3372
+ for name in self._custom_property_list
3373
+ },
3374
+ )
2621
3375
 
2622
3376
  def __init__(
2623
3377
  self,
2624
- file_format: str = None,
2625
- tb: int = None,
2626
- te: int = None,
2627
- z_mult: float = 1.0,
2628
- file_type: str = "",
2629
- delim: str = ",",
2630
- eigen: bool = False,
2631
- shape: tuple = None,
2632
- raw_size: tuple = None,
2633
- reorder: bool = False,
2634
- xml_attributes: tuple = None,
2635
- name: str = None,
2636
- time_resolution: Union[int, None] = None,
3378
+ *,
3379
+ successor: dict[int, Sequence] | None = None,
3380
+ predecessor: dict[int, int | Sequence] | None = None,
3381
+ time: dict[int, int] | None = None,
3382
+ starting_time: int | None = None,
3383
+ pos: dict[int, Iterable] | None = None,
3384
+ name: str | None = None,
3385
+ root_leaf_value: Sequence | None = None,
3386
+ **kwargs,
2637
3387
  ):
2638
- """
2639
- TODO: complete the doc
2640
- Main library to build tree graph representation of lineage tree data
2641
- It can read TGMM, ASTEC, SVF, MaMuT and TrackMate outputs.
2642
-
2643
- Args:
2644
- file_format (str): either - path format to TGMM xmls
2645
- - path to the MaMuT xml
2646
- - path to the binary file
2647
- tb (int, optional):first time point (necessary for TGMM xmls only)
2648
- te (int, optional): last time point (necessary for TGMM xmls only)
2649
- z_mult (float, optional):z aspect ratio if necessary (usually only for TGMM xmls)
2650
- file_type (str, optional):type of input file. Accepts:
2651
- 'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
2652
- default is 'binary'
2653
- delim (str, optional): _description_. Defaults to ",".
2654
- eigen (bool, optional): _description_. Defaults to False.
2655
- shape (tuple, optional): _description_. Defaults to None.
2656
- raw_size (tuple, optional): _description_. Defaults to None.
2657
- reorder (bool, optional): _description_. Defaults to False.
2658
- xml_attributes (tuple, optional): _description_. Defaults to None.
2659
- name (str, optional): The name of the dataset. Defaults to None.
2660
- time_resolution (Union[int, None], optional): Time resolution in mins (If time resolution is smaller than one minute input the time in ms). Defaults to None.
2661
- """
3388
+ """Create a lineageTree object from minimal information, without reading from a file.
3389
+ Either `successor` or `predecessor` should be specified.
3390
+
3391
+ Parameters
3392
+ ----------
3393
+ successor : dict mapping int to Iterable
3394
+ Dictionary assigning nodes to their successors.
3395
+ predecessor : dict mapping int to int or Iterable
3396
+ Dictionary assigning nodes to their predecessors.
3397
+ time : dict mapping int to int, optional
3398
+ Dictionary assigning nodes to the time point they were recorded to.
3399
+ Defaults to None, in which case all times are set to `starting_time`.
3400
+ starting_time : int, optional
3401
+ Starting time of the lineage tree. Defaults to 0.
3402
+ pos : dict mapping int to Iterable, optional
3403
+ Dictionary assigning nodes to their positions. Defaults to None.
3404
+ name : str, optional
3405
+ Name of the lineage tree. Defaults to None.
3406
+ root_leaf_value : Iterable, optional
3407
+ Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
3408
+ Defaults are `[None, (), [], set()]`.
3409
+ **kwargs:
3410
+ Supported keyword arguments are dictionaries assigning nodes to any custom property.
3411
+ The property must be specified for every node, and named differently from lineageTree's own attributes.
3412
+ """
3413
+ self.__version__ = __version__
3414
+ self.name = str(name) if name is not None else None
3415
+ if successor is not None and predecessor is not None:
3416
+ raise ValueError(
3417
+ "You cannot have both successors and predecessors."
3418
+ )
2662
3419
 
2663
- self.name = name
2664
- self.time_nodes = {}
2665
- self.time_edges = {}
2666
- self.max_id = -1
2667
- self.next_id = []
2668
- self.nodes = set()
2669
- self.successor = {}
2670
- self.predecessor = {}
2671
- self.pos = {}
2672
- self.time_id = {}
2673
- self.time = {}
2674
- if time_resolution is not None:
2675
- self._time_resolution = time_resolution
2676
- self.kdtrees = {}
2677
- self.spatial_density = {}
2678
- if file_type and file_format:
2679
- if xml_attributes is None:
2680
- self.xml_attributes = []
2681
- else:
2682
- self.xml_attributes = xml_attributes
2683
- file_type = file_type.lower()
2684
- if file_type == "tgmm":
2685
- self.read_tgmm_xml(file_format, tb, te, z_mult)
2686
- self.t_b = tb
2687
- self.t_e = te
2688
- elif file_type == "mamut" or file_type == "trackmate":
2689
- self.read_from_mamut_xml(file_format)
2690
- elif file_type == "celegans":
2691
- self.read_from_txt_for_celegans(file_format)
2692
- elif file_type == "celegans_cao":
2693
- self.read_from_txt_for_celegans_CAO(
2694
- file_format,
2695
- reorder=reorder,
2696
- shape=shape,
2697
- raw_size=raw_size,
3420
+ if root_leaf_value is None:
3421
+ root_leaf_value = [None, (), [], set()]
3422
+ elif not isinstance(root_leaf_value, Iterable):
3423
+ raise TypeError(
3424
+ f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
3425
+ )
3426
+ elif len(root_leaf_value) < 1:
3427
+ raise ValueError(
3428
+ "root_leaf_value should have at least one element."
3429
+ )
3430
+ self._successor = {}
3431
+ self._predecessor = {}
3432
+ if successor is not None:
3433
+ for pred, succs in successor.items():
3434
+ if succs in root_leaf_value:
3435
+ self._successor[pred] = ()
3436
+ else:
3437
+ if not isinstance(succs, Iterable):
3438
+ raise TypeError(
3439
+ f"Successors should be Iterable, got {type(succs)}."
3440
+ )
3441
+ if len(succs) == 0:
3442
+ raise ValueError(
3443
+ f"{succs} was not declared as a leaf but was found as a successor.\n"
3444
+ "Please lift the ambiguity."
3445
+ )
3446
+ self._successor[pred] = tuple(succs)
3447
+ for succ in succs:
3448
+ if succ in self._predecessor:
3449
+ raise ValueError(
3450
+ "Node can have at most one predecessor."
3451
+ )
3452
+ self._predecessor[succ] = (pred,)
3453
+ elif predecessor is not None:
3454
+ for succ, pred in predecessor.items():
3455
+ if pred in root_leaf_value:
3456
+ self._predecessor[succ] = ()
3457
+ else:
3458
+ if isinstance(pred, Sequence):
3459
+ if len(pred) == 0:
3460
+ raise ValueError(
3461
+ f"{pred} was not declared as a leaf but was found as a successor.\n"
3462
+ "Please lift the ambiguity."
3463
+ )
3464
+ if 1 < len(pred):
3465
+ raise ValueError(
3466
+ "Node can have at most one predecessor."
3467
+ )
3468
+ pred = pred[0]
3469
+ self._predecessor[succ] = (pred,)
3470
+ self._successor.setdefault(pred, ())
3471
+ self._successor[pred] += (succ,)
3472
+ for root in set(self._successor).difference(self._predecessor):
3473
+ self._predecessor[root] = ()
3474
+ for leaf in set(self._predecessor).difference(self._successor):
3475
+ self._successor[leaf] = ()
3476
+
3477
+ if self.__check_for_cycles():
3478
+ raise ValueError(
3479
+ "Cycles were found in the tree, there should not be any."
3480
+ )
3481
+
3482
+ if pos is None or len(pos) == 0:
3483
+ self.pos = {}
3484
+ else:
3485
+ if self.nodes.difference(pos) != set():
3486
+ raise ValueError("Please provide the position of all nodes.")
3487
+ self.pos = {
3488
+ node: np.array(position) for node, position in pos.items()
3489
+ }
3490
+ if "labels" in kwargs:
3491
+ self._labels = kwargs["labels"]
3492
+ kwargs.pop("labels")
3493
+ if time is None:
3494
+ if starting_time is None:
3495
+ starting_time = 0
3496
+ if not isinstance(starting_time, int):
3497
+ warnings.warn(
3498
+ f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
3499
+ stacklevel=2,
2698
3500
  )
2699
- elif file_type == "mastodon":
2700
- if isinstance(file_format, list) and len(file_format) == 2:
2701
- self.read_from_mastodon_csv(file_format)
3501
+ self._time = dict.fromkeys(self.roots, starting_time)
3502
+ queue = list(self.roots)
3503
+ for node in queue:
3504
+ for succ in self._successor[node]:
3505
+ self._time[succ] = self._time[node] + 1
3506
+ queue.append(succ)
3507
+ else:
3508
+ if starting_time is not None:
3509
+ warnings.warn(
3510
+ "Both `time` and `starting_time` were provided, `starting_time` was ignored.",
3511
+ stacklevel=2,
3512
+ )
3513
+ self._time = {n: int(time[n]) for n in self.nodes}
3514
+ if self._time != time:
3515
+ if len(self._time) != len(time):
3516
+ warnings.warn(
3517
+ "The provided `time` dictionary had keys that were not nodes. "
3518
+ "They have been removed",
3519
+ stacklevel=2,
3520
+ )
2702
3521
  else:
2703
- if isinstance(file_format, list):
2704
- file_format = file_format[0]
2705
- self.read_from_mastodon(file_format, name)
2706
- elif file_type == "astec":
2707
- self.read_from_ASTEC(file_format, eigen)
2708
- elif file_type == "csv":
2709
- self.read_from_csv(file_format, z_mult, link=1, delim=delim)
2710
- elif file_type == "bao":
2711
- self.read_C_elegans_bao(file_format)
2712
- elif file_format and file_format.endswith(".lT"):
2713
- with open(file_format, "br") as f:
2714
- tmp = pkl.load(f)
2715
- f.close()
2716
- self.__dict__.update(tmp.__dict__)
2717
- elif file_format is not None:
2718
- self.read_from_binary(file_format)
2719
- if self.name is None:
2720
- try:
2721
- self.name = Path(file_format).stem
2722
- except TypeError:
2723
- self.name = Path(file_format[0]).stem
3522
+ warnings.warn(
3523
+ "The provided `time` dictionary had values that were not `int`. "
3524
+ "These values have been truncated and converted to `int`",
3525
+ stacklevel=2,
3526
+ )
3527
+ if self.nodes.symmetric_difference(self._time) != set():
3528
+ raise ValueError(
3529
+ "Please provide the time of all nodes and only existing nodes."
3530
+ )
3531
+ if not all(
3532
+ self._time[node] < self._time[s]
3533
+ for node, succ in self._successor.items()
3534
+ for s in succ
3535
+ ):
3536
+ raise ValueError(
3537
+ "Provided times are not strictly increasing. Setting times to default."
3538
+ )
3539
+ # custom properties
3540
+ self._custom_property_list = []
3541
+ for name, d in kwargs.items():
3542
+ if name in self.__dict__:
3543
+ warnings.warn(
3544
+ f"Attribute name {name} is reserved.", stacklevel=2
3545
+ )
3546
+ continue
3547
+ setattr(self, name, d)
3548
+ self._custom_property_list.append(name)
3549
+ if not hasattr(self, "_comparisons"):
3550
+ self._comparisons = {}