LineageTree 1.8.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,643 +2,538 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.metadata
5
9
  import os
6
10
  import pickle as pkl
7
11
  import struct
8
12
  import warnings
9
- from collections.abc import Iterable
10
- from functools import partial
13
+ from collections.abc import Callable, Iterable, Sequence
14
+ from functools import partial, wraps
11
15
  from itertools import combinations
12
16
  from numbers import Number
13
- from pathlib import Path
14
- from typing import TextIO, Union
15
-
16
- from .loaders import lineageTreeLoaders
17
- from .tree_styles import tree_style
18
-
19
- try:
20
- from edist import uted
21
- except ImportError:
22
- warnings.warn(
23
- "No edist installed therefore you will not be able to compute the tree edit distance.",
24
- stacklevel=2,
25
- )
17
+ from types import MappingProxyType
18
+ from typing import TYPE_CHECKING, Literal
19
+
20
+ import matplotlib.colors as mcolors
26
21
  import matplotlib.pyplot as plt
27
22
  import numpy as np
23
+ import svgwrite
24
+ from edist import uted
25
+ from matplotlib import colormaps
26
+ from matplotlib.collections import LineCollection
27
+ from packaging.version import Version
28
28
  from scipy.interpolate import InterpolatedUnivariateSpline
29
- from scipy.spatial import Delaunay, distance
30
- from scipy.spatial import cKDTree as KDTree
29
+ from scipy.sparse import dok_array
30
+ from scipy.spatial import Delaunay, KDTree, distance
31
31
 
32
+ from .tree_approximation import TreeApproximationTemplate, tree_style
32
33
  from .utils import (
33
- create_links_and_cycles,
34
+ convert_style_to_number,
35
+ create_links_and_chains,
34
36
  hierarchical_pos,
35
37
  )
36
38
 
39
+ if TYPE_CHECKING:
40
+ from edist.alignment import Alignment
41
+
42
+
43
+ class dynamic_property(property):
44
+ def __init__(
45
+ self, fget=None, fset=None, fdel=None, doc=None, protected_name=None
46
+ ):
47
+ super().__init__(fget, fset, fdel, doc)
48
+ self.protected_name = protected_name
49
+
50
+ def __set_name__(self, owner, name):
51
+ self.name = name
52
+ if self.protected_name is None:
53
+ self.protected_name = f"_{name}"
54
+ if not hasattr(owner, "_protected_dynamic_properties"):
55
+ owner._protected_dynamic_properties = []
56
+ owner._protected_dynamic_properties.append(self.protected_name)
57
+ if not hasattr(owner, "_dynamic_properties"):
58
+ owner._dynamic_properties = []
59
+ owner._dynamic_properties += [name, self.protected_name]
60
+ setattr(owner, self.protected_name, None)
61
+
62
+ def __get__(self, instance, owner):
63
+ if instance is None:
64
+ return self
65
+ instance._has_been_reset = False
66
+ if getattr(instance, self.protected_name) is None:
67
+ value = super().__get__(instance, owner)
68
+ setattr(instance, self.protected_name, value)
69
+ return value
70
+ else:
71
+ return getattr(instance, self.protected_name)
72
+
73
+
74
+ class lineageTree:
75
+ norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
76
+
77
+ def modifier(wrapped_func):
78
+ @wraps(wrapped_func)
79
+ def raising_flag(self, *args, **kwargs):
80
+ should_reset = (
81
+ not hasattr(self, "_has_been_reset")
82
+ or not self._has_been_reset
83
+ )
84
+ out_func = wrapped_func(self, *args, **kwargs)
85
+ if should_reset:
86
+ for prop in self._protected_dynamic_properties:
87
+ self.__dict__[prop] = None
88
+ self._has_been_reset = True
89
+ return out_func
90
+
91
+ return raising_flag
92
+
93
+ def __check_cc_cycles(self, n: int) -> tuple[bool, set[int]]:
94
+ """Check if the connected component of a given node `n` has a cycle.
95
+
96
+ Returns
97
+ -------
98
+ bool
99
+ True if the tree has cycles, False otherwise.
100
+ set of int
101
+ The set of nodes that have been checked.
102
+ """
103
+ to_do = [n]
104
+ no_cycle = True
105
+ already_done = set()
106
+ while to_do and no_cycle:
107
+ current = to_do.pop(-1)
108
+ if current not in already_done:
109
+ already_done.add(current)
110
+ else:
111
+ no_cycle = False
112
+ to_do.extend(self._successor[current])
113
+ to_do = list(self._predecessor[n])
114
+ while to_do and no_cycle:
115
+ current = to_do.pop(-1)
116
+ if current not in already_done:
117
+ already_done.add(current)
118
+ else:
119
+ no_cycle = False
120
+ to_do.extend(self._predecessor[current])
121
+ return not no_cycle, already_done
37
122
 
38
- class lineageTree(lineageTreeLoaders):
39
- def __eq__(self, other):
123
+ def __check_for_cycles(self) -> bool:
124
+ """Check if the tree has cycles.
125
+
126
+ Returns
127
+ -------
128
+ bool
129
+ True if the tree has cycles, False otherwise.
130
+ """
131
+ to_do = set(self.nodes)
132
+ found_cycle = False
133
+ while to_do and not found_cycle:
134
+ current = to_do.pop()
135
+ found_cycle, done = self.__check_cc_cycles(current)
136
+ to_do.difference_update(done)
137
+ return found_cycle
138
+
139
+ def __eq__(self, other) -> bool:
40
140
  if isinstance(other, lineageTree):
41
- return other.successor == self.successor
42
- return False
141
+ return (
142
+ other._successor == self._successor
143
+ and other._predecessor == self._predecessor
144
+ and other._time == self._time
145
+ )
146
+ else:
147
+ return False
43
148
 
44
- def get_next_id(self):
45
- """Computes the next authorized id.
149
+ def get_next_id(self) -> int:
150
+ """Computes the next authorized id and assign it.
46
151
 
47
- Returns:
48
- int: next authorized id
152
+ Returns
153
+ -------
154
+ int
155
+ next authorized id
49
156
  """
50
- if self.max_id == -1 and self.nodes:
51
- self.max_id = max(self.nodes)
52
- if self.next_id == []:
157
+ if not hasattr(self, "max_id") or (self.max_id == -1 and self.nodes):
158
+ self.max_id = max(self.nodes) if len(self.nodes) else 0
159
+ if not hasattr(self, "next_id") or self.next_id == []:
53
160
  self.max_id += 1
54
161
  return self.max_id
55
162
  else:
56
163
  return self.next_id.pop()
57
164
 
58
- def complete_lineage(self, nodes: Union[int, set] = None):
59
- """Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
60
- for tree edit distance algorithms.
61
-
62
- Args:
63
- nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
64
- """
65
- if nodes is None:
66
- nodes = set(self.roots)
67
- elif isinstance(nodes, int):
68
- nodes = {nodes}
69
- for node in nodes:
70
- sub = set(self.get_sub_tree(node))
71
- specific_leaves = sub.intersection(self.leaves)
72
- for leaf in specific_leaves:
73
- self.add_branch(
74
- leaf,
75
- (self.t_e - self.time[leaf]),
76
- reverse=True,
77
- move_timepoints=True,
78
- )
79
-
80
165
  ###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
81
- def add_branch(
166
+ @modifier
167
+ def add_chain(
82
168
  self,
83
- pred: int,
169
+ node: int,
84
170
  length: int,
85
- move_timepoints: bool = True,
86
- pos: Union[callable, None] = None,
87
- reverse: bool = False,
88
- ):
89
- """Adds a branch of specific length to a node either as a successor or as a predecessor.
171
+ downstream: bool,
172
+ pos: Callable | None = None,
173
+ ) -> int:
174
+ """Adds a chain of specific length to a node either as a successor or as a predecessor.
90
175
  If it is placed on top of a tree all the nodes will move timepoints #length down.
91
176
 
92
- Args:
93
- pred (int): Id of the successor (predecessor if reverse is False)
94
- length (int): The length of the new branch.
95
- pos (np.ndarray, optional): The new position of the branch. Defaults to None.
96
- move_timepoints (bool): Moves the time, important only if reverse= True
97
- reverese (bool): If True will create a branch that goes forwards in time otherwise backwards.
98
- Returns:
99
- (int): Id of the first node of the sublineage.
177
+ Parameters
178
+ ----------
179
+ node : int
180
+ Id of the successor (predecessor if `downstream==False`)
181
+ length : int
182
+ The length of the new chain.
183
+ downstream : bool, default=True
184
+ If `True` will create a chain that goes forwards in time otherwise backwards.
185
+ pos : np.ndarray, optional
186
+ The new position of the chain. Defaults to None.
187
+
188
+ Returns
189
+ -------
190
+ int
191
+ Id of the first node of the sublineage.
100
192
  """
101
193
  if length == 0:
102
- return pred
103
- if self.predecessor.get(pred) and not reverse:
104
- raise Warning("Cannot add 2 predecessors to a node")
105
- time = self.time[pred]
106
- original = pred
107
- if not reverse:
108
- if move_timepoints:
109
- nodes_to_move = set(self.get_sub_tree(pred))
110
- new_times = {
111
- node: self.time[node] + length for node in nodes_to_move
112
- }
113
- for node in nodes_to_move:
114
- old_time = self.time[node]
115
- self.time_nodes[old_time].remove(node)
116
- self.time_nodes.setdefault(old_time + length, set()).add(
117
- node
118
- )
119
- self.time.update(new_times)
120
- for t in range(length - 1, -1, -1):
121
- _next = self.add_node(
122
- time + t,
123
- succ=pred,
124
- pos=self.pos[original],
125
- reverse=True,
126
- )
127
- pred = _next
128
- else:
129
- for t in range(length):
130
- _next = self.add_node(
131
- time - t,
132
- succ=pred,
133
- pos=self.pos[original],
134
- reverse=True,
135
- )
136
- pred = _next
194
+ return node
195
+ if length < 1:
196
+ raise ValueError("Length cannot be <1")
197
+ if downstream:
198
+ for _ in range(int(length)):
199
+ old_node = node
200
+ node = self._add_node(pred=[old_node])
201
+ self._time[node] = self._time[old_node] + 1
137
202
  else:
138
- for _ in range(length):
139
- _next = self.add_node(
140
- self.time[pred] + 1,
141
- succ=pred,
142
- pos=self.pos[original],
143
- reverse=False,
203
+ if self._predecessor[node]:
204
+ raise Warning("The node already has a predecessor.")
205
+ if self._time[node] - length < self.t_b:
206
+ raise Warning(
207
+ "A node cannot created outside the lower bound of the dataset. (It is possible to change it by lT.t_b = int(...))"
144
208
  )
145
- pred = _next
146
- self.successor[self.get_cycle(pred)[-1]] = []
147
- self.labels[pred] = "New branch"
148
- if self.time[pred] == self.t_b:
149
- self.labels[pred] = "New branch"
150
- if original in self.roots and reverse is True:
151
- self.labels[pred] = "New branch"
152
- self.labels.pop(original, -1)
153
- self.t_e = max(self.time_nodes)
154
- return pred
155
-
156
- def cut_tree(self, root):
157
- """It transforms a lineage that has at least 2 divisions into 2 independent lineages,
158
- that spawn from the time point of the first node. (splits a tree into 2)
159
-
160
- Args:
161
- root (int): The id of the node, which will be cut.
162
-
163
- Returns:
164
- int: The id of the new tree
165
- """
166
- cycle = self.get_successors(root)
167
- last_cell = cycle[-1]
168
- if last_cell in self.successor:
169
- new_lT = self.successor[last_cell].pop()
170
- self.predecessor.pop(new_lT)
171
- label_of_root = self.labels.get(cycle[0], cycle[0])
172
- self.labels[cycle[0]] = f"L-Split {label_of_root}"
173
- new_tr = self.add_branch(new_lT, len(cycle), move_timepoints=False)
174
- self.roots.add(new_tr)
175
- self.labels[new_tr] = f"R-Split {label_of_root}"
176
- return new_tr
177
- else:
178
- raise Warning("No division of the branch")
179
-
180
- def fuse_lineage_tree(
181
- self,
182
- l1_root: int,
183
- l2_root: int,
184
- length_l1: int = 0,
185
- length_l2: int = 0,
186
- length: int = 1,
187
- ):
188
- """Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
189
- first node and the node of the resulting lineage can also be longer.
190
-
191
- Args:
192
- l1_root (int): Id of the first root
193
- l2_root (int): Id of the second root
194
- length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
195
- length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
196
- length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
197
-
198
- Returns:
199
- int: The id of the root of the new lineage.
200
- """
201
- if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
202
- raise ValueError("Please select 2 roots.")
203
- if self.time[l1_root] != self.time[l2_root]:
204
- warnings.warn(
205
- "Using lineagetrees that do not exist in the same timepoint. The operation will continue"
206
- )
207
- new_root1 = self.add_branch(l1_root, length_l1)
208
- new_root2 = self.add_branch(l2_root, length_l2)
209
- next_root1 = self[new_root1][0]
210
- self.remove_nodes(new_root1)
211
- self.successor[new_root2].append(next_root1)
212
- self.predecessor[next_root1] = [new_root2]
213
- new_branch = self.add_branch(
214
- new_root2,
215
- length - 1,
216
- )
217
- self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
218
- return new_branch
219
-
220
- def copy_lineage(self, root):
221
- """
222
- Copies the structure of a tree and makes a new with new nodes.
223
- Warning does not take into account the predecessor of the root node.
224
-
225
- Args:
226
- root (int): The root of the tree to be copied
209
+ for _ in range(int(length)):
210
+ old_node = node
211
+ node = self._add_node(succ=[old_node])
212
+ self._time[node] = self._time[old_node] - 1
213
+ return node
214
+
215
+ @modifier
216
+ def add_root(self, t: int, pos: list | None = None) -> int:
217
+ """Adds a root to a specific timepoint.
218
+
219
+ Parameters
220
+ ----------
221
+ t :int
222
+ The timepoint the node is going to be added.
223
+ pos : list
224
+ The position of the new node.
225
+ Returns
226
+ -------
227
+ int
228
+ The id of the new root.
229
+ """
230
+ C_next = self.get_next_id()
231
+ self._successor[C_next] = ()
232
+ self._predecessor[C_next] = ()
233
+ self._time[C_next] = t
234
+ self.pos[C_next] = pos if isinstance(pos, list) else []
235
+ self._changed_roots = True
236
+ return C_next
227
237
 
228
- Returns:
229
- int: The root of the new tree.
230
- """
231
- new_nodes = {
232
- old_node: self.get_next_id()
233
- for old_node in self.get_sub_tree(root)
234
- }
235
- self.nodes.update(new_nodes.values())
236
- for old_node, new_node in new_nodes.items():
237
- self.time[new_node] = self.time[old_node]
238
- succ = self.successor.get(old_node)
239
- if succ:
240
- self.successor[new_node] = [new_nodes[n] for n in succ]
241
- pred = self.predecessor.get(old_node)
242
- if pred:
243
- self.predecessor[new_node] = [new_nodes[n] for n in pred]
244
- self.pos[new_node] = self.pos[old_node] + 0.5
245
- self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
246
- new_root = new_nodes[root]
247
- self.labels[new_root] = f"Copy of {root}"
248
- if self.time[new_root] == 0:
249
- self.roots.add(new_root)
250
- return new_root
251
-
252
- def add_node(
238
+ def _add_node(
253
239
  self,
254
- t: int = None,
255
- succ: int = None,
256
- pos: np.ndarray = None,
257
- nid: int = None,
258
- reverse: bool = False,
240
+ succ: list | None = None,
241
+ pred: list | None = None,
242
+ pos: np.ndarray | None = None,
243
+ nid: int | None = None,
259
244
  ) -> int:
260
- """Adds a node to the lineageTree and update it accordingly.
261
-
262
- Args:
263
- t (int): int, time to which to add the node
264
- succ (int): id of the node the new node is a successor to
265
- pos ([float, ]): list of three floats representing the 3D
266
- spatial position of the node
267
- nid (int): id value of the new node, to be used carefully,
268
- if None is provided the new id is automatically computed.
269
- reverse (bool): True if in this lineageTree the predecessors
270
- are the successors and reciprocally.
271
- This is there for bacward compatibility, should be left at False.
272
- Returns:
273
- int: id of the new node.
274
- """
245
+ """Adds a node to the LineageTree object that is either a successor or a predecessor of another node.
246
+ Does not handle time! You cannot enter both a successor and a predecessor.
247
+
248
+ Parameters
249
+ ----------
250
+ succ : list
251
+ list of ids of the nodes the new node is a successor to
252
+ pred : list
253
+ list of ids of the nodes the new node is a predecessor to
254
+ pos : np.ndarray, optional
255
+ position of the new node
256
+ nid : int, optional
257
+ id value of the new node, to be used carefully,
258
+ if None is provided the new id is automatically computed.
259
+
260
+ Returns
261
+ -------
262
+ int
263
+ id of the new node.
264
+ """
265
+ if not succ and not pred:
266
+ raise Warning(
267
+ "Please enter a successor or a predecessor, otherwise use the add_roots() function."
268
+ )
275
269
  C_next = self.get_next_id() if nid is None else nid
276
- self.time_nodes.setdefault(t, set()).add(C_next)
277
- if succ is not None and not reverse:
278
- self.successor.setdefault(succ, []).append(C_next)
279
- self.predecessor.setdefault(C_next, []).append(succ)
280
- elif succ is not None:
281
- self.predecessor.setdefault(succ, []).append(C_next)
282
- self.successor.setdefault(C_next, []).append(succ)
283
- self.nodes.add(C_next)
284
- self.pos[C_next] = pos
285
- self.time[C_next] = t
270
+ if succ:
271
+ self._successor[C_next] = succ
272
+ for suc in succ:
273
+ self._predecessor[suc] = (C_next,)
274
+ else:
275
+ self._successor[C_next] = ()
276
+ if pred:
277
+ self._predecessor[C_next] = pred
278
+ self._successor[pred[0]] = self._successor.setdefault(
279
+ pred[0], ()
280
+ ) + (C_next,)
281
+ else:
282
+ self._predecessor[C_next] = ()
283
+ if isinstance(pos, list):
284
+ self.pos[C_next] = pos
286
285
  return C_next
287
286
 
288
- def remove_nodes(self, group: Union[int, set, list]):
287
+ @modifier
288
+ def remove_nodes(self, group: int | set | list) -> None:
289
289
  """Removes a group of nodes from the LineageTree
290
290
 
291
- Args:
292
- group (set|list|int): One or more nodes that are to be removed.
291
+ Parameters
292
+ ----------
293
+ group : set of int or list of int or int
294
+ One or more nodes that are to be removed.
293
295
  """
294
- if isinstance(group, int):
296
+ if isinstance(group, int | float):
295
297
  group = {group}
296
298
  if isinstance(group, list):
297
299
  group = set(group)
298
- group = group.intersection(self.nodes)
299
- self.nodes.difference_update(group)
300
- times = {self.time.pop(n) for n in group}
301
- for t in times:
302
- self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
300
+ group = self.nodes.intersection(group)
303
301
  for node in group:
304
- self.pos.pop(node)
305
- if self.predecessor.get(node):
306
- pred = self.predecessor[node][0]
307
- siblings = self.successor.pop(pred, [])
308
- if len(siblings) == 2:
309
- siblings.remove(node)
310
- self.successor[pred] = siblings
311
- self.predecessor.pop(node, [])
312
- for succ in self.successor.get(node, []):
313
- self.predecessor.pop(succ, [])
314
- self.successor.pop(node, [])
315
- self.labels.pop(node, 0)
316
- if node in self.roots:
317
- self.roots.remove(node)
318
-
319
- def modify_branch(self, node, new_length):
320
- """Changes the length of a branch, so it adds or removes nodes
321
- to make the correct length of the cycle.
322
-
323
- Args:
324
- node (int): Any node of the branch to be modified/
325
- new_length (int): The new length of the tree.
326
- """
327
- if new_length <= 1:
328
- warnings.warn("New length should be more than 1", stacklevel=2)
329
- return None
330
- cycle = self.get_cycle(node)
331
- length = len(cycle)
332
- successors = self.successor.get(cycle[-1])
333
- if length == 1 and new_length != 1:
334
- pred = self.predecessor.pop(node, None)
335
- new_node = self.add_branch(
336
- node,
337
- length=new_length - 1,
338
- move_timepoints=True,
339
- reverse=False,
340
- )
341
- if pred:
342
- self.successor[pred[0]].remove(node)
343
- self.successor[pred[0]].append(new_node)
344
- elif self.leaves.intersection(cycle) and new_length < length:
345
- self.remove_nodes(cycle[new_length:])
346
- elif new_length < length:
347
- to_remove = length - new_length
348
- last_cell = cycle[new_length - 1]
349
- subtree = self.get_sub_tree(cycle[-1])[1:]
350
- self.remove_nodes(cycle[new_length:])
351
- self.successor[last_cell] = successors
352
- if successors:
353
- for succ in successors:
354
- self.predecessor[succ] = [last_cell]
355
- for node in subtree:
356
- if node not in cycle[new_length - 1 :]:
357
- old_time = self.time[node]
358
- self.time[node] = old_time - to_remove
359
- self.time_nodes[old_time].remove(node)
360
- self.time_nodes.setdefault(
361
- old_time - to_remove, set()
362
- ).add(node)
363
- elif length < new_length:
364
- to_add = new_length - length
365
- last_cell = cycle[-1]
366
- self.successor.pop(cycle[-2])
367
- self.predecessor.pop(last_cell)
368
- succ = self.add_branch(
369
- last_cell, length=to_add, move_timepoints=True, reverse=False
370
- )
371
- self.predecessor[succ] = [cycle[-2]]
372
- self.successor[cycle[-2]] = [succ]
373
- self.time[last_cell] = (
374
- self.time[self.predecessor[last_cell][0]] + 1
375
- )
376
- else:
377
- return None
378
-
379
- @property
380
- def time_resolution(self):
381
- if not hasattr(self, "_time_resolution"):
382
- self.time_resolution = 1
383
- return self._time_resolution / 10
384
-
385
- @time_resolution.setter
386
- def time_resolution(self, time_resolution):
387
- if time_resolution is not None:
388
- self._time_resolution = int(time_resolution * 10)
389
- else:
390
- warnings.warn("Time resolution set to default 0", stacklevel=2)
391
- self._time_resolution = 10
392
-
393
- @property
394
- def depth(self):
395
- if not hasattr(self, "_depth"):
396
- self._depth = {}
397
- for leaf in self.leaves:
398
- self._depth[leaf] = 1
399
- while leaf in self.predecessor:
400
- parent = self.predecessor[leaf][0]
401
- current_depth = self._depth.get(parent, 0)
402
- self._depth[parent] = max(
403
- self._depth[leaf] + 1, current_depth
404
- )
405
- leaf = parent
406
- for root in self.roots - set(self._depth):
407
- self._depth[root] = 1
408
- return self._depth
302
+ for attr in self.__dict__:
303
+ attr_value = self.__getattribute__(attr)
304
+ if isinstance(attr_value, dict) and attr not in [
305
+ "successor",
306
+ "predecessor",
307
+ "_successor",
308
+ "_predecessor",
309
+ "_time",
310
+ ]:
311
+ attr_value.pop(node, ())
312
+ if self._predecessor.get(node):
313
+ self._successor[self._predecessor[node][0]] = tuple(
314
+ set(
315
+ self._successor[self._predecessor[node][0]]
316
+ ).difference(group)
317
+ )
318
+ for p_node in self._successor.get(node, []):
319
+ self._predecessor[p_node] = ()
320
+ self._predecessor.pop(node, ())
321
+ self._successor.pop(node, ())
409
322
 
410
323
  @property
411
- def roots(self):
412
- return set(self.nodes).difference(self.predecessor)
324
+ def successor(self) -> MappingProxyType[int, tuple[int]]:
325
+ """The successor of the tree."""
326
+ if not hasattr(self, "_protected_successor"):
327
+ self._protected_successor = MappingProxyType(self._successor)
328
+ return self._protected_successor
413
329
 
414
330
  @property
415
- def edges(self):
416
- return {(k, vi) for k, v in self.successor.items() for vi in v}
331
+ def predecessor(self) -> MappingProxyType[int, tuple[int]]:
332
+ """The predecessor of the tree."""
333
+ if not hasattr(self, "_protected_predecessor"):
334
+ self._protected_predecessor = MappingProxyType(self._predecessor)
335
+ return self._protected_predecessor
417
336
 
418
337
  @property
419
- def leaves(self):
420
- return {p for p, s in self.successor.items() if s == []}
338
+ def time(self) -> MappingProxyType[int, int]:
339
+ """The time of the tree."""
340
+ if not hasattr(self, "_protected_time"):
341
+ self._protected_time = MappingProxyType(self._time)
342
+ return self._protected_time
343
+
344
+ @dynamic_property
345
+ def t_b(self) -> int:
346
+ """The first timepoint of the tree."""
347
+ return min(self._time.values())
348
+
349
+ @dynamic_property
350
+ def t_e(self) -> int:
351
+ """The last timepoint of the tree."""
352
+ return max(self._time.values())
353
+
354
+ @dynamic_property
355
+ def nodes(self) -> frozenset[int]:
356
+ """Nodes of the tree"""
357
+ return frozenset(self._successor.keys())
358
+
359
+ @dynamic_property
360
+ def depth(self) -> dict[int, int]:
361
+ """The depth of each node in the tree."""
362
+ _depth = {}
363
+ for leaf in self.leaves:
364
+ _depth[leaf] = 1
365
+ while leaf in self._predecessor and self._predecessor[leaf]:
366
+ parent = self._predecessor[leaf][0]
367
+ current_depth = _depth.get(parent, 0)
368
+ _depth[parent] = max(_depth[leaf] + 1, current_depth)
369
+ leaf = parent
370
+ for root in self.roots - set(_depth):
371
+ _depth[root] = 1
372
+ return _depth
373
+
374
+ @dynamic_property
375
+ def roots(self) -> frozenset[int]:
376
+ """Set of roots of the tree"""
377
+ return frozenset({s for s, p in self._predecessor.items() if p == ()})
378
+
379
+ @dynamic_property
380
+ def leaves(self) -> frozenset[int]:
381
+ """Set of leaves"""
382
+ return frozenset({p for p, s in self._successor.items() if s == ()})
383
+
384
+ @dynamic_property
385
+ def edges(self) -> tuple[tuple[int, int]]:
386
+ """Set of edges"""
387
+ return tuple((p, si) for p, s in self._successor.items() for si in s)
421
388
 
422
389
  @property
423
- def labels(self):
390
+ def labels(self) -> dict[int, str]:
391
+ """The labels of the nodes."""
424
392
  if not hasattr(self, "_labels"):
425
- if hasattr(self, "cell_name"):
393
+ if hasattr(self, "node_name"):
426
394
  self._labels = {
427
- i: self.cell_name.get(i, "Unlabeled") for i in self.roots
395
+ i: self.node_name.get(i, "Unlabeled") for i in self.roots
428
396
  }
429
397
  else:
430
398
  self._labels = {
431
399
  root: "Unlabeled"
432
400
  for root in self.roots
433
401
  for leaf in self.find_leaves(root)
434
- if abs(self.time[leaf] - self.time[root])
402
+ if abs(self._time[leaf] - self._time[root])
435
403
  >= abs(self.t_e - self.t_b) / 4
436
404
  }
437
405
  return self._labels
438
406
 
439
- def _write_header_am(self, f: TextIO, nb_points: int, length: int):
440
- """Header for Amira .am files"""
441
- f.write("# AmiraMesh 3D ASCII 2.0\n")
442
- f.write("define VERTEX %d\n" % (nb_points * 2))
443
- f.write("define EDGE %d\n" % nb_points)
444
- f.write("define POINT %d\n" % ((length) * nb_points))
445
- f.write("Parameters {\n")
446
- f.write('\tContentType "HxSpatialGraph"\n')
447
- f.write("}\n")
448
-
449
- f.write("VERTEX { float[3] VertexCoordinates } @1\n")
450
- f.write("EDGE { int[2] EdgeConnectivity } @2\n")
451
- f.write("EDGE { int NumEdgePoints } @3\n")
452
- f.write("POINT { float[3] EdgePointCoordinates } @4\n")
453
- f.write("VERTEX { float Vcolor } @5\n")
454
- f.write("VERTEX { int Vbool } @6\n")
455
- f.write("EDGE { float Ecolor } @7\n")
456
- f.write("VERTEX { int Vbool2 } @8\n")
457
-
458
- def write_to_am(
459
- self,
460
- path_format: str,
461
- t_b: int = None,
462
- t_e: int = None,
463
- length: int = 5,
464
- manual_labels: dict = None,
465
- default_label: int = 5,
466
- new_pos: np.ndarray = None,
467
- ):
468
- """Writes a lineageTree into an Amira readable data (.am format).
469
-
470
- Args:
471
- path_format (str): path to the output. It should contain 1 %03d where the time step will be entered
472
- t_b (int): first time point to write (if None, min(LT.to_take_time) is taken)
473
- t_e (int): last time point to write (if None, max(LT.to_take_time) is taken)
474
- note, if there is no 'to_take_time' attribute, self.time_nodes
475
- is considered instead (historical)
476
- length (int): length of the track to print (how many time before).
477
- manual_labels ({id: label, }): dictionary that maps cell ids to
478
- default_label (int): default value for the manual label
479
- new_pos ({id: [x, y, z]}): dictionary that maps a 3D position to a cell ID.
480
- if new_pos == None (default) then self.pos is considered.
481
- """
482
- if not hasattr(self, "to_take_time"):
483
- self.to_take_time = self.time_nodes
484
- if t_b is None:
485
- t_b = min(self.to_take_time.keys())
486
- if t_e is None:
487
- t_e = max(self.to_take_time.keys())
488
- if new_pos is None:
489
- new_pos = self.pos
490
-
491
- if manual_labels is None:
492
- manual_labels = {}
493
- for t in range(t_b, t_e + 1):
494
- with open(path_format % t, "w") as f:
495
- nb_points = len(self.to_take_time[t])
496
- self._write_header_am(f, nb_points, length)
497
- points_v = {}
498
- for C in self.to_take_time[t]:
499
- C_tmp = C
500
- positions = []
501
- for _ in range(length):
502
- C_tmp = self.predecessor.get(C_tmp, [C_tmp])[0]
503
- positions.append(new_pos[C_tmp])
504
- points_v[C] = positions
505
-
506
- f.write("@1\n")
507
- for C in self.to_take_time[t]:
508
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][0])))
509
- f.write("{:f} {:f} {:f}\n".format(*tuple(points_v[C][-1])))
510
-
511
- f.write("@2\n")
512
- for i, _ in enumerate(self.to_take_time[t]):
513
- f.write("%d %d\n" % (2 * i, 2 * i + 1))
514
-
515
- f.write("@3\n")
516
- for _ in self.to_take_time[t]:
517
- f.write("%d\n" % (length))
518
-
519
- f.write("@4\n")
520
- for C in self.to_take_time[t]:
521
- for p in points_v[C]:
522
- f.write("{:f} {:f} {:f}\n".format(*tuple(p)))
523
-
524
- f.write("@5\n")
525
- for C in self.to_take_time[t]:
526
- f.write(f"{manual_labels.get(C, default_label):f}\n")
527
- f.write(f"{0:f}\n")
528
-
529
- f.write("@6\n")
530
- for C in self.to_take_time[t]:
531
- f.write(
532
- "%d\n"
533
- % (
534
- int(
535
- manual_labels.get(C, default_label)
536
- != default_label
537
- )
538
- )
539
- )
540
- f.write("%d\n" % (0))
541
-
542
- f.write("@7\n")
543
- for C in self.to_take_time[t]:
544
- f.write(
545
- f"{np.linalg.norm(points_v[C][0] - points_v[C][-1]):f}\n"
546
- )
547
-
548
- f.write("@8\n")
549
- for _ in self.to_take_time[t]:
550
- f.write("%d\n" % (1))
551
- f.write("%d\n" % (0))
552
- f.close()
407
+ @property
408
+ def time_resolution(self) -> float:
409
+ if not hasattr(self, "_time_resolution"):
410
+ self._time_resolution = 0
411
+ return self._time_resolution / 10
553
412
 
554
- def _get_height(self, c: int, done: dict):
555
- """Recursively computes the height of a cell within a tree * a space factor.
413
+ @time_resolution.setter
414
+ def time_resolution(self, time_resolution) -> None:
415
+ if time_resolution is not None and time_resolution > 0:
416
+ self._time_resolution = int(time_resolution * 10)
417
+ else:
418
+ warnings.warn("Time resolution set to default 0", stacklevel=2)
419
+ self._time_resolution = 0
420
+
421
+ def __setstate__(self, state):
422
+ if "_successor" not in state:
423
+ state["_successor"] = state["successor"]
424
+ if "_predecessor" not in state:
425
+ state["_predecessor"] = state["predecessor"]
426
+ if "_time" not in state:
427
+ state["_time"] = state["time"]
428
+ self.__dict__.update(state)
429
+
430
+ def _get_height(self, c: int, done: dict) -> float:
431
+ """Recursively computes the height of a node within a tree times a space factor.
556
432
  This function is specific to the function write_to_svg.
557
433
 
558
- Args:
559
- c (int): id of a cell in a lineage tree from which the height will be computed from
560
- done ({int: [int, int]}): a dictionary that maps a cell id to its vertical and horizontal position
561
- Returns:
562
- float:
434
+ Parameters
435
+ ----------
436
+ c : int
437
+ id of a node in a lineage tree from which the height will be computed from
438
+ done : dict mapping int to list of two int
439
+ a dictionary that maps a node id to its vertical and horizontal position
440
+
441
+ Returns
442
+ -------
443
+ float
444
+ the height of the node `c`
563
445
  """
564
446
  if c in done:
565
447
  return done[c][0]
566
448
  else:
567
449
  P = np.mean(
568
- [self._get_height(di, done) for di in self.successor[c]]
450
+ [self._get_height(di, done) for di in self._successor[c]]
569
451
  )
570
- done[c] = [P, self.vert_space_factor * self.time[c]]
452
+ done[c] = [P, self.vert_space_factor * self._time[c]]
571
453
  return P
572
454
 
573
455
  def write_to_svg(
574
456
  self,
575
457
  file_name: str,
576
- roots: list = None,
458
+ roots: list | None = None,
577
459
  draw_nodes: bool = True,
578
460
  draw_edges: bool = True,
579
- order_key: callable = None,
461
+ order_key: Callable | None = None,
580
462
  vert_space_factor: float = 0.5,
581
463
  horizontal_space: float = 1,
582
- node_size: callable = None,
583
- stroke_width: callable = None,
464
+ node_size: Callable | str | None = None,
465
+ stroke_width: Callable | None = None,
584
466
  factor: float = 1.0,
585
- node_color: callable = None,
586
- stroke_color: callable = None,
587
- positions: dict = None,
588
- node_color_map: callable = None,
589
- normalize: bool = True,
590
- ):
591
- ##### remove background? default True background value? default 1
592
-
467
+ node_color: Callable | str | None = None,
468
+ stroke_color: Callable | None = None,
469
+ positions: dict | None = None,
470
+ node_color_map: Callable | str | None = None,
471
+ ) -> None:
593
472
  """Writes the lineage tree to an SVG file.
594
473
  Node and edges coloring and size can be provided.
595
474
 
596
- Args:
597
- file_name (str): filesystem filename valid for `open()`
598
- roots ([int, ...]): list of node ids to be drawn. If `None` all the nodes will be drawn. Default `None`
599
- draw_nodes (bool): wether to print the nodes or not, default `True`
600
- draw_edges (bool): wether to print the edges or not, default `True`
601
- order_key (callable): function that would work for the attribute `key=` for the `sort`/`sorted` function
602
- vert_space_factor (float): the vertical position of a node is its time. `vert_space_factor` is a
603
- multiplier to space more or less nodes in time
604
- horizontal_space (float): space between two consecutive nodes
605
- node_size (callable | str): a function that maps a node id to a `float` value that will determine the
606
- radius of the node. The default function return the constant value `vertical_space_factor/2.1`
607
- If a string is given instead and it is a property of the tree,
608
- the the size will be mapped according to the property
609
- stroke_width (callable): a function that maps a node id to a `float` value that will determine the
610
- width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
611
- factor (float): scaling factor for nodes positions, default 1
612
- node_color (callable | str): a function that maps a node id to a triplet between 0 and 255.
613
- The triplet will determine the color of the node. If a string is given instead and it is a property
614
- of the tree, the the color will be mapped according to the property
615
- node_color_map (callable | str): the name of the colormap to use to color the nodes, or a colormap function
616
- stroke_color (callable): a function that maps a node id to a triplet between 0 and 255.
617
- The triplet will determine the color of the stroke of the inward edge.
618
- positions ({int: [float, float], ...}): dictionary that maps a node id to a 2D position.
619
- Default `None`. If provided it will be used to position the nodes.
475
+ Parameters
476
+ ----------
477
+ file_name : str
478
+ filesystem filename valid for `open()`
479
+ roots : list of int, defaults to `self.roots`
480
+ list of node ids to be drawn. If `None` or not provided all the nodes will be drawn. Default `None`
481
+ draw_nodes : bool, default True
482
+ wether to print the nodes or not
483
+ draw_edges : bool, default True
484
+ wether to print the edges or not
485
+ order_key : Callable, optional
486
+ function that would work for the attribute `key=` for the `sort`/`sorted` function
487
+ vert_space_factor : float, default=0.5
488
+ the vertical position of a node is its time. `vert_space_factor` is a
489
+ multiplier to space more or less nodes in time
490
+ horizontal_space : float, default=1
491
+ space between two consecutive nodes
492
+ node_size : Callable or str, optional
493
+ a function that maps a node id to a `float` value that will determine the
494
+ radius of the node. The default function return the constant value `vertical_space_factor/2.1`
495
+ If a string is given instead and it is a property of the tree,
496
+ the the size will be mapped according to the property
497
+ stroke_width : Callable, optional
498
+ a function that maps a node id to a `float` value that will determine the
499
+ width of the daughter edge. The default function return the constant value `vertical_space_factor/2.1`
500
+ factor : float, default=1.0
501
+ scaling factor for nodes positions, default 1
502
+ node_color : Callable or str, optional
503
+ a function that maps a node id to a triplet between 0 and 255.
504
+ The triplet will determine the color of the node. If a string is given instead and it is a property
505
+ of the tree, the the color will be mapped according to the property
506
+ node_color_map : Callable or str, optional
507
+ the name of the colormap to use to color the nodes, or a colormap function
508
+ stroke_color : Callable, optional
509
+ a function that maps a node id to a triplet between 0 and 255.
510
+ The triplet will determine the color of the stroke of the inward edge.
511
+ positions : dict mapping int to list of two float, optional
512
+ dictionary that maps a node id to a 2D position.
513
+ Default `None`. If provided it will be used to position the nodes.
620
514
  """
621
- import svgwrite
622
515
 
623
516
  def normalize_values(v, nodes, _range, shift, mult):
624
517
  min_ = np.percentile(v, 1)
625
518
  max_ = np.percentile(v, 99)
626
519
  values = _range * ((v - min_) / (max_ - min_)) + shift
627
- values_dict_nodes = dict(zip(nodes, values))
520
+ values_dict_nodes = dict(zip(nodes, values, strict=True))
628
521
  return lambda x: values_dict_nodes[x] * mult
629
522
 
630
523
  if roots is None:
631
524
  roots = self.roots
632
525
  if hasattr(self, "image_label"):
633
- roots = [cell for cell in roots if self.image_label[cell] != 1]
526
+ roots = [node for node in roots if self.image_label[node] != 1]
634
527
 
635
528
  if node_size is None:
636
529
 
637
530
  def node_size(x):
638
531
  return vert_space_factor / 2.1
639
532
 
640
- elif isinstance(node_size, str) and node_size in self.__dict__:
641
- values = np.array([self[node_size][c] for c in self.nodes])
533
+ else:
534
+ values = np.array(
535
+ [self._successor[node_size][c] for c in self.nodes]
536
+ )
642
537
  node_size = normalize_values(
643
538
  values, self.nodes, 0.5, 0.5, vert_space_factor / 2.1
644
539
  )
@@ -653,18 +548,19 @@ class lineageTree(lineageTreeLoaders):
653
548
  return 0, 0, 0
654
549
 
655
550
  elif isinstance(node_color, str) and node_color in self.__dict__:
656
- if isinstance(node_color_map, str):
657
- from matplotlib import colormaps
551
+ from matplotlib import colormaps
658
552
 
659
- if node_color_map in colormaps:
660
- node_color_map = colormaps[node_color_map]
661
- else:
662
- node_color_map = colormaps["viridis"]
663
- values = np.array([self[node_color][c] for c in self.nodes])
553
+ if node_color_map in colormaps:
554
+ cmap = colormaps[node_color_map]
555
+ else:
556
+ cmap = colormaps["viridis"]
557
+ values = np.array(
558
+ [self._successor[node_color][c] for c in self.nodes]
559
+ )
664
560
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
665
561
 
666
562
  def node_color(x):
667
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
563
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
668
564
 
669
565
  coloring_edges = stroke_color is not None
670
566
  if not coloring_edges:
@@ -673,24 +569,25 @@ class lineageTree(lineageTreeLoaders):
673
569
  return 0, 0, 0
674
570
 
675
571
  elif isinstance(stroke_color, str) and stroke_color in self.__dict__:
676
- if isinstance(node_color_map, str):
677
- from matplotlib import colormaps
572
+ from matplotlib import colormaps
678
573
 
679
- if node_color_map in colormaps:
680
- node_color_map = colormaps[node_color_map]
681
- else:
682
- node_color_map = colormaps["viridis"]
683
- values = np.array([self[stroke_color][c] for c in self.nodes])
574
+ if node_color_map in colormaps:
575
+ cmap = colormaps[node_color_map]
576
+ else:
577
+ cmap = colormaps["viridis"]
578
+ values = np.array(
579
+ [self._successor[stroke_color][c] for c in self.nodes]
580
+ )
684
581
  normed_vals = normalize_values(values, self.nodes, 1, 0, 1)
685
582
 
686
583
  def stroke_color(x):
687
- return [k * 255 for k in node_color_map(normed_vals(x))[:-1]]
584
+ return [k * 255 for k in cmap(normed_vals(x))[:-1]]
688
585
 
689
586
  prev_x = 0
690
587
  self.vert_space_factor = vert_space_factor
691
588
  if order_key is not None:
692
589
  roots.sort(key=order_key)
693
- treated_cells = []
590
+ treated_nodes = []
694
591
 
695
592
  pos_given = positions is not None
696
593
  if not pos_given:
@@ -701,25 +598,26 @@ class lineageTree(lineageTreeLoaders):
701
598
  [0.0, 0.0],
702
599
  ]
703
600
  * len(self.nodes),
704
- )
601
+ strict=True,
602
+ ),
705
603
  )
706
604
  for _i, r in enumerate(roots):
707
605
  r_leaves = []
708
606
  to_do = [r]
709
607
  while len(to_do) != 0:
710
608
  curr = to_do.pop(0)
711
- treated_cells += [curr]
712
- if curr in self.successor:
609
+ treated_nodes += [curr]
610
+ if not self._successor[curr]:
713
611
  if order_key is not None:
714
- to_do += sorted(self.successor[curr], key=order_key)
612
+ to_do += sorted(self._successor[curr], key=order_key)
715
613
  else:
716
- to_do += self.successor[curr]
614
+ to_do += self._successor[curr]
717
615
  else:
718
616
  r_leaves += [curr]
719
617
  r_pos = {
720
618
  leave: [
721
619
  prev_x + horizontal_space * (1 + j),
722
- self.vert_space_factor * self.time[leave],
620
+ self.vert_space_factor * self._time[leave],
723
621
  ]
724
622
  for j, leave in enumerate(r_leaves)
725
623
  }
@@ -734,12 +632,12 @@ class lineageTree(lineageTreeLoaders):
734
632
  size=factor * np.max(list(positions.values()), axis=0),
735
633
  )
736
634
  if draw_edges and not draw_nodes and not coloring_edges:
737
- to_do = set(treated_cells)
635
+ to_do = set(treated_nodes)
738
636
  while len(to_do) > 0:
739
637
  curr = to_do.pop()
740
- c_cycle = self.get_cycle(curr)
741
- x1, y1 = positions[c_cycle[0]]
742
- x2, y2 = positions[c_cycle[-1]]
638
+ c_chain = self.get_chain_of_node(curr)
639
+ x1, y1 = positions[c_chain[0]]
640
+ x2, y2 = positions[c_chain[-1]]
743
641
  dwg.add(
744
642
  dwg.line(
745
643
  (factor * x1, factor * y1),
@@ -747,7 +645,7 @@ class lineageTree(lineageTreeLoaders):
747
645
  stroke=svgwrite.rgb(0, 0, 0),
748
646
  )
749
647
  )
750
- for si in self[c_cycle[-1]]:
648
+ for si in self._successor[c_chain[-1]]:
751
649
  x3, y3 = positions[si]
752
650
  dwg.add(
753
651
  dwg.line(
@@ -756,11 +654,11 @@ class lineageTree(lineageTreeLoaders):
756
654
  stroke=svgwrite.rgb(0, 0, 0),
757
655
  )
758
656
  )
759
- to_do.difference_update(c_cycle)
657
+ to_do.difference_update(c_chain)
760
658
  else:
761
- for c in treated_cells:
659
+ for c in treated_nodes:
762
660
  x1, y1 = positions[c]
763
- for si in self[c]:
661
+ for si in self._successor[c]:
764
662
  x2, y2 = positions[si]
765
663
  if draw_edges:
766
664
  dwg.add(
@@ -771,7 +669,7 @@ class lineageTree(lineageTreeLoaders):
771
669
  stroke_width=svgwrite.pt(stroke_width(si)),
772
670
  )
773
671
  )
774
- for c in treated_cells:
672
+ for c in treated_nodes:
775
673
  x1, y1 = positions[c]
776
674
  if draw_nodes:
777
675
  dwg.add(
@@ -788,49 +686,58 @@ class lineageTree(lineageTreeLoaders):
788
686
  fname: str,
789
687
  t_min: int = -1,
790
688
  t_max: int = np.inf,
791
- nodes_to_use: list = None,
689
+ nodes_to_use: list[int] | None = None,
792
690
  temporal: bool = True,
793
- spatial: str = None,
691
+ spatial: str | None = None,
794
692
  write_layout: bool = True,
795
- node_properties: dict = None,
693
+ node_properties: dict | None = None,
796
694
  Names: bool = False,
797
- ):
695
+ ) -> None:
798
696
  """Write a lineage tree into an understable tulip file.
799
697
 
800
- Args:
801
- fname (str): path to the tulip file to create
802
- t_min (int): minimum time to consider, default -1
803
- t_max (int): maximum time to consider, default np.inf
804
- nodes_to_use ([int, ]): list of nodes to show in the graph,
805
- default *None*, then self.nodes is used
806
- (taking into account *t_min* and *t_max*)
807
- temporal (bool): True if the temporal links should be printed, default True
808
- spatial (str): Build spatial edges from a spatial neighbourhood graph.
809
- The graph has to be computed before running this function
810
- 'ball': neighbours at a given distance,
811
- 'kn': k-nearest neighbours,
812
- 'GG': gabriel graph,
813
- None: no spatial edges are writen.
814
- Default None
815
- write_layout (bool): True, write the spatial position as layout,
816
- False, do not write spatial positionm
817
- default True
818
- node_properties ({`p_name`, [{id, p_value}, default]}): a dictionary of properties to write
819
- To a key representing the name of the property is
820
- paired a dictionary that maps a cell id to a property
821
- and a default value for this property
822
- Names (bool): Only works with ASTEC outputs, True to sort the cells by their names
698
+ Parameters
699
+ ----------
700
+ fname : str
701
+ path to the tulip file to create
702
+ t_min : int, default=-1
703
+ minimum time to consider
704
+ t_max : int, default=np.inf
705
+ maximum time to consider
706
+ nodes_to_use : list of int, optional
707
+ list of nodes to show in the graph,
708
+ if `None` then self.nodes is used
709
+ (taking into account `t_min` and `t_max`)
710
+ temporal : bool, default=True
711
+ True if the temporal links should be printed
712
+ spatial : str, optional
713
+ Build spatial edges from a spatial neighbourhood graph.
714
+ The graph has to be computed before running this function
715
+ 'ball': neighbours at a given distance,
716
+ 'kn': k-nearest neighbours,
717
+ 'GG': gabriel graph,
718
+ None: no spatial edges are writen.
719
+ Default None
720
+ write_layout : bool, default=True
721
+ write the spatial position as layout if True
722
+ do not write spatial position otherwise
723
+ node_properties : dict mapping str to list of dict of properties and its default value, optional
724
+ a dictionary of properties to write
725
+ To a key representing the name of the property is
726
+ paired a dictionary that maps a node id to a property
727
+ and a default value for this property
728
+ Names : bool, default=True
729
+ Only works with ASTEC outputs, True to sort the nodes by their names
823
730
  """
824
731
 
825
732
  def format_names(names_which_matter):
826
- """Return an ensured formated cell names"""
733
+ """Return an ensured formated node names"""
827
734
  tmp = {}
828
735
  for k, v in names_which_matter.items():
829
736
  tmp[k] = (
830
737
  v.split(".")[0][0]
831
- + "%02d" % int(v.split(".")[0][1:])
738
+ + "{:02d}".format(int(v.split(".")[0][1:]))
832
739
  + "."
833
- + "%04d" % int(v.split(".")[1][:-1])
740
+ + "{:04d}".format(int(v.split(".")[1][:-1]))
834
741
  + v.split(".")[1][-1]
835
742
  )
836
743
  return tmp
@@ -857,20 +764,20 @@ class lineageTree(lineageTreeLoaders):
857
764
  if not nodes_to_use:
858
765
  if t_max != np.inf or t_min > -1:
859
766
  nodes_to_use = [
860
- n for n in self.nodes if t_min < self.time[n] <= t_max
767
+ n for n in self.nodes if t_min < self._time[n] <= t_max
861
768
  ]
862
769
  edges_to_use = []
863
770
  if temporal:
864
771
  edges_to_use += [
865
772
  e
866
773
  for e in self.edges
867
- if t_min < self.time[e[0]] < t_max
774
+ if t_min < self._time[e[0]] < t_max
868
775
  ]
869
776
  if spatial:
870
777
  edges_to_use += [
871
778
  e
872
779
  for e in s_edges
873
- if t_min < self.time[e[0]] < t_max
780
+ if t_min < self._time[e[0]] < t_max
874
781
  ]
875
782
  else:
876
783
  nodes_to_use = list(self.nodes)
@@ -884,12 +791,12 @@ class lineageTree(lineageTreeLoaders):
884
791
  nodes_to_use = set(nodes_to_use)
885
792
  if temporal:
886
793
  for n in nodes_to_use:
887
- for d in self.successor.get(n, []):
794
+ for d in self._successor[n]:
888
795
  if d in nodes_to_use:
889
796
  edges_to_use.append((n, d))
890
797
  if spatial:
891
798
  edges_to_use += [
892
- e for e in s_edges if t_min < self.time[e[0]] < t_max
799
+ e for e in s_edges if t_min < self._time[e[0]] < t_max
893
800
  ]
894
801
  nodes_to_use = set(nodes_to_use)
895
802
  if Names:
@@ -907,12 +814,12 @@ class lineageTree(lineageTreeLoaders):
907
814
  for k, v in node_properties[Names][0].items():
908
815
  if (
909
816
  len(
910
- self.successor.get(
911
- self.predecessor.get(k, [-1])[0], []
817
+ self._successor.get(
818
+ self._predecessor.get(k, [-1])[0], ()
912
819
  )
913
820
  )
914
821
  != 1
915
- or self.time[k] == t_min + 1
822
+ or self._time[k] == t_min + 1
916
823
  ):
917
824
  tmp_names[k] = v
918
825
  node_properties[Names][0] = tmp_names
@@ -942,7 +849,7 @@ class lineageTree(lineageTreeLoaders):
942
849
  f.write('\t(default "0" "0")\n')
943
850
  for n in nodes_to_use:
944
851
  f.write(
945
- "\t(node " + str(n) + ' "' + str(self.time[n]) + '")\n'
852
+ "\t(node " + str(n) + ' "' + str(self._time[n]) + '")\n'
946
853
  )
947
854
  f.write(")\n")
948
855
 
@@ -989,7 +896,9 @@ class lineageTree(lineageTreeLoaders):
989
896
  f.write(")")
990
897
  f.close()
991
898
 
992
- def to_binary(self, fname: str, starting_points: list = None):
899
+ def to_binary(
900
+ self, fname: str, starting_points: list[int] | None = None
901
+ ) -> None:
993
902
  """Writes the lineage tree (a forest) as a binary structure
994
903
  (assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
995
904
  The binary file is composed of 3 sequences of numbers and
@@ -1004,36 +913,37 @@ class lineageTree(lineageTreeLoaders):
1004
913
  The *time_sequence* is stored as a list of unsigned short (0 -> 2^(8*2)-1)
1005
914
  The *pos_sequence* is stored as a list of double.
1006
915
 
1007
- Args:
1008
- fname (str): name of the binary file
1009
- starting_points ([int, ]): list of the roots to be written.
1010
- If None, all roots are written, default value, None
916
+ Parameters
917
+ ----------
918
+ fname : str
919
+ name of the binary file
920
+ starting_points : list of int, optional
921
+ list of the roots to be written.
922
+ If `None`, all roots are written, default value, None
1011
923
  """
1012
924
  if starting_points is None:
1013
- starting_points = [
1014
- c for c in self.successor if self.predecessor.get(c, []) == []
1015
- ]
925
+ starting_points = list(self.roots)
1016
926
  number_sequence = [-1]
1017
927
  pos_sequence = []
1018
928
  time_sequence = []
1019
929
  for c in starting_points:
1020
- time_sequence.append(self.time.get(c, 0))
930
+ time_sequence.append(self._time.get(c, 0))
1021
931
  to_treat = [c]
1022
- while to_treat != []:
932
+ while to_treat:
1023
933
  curr_c = to_treat.pop()
1024
934
  number_sequence.append(curr_c)
1025
935
  pos_sequence += list(self.pos[curr_c])
1026
- if self[curr_c] == []:
936
+ if self._successor[curr_c] == ():
1027
937
  number_sequence.append(-1)
1028
- elif len(self.successor[curr_c]) == 1:
1029
- to_treat += self.successor[curr_c]
938
+ elif len(self._successor[curr_c]) == 1:
939
+ to_treat += self._successor[curr_c]
1030
940
  else:
1031
941
  number_sequence.append(-2)
1032
- to_treat += self.successor[curr_c]
942
+ to_treat += self._successor[curr_c]
1033
943
  remaining_nodes = set(self.nodes) - set(number_sequence)
1034
944
 
1035
945
  for c in remaining_nodes:
1036
- time_sequence.append(self.time.get(c, 0))
946
+ time_sequence.append(self._time.get(c, 0))
1037
947
  number_sequence.append(c)
1038
948
  pos_sequence += list(self.pos[c])
1039
949
  number_sequence.append(-1)
@@ -1048,50 +958,270 @@ class lineageTree(lineageTreeLoaders):
1048
958
 
1049
959
  f.close()
1050
960
 
1051
- def write(self, fname: str):
1052
- """
1053
- Write a lineage tree on disk as an .lT file.
961
+ def write(self, fname: str) -> None:
962
+ """Write a lineage tree on disk as an .lT file.
1054
963
 
1055
- Args:
1056
- fname (str): path to and name of the file to save
964
+ Parameters
965
+ ----------
966
+ fname : str
967
+ path to and name of the file to save
1057
968
  """
1058
969
  if os.path.splitext(fname)[-1] != ".lT":
1059
970
  fname = os.path.extsep.join((fname, "lT"))
971
+ if hasattr(self, "_protected_predecessor"):
972
+ del self._protected_predecessor
973
+ if hasattr(self, "_protected_successor"):
974
+ del self._protected_successor
975
+ if hasattr(self, "_protected_time"):
976
+ del self._protected_time
1060
977
  with open(fname, "bw") as f:
1061
978
  pkl.dump(self, f)
1062
979
  f.close()
1063
980
 
1064
981
  @classmethod
1065
- def load(clf, fname: str, rm_empty_lists=False):
1066
- """
1067
- Loading a lineage tree from a ".lT" file.
982
+ def load(clf, fname: str):
983
+ """Loading a lineage tree from a '.lT' file.
1068
984
 
1069
- Args:
1070
- fname (str): path to and name of the file to read
985
+ Parameters
986
+ ----------
987
+ fname : str
988
+ path to and name of the file to read
1071
989
 
1072
- Returns:
1073
- (lineageTree): loaded file
990
+ Returns
991
+ -------
992
+ LineageTree
993
+ loaded file
1074
994
  """
1075
995
  with open(fname, "br") as f:
1076
996
  lT = pkl.load(f)
1077
997
  f.close()
998
+ if not hasattr(lT, "__version__") or Version(lT.__version__) < Version(
999
+ "2.0.0"
1000
+ ):
1001
+ properties = {
1002
+ prop_name: prop
1003
+ for prop_name, prop in lT.__dict__.items()
1004
+ if isinstance(prop, dict)
1005
+ and prop_name
1006
+ not in [
1007
+ "successor",
1008
+ "predecessor",
1009
+ "time",
1010
+ "_successor",
1011
+ "_predecessor",
1012
+ "_time",
1013
+ "pos",
1014
+ "labels",
1015
+ ]
1016
+ + lineageTree._dynamic_properties
1017
+ + lineageTree._protected_dynamic_properties
1018
+ }
1019
+ print("_comparisons" in properties)
1020
+ lT = lineageTree(
1021
+ successor=lT._successor,
1022
+ time=lT._time,
1023
+ pos=lT.pos,
1024
+ name=lT.name if hasattr(lT, "name") else None,
1025
+ **properties,
1026
+ )
1078
1027
  if not hasattr(lT, "time_resolution"):
1079
- lT.time_resolution = None
1028
+ lT.time_resolution = 1
1029
+
1080
1030
  return lT
1081
1031
 
1082
- def get_idx3d(self, t: int) -> tuple:
1083
- """Get a 3d kdtree for the dataset at time *t* .
1084
- The kdtree is stored in *self.kdtrees[t]*
1085
-
1086
- Args:
1087
- t (int): time
1088
- Returns:
1089
- (kdtree, [int, ]): the built kdtree and
1090
- the correspondancy list,
1091
- If the query in the kdtree gives you the value i,
1092
- then it corresponds to the id in the tree to_check_self[i]
1032
+ def get_predecessors(
1033
+ self,
1034
+ x: int,
1035
+ depth: int | None = None,
1036
+ start_time: int | None = None,
1037
+ end_time: int | None = None,
1038
+ ) -> list[int]:
1039
+ """Computes the predecessors of the node `x` up to
1040
+ `depth` predecessors or the begining of the life of `x`.
1041
+ The ordered list of ids is returned.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ x : int
1046
+ id of the node to compute
1047
+ depth : int
1048
+ maximum number of predecessors to return
1049
+
1050
+ Returns
1051
+ -------
1052
+ list of int
1053
+ list of ids, the last id is `x`
1093
1054
  """
1094
- to_check_self = list(self.time_nodes[t])
1055
+ if not start_time:
1056
+ start_time = self.t_b
1057
+ if not end_time:
1058
+ end_time = self.t_e
1059
+ unconstrained_chain = [x]
1060
+ chain = [x] if start_time <= self._time[x] <= end_time else []
1061
+ acc = 0
1062
+ while (
1063
+ acc != depth
1064
+ and start_time < self._time[unconstrained_chain[0]]
1065
+ and (
1066
+ self._predecessor[unconstrained_chain[0]] != ()
1067
+ and ( # Please dont change very important even if it looks weird.
1068
+ len(
1069
+ self._successor[
1070
+ self._predecessor[unconstrained_chain[0]][0]
1071
+ ]
1072
+ )
1073
+ == 1
1074
+ )
1075
+ )
1076
+ ):
1077
+ unconstrained_chain.insert(
1078
+ 0, self._predecessor[unconstrained_chain[0]][0]
1079
+ )
1080
+ acc += 1
1081
+ if start_time <= self._time[unconstrained_chain[0]] <= end_time:
1082
+ chain.insert(0, unconstrained_chain[0])
1083
+
1084
+ return chain
1085
+
1086
+ def get_successors(
1087
+ self, x: int, depth: int | None = None, end_time: int | None = None
1088
+ ) -> list[int]:
1089
+ """Computes the successors of the node `x` up to
1090
+ `depth` successors or the end of the life of `x`.
1091
+ The ordered list of ids is returned.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ x : int
1096
+ id of the node to compute
1097
+ depth : int, optional
1098
+ maximum number of predecessors to return
1099
+ end_time : int, optional
1100
+ maximum time to consider
1101
+
1102
+ Returns
1103
+ -------
1104
+ list of int
1105
+ list of ids, the first id is `x`
1106
+ """
1107
+ if end_time is None:
1108
+ end_time = self.t_e
1109
+ chain = [x]
1110
+ acc = 0
1111
+ while (
1112
+ len(self._successor[chain[-1]]) == 1
1113
+ and acc != depth
1114
+ and self._time[chain[-1]] < end_time
1115
+ ):
1116
+ chain += self._successor[chain[-1]]
1117
+ acc += 1
1118
+
1119
+ return chain
1120
+
1121
+ def get_chain_of_node(
1122
+ self,
1123
+ x: int,
1124
+ depth: int | None = None,
1125
+ depth_pred: int | None = None,
1126
+ depth_succ: int | None = None,
1127
+ end_time: int | None = None,
1128
+ ) -> list[int]:
1129
+ """Computes the predecessors and successors of the node `x` up to
1130
+ `depth_pred` predecessors plus `depth_succ` successors.
1131
+ If the value `depth` is provided and not None,
1132
+ `depth_pred` and `depth_succ` are overwriten by `depth`.
1133
+ The ordered list of ids is returned.
1134
+ If all `depth` are None, the full chain is returned.
1135
+
1136
+ Parameters
1137
+ ----------
1138
+ x : int
1139
+ id of the node to compute
1140
+ depth : int, optional
1141
+ maximum number of predecessors and successor to return
1142
+ depth_pred : int, optional
1143
+ maximum number of predecessors to return
1144
+ depth_succ : int, optional
1145
+ maximum number of successors to return
1146
+
1147
+ Returns
1148
+ -------
1149
+ list of int
1150
+ list of node ids
1151
+ """
1152
+ if end_time is None:
1153
+ end_time = self.t_e
1154
+ if depth is not None:
1155
+ depth_pred = depth_succ = depth
1156
+ return self.get_predecessors(x, depth_pred, end_time=end_time)[
1157
+ :-1
1158
+ ] + self.get_successors(x, depth_succ, end_time=end_time)
1159
+
1160
+ @dynamic_property
1161
+ def all_chains(self) -> list[list[int]]:
1162
+ """List of all chains in the tree, ordered in depth-first search."""
1163
+ return self._compute_all_chains()
1164
+
1165
+ @dynamic_property
1166
+ def time_nodes(self):
1167
+ _time_nodes = {}
1168
+ for c, t in self._time.items():
1169
+ _time_nodes.setdefault(t, set()).add(c)
1170
+ return _time_nodes
1171
+
1172
+ def m(self, i, j):
1173
+ if (i, j) not in self._tmp_parenting:
1174
+ if i == j: # the distance to the node itself is 0
1175
+ self._tmp_parenting[(i, j)] = 0
1176
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1177
+ elif not self._predecessor[
1178
+ j
1179
+ ]: # j and i are note connected so the distance if inf
1180
+ self._tmp_parenting[(i, j)] = np.inf
1181
+ else: # the distance between i and j is the distance between i and pred(j) + 1
1182
+ self._tmp_parenting[(i, j)] = (
1183
+ self.m(i, self._predecessor[j][0]) + 1
1184
+ )
1185
+ self._parenting[i, j] = self._tmp_parenting[(i, j)]
1186
+ self._parenting[j, i] = -self._tmp_parenting[(i, j)]
1187
+ return self._tmp_parenting[(i, j)]
1188
+
1189
+ @property
1190
+ def parenting(self):
1191
+ if not hasattr(self, "_parenting"):
1192
+ self._parenting = dok_array((max(self.nodes) + 1,) * 2)
1193
+ self._tmp_parenting = {}
1194
+ for i, j in combinations(self.nodes, 2):
1195
+ if self._time[j] < self.time[i]:
1196
+ i, j = j, i
1197
+ self._tmp_parenting[(i, j)] = self.m(i, j)
1198
+ del self._tmp_parenting
1199
+ return self._parenting
1200
+
1201
+ def get_idx3d(self, t: int) -> tuple[KDTree, np.ndarray]:
1202
+ """Get a 3d kdtree for the dataset at time `t`.
1203
+ The kdtree is stored in `self.kdtrees[t]` and returned.
1204
+ The correspondancy list is also returned.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ t : int
1209
+ time
1210
+
1211
+ Returns
1212
+ -------
1213
+ KDTree
1214
+ The KDTree corresponding to the lineage tree at time `t`
1215
+ np.ndarray
1216
+ The correspondancy list in the KDTree.
1217
+ If the query in the kdtree gives you the value `i`,
1218
+ then it corresponds to the id in the tree `to_check_self[i]`
1219
+ """
1220
+ to_check_self = list(self.nodes_at_t(t=t))
1221
+
1222
+ if not hasattr(self, "kdtrees"):
1223
+ self.kdtrees = {}
1224
+
1095
1225
  if t not in self.kdtrees:
1096
1226
  data_corres = {}
1097
1227
  data = []
@@ -1104,16 +1234,21 @@ class lineageTree(lineageTreeLoaders):
1104
1234
  idx3d = self.kdtrees[t]
1105
1235
  return idx3d, np.array(to_check_self)
1106
1236
 
1107
- def get_gabriel_graph(self, t: int) -> dict:
1108
- """Build the Gabriel graph of the given graph for time point `t`
1109
- The Garbiel graph is then stored in self.Gabriel_graph and returned
1110
- *WARNING: the graph is not recomputed if already computed. even if nodes were added*.
1237
+ def get_gabriel_graph(self, t: int) -> dict[int, set[int]]:
1238
+ """Build the Gabriel graph of the given graph for time point `t`.
1239
+ The Garbiel graph is then stored in `self.Gabriel_graph` and returned.
1240
+
1241
+ .. warning:: the graph is not recomputed if already computed, even if the point cloud has changed
1242
+
1243
+ Parameters
1244
+ ----------
1245
+ t : int
1246
+ time
1111
1247
 
1112
- Args:
1113
- t (int): time
1114
- Returns:
1115
- {int, set([int, ])}: a dictionary that maps a node to
1116
- the set of its neighbors
1248
+ Returns
1249
+ -------
1250
+ dict of int to set of int
1251
+ A dictionary that maps a node to the set of its neighbors
1117
1252
  """
1118
1253
  if not hasattr(self, "Gabriel_graph"):
1119
1254
  self.Gabriel_graph = {}
@@ -1158,178 +1293,110 @@ class lineageTree(lineageTreeLoaders):
1158
1293
 
1159
1294
  return self.Gabriel_graph[t]
1160
1295
 
1161
- def get_predecessors(
1162
- self, x: int, depth: int = None, start_time: int = None, end_time=None
1163
- ) -> list:
1164
- """Computes the predecessors of the node `x` up to
1165
- `depth` predecessors or the begining of the life of `x`.
1166
- The ordered list of ids is returned.
1167
-
1168
- Args:
1169
- x (int): id of the node to compute
1170
- depth (int): maximum number of predecessors to return
1171
- Returns:
1172
- [int, ]: list of ids, the last id is `x`
1173
- """
1174
- if not start_time:
1175
- start_time = self.t_b
1176
- if not end_time:
1177
- end_time = self.t_e
1178
- unconstrained_cycle = [x]
1179
- cycle = [x] if start_time <= self.time[x] <= end_time else []
1180
- acc = 0
1181
- while (
1182
- len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
1183
- == 1
1184
- and acc != depth
1185
- and start_time
1186
- <= self.time.get(
1187
- self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
1188
- )
1189
- ):
1190
- unconstrained_cycle.insert(
1191
- 0, self.predecessor[unconstrained_cycle[0]][0]
1192
- )
1193
- acc += 1
1194
- if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
1195
- cycle.insert(0, unconstrained_cycle[0])
1196
-
1197
- return cycle
1198
-
1199
- def get_successors(
1200
- self, x: int, depth: int = None, end_time: int = None
1201
- ) -> list:
1202
- """Computes the successors of the node `x` up to
1203
- `depth` successors or the end of the life of `x`.
1204
- The ordered list of ids is returned.
1205
-
1206
- Args:
1207
- x (int): id of the node to compute
1208
- depth (int): maximum number of predecessors to return
1209
- Returns:
1210
- [int, ]: list of ids, the first id is `x`
1211
- """
1212
- if end_time is None:
1213
- end_time = self.t_e
1214
- cycle = [x]
1215
- acc = 0
1216
- while (
1217
- len(self[cycle[-1]]) == 1
1218
- and acc != depth
1219
- and self.time[cycle[-1]] < end_time
1220
- ):
1221
- cycle += self.successor[cycle[-1]]
1222
- acc += 1
1223
-
1224
- return cycle
1225
-
1226
- def get_cycle(
1227
- self,
1228
- x: int,
1229
- depth: int = None,
1230
- depth_pred: int = None,
1231
- depth_succ: int = None,
1232
- end_time: int = None,
1233
- ) -> list:
1234
- """Computes the predecessors and successors of the node `x` up to
1235
- `depth_pred` predecessors plus `depth_succ` successors.
1236
- If the value `depth` is provided and not None,
1237
- `depth_pred` and `depth_succ` are overwriten by `depth`.
1238
- The ordered list of ids is returned.
1239
- If all `depth` are None, the full cycle is returned.
1240
-
1241
- Args:
1242
- x (int): id of the node to compute
1243
- depth (int): maximum number of predecessors and successor to return
1244
- depth_pred (int): maximum number of predecessors to return
1245
- depth_succ (int): maximum number of successors to return
1246
- Returns:
1247
- [int, ]: list of ids
1248
- """
1249
- if end_time is None:
1250
- end_time = self.t_e
1251
- if depth is not None:
1252
- depth_pred = depth_succ = depth
1253
- return self.get_predecessors(x, depth_pred, end_time=end_time)[
1254
- :-1
1255
- ] + self.get_successors(x, depth_succ, end_time=end_time)
1256
-
1257
- @property
1258
- def all_tracks(self):
1259
- if not hasattr(self, "_all_tracks"):
1260
- self._all_tracks = self.get_all_tracks()
1261
- return self._all_tracks
1262
-
1263
- def get_all_branches_of_node(
1264
- self, node: int, end_time: int = None
1265
- ) -> list:
1266
- """Computes all the tracks of the subtree spawn by a given node.
1267
- Similar to get_all_tracks().
1296
+ def get_all_chains_of_subtree(
1297
+ self, node: int, end_time: int | None = None
1298
+ ) -> list[list[int]]:
1299
+ """Computes all the chains of the subtree spawn by a given node.
1300
+ Similar to get_all_chains().
1268
1301
 
1269
- Args:
1270
- node (int, optional): The node that we want to get its branches.
1302
+ Parameters
1303
+ ----------
1304
+ node : int
1305
+ The node from which we want to get its chains.
1306
+ end_time : int, optional
1307
+ The time at which we want to stop the chains.
1271
1308
 
1272
- Returns:
1273
- ([[int, ...], ...]): list of lists containing track cell ids
1309
+ Returns
1310
+ -------
1311
+ list of list of int
1312
+ list of chains
1274
1313
  """
1275
1314
  if not end_time:
1276
1315
  end_time = self.t_e
1277
- branches = [self.get_successors(node)]
1278
- to_do = list(self[branches[0][-1]])
1316
+ chains = [self.get_successors(node)]
1317
+ to_do = list(self._successor[chains[0][-1]])
1279
1318
  while to_do:
1280
1319
  current = to_do.pop()
1281
- track = self.get_successors(current, end_time=end_time)
1282
- # if len(track) != 1 or self.time[current] <= end_time:
1283
- if self.time[track[-1]] <= end_time:
1284
- branches += [track]
1285
- to_do += self[track[-1]]
1286
- return branches
1287
-
1288
- def get_all_tracks(self, force_recompute: bool = False) -> list:
1289
- """Computes all the tracks of a given lineage tree,
1290
- stores it in `self.all_tracks` and returns it.
1291
-
1292
- Returns:
1293
- ([[int, ...], ...]): list of lists containing track cell ids
1294
- """
1295
- if not hasattr(self, "_all_tracks") or force_recompute:
1296
- self._all_tracks = []
1297
- to_do = list(self.roots)
1298
- while len(to_do) != 0:
1299
- current = to_do.pop()
1300
- track = self.get_cycle(current)
1301
- self._all_tracks += [track]
1302
- to_do.extend(self[track[-1]])
1303
- return self._all_tracks
1304
-
1305
- def get_tracks(self, roots: list = None) -> list:
1306
- """Computes the tracks given by the list of nodes `roots` and returns it.
1307
-
1308
- Args:
1309
- roots (list): list of ids of the roots to be computed
1310
- Returns:
1311
- ([[int, ...], ...]): list of lists containing track cell ids
1312
- """
1313
- if roots is None:
1314
- return self.get_all_tracks(force_recompute=True)
1315
- else:
1316
- tracks = []
1317
- to_do = list(roots)
1318
- while len(to_do) != 0:
1319
- current = to_do.pop()
1320
- track = self.get_cycle(current)
1321
- tracks.append(track)
1322
- to_do.extend(self[track[-1]])
1323
- return tracks
1324
-
1325
- def find_leaves(self, roots: Union[int, set, list, tuple]):
1320
+ chain = self.get_successors(current, end_time=end_time)
1321
+ if self._time[chain[-1]] <= end_time:
1322
+ chains += [chain]
1323
+ to_do += self._successor[chain[-1]]
1324
+ return chains
1325
+
1326
+ def _compute_all_chains(self) -> list[list[int]]:
1327
+ """Computes all the chains of a given lineage tree,
1328
+ stores it in `self.all_chains` and returns it.
1329
+
1330
+ Returns
1331
+ -------
1332
+ list of list of int
1333
+ list of chains
1334
+ """
1335
+ all_chains = []
1336
+ to_do = sorted(self.roots, key=self.time.get, reverse=True)
1337
+ while len(to_do) != 0:
1338
+ current = to_do.pop()
1339
+ chain = self.get_chain_of_node(current)
1340
+ all_chains += [chain]
1341
+ to_do.extend(self._successor[chain[-1]])
1342
+ return all_chains
1343
+
1344
+ def __get_chains( # TODO: Probably should be removed, might be used by DTW. Might also be a @dynamic_property
1345
+ self, nodes: Iterable | int | None = None
1346
+ ) -> dict[int, list[list[int]]]:
1347
+ """Returns all the chains in the subtrees spawned by each of the given nodes.
1348
+
1349
+ Parameters
1350
+ ----------
1351
+ nodes : Iterable or int, optional
1352
+ id or Iterable of ids of the nodes to be computed, if `None` all roots are used
1353
+
1354
+ Returns
1355
+ -------
1356
+ dict mapping int to list of Chain
1357
+ dictionary mapping the node ids to a list of chains
1358
+ """
1359
+ all_chains = self.all_chains
1360
+ if nodes is None:
1361
+ nodes = self.roots
1362
+ if not isinstance(nodes, Iterable):
1363
+ nodes = [nodes]
1364
+ output_chains = {}
1365
+ for n in nodes:
1366
+ starting_node = self.get_predecessors(n)[0]
1367
+ found = False
1368
+ done = False
1369
+ starting_time = self.time[n]
1370
+ i = 0
1371
+ current_chain = []
1372
+ while not done and i < len(all_chains):
1373
+ curr_found = all_chains[i][0] == starting_node
1374
+ found = found or curr_found
1375
+ if found:
1376
+ done = (
1377
+ self.time[all_chains[i][0]] <= starting_time
1378
+ ) and not curr_found
1379
+ if not done:
1380
+ if curr_found:
1381
+ current_chain.append(self.get_successors(n))
1382
+ else:
1383
+ current_chain.append(all_chains[i])
1384
+ i += 1
1385
+ output_chains[n] = current_chain
1386
+ return output_chains
1387
+
1388
+ def find_leaves(self, roots: int | Iterable) -> set[int]:
1326
1389
  """Finds the leaves of a tree spawned by one or more nodes.
1327
1390
 
1328
- Args:
1329
- roots (Union[int,set,list,tuple]): The roots of the trees.
1391
+ Parameters
1392
+ ----------
1393
+ roots : int or Iterable
1394
+ The roots of the trees spawning the leaves
1330
1395
 
1331
- Returns:
1332
- set: The leaves of one or more trees.
1396
+ Returns
1397
+ -------
1398
+ set
1399
+ The leaves of one or more trees.
1333
1400
  """
1334
1401
  if not isinstance(roots, Iterable):
1335
1402
  to_do = [roots]
@@ -1338,28 +1405,34 @@ class lineageTree(lineageTreeLoaders):
1338
1405
  leaves = set()
1339
1406
  while to_do:
1340
1407
  curr = to_do.pop()
1341
- succ = self.successor.get(curr, [])
1408
+ succ = self._successor[curr]
1342
1409
  if not succ:
1343
1410
  leaves.add(curr)
1344
1411
  to_do += succ
1345
1412
  return leaves
1346
1413
 
1347
- def get_sub_tree(
1414
+ def get_subtree_nodes(
1348
1415
  self,
1349
- x: Union[int, Iterable],
1350
- end_time: Union[int, None] = None,
1416
+ x: int | Iterable,
1417
+ end_time: int | None = None,
1351
1418
  preorder: bool = False,
1352
- ) -> list:
1353
- """Computes the list of cells from the subtree spawned by *x*
1354
- The default output order is breadth first traversal.
1419
+ ) -> list[int]:
1420
+ """Computes the list of nodes from the subtree spawned by *x*
1421
+ The default output order is Breadth First Traversal.
1355
1422
  Unless preorder is `True` in that case the order is
1356
- Depth first traversal preordered.
1423
+ Depth First Traversal (DFT) preordered.
1357
1424
 
1358
- Args:
1359
- x (int): id of root node
1360
- preorder (bool): if True the output is preorder DFT
1361
- Returns:
1362
- ([int, ...]): the ordered list of node ids
1425
+ Parameters
1426
+ ----------
1427
+ x : int
1428
+ id of root node
1429
+ preorder : bool, default=False
1430
+ if True the output preorder is DFT
1431
+
1432
+ Returns
1433
+ -------
1434
+ list of int
1435
+ the ordered list of node ids
1363
1436
  """
1364
1437
  if not end_time:
1365
1438
  end_time = self.t_e
@@ -1367,233 +1440,258 @@ class lineageTree(lineageTreeLoaders):
1367
1440
  to_do = [x]
1368
1441
  elif isinstance(x, Iterable):
1369
1442
  to_do = list(x)
1370
- sub_tree = []
1443
+ subtree = []
1371
1444
  while to_do:
1372
1445
  curr = to_do.pop()
1373
- succ = self.successor.get(curr, [])
1374
- if succ and end_time < self.time.get(curr, end_time):
1446
+ succ = self._successor[curr]
1447
+ if succ and end_time < self._time.get(curr, end_time):
1375
1448
  succ = []
1376
1449
  continue
1377
1450
  if preorder:
1378
1451
  to_do = succ + to_do
1379
1452
  else:
1380
1453
  to_do += succ
1381
- sub_tree += [curr]
1382
- return sub_tree
1454
+ subtree += [curr]
1455
+ return subtree
1383
1456
 
1384
1457
  def compute_spatial_density(
1385
- self, t_b: int = None, t_e: int = None, th: float = 50
1386
- ) -> dict:
1387
- """Computes the spatial density of cells between `t_b` and `t_e`.
1388
- The spatial density is computed as follow:
1389
- #cell/(4/3*pi*th^3)
1390
- The results is stored in self.spatial_density is returned.
1391
-
1392
- Args:
1393
- t_b (int): starting time to look at, default first time point
1394
- t_e (int): ending time to look at, default last time point
1395
- th (float): size of the neighbourhood
1396
- Returns:
1397
- {int, float}: dictionary that maps a cell id to its spatial density
1398
- """
1458
+ self, t_b: int | None = None, t_e: int | None = None, th: float = 50
1459
+ ) -> dict[int, float]:
1460
+ """Computes the spatial density of nodes between `t_b` and `t_e`.
1461
+ The results is stored in `self.spatial_density` and returned.
1462
+
1463
+ Parameters
1464
+ ----------
1465
+ t_b : int, optional
1466
+ starting time to look at, default first time point
1467
+ t_e : int, optional
1468
+ ending time to look at, default last time point
1469
+ th : float, default=50
1470
+ size of the neighbourhood
1471
+
1472
+ Returns
1473
+ -------
1474
+ dict mapping int to float
1475
+ dictionary that maps a node id to its spatial density
1476
+ """
1477
+ if not hasattr(self, "spatial_density"):
1478
+ self.spatial_density = {}
1399
1479
  s_vol = 4 / 3.0 * np.pi * th**3
1400
- time_range = set(range(t_b, t_e + 1)).intersection(self.time_nodes)
1480
+ if t_b is None:
1481
+ t_b = self.t_b
1482
+ if t_e is None:
1483
+ t_e = self.t_e
1484
+ time_range = set(range(t_b, t_e)).intersection(self._time.values())
1401
1485
  for t in time_range:
1402
1486
  idx3d, nodes = self.get_idx3d(t)
1403
1487
  nb_ni = [
1404
1488
  (len(ni) - 1) / s_vol
1405
1489
  for ni in idx3d.query_ball_tree(idx3d, th)
1406
1490
  ]
1407
- self.spatial_density.update(dict(zip(nodes, nb_ni)))
1491
+ self.spatial_density.update(dict(zip(nodes, nb_ni, strict=True)))
1408
1492
  return self.spatial_density
1409
1493
 
1410
- def compute_k_nearest_neighbours(self, k: int = 10) -> dict:
1494
+ def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
1411
1495
  """Computes the k-nearest neighbors
1412
1496
  Writes the output in the attribute `kn_graph`
1413
1497
  and returns it.
1414
1498
 
1415
- Args:
1416
- k (float): number of nearest neighours
1417
- Returns:
1418
- {int, set([int, ...])}: dictionary that maps
1419
- a cell id to its `k` nearest neighbors
1499
+ Parameters
1500
+ ----------
1501
+ k : float
1502
+ number of nearest neighours
1503
+
1504
+ Returns
1505
+ -------
1506
+ dict mapping int to set of int
1507
+ dictionary that maps
1508
+ a node id to its `k` nearest neighbors
1420
1509
  """
1421
1510
  self.kn_graph = {}
1422
- for t, nodes in self.time_nodes.items():
1423
- use_k = k if k < len(nodes) else len(nodes)
1424
- idx3d, nodes = self.get_idx3d(t)
1425
- pos = [self.pos[c] for c in nodes]
1426
- _, neighbs = idx3d.query(pos, use_k)
1427
- out = dict(zip(nodes, [set(nodes[ni[1:]]) for ni in neighbs]))
1428
- self.kn_graph.update(out)
1511
+ for t in set(self._time.values()):
1512
+ nodes = self.nodes_at_t(t)
1513
+ if 1 < len(nodes):
1514
+ use_k = k if k < len(nodes) else len(nodes)
1515
+ idx3d, nodes = self.get_idx3d(t)
1516
+ pos = [self.pos[c] for c in nodes]
1517
+ _, neighbs = idx3d.query(pos, use_k)
1518
+ out = dict(
1519
+ zip(
1520
+ nodes,
1521
+ map(set, nodes[neighbs]),
1522
+ strict=True,
1523
+ )
1524
+ )
1525
+ self.kn_graph.update(out)
1526
+ else:
1527
+ n = nodes.pop
1528
+ self.kn_graph.update({n: {n}})
1429
1529
  return self.kn_graph
1430
1530
 
1431
- def compute_spatial_edges(self, th: int = 50) -> dict:
1531
+ def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
1432
1532
  """Computes the neighbors at a distance `th`
1433
1533
  Writes the output in the attribute `th_edge`
1434
1534
  and returns it.
1435
1535
 
1436
- Args:
1437
- th (float): distance to consider neighbors
1438
- Returns:
1439
- {int, set([int, ...])}: dictionary that maps
1440
- a cell id to its neighbors at a distance `th`
1536
+ Parameters
1537
+ ----------
1538
+ th : float, default=50
1539
+ distance to consider neighbors
1540
+
1541
+ Returns
1542
+ -------
1543
+ dict mapping int to set of int
1544
+ dictionary that maps a node id to its neighbors at a distance `th`
1441
1545
  """
1442
1546
  self.th_edges = {}
1443
- for t, _ in self.time_nodes.items():
1547
+ for t in set(self._time.values()):
1548
+ nodes = self.nodes_at_t(t)
1444
1549
  idx3d, nodes = self.get_idx3d(t)
1445
1550
  neighbs = idx3d.query_ball_tree(idx3d, th)
1446
- out = dict(zip(nodes, [set(nodes[ni]) for ni in neighbs]))
1551
+ out = dict(
1552
+ zip(nodes, [set(nodes[ni]) for ni in neighbs], strict=True)
1553
+ )
1447
1554
  self.th_edges.update(
1448
1555
  {k: v.difference([k]) for k, v in out.items()}
1449
1556
  )
1450
1557
  return self.th_edges
1451
1558
 
1452
- def main_axes(self, time: int = None):
1453
- """Finds the main axes for a timepoint.
1454
- If none will select the timepoint with the highest amound of cells.
1455
-
1456
- Args:
1457
- time (int, optional): The timepoint to find the main axes.
1458
- If None will find the timepoint
1459
- with the largest number of cells.
1460
-
1461
- Returns:
1462
- list: A list that contains the array of eigenvalues and eigenvectors.
1463
- """
1464
- if time is None:
1465
- time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
1466
- pos = np.array([self.pos[node] for node in self.time_nodes[time]])
1467
- pos = pos - np.mean(pos, axis=0)
1468
- cov = np.cov(np.array(pos).T)
1469
- eig_val, eig_vec = np.linalg.eig(cov)
1470
- srt = np.argsort(eig_val)[::-1]
1471
- self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
1472
- return eig_val[srt], eig_vec[:, srt]
1473
-
1474
- def scale_embryo(self, scale=1000):
1475
- """Scale the embryo using their eigenvalues.
1476
-
1477
- Args:
1478
- scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
1479
-
1480
- Returns:
1481
- float: The scale factor.
1482
- """
1483
- eig = self.main_axes()[0]
1484
- return scale / (np.sqrt(eig[0]))
1485
-
1486
- @staticmethod
1487
- def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
1488
- """Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
1489
- Uses the Rodrigues rotation formula.
1490
-
1491
- Args:
1492
- vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
1493
- vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
1494
-
1495
- Returns:
1496
- np.array: The rotation matrix.
1497
- """
1498
- vector1 = vector1 / np.linalg.norm(vector1)
1499
- vector2 = vector2 / np.linalg.norm(vector2)
1500
- if vector1 @ vector2 == 1:
1501
- return np.eye(3)
1502
- angle = np.arccos(vector1 @ vector2)
1503
- axis = np.cross(vector1, vector2)
1504
- axis = axis / np.linalg.norm(axis)
1505
- K = np.array(
1506
- [
1507
- [0, -axis[2], axis[1]],
1508
- [axis[2], 0, -axis[0]],
1509
- [-axis[1], axis[0], 0],
1510
- ]
1511
- )
1512
- return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
1513
-
1514
- def get_ancestor_at_t(self, n: int, time: int = None):
1515
- """
1516
- Find the id of the ancestor of a give node `n`
1559
+ def get_ancestor_at_t(self, n: int, time: int | None = None) -> int:
1560
+ """Find the id of the ancestor of a give node `n`
1517
1561
  at a given time `time`.
1518
1562
 
1519
- If there is no ancestor, returns `-1`
1520
- If time is None return the root of the sub tree that spawns
1563
+ If there is no ancestor, returns `None`
1564
+ If time is None return the root of the subtree that spawns
1521
1565
  the node n.
1522
1566
 
1523
- Args:
1524
- n (int): node for which to look the ancestor
1525
- time (int): time at which the ancestor has to be found.
1526
- If `None` the ancestor at the first time point
1527
- will be found (default `None`)
1528
-
1529
- Returns:
1530
- (int): the id of the ancestor at time `time`,
1531
- `-1` if it does not exist
1567
+ Parameters
1568
+ ----------
1569
+ n : int
1570
+ node for which to look the ancestor
1571
+ time : int, optional
1572
+ time at which the ancestor has to be found.
1573
+ If `None` the ancestor at the first time point
1574
+ will be found.
1575
+
1576
+ Returns
1577
+ -------
1578
+ int
1579
+ the id of the ancestor at time `time`,
1580
+ `-1` if there is no ancestor.
1532
1581
  """
1533
1582
  if n not in self.nodes:
1534
- return
1583
+ return -1
1535
1584
  if time is None:
1536
1585
  time = self.t_b
1537
1586
  ancestor = n
1538
1587
  while (
1539
- time < self.time.get(ancestor, -1) and ancestor in self.predecessor
1588
+ time < self._time.get(ancestor, self.t_b - 1)
1589
+ and self._predecessor[ancestor]
1540
1590
  ):
1541
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1542
- return ancestor
1591
+ ancestor = self._predecessor[ancestor][0]
1592
+ if self._time.get(ancestor, self.t_b - 1) == time:
1593
+ return ancestor
1594
+ else:
1595
+ return -1
1543
1596
 
1544
- def get_labelled_ancestor(self, node: int):
1545
- """Finds the first labelled ancestor and returns its ID otherwise returns None
1597
+ def get_labelled_ancestor(self, node: int) -> int:
1598
+ """Finds the first labelled ancestor and returns its ID otherwise returns -1
1546
1599
 
1547
- Args:
1548
- node (int): The id of the node
1600
+ Parameters
1601
+ ----------
1602
+ node : int
1603
+ The id of the node
1549
1604
 
1550
- Returns:
1551
- [None,int]: Returns the first ancestor found that has a label otherwise
1552
- None.
1605
+ Returns
1606
+ -------
1607
+ int
1608
+ Returns the first ancestor found that has a label otherwise `-1`.
1553
1609
  """
1554
1610
  if node not in self.nodes:
1555
- return None
1611
+ return -1
1556
1612
  ancestor = node
1557
1613
  while (
1558
- self.t_b <= self.time.get(ancestor, self.t_b - 1)
1614
+ self.t_b <= self._time.get(ancestor, self.t_b - 1)
1559
1615
  and ancestor != -1
1560
1616
  ):
1561
1617
  if ancestor in self.labels:
1562
1618
  return ancestor
1563
- ancestor = self.predecessor.get(ancestor, [-1])[0]
1564
- return
1619
+ ancestor = self._predecessor.get(ancestor, [-1])[0]
1620
+ return -1
1621
+
1622
+ def get_ancestor_with_attribute(self, node: int, attribute: str) -> int:
1623
+ """General purpose function to help with searching the first ancestor that has an attribute.
1624
+ Similar to get_labeled_ancestor and may make it redundant.
1625
+
1626
+ Parameters
1627
+ ----------
1628
+ node : int
1629
+ The id of the node
1630
+
1631
+ Returns
1632
+ -------
1633
+ int
1634
+ Returns the first ancestor found that has an attribute otherwise `-1`.
1635
+ """
1636
+ attr_dict = self.__getattribute__(attribute)
1637
+ if not isinstance(attr_dict, dict):
1638
+ raise ValueError("Please select a dict attribute")
1639
+ if node not in self.nodes:
1640
+ return -1
1641
+ if node in attr_dict:
1642
+ return node
1643
+ if node in self.roots:
1644
+ return -1
1645
+ ancestor = (node,)
1646
+ while ancestor and ancestor != [-1]:
1647
+ ancestor = ancestor[0]
1648
+ if ancestor in attr_dict:
1649
+ return ancestor
1650
+ ancestor = self._predecessor.get(ancestor, [-1])
1651
+ return -1
1565
1652
 
1566
1653
  def unordered_tree_edit_distances_at_time_t(
1567
1654
  self,
1568
1655
  t: int,
1569
- end_time: int = None,
1570
- style="simple",
1656
+ end_time: int | None = None,
1657
+ style: (
1658
+ Literal["simple", "full", "downsampled", "normalized_simple"]
1659
+ | type[TreeApproximationTemplate]
1660
+ ) = "simple",
1571
1661
  downsample: int = 2,
1572
- normalize: bool = True,
1662
+ norm: Literal["max", "sum", None] = "max",
1573
1663
  recompute: bool = False,
1574
- ) -> dict:
1575
- """
1576
- Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1577
-
1578
- Args:
1579
- t (int): time to look at
1580
- delta (callable): comparison function (see edist doc for more information)
1581
- norm (callable): norming function that takes the number of nodes
1582
- of the tree spawned by `n1` and the number of nodes
1583
- of the tree spawned by `n2` as arguments.
1584
- recompute (bool): if True, forces to recompute the distances (default: False)
1585
- end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
1586
- will be taken into account.
1587
-
1588
- Returns:
1589
- (dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
1664
+ ) -> dict[tuple[int, int], float]:
1665
+ """Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
1666
+
1667
+ Parameters
1668
+ ----------
1669
+ t : int
1670
+ time to look at
1671
+ end_time : int
1672
+ The final time point the comparison algorithm will take into account.
1673
+ If None all nodes will be taken into account.
1674
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1675
+ Which tree approximation is going to be used for the comparisons.
1676
+ downsample : int, default=2
1677
+ The downsample factor for the downsampled tree approximation.
1678
+ Used only when `style="downsampled"`.
1679
+ norm : {"max", "sum"}, default="max"
1680
+ The normalization method to use.
1681
+ recompute : bool, default=False
1682
+ If True, forces to recompute the distances
1683
+
1684
+ Returns
1685
+ -------
1686
+ dict mapping a tuple of tuple that contains 2 ints to float
1687
+ a dictionary that maps a pair of node ids at time `t` to their unordered tree edit distance
1590
1688
  """
1591
1689
  if not hasattr(self, "uted"):
1592
1690
  self.uted = {}
1593
1691
  elif t in self.uted and not recompute:
1594
1692
  return self.uted[t]
1595
1693
  self.uted[t] = {}
1596
- roots = self.time_nodes[t]
1694
+ roots = self.nodes_at_t(t=t)
1597
1695
  for n1, n2 in combinations(roots, 2):
1598
1696
  key = tuple(sorted((n1, n2)))
1599
1697
  self.uted[t][key] = self.unordered_tree_edit_distance(
@@ -1602,37 +1700,132 @@ class lineageTree(lineageTreeLoaders):
1602
1700
  end_time=end_time,
1603
1701
  style=style,
1604
1702
  downsample=downsample,
1605
- normalize=normalize,
1703
+ norm=norm,
1606
1704
  )
1607
1705
  return self.uted[t]
1608
1706
 
1609
- def unordered_tree_edit_distance(
1707
+ def __calculate_distance_of_sub_tree(
1708
+ self,
1709
+ node1: int,
1710
+ node2: int,
1711
+ alignment: Alignment,
1712
+ corres1: dict[int, int],
1713
+ corres2: dict[int, int],
1714
+ delta_tmp: Callable,
1715
+ norm: Callable,
1716
+ norm1: int | float,
1717
+ norm2: int | float,
1718
+ ) -> float:
1719
+ """Calculates the distance of the subtree of each node matched in a comparison.
1720
+ DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
1721
+ TODO ITS BOUND TO CHANGE
1722
+ Parameters
1723
+ ----------
1724
+ node1 : int
1725
+ The root of the first subtree
1726
+ node2 : int
1727
+ The root of the second subtree
1728
+ alignment : Alignment
1729
+ The alignment of the subtree
1730
+ corres1 : dict
1731
+ The correspndance dictionary of the first lineage
1732
+ corres2 : dict
1733
+ The correspondance dictionary of the second lineage
1734
+ delta_tmp : Callable
1735
+ The delta function for the comparisons
1736
+ norm : Callable
1737
+ How should the lineages be normalized
1738
+ norm1 : int or float
1739
+ The result of the normalization of the first tree
1740
+ norm2 : int or float
1741
+ The result of the normalization of the second tree
1742
+
1743
+ Returns
1744
+ -------
1745
+ float
1746
+ The result of the comparison of the subtree
1747
+ """
1748
+ sub_tree_1 = set(self.get_subtree_nodes(node1))
1749
+ sub_tree_2 = set(self.get_subtree_nodes(node2))
1750
+ res = 0
1751
+ for m in alignment:
1752
+ if (
1753
+ corres1.get(m._left, -1) in sub_tree_1
1754
+ or corres2.get(m._right, -1) in sub_tree_2
1755
+ ):
1756
+ res += delta_tmp(
1757
+ m._left if m._left != -1 else None,
1758
+ m._right if m._right != -1 else None,
1759
+ )
1760
+ return res / norm([norm1, norm2])
1761
+
1762
+ def clear_comparisons(self):
1763
+ self._comparisons.clear()
1764
+
1765
+ def __unordereded_backtrace(
1610
1766
  self,
1611
1767
  n1: int,
1612
1768
  n2: int,
1613
- end_time: int = None,
1614
- norm: Union["max", "sum", None] = "max",
1615
- style="simple",
1769
+ end_time: int | None = None,
1770
+ norm: Literal["max", "sum", None] = "max",
1771
+ style: (
1772
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1773
+ | type[TreeApproximationTemplate]
1774
+ ) = "simple",
1616
1775
  downsample: int = 2,
1617
- ) -> float:
1776
+ ) -> dict[
1777
+ str,
1778
+ Alignment
1779
+ | tuple[TreeApproximationTemplate, TreeApproximationTemplate],
1780
+ ]:
1618
1781
  """
1619
- Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
1782
+ Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
1620
1783
  by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
1621
1784
  cost is given by the function delta (see edist doc for more information).
1622
- The distance is normed by the function norm that takes the two list of nodes
1623
- spawned by the trees `n1` and `n2`.
1624
-
1625
- Args:
1626
- n1 (int): id of the first node to compare
1627
- n2 (int): id of the second node to compare
1628
- tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
1629
- Defaults to "fragmented".
1630
-
1631
- Returns:
1632
- (float) The normed unordered tree edit distance
1633
- """
1634
1785
 
1635
- tree = tree_style[style].value
1786
+ Parameters
1787
+ ----------
1788
+ n1 : int
1789
+ id of the first node to compare
1790
+ n2 : int
1791
+ id of the second node to compare
1792
+ end_time : int
1793
+ The final time point the comparison algorithm will take into account.
1794
+ If None all nodes will be taken into account.
1795
+ norm : {"max", "sum"}, default="max"
1796
+ The normalization method to use.
1797
+ style : {"simple", "full", "downsampled", "normalized_simple"} or TreeApproximationTemplate subclass, default="simple"
1798
+ Which tree approximation is going to be used for the comparisons.
1799
+ downsample : int, default=2
1800
+ The downsample factor for the downsampled tree approximation.
1801
+ Used only when `style="downsampled"`.
1802
+
1803
+ Returns
1804
+ -------
1805
+ dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
1806
+ - 'alignment'
1807
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.
1808
+ - 'trees'
1809
+ A list of the two trees that have been mapped to each other.
1810
+ """
1811
+
1812
+ parameters = (
1813
+ end_time,
1814
+ convert_style_to_number(style=style, downsample=downsample),
1815
+ )
1816
+ n1, n2 = sorted([n1, n2])
1817
+ self._comparisons.setdefault(parameters, {})
1818
+ if len(self._comparisons) > 100:
1819
+ warnings.warn(
1820
+ "More than 100 comparisons are saved, use clear_comparisons() to delete them.",
1821
+ stacklevel=2,
1822
+ )
1823
+ if isinstance(style, str):
1824
+ tree = tree_style[style].value
1825
+ elif issubclass(style, TreeApproximationTemplate):
1826
+ tree = style
1827
+ else:
1828
+ raise ValueError("Please use a valid approximation.")
1636
1829
  tree1 = tree(
1637
1830
  lT=self,
1638
1831
  downsample=downsample,
@@ -1661,7 +1854,11 @@ class lineageTree(lineageTreeLoaders):
1661
1854
  corres2,
1662
1855
  ) = tree2.edist
1663
1856
  if len(nodes1) == len(nodes2) == 0:
1664
- return 0
1857
+ self._comparisons[parameters][(n1, n2)] = {
1858
+ "alignment": (),
1859
+ "trees": (),
1860
+ }
1861
+ return self._comparisons[parameters][(n1, n2)]
1665
1862
  delta_tmp = partial(
1666
1863
  delta,
1667
1864
  corres1=corres1,
@@ -1669,126 +1866,538 @@ class lineageTree(lineageTreeLoaders):
1669
1866
  times1=times1,
1670
1867
  times2=times2,
1671
1868
  )
1672
- norm1 = tree1.get_norm()
1673
- norm2 = tree2.get_norm()
1674
- norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
1675
- if norm is None:
1676
- norm = "None"
1677
- if norm not in norm_dict:
1869
+ btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
1870
+
1871
+ self._comparisons[parameters][(n1, n2)] = {
1872
+ "alignment": btrc,
1873
+ "trees": (tree1, tree2),
1874
+ }
1875
+ return self._comparisons[parameters][(n1, n2)]
1876
+
1877
+ def plot_tree_distance_graphs(
1878
+ self,
1879
+ n1: int,
1880
+ n2: int,
1881
+ end_time: int | None = None,
1882
+ norm: Literal["max", "sum", None] = "max",
1883
+ style: (
1884
+ Literal["simple", "normalized_simple", "full", "downsampled"]
1885
+ | type[TreeApproximationTemplate]
1886
+ ) = "simple",
1887
+ downsample: int = 2,
1888
+ colormap: str = "cool",
1889
+ default_color: str = "black",
1890
+ size: float = 10,
1891
+ lw: float = 0.3,
1892
+ ax: list[plt.Axes] | None = None,
1893
+ ) -> tuple[plt.figure, plt.Axes]:
1894
+ """
1895
+ Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
1896
+
1897
+ Parameters
1898
+ ----------
1899
+ n1 : int
1900
+ id of the first node to compare
1901
+ n2 : int
1902
+ id of the second node to compare
1903
+ end_time : int
1904
+ The final time point the comparison algorithm will take into account.
1905
+ If None all nodes will be taken into account.
1906
+ norm : {"max", "sum"}, default="max"
1907
+ The normalization method to use.
1908
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
1909
+ Which tree approximation is going to be used for the comparisons.
1910
+ downsample : int, default=2
1911
+ The downsample factor for the downsampled tree approximation.
1912
+ Used only when `style="downsampled"`.
1913
+ colormap : str, default="cool"
1914
+ The colormap used for matched nodes, defaults to "cool"
1915
+ default_color : str
1916
+ The color of the unmatched nodes, defaults to "black"
1917
+ size : float
1918
+ The size of the nodes, defaults to 10
1919
+ lw : float
1920
+ The width of the edges, defaults to 0.3
1921
+ ax : np.ndarray, optional
1922
+ The axes used, if not provided another set of axes is produced, defaults to None
1923
+
1924
+ Returns
1925
+ -------
1926
+ plt.Figure
1927
+ The figure of the plot
1928
+ plt.Axes
1929
+ The axes of the plot
1930
+ """
1931
+ parameters = (
1932
+ end_time,
1933
+ convert_style_to_number(style=style, downsample=downsample),
1934
+ )
1935
+ n1, n2 = sorted([n1, n2])
1936
+ self._comparisons.setdefault(parameters, {})
1937
+ if self._comparisons[parameters].get((n1, n2)):
1938
+ tmp = self._comparisons[parameters][(n1, n2)]
1939
+ else:
1940
+ tmp = self.__unordereded_backtrace(
1941
+ n1, n2, end_time, norm, style, downsample
1942
+ )
1943
+ btrc: Alignment = tmp["alignment"]
1944
+ tree1, tree2 = tmp["trees"]
1945
+ _, times1 = tree1.tree
1946
+ _, times2 = tree2.tree
1947
+ (
1948
+ *_,
1949
+ corres1,
1950
+ ) = tree1.edist
1951
+ (
1952
+ *_,
1953
+ corres2,
1954
+ ) = tree2.edist
1955
+ delta_tmp = partial(
1956
+ tree1.delta,
1957
+ corres1=corres1,
1958
+ corres2=corres2,
1959
+ times1=times1,
1960
+ times2=times2,
1961
+ )
1962
+
1963
+ if norm not in self.norm_dict:
1964
+ raise Warning(
1965
+ "Select a viable normalization method (max, sum, None)"
1966
+ )
1967
+ matched_right = []
1968
+ matched_left = []
1969
+ colors = {}
1970
+ if style not in ("full", "downsampled"):
1971
+ for m in btrc:
1972
+ if m._left != -1 and m._right != -1:
1973
+ cyc1 = self.get_chain_of_node(corres1[m._left])
1974
+ if len(cyc1) > 1:
1975
+ node_1, *_, l_node_1 = cyc1
1976
+ matched_left.append(node_1)
1977
+ matched_left.append(l_node_1)
1978
+ elif len(cyc1) == 1:
1979
+ node_1 = l_node_1 = cyc1.pop()
1980
+ matched_left.append(node_1)
1981
+
1982
+ cyc2 = self.get_chain_of_node(corres2[m._right])
1983
+ if len(cyc2) > 1:
1984
+ node_2, *_, l_node_2 = cyc2
1985
+ matched_right.append(node_2)
1986
+ matched_right.append(l_node_2)
1987
+
1988
+ elif len(cyc2) == 1:
1989
+ node_2 = l_node_2 = cyc2.pop()
1990
+ matched_right.append(node_2)
1991
+
1992
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
1993
+ node_1,
1994
+ node_2,
1995
+ btrc,
1996
+ corres1,
1997
+ corres2,
1998
+ delta_tmp,
1999
+ self.norm_dict[norm],
2000
+ tree1.get_norm(node_1),
2001
+ tree2.get_norm(node_2),
2002
+ )
2003
+ colors[node_2] = colors[node_1]
2004
+ colors[l_node_1] = colors[node_1]
2005
+ colors[l_node_2] = colors[node_2]
2006
+ else:
2007
+ for m in btrc:
2008
+ if m._left != -1 and m._right != -1:
2009
+ node_1 = corres1[m._left]
2010
+ node_2 = corres2[m._right]
2011
+
2012
+ if (
2013
+ self.get_chain_of_node(node_1)[0] == node_1
2014
+ or self.get_chain_of_node(node_2)[0] == node_2
2015
+ and (node_1 not in colors or node_2 not in colors)
2016
+ ):
2017
+ matched_left.append(node_1)
2018
+ l_node_1 = self.get_chain_of_node(node_1)[-1]
2019
+ matched_left.append(l_node_1)
2020
+ matched_right.append(node_2)
2021
+ l_node_2 = self.get_chain_of_node(node_2)[-1]
2022
+ matched_right.append(l_node_2)
2023
+ colors[node_1] = self.__calculate_distance_of_sub_tree(
2024
+ node_1,
2025
+ node_2,
2026
+ btrc,
2027
+ corres1,
2028
+ corres2,
2029
+ delta_tmp,
2030
+ self.norm_dict[norm],
2031
+ tree1.get_norm(node_1),
2032
+ tree2.get_norm(node_2),
2033
+ )
2034
+ colors[l_node_1] = colors[node_1]
2035
+ colors[node_2] = colors[node_1]
2036
+ colors[l_node_2] = colors[node_1]
2037
+ if ax is None:
2038
+ fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
2039
+ cmap = colormaps[colormap]
2040
+ c_norm = mcolors.Normalize(0, 1)
2041
+ colors = {c: cmap(c_norm(v)) for c, v in colors.items()}
2042
+ self.plot_subtree(
2043
+ self.get_ancestor_at_t(n1),
2044
+ end_time=end_time,
2045
+ size=size,
2046
+ selected_nodes=matched_left,
2047
+ color_of_nodes=colors,
2048
+ selected_edges=matched_left,
2049
+ color_of_edges=colors,
2050
+ default_color=default_color,
2051
+ lw=lw,
2052
+ ax=ax[0],
2053
+ )
2054
+ self.plot_subtree(
2055
+ self.get_ancestor_at_t(n2),
2056
+ end_time=end_time,
2057
+ size=size,
2058
+ selected_nodes=matched_right,
2059
+ color_of_nodes=colors,
2060
+ selected_edges=matched_right,
2061
+ color_of_edges=colors,
2062
+ default_color=default_color,
2063
+ lw=lw,
2064
+ ax=ax[1],
2065
+ )
2066
+ return ax[0].get_figure(), ax
2067
+
2068
+ def labelled_mappings(
2069
+ self,
2070
+ n1: int,
2071
+ n2: int,
2072
+ end_time: int | None = None,
2073
+ norm: Literal["max", "sum", None] = "max",
2074
+ style: (
2075
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2076
+ | type[TreeApproximationTemplate]
2077
+ ) = "simple",
2078
+ downsample: int = 2,
2079
+ ) -> dict[str, list[str]]:
2080
+ """
2081
+ Returns the labels or IDs of all the nodes in the subtrees compared.
2082
+
2083
+
2084
+ Parameters
2085
+ ----------
2086
+ n1 : int
2087
+ id of the first node to compare
2088
+ n2 : int
2089
+ id of the second node to compare
2090
+ end_time : int, optional
2091
+ The final time point the comparison algorithm will take into account.
2092
+ If None or not provided all nodes will be taken into account.
2093
+ norm : {"max", "sum"}, default="max"
2094
+ The normalization method to use, defaults to 'max'.
2095
+ style : {"simple", "full", "downsampled", "normalized_simple} or TreeApproximationTemplate subclass, default="simple"
2096
+ Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
2097
+ downsample : int, default=2
2098
+ The downsample factor for the downsampled tree approximation.
2099
+ Used only when `style="downsampled"`.
2100
+
2101
+ Returns
2102
+ -------
2103
+ dict mapping str to list of str
2104
+ - 'matched' The labels of the matched nodes of the alignment.
2105
+ - 'unmatched' The labels of the unmatched nodes of the alginment.
2106
+ """
2107
+ parameters = (
2108
+ end_time,
2109
+ convert_style_to_number(style=style, downsample=downsample),
2110
+ )
2111
+ n1, n2 = sorted([n1, n2])
2112
+ self._comparisons.setdefault(parameters, {})
2113
+ if self._comparisons[parameters].get((n1, n2)):
2114
+ tmp = self._comparisons[parameters][(n1, n2)]
2115
+ else:
2116
+ tmp = self.__unordereded_backtrace(
2117
+ n1, n2, end_time, norm, style, downsample
2118
+ )
2119
+ btrc = tmp["alignment"]
2120
+ tree1, tree2 = tmp["trees"]
2121
+
2122
+ (
2123
+ *_,
2124
+ corres1,
2125
+ ) = tree1.edist
2126
+ (
2127
+ *_,
2128
+ corres2,
2129
+ ) = tree2.edist
2130
+
2131
+ if norm not in self.norm_dict:
1678
2132
  raise Warning(
1679
2133
  "Select a viable normalization method (max, sum, None)"
1680
2134
  )
1681
- return uted.uted(
1682
- nodes1, adj1, nodes2, adj2, delta=delta_tmp
1683
- ) / norm_dict[norm]([norm1, norm2])
2135
+ matched = []
2136
+ unmatched = []
2137
+ if style not in ("full", "downsampled"):
2138
+ for m in btrc:
2139
+ if m._left != -1 and m._right != -1:
2140
+ cyc1 = self.get_chain_of_node(corres1[m._left])
2141
+ if len(cyc1) > 1:
2142
+ node_1, *_ = cyc1
2143
+ elif len(cyc1) == 1:
2144
+ node_1 = cyc1.pop()
2145
+ cyc2 = self.get_chain_of_node(corres2[m._right])
2146
+ if len(cyc2) > 1:
2147
+ node_2, *_ = cyc2
2148
+ elif len(cyc2) == 1:
2149
+ node_2 = cyc2.pop()
2150
+ matched.append(
2151
+ (
2152
+ self.labels.get(node_1, node_1),
2153
+ self.labels.get(node_2, node_2),
2154
+ )
2155
+ )
2156
+
2157
+ else:
2158
+ if m._left != -1:
2159
+ node_1 = self.get_chain_of_node(
2160
+ corres1.get(m._left, "-")
2161
+ )[0]
2162
+ else:
2163
+ node_1 = self.get_chain_of_node(
2164
+ corres2.get(m._right, "-")
2165
+ )[0]
2166
+ unmatched.append(self.labels.get(node_1, node_1))
2167
+ else:
2168
+ for m in btrc:
2169
+ if m._left != -1 and m._right != -1:
2170
+ node_1 = corres1[m._left]
2171
+ node_2 = corres2[m._right]
2172
+ matched.append(
2173
+ (
2174
+ self.labels.get(node_1, node_1),
2175
+ self.labels.get(node_2, node_2),
2176
+ )
2177
+ )
2178
+ else:
2179
+ if m._left != -1:
2180
+ node_1 = corres1[m._left]
2181
+ else:
2182
+ node_1 = corres2[m._right]
2183
+ unmatched.append(self.labels.get(node_1, node_1))
2184
+ return {"matched": matched, "unmatched": unmatched}
2185
+
2186
+ def unordered_tree_edit_distance(
2187
+ self,
2188
+ n1: int,
2189
+ n2: int,
2190
+ end_time: int | None = None,
2191
+ norm: Literal["max", "sum", None] = "max",
2192
+ style: (
2193
+ Literal["simple", "normalized_simple", "full", "downsampled"]
2194
+ | type[TreeApproximationTemplate]
2195
+ ) = "simple",
2196
+ downsample: int = 2,
2197
+ return_norms: bool = False,
2198
+ ) -> float | tuple[float, tuple[float, float]]:
2199
+ """
2200
+ Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
2201
+ by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
2202
+ cost is given by the function delta (see edist doc for more information).
2203
+ The distance is normed by the function norm that takes the two list of nodes
2204
+ spawned by the trees `n1` and `n2`.
2205
+
2206
+ Parameters
2207
+ ----------
2208
+ n1 : int
2209
+ id of the first node to compare
2210
+ n2 : int
2211
+ id of the second node to compare
2212
+ end_time : int, optional
2213
+ The final time point the comparison algorithm will take into account.
2214
+ If None or not provided all nodes will be taken into account.
2215
+ norm : {"max", "sum"}, default="max"
2216
+ The normalization method to use, defaults to 'max'.
2217
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
2218
+ Which tree approximation is going to be used for the comparisons.
2219
+ downsample : int, default=2
2220
+ The downsample factor for the downsampled tree approximation.
2221
+ Used only when `style="downsampled"`.
2222
+
2223
+ Returns
2224
+ -------
2225
+ float
2226
+ The normalized unordered tree edit distance between `n1` and `n2`
2227
+ """
2228
+ parameters = (
2229
+ end_time,
2230
+ convert_style_to_number(style=style, downsample=downsample),
2231
+ )
2232
+ n1, n2 = sorted([n1, n2])
2233
+ self._comparisons.setdefault(parameters, {})
2234
+ if self._comparisons[parameters].get((n1, n2)):
2235
+ tmp = self._comparisons[parameters][(n1, n2)]
2236
+ else:
2237
+ tmp = self.__unordereded_backtrace(
2238
+ n1, n2, end_time, norm, style, downsample
2239
+ )
2240
+ btrc = tmp["alignment"]
2241
+ tree1, tree2 = tmp["trees"]
2242
+ _, times1 = tree1.tree
2243
+ _, times2 = tree2.tree
2244
+ (
2245
+ nodes1,
2246
+ adj1,
2247
+ corres1,
2248
+ ) = tree1.edist
2249
+ (
2250
+ nodes2,
2251
+ adj2,
2252
+ corres2,
2253
+ ) = tree2.edist
2254
+ delta_tmp = partial(
2255
+ tree1.delta,
2256
+ corres1=corres1,
2257
+ corres2=corres2,
2258
+ times1=times1,
2259
+ times2=times2,
2260
+ )
2261
+
2262
+ if norm not in self.norm_dict:
2263
+ raise ValueError(
2264
+ "Select a viable normalization method (max, sum, None)"
2265
+ )
2266
+ cost = btrc.cost(nodes1, nodes2, delta_tmp)
2267
+ norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
2268
+ if return_norms:
2269
+ return cost, norm_values
2270
+ return cost / self.norm_dict[norm](norm_values)
1684
2271
 
1685
2272
  @staticmethod
1686
2273
  def __plot_nodes(
1687
- hier, selected_nodes, color, size, ax, default_color="black", **kwargs
1688
- ):
2274
+ hier: dict,
2275
+ selected_nodes: set,
2276
+ color: str | dict | list,
2277
+ size: int | float,
2278
+ ax: plt.Axes,
2279
+ default_color: str = "black",
2280
+ **kwargs,
2281
+ ) -> None:
1689
2282
  """
1690
2283
  Private method that plots the nodes of the tree.
1691
2284
  """
1692
- hier_unselected = np.array(
1693
- [v for k, v in hier.items() if k not in selected_nodes]
1694
- )
1695
- if hier_unselected.any():
1696
- ax.scatter(
1697
- *hier_unselected.T,
1698
- s=size,
1699
- zorder=10,
1700
- color=default_color,
1701
- **kwargs,
1702
- )
1703
- if selected_nodes.intersection(hier.keys()):
1704
- hier_selected = np.array(
1705
- [v for k, v in hier.items() if k in selected_nodes]
1706
- )
1707
- ax.scatter(
1708
- *hier_selected.T, s=size, zorder=10, color=color, **kwargs
1709
- )
2285
+
2286
+ if isinstance(color, dict):
2287
+ color = [color.get(k, default_color) for k in hier]
2288
+ elif isinstance(color, str | list):
2289
+ color = [
2290
+ color if node in selected_nodes else default_color
2291
+ for node in hier
2292
+ ]
2293
+ hier_pos = np.array(list(hier.values()))
2294
+ ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)
1710
2295
 
1711
2296
  @staticmethod
1712
2297
  def __plot_edges(
1713
- hier,
1714
- lnks_tms,
1715
- selected_edges,
1716
- color,
1717
- ax,
1718
- default_color="black",
2298
+ hier: dict,
2299
+ lnks_tms: dict,
2300
+ selected_edges: Iterable,
2301
+ color: str | dict | list,
2302
+ lw: float,
2303
+ ax: plt.Axes,
2304
+ default_color: str = "black",
1719
2305
  **kwargs,
1720
- ):
2306
+ ) -> None:
1721
2307
  """
1722
2308
  Private method that plots the edges of the tree.
1723
2309
  """
1724
- x, y = [], []
1725
- for pred, succs in lnks_tms["links"].items():
1726
- for succ in succs:
1727
- if pred not in selected_edges or succ not in selected_edges:
1728
- x.extend((hier[succ][0], hier[pred][0], None))
1729
- y.extend((hier[succ][1], hier[pred][1], None))
1730
- ax.plot(x, y, linewidth=0.3, zorder=0.1, c=default_color, **kwargs)
1731
- x, y = [], []
2310
+ if isinstance(color, dict):
2311
+ selected_edges = color.keys()
2312
+ lines = []
2313
+ c = []
1732
2314
  for pred, succs in lnks_tms["links"].items():
1733
- for succ in succs:
1734
- if pred in selected_edges and succ in selected_edges:
1735
- x.extend((hier[succ][0], hier[pred][0], None))
1736
- y.extend((hier[succ][1], hier[pred][1], None))
1737
- ax.plot(x, y, linewidth=0.3, zorder=0.2, c=color, **kwargs)
2315
+ for suc in succs:
2316
+ lines.append(
2317
+ [
2318
+ [hier[suc][0], hier[suc][1]],
2319
+ [hier[pred][0], hier[pred][1]],
2320
+ ]
2321
+ )
2322
+ if pred in selected_edges:
2323
+ if isinstance(color, str | list):
2324
+ c.append(color)
2325
+ elif isinstance(color, dict):
2326
+ c.append(color[pred])
2327
+ else:
2328
+ c.append(default_color)
2329
+ lc = LineCollection(lines, colors=c, linewidth=lw, **kwargs)
2330
+ ax.add_collection(lc)
1738
2331
 
1739
2332
  def draw_tree_graph(
1740
2333
  self,
1741
- hier,
1742
- lnks_tms,
1743
- selected_nodes=None,
1744
- selected_edges=None,
1745
- color_of_nodes="magenta",
1746
- color_of_edges=None,
1747
- size=10,
1748
- ax=None,
1749
- default_color="black",
2334
+ hier: dict[int, tuple[int, int]],
2335
+ lnks_tms: dict[str, dict[int, list | int]],
2336
+ selected_nodes: list | set | None = None,
2337
+ selected_edges: list | set | None = None,
2338
+ color_of_nodes: str | dict = "magenta",
2339
+ color_of_edges: str | dict = "magenta",
2340
+ size: int | float = 10,
2341
+ lw: float = 0.3,
2342
+ ax: plt.Axes | None = None,
2343
+ default_color: str = "black",
1750
2344
  **kwargs,
1751
- ):
2345
+ ) -> tuple[plt.Figure, plt.Axes]:
1752
2346
  """Function to plot the tree graph.
1753
2347
 
1754
- Args:
1755
- hier (dict): Dictinary that contains the positions of all nodes.
1756
- lnks_tms (dict): 2 dictionaries: 1 contains all links from start of life cycle to end of life cycle and
1757
- the succesors of each cell.
1758
- 1 contains the length of each life cycle.
1759
- selected_nodes (list|set, optional): Which cells are to be selected (Painted with a different color). Defaults to None.
1760
- selected_edges (list|set, optional): Which edges are to be selected (Painted with a different color). Defaults to None.
1761
- color_of_nodes (str, optional): Color of selected nodes. Defaults to "magenta".
1762
- color_of_edges (_type_, optional): Color of selected edges. Defaults to None.
1763
- size (int, optional): Size of the nodes. Defaults to 10.
1764
- ax (_type_, optional): Plot the graph on existing ax. Defaults to None.
1765
- figure (_type_, optional): _description_. Defaults to None.
1766
- default_color (str, optional): Default color of nodes. Defaults to "black".
1767
-
1768
- Returns:
1769
- figure, ax: The matplotlib figure and ax object.
2348
+ Parameters
2349
+ ----------
2350
+ hier : dict mapping int to tuple of int
2351
+ Dictionary that contains the positions of all nodes.
2352
+ lnks_tms : dict mapping string to dictionaries mapping int to list or int
2353
+ - 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
2354
+ - 'times' : contains the distance between the start and the end of each chain.
2355
+ selected_nodes : list or set, optional
2356
+ Which nodes are to be selected (Painted with a different color, according to 'color_'of_nodes')
2357
+ selected_edges : list or set, optional
2358
+ Which edges are to be selected (Painted with a different color, according to 'color_'of_edges')
2359
+ color_of_nodes : str, default="magenta"
2360
+ Color of selected nodes
2361
+ color_of_edges : str, default="magenta"
2362
+ Color of selected edges
2363
+ size : int, default=10
2364
+ Size of the nodes, defaults to 10
2365
+ lw : float, default=0.3
2366
+ The width of the edges of the tree graph, defaults to 0.3
2367
+ ax : plt.Axes, optional
2368
+ Plot the graph on existing ax. If not provided or None a new ax is going to be created.
2369
+ default_color : str, default="black"
2370
+ Default color of nodes
2371
+
2372
+ Returns
2373
+ -------
2374
+ plt.Figure
2375
+ The matplotlib figure
2376
+ plt.Axes
2377
+ The matplotlib ax
1770
2378
  """
1771
2379
  if selected_nodes is None:
1772
2380
  selected_nodes = []
1773
2381
  if selected_edges is None:
1774
2382
  selected_edges = []
1775
2383
  if ax is None:
1776
- figure, ax = plt.subplots()
2384
+ _, ax = plt.subplots()
1777
2385
  else:
1778
2386
  ax.clear()
1779
2387
  if not isinstance(selected_nodes, set):
1780
2388
  selected_nodes = set(selected_nodes)
1781
2389
  if not isinstance(selected_edges, set):
1782
2390
  selected_edges = set(selected_edges)
1783
- self.__plot_nodes(
1784
- hier,
1785
- selected_nodes,
1786
- color_of_nodes,
1787
- size=size,
1788
- ax=ax,
1789
- default_color=default_color,
1790
- **kwargs,
1791
- )
2391
+ if 0 < size:
2392
+ self.__plot_nodes(
2393
+ hier,
2394
+ selected_nodes,
2395
+ color_of_nodes,
2396
+ size=size,
2397
+ ax=ax,
2398
+ default_color=default_color,
2399
+ **kwargs,
2400
+ )
1792
2401
  if not color_of_edges:
1793
2402
  color_of_edges = color_of_nodes
1794
2403
  self.__plot_edges(
@@ -1796,62 +2405,107 @@ class lineageTree(lineageTreeLoaders):
1796
2405
  lnks_tms,
1797
2406
  selected_edges,
1798
2407
  color_of_edges,
2408
+ lw,
1799
2409
  ax,
1800
2410
  default_color=default_color,
1801
2411
  **kwargs,
1802
2412
  )
2413
+ ax.autoscale()
2414
+ plt.draw()
1803
2415
  ax.get_yaxis().set_visible(False)
1804
2416
  ax.get_xaxis().set_visible(False)
1805
2417
  return ax.get_figure(), ax
1806
2418
 
1807
- def to_simple_graph(self, node=None, start_time: int = None):
2419
+ def _create_dict_of_plots(
2420
+ self,
2421
+ node: int | Iterable[int] | None = None,
2422
+ start_time: int | None = None,
2423
+ end_time: int | None = None,
2424
+ ) -> dict[int, dict]:
1808
2425
  """Generates a dictionary of graphs where the keys are the index of the graph and
1809
- the values are the graphs themselves which are produced by create_links_and _cycles
1810
-
1811
- Args:
1812
- node (_type_, optional): The id of the node/nodes to produce the simple graphs. Defaults to None.
1813
- start_time (int, optional): Important only if there are no nodes it will produce the graph of every
1814
- root that starts before or at start time. Defaults to None.
1815
-
1816
- Returns:
1817
- (dict): The keys are just index values 0-n and the values are the graphs produced.
2426
+ the values are the graphs themselves which are produced by `create_links_and_chains`
2427
+
2428
+ Parameters
2429
+ ----------
2430
+ node : int or Iterable of int, optional
2431
+ The id of the node/nodes to produce the simple graphs, if not provided or None will
2432
+ calculate the dicts for every root that starts before 'start_time'
2433
+ start_time : int, optional
2434
+ Important only if there are no nodes it will produce the graph of every
2435
+ root that starts before or at start time. If not provided or None the 'start_time' defaults to the start of the dataset.
2436
+ end_time : int, optional
2437
+ The last timepoint to be considered, if not provided or None the last timepoint of the
2438
+ dataset (t_e) is considered.
2439
+
2440
+ Returns
2441
+ -------
2442
+ dict mapping int to dict
2443
+ The keys are just index values 0-n and the values are the graphs produced.
1818
2444
  """
1819
2445
  if start_time is None:
1820
2446
  start_time = self.t_b
2447
+ if end_time is None:
2448
+ end_time = self.t_e
1821
2449
  if node is None:
1822
2450
  mothers = [
1823
- root for root in self.roots if self.time[root] <= start_time
2451
+ root for root in self.roots if self._time[root] <= start_time
1824
2452
  ]
2453
+ elif isinstance(node, Iterable):
2454
+ mothers = node
1825
2455
  else:
1826
- mothers = node if isinstance(node, (list, set)) else [node]
2456
+ mothers = [node]
1827
2457
  return {
1828
- i: create_links_and_cycles(self, mother)
2458
+ i: create_links_and_chains(self, mother, end_time=end_time)
1829
2459
  for i, mother in enumerate(mothers)
1830
2460
  }
1831
2461
 
1832
2462
  def plot_all_lineages(
1833
2463
  self,
1834
- nodes: list = None,
1835
- last_time_point_to_consider: int = None,
1836
- nrows=2,
1837
- figsize=(10, 15),
1838
- dpi=100,
1839
- fontsize=15,
1840
- axes=None,
1841
- vert_gap=1,
2464
+ nodes: list | None = None,
2465
+ last_time_point_to_consider: int | None = None,
2466
+ nrows: int = 2,
2467
+ figsize: tuple[int, int] = (10, 15),
2468
+ dpi: int = 100,
2469
+ fontsize: int = 15,
2470
+ axes: plt.Axes | None = None,
2471
+ vert_gap: int = 1,
1842
2472
  **kwargs,
1843
- ):
2473
+ ) -> tuple[plt.Figure, plt.Axes, dict[plt.Axes, int]]:
1844
2474
  """Plots all lineages.
1845
2475
 
1846
- Args:
1847
- last_time_point_to_consider (int, optional): Which timepoints and upwards are the graphs to be plotted.
1848
- For example if start_time is 10, then all trees that begin
1849
- on tp 10 or before are calculated. Defaults to None, where
1850
- it will plot all the roots that exist on self.t_b.
1851
- nrows (int): How many rows of plots should be printed.
1852
- kwargs: args accepted by matplotlib
2476
+ Parameters
2477
+ ----------
2478
+ nodes : list, optional
2479
+ The nodes spawning the graphs to be plotted.
2480
+ last_time_point_to_consider : int, optional
2481
+ Which timepoints and upwards are the graphs to be plotted.
2482
+ For example if start_time is 10, then all trees that begin
2483
+ on tp 10 or before are calculated. Defaults to None, where
2484
+ it will plot all the roots that exist on `self.t_b`.
2485
+ nrows : int, default=2
2486
+ How many rows of plots should be printed.
2487
+ figsize : tuple, default=(10, 15)
2488
+ The size of the figure.
2489
+ dpi : int, default=100
2490
+ The dpi of the figure.
2491
+ fontsize : int, default=15
2492
+ The fontsize of the labels.
2493
+ axes : plt.Axes, optional
2494
+ The axes to plot the graphs on. If None or not provided new axes are going to be created.
2495
+ vert_gap : int, default=1
2496
+ space between the nodes, defaults to 1
2497
+ **kwargs:
2498
+ kwargs accepted by matplotlib.pyplot.plot, matplotlib.pyplot.scatter
2499
+
2500
+ Returns
2501
+ -------
2502
+ plt.Figure
2503
+ The figure
2504
+ plt.Axes
2505
+ The axes
2506
+ dict of plt.Axes to int
2507
+ A dictionary that maps the axes to the root of the tree.
1853
2508
  """
1854
-
1855
2509
  nrows = int(nrows)
1856
2510
  if last_time_point_to_consider is None:
1857
2511
  last_time_point_to_consider = self.t_b
@@ -1860,17 +2514,18 @@ class lineageTree(lineageTreeLoaders):
1860
2514
  raise Warning("Number of rows has to be at least 1")
1861
2515
  if nodes:
1862
2516
  graphs = {
1863
- i: self.to_simple_graph(node) for i, node in enumerate(nodes)
2517
+ i: self._create_dict_of_plots(node)
2518
+ for i, node in enumerate(nodes)
1864
2519
  }
1865
2520
  else:
1866
- graphs = self.to_simple_graph(
2521
+ graphs = self._create_dict_of_plots(
1867
2522
  start_time=last_time_point_to_consider
1868
2523
  )
1869
2524
  pos = {
1870
2525
  i: hierarchical_pos(
1871
2526
  g,
1872
2527
  g["root"],
1873
- ycenter=-int(self.time[g["root"]]),
2528
+ ycenter=-int(self._time[g["root"]]),
1874
2529
  vert_gap=vert_gap,
1875
2530
  )
1876
2531
  for i, g in graphs.items()
@@ -1925,135 +2580,147 @@ class lineageTree(lineageTreeLoaders):
1925
2580
  [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
1926
2581
  return axes.flatten()[0].get_figure(), axes, ax2root
1927
2582
 
1928
- def plot_node(
2583
+ def plot_subtree(
1929
2584
  self,
1930
- node,
1931
- figsize=(4, 7),
1932
- dpi=150,
1933
- vert_gap=2,
1934
- ax=None,
1935
- **kwargs,
1936
- ):
2585
+ node: int,
2586
+ end_time: int | None = None,
2587
+ figsize: tuple[int, int] = (4, 7),
2588
+ dpi: int = 150,
2589
+ vert_gap: int = 2,
2590
+ selected_nodes: list | None = None,
2591
+ selected_edges: list | None = None,
2592
+ color_of_nodes: str | dict = "magenta",
2593
+ color_of_edges: str | dict = "magenta",
2594
+ size: int | float = 10,
2595
+ lw: float = 0.1,
2596
+ default_color: str = "black",
2597
+ ax: plt.Axes | None = None,
2598
+ ) -> tuple[plt.Figure, plt.Axes]:
1937
2599
  """Plots the subtree spawn by a node.
1938
2600
 
1939
- Args:
1940
- node (int): The id of the node that is going to be plotted.
1941
- kwargs: args accepted by matplotlib
1942
- """
1943
- graph = self.to_simple_graph(node)
2601
+ Parameters
2602
+ ----------
2603
+ node : int
2604
+ The id of the node that is going to be plotted.
2605
+ end_time : int, optional
2606
+ The last timepoint to be considered, if None or not provided the last timepoint of the dataset (t_e) is considered.
2607
+ figsize : tuple of 2 ints, default=(4,7)
2608
+ The size of the figure, deafults to (4,7)
2609
+ vert_gap : int, default=2
2610
+ The verical gap of a node when it divides, defaults to 2.
2611
+ dpi : int, default=150
2612
+ The dpi of the figure, defaults to 150
2613
+ selected_nodes : list, optional
2614
+ The nodes that are selected by the user to be colored in a different color, defaults to None
2615
+ selected_edges : list, optional
2616
+ The edges that are selected by the user to be colored in a different color, defaults to None
2617
+ color_of_nodes : str, default="magenta"
2618
+ The color of the nodes to be colored, except the default colored ones, defaults to "magenta"
2619
+ color_of_edges : str, default="magenta"
2620
+ The color of the edges to be colored, except the default colored ones, defaults to "magenta"
2621
+ size : int, default=10
2622
+ The size of the nodes, defaults to 10
2623
+ lw : float, default=0.1
2624
+ The widthe of the edges of the tree graph, defaults to 0.1
2625
+ default_color : str, default="black"
2626
+ The default color of nodes and edges, defaults to "black"
2627
+ ax : plt.Axes, optional
2628
+ The ax where the plot is going to be applied, if not provided or None new axes will be created.
2629
+
2630
+ Returns
2631
+ -------
2632
+ plt.Figure
2633
+ The matplotlib figure
2634
+ plt.Axes
2635
+ The matplotlib axes
2636
+
2637
+ Raises
2638
+ ------
2639
+ Warning
2640
+ If more than one nodes are received
2641
+ """
2642
+ graph = self._create_dict_of_plots(node, end_time=end_time)
1944
2643
  if len(graph) > 1:
1945
- raise Warning("Please enter only one node")
2644
+ raise Warning(
2645
+ "Please use lT.plot_all_lineages(nodes) for plotting multiple nodes."
2646
+ )
1946
2647
  graph = graph[0]
1947
2648
  if not ax:
1948
- figure, ax = plt.subplots(
1949
- nrows=1, ncols=1, figsize=figsize, dpi=dpi
1950
- )
2649
+ _, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
1951
2650
  self.draw_tree_graph(
1952
2651
  hier=hierarchical_pos(
1953
2652
  graph,
1954
2653
  graph["root"],
1955
2654
  vert_gap=vert_gap,
1956
- ycenter=-int(self.time[node]),
2655
+ ycenter=-int(self._time[node]),
1957
2656
  ),
2657
+ selected_edges=selected_edges,
2658
+ selected_nodes=selected_nodes,
2659
+ color_of_edges=color_of_edges,
2660
+ color_of_nodes=color_of_nodes,
2661
+ default_color=default_color,
2662
+ size=size,
2663
+ lw=lw,
1958
2664
  lnks_tms=graph,
1959
2665
  ax=ax,
1960
2666
  )
1961
2667
  return ax.get_figure(), ax
1962
2668
 
1963
- # def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
1964
- # metric='euclidian', **kwargs):
1965
- # """ Computes the dynamic time warping distance between the tracks t1 and t2
1966
-
1967
- # Args:
1968
- # t1 ([int, ]): list of node ids for the first track
1969
- # t2 ([int, ]): list of node ids for the second track
1970
- # w (int): maximum wapring allowed (default infinite),
1971
- # if w=1 then the DTW is the distance between t1 and t2
1972
- # start_delay (int): maximum number of time points that can be
1973
- # skipped at the beginning of the track
1974
- # end_delay (int): minimum number of time points that can be
1975
- # skipped at the beginning of the track
1976
- # metric (str): str or callable, optional The distance metric to use.
1977
- # Default='euclidean'. Refer to the documentation for
1978
- # scipy.spatial.distance.cdist. Some examples:
1979
- # 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
1980
- # 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
1981
- # 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
1982
- # 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean',
1983
- # 'wminkowski', 'yule'
1984
- # **kwargs (dict): Extra arguments to `metric`: refer to each metric
1985
- # documentation in scipy.spatial.distance (optional)
1986
-
1987
- # Returns:
1988
- # float: the dynamic time warping distance between the two tracks
1989
- # """
1990
- # from scipy.sparse import
1991
- # pos_t1 = [self.pos[ti] for ti in t1]
1992
- # pos_t2 = [self.pos[ti] for ti in t2]
1993
- # distance_matrix = np.zeros((len(t1), len(t2))) + np.inf
1994
-
1995
- # c = distance.cdist(exp_data, num_data, metric=metric, **kwargs)
1996
-
1997
- # d = np.zeros(c.shape)
1998
- # d[0, 0] = c[0, 0]
1999
- # n, m = c.shape
2000
- # for i in range(1, n):
2001
- # d[i, 0] = d[i-1, 0] + c[i, 0]
2002
- # for j in range(1, m):
2003
- # d[0, j] = d[0, j-1] + c[0, j]
2004
- # for i in range(1, n):
2005
- # for j in range(1, m):
2006
- # d[i, j] = c[i, j] + min((d[i-1, j], d[i, j-1], d[i-1, j-1]))
2007
- # return d[-1, -1], d
2008
-
2009
- def __getitem__(self, item):
2010
- if isinstance(item, str):
2011
- return self.__dict__[item]
2012
- elif np.issubdtype(type(item), np.integer):
2013
- return self.successor.get(item, [])
2014
- else:
2015
- raise KeyError(
2016
- "Only integer or string are valid key for lineageTree"
2017
- )
2018
-
2019
- def get_cells_at_t_from_root(self, r: [int, list], t: int = None) -> list:
2669
+ def nodes_at_t(
2670
+ self,
2671
+ t: int,
2672
+ r: int | Iterable[int] | None = None,
2673
+ ) -> list:
2020
2674
  """
2021
- Returns the list of cells at time `t` that are spawn by the node(s) `r`.
2675
+ Returns the list of nodes at time `t` that are spawn by the node(s) `r`.
2022
2676
 
2023
- Args:
2024
- r (int | list): id or list of ids of the spawning node
2025
- t (int): target time, if None goes as far as possible
2026
- (default None)
2677
+ Parameters
2678
+ ----------
2679
+ t : int
2680
+ target time, if `None` goes as far as possible
2681
+ r : int or Iterable of int, optional
2682
+ id or list of ids of the spawning node
2027
2683
 
2028
- Returns:
2029
- (list) list of nodes at time `t` spawned by `r`
2684
+ Returns
2685
+ -------
2686
+ list
2687
+ list of nodes at time `t` spawned by `r`
2030
2688
  """
2031
- if not isinstance(r, list):
2689
+ if not r and r != 0:
2690
+ r = {root for root in self.roots if self.time[root] <= t}
2691
+ if isinstance(r, int):
2032
2692
  r = [r]
2693
+ if t is None:
2694
+ t = self.t_e
2033
2695
  to_do = list(r)
2034
2696
  final_nodes = []
2035
2697
  while len(to_do) > 0:
2036
2698
  curr = to_do.pop()
2037
- for _next in self[curr]:
2038
- if self.time[_next] < t:
2699
+ for _next in self._successor[curr]:
2700
+ if self._time[_next] < t:
2039
2701
  to_do.append(_next)
2040
- elif self.time[_next] == t:
2702
+ elif self._time[_next] == t:
2041
2703
  final_nodes.append(_next)
2042
2704
  if not final_nodes:
2043
2705
  return list(r)
2044
2706
  return final_nodes
2045
2707
 
2046
2708
  @staticmethod
2047
- def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
2709
+ def __calculate_diag_line(dist_mat: np.ndarray) -> tuple[float, float]:
2048
2710
  """
2049
2711
  Calculate the line that centers the band w.
2050
2712
 
2051
- Args:
2052
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2713
+ Parameters
2714
+ ----------
2715
+ dist_mat : np.ndarray
2716
+ distance matrix obtained by the function calculate_dtw
2053
2717
 
2054
- Returns:
2055
- (float) Slope
2056
- (float) intercept of the line
2718
+ Returns
2719
+ -------
2720
+ float
2721
+ The slope of the curve
2722
+ float
2723
+ The intercept of the curve
2057
2724
  """
2058
2725
  i, j = dist_mat.shape
2059
2726
  x1 = max(0, i - j) / 2
@@ -2073,22 +2740,33 @@ class lineageTree(lineageTreeLoaders):
2073
2740
  fast: bool = False,
2074
2741
  w: int = 0,
2075
2742
  centered_band: bool = True,
2076
- ) -> (((int, int), ...), np.ndarray):
2743
+ ) -> tuple[list[int], np.ndarray, float]:
2077
2744
  """
2078
2745
  Find DTW minimum cost between two series using dynamic programming.
2079
2746
 
2080
- Args:
2081
- dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2082
- start_d (int): start delay
2083
- back_d (int): end delay
2084
- w (int): window constrain
2085
- slope (float): to calculate window - givem by the function __calculate_diag_line
2086
- intercept (flost): to calculate window - givem by the function __calculate_diag_line
2087
- use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
2088
-
2089
- Returns:
2090
- (tuple of tuples) Aligment path
2091
- (matrix) Cost matrix
2747
+ Parameters
2748
+ ----------
2749
+ dist_mat : np.ndarray
2750
+ distance matrix obtained by the function calculate_dtw
2751
+ start_d : int, default=0
2752
+ start delay
2753
+ back_d : int, default=0
2754
+ end delay
2755
+ fast : bool, default=False
2756
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2757
+ w : int, default=0
2758
+ window constrain
2759
+ centered_band : bool, default=True
2760
+ if `True`, the band will be centered around the diagonal
2761
+
2762
+ Returns
2763
+ -------
2764
+ tuple of tuples of int
2765
+ Aligment path
2766
+ np.ndarray
2767
+ cost matrix
2768
+ float
2769
+ optimal cost
2092
2770
  """
2093
2771
  N, M = dist_mat.shape
2094
2772
  w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
@@ -2209,7 +2887,6 @@ class lineageTree(lineageTreeLoaders):
2209
2887
 
2210
2888
  # special reflection case
2211
2889
  if np.linalg.det(R) < 0:
2212
- # print("det(R) < R, reflection detected!, correcting for it ...")
2213
2890
  Vt[2, :] *= -1
2214
2891
  R = Vt.T @ U.T
2215
2892
 
@@ -2218,52 +2895,59 @@ class lineageTree(lineageTreeLoaders):
2218
2895
  return R, t
2219
2896
 
2220
2897
  def __interpolate(
2221
- self, track1: list, track2: list, threshold: int
2222
- ) -> (np.ndarray, np.ndarray):
2898
+ self, chain1: list, chain2: list, threshold: int
2899
+ ) -> tuple[np.ndarray, np.ndarray]:
2223
2900
  """
2224
2901
  Interpolate two series that have different lengths
2225
2902
 
2226
- Args:
2227
- track1 (list): list of nodes of the first cell cycle to compare
2228
- track2 (list): list of nodes of the second cell cycle to compare
2229
- threshold (int): set a maximum number of points a track can have
2230
-
2231
- Returns:
2232
- (list of list) x, y, z postions for track1
2233
- (list of list) x, y, z postions for track2
2903
+ Parameters
2904
+ ----------
2905
+ chain1 : list of int
2906
+ list of nodes of the first chain to compare
2907
+ chain2 : list of int
2908
+ list of nodes of the second chain to compare
2909
+ threshold : int
2910
+ set a maximum number of points a chain can have
2911
+
2912
+ Returns
2913
+ -------
2914
+ list of np.ndarray
2915
+ `x`, `y`, `z` postions for `chain1`
2916
+ list of np.ndarray
2917
+ `x`, `y`, `z` postions for `chain2`
2234
2918
  """
2235
2919
  inter1_pos = []
2236
2920
  inter2_pos = []
2237
2921
 
2238
- track1_pos = np.array([self.pos[c_id] for c_id in track1])
2239
- track2_pos = np.array([self.pos[c_id] for c_id in track2])
2922
+ chain1_pos = np.array([self.pos[c_id] for c_id in chain1])
2923
+ chain2_pos = np.array([self.pos[c_id] for c_id in chain2])
2240
2924
 
2241
- # Both tracks have the same length and size below the threshold - nothing is done
2242
- if len(track1) == len(track2) and (
2243
- len(track1) <= threshold or len(track2) <= threshold
2925
+ # Both chains have the same length and size below the threshold - nothing is done
2926
+ if len(chain1) == len(chain2) and (
2927
+ len(chain1) <= threshold or len(chain2) <= threshold
2244
2928
  ):
2245
- return track1_pos, track2_pos
2246
- # Both tracks have the same length but one or more sizes are above the threshold
2247
- elif len(track1) > threshold or len(track2) > threshold:
2929
+ return chain1_pos, chain2_pos
2930
+ # Both chains have the same length but one or more sizes are above the threshold
2931
+ elif len(chain1) > threshold or len(chain2) > threshold:
2248
2932
  sampling = threshold
2249
- # Tracks have different lengths and the sizes are below the threshold
2933
+ # chains have different lengths and the sizes are below the threshold
2250
2934
  else:
2251
- sampling = max(len(track1), len(track2))
2935
+ sampling = max(len(chain1), len(chain2))
2252
2936
 
2253
2937
  for pos in range(3):
2254
- track1_interp = InterpolatedUnivariateSpline(
2255
- np.linspace(0, 1, len(track1_pos[:, pos])),
2256
- track1_pos[:, pos],
2938
+ chain1_interp = InterpolatedUnivariateSpline(
2939
+ np.linspace(0, 1, len(chain1_pos[:, pos])),
2940
+ chain1_pos[:, pos],
2257
2941
  k=1,
2258
2942
  )
2259
- inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
2943
+ inter1_pos.append(chain1_interp(np.linspace(0, 1, sampling)))
2260
2944
 
2261
- track2_interp = InterpolatedUnivariateSpline(
2262
- np.linspace(0, 1, len(track2_pos[:, pos])),
2263
- track2_pos[:, pos],
2945
+ chain2_interp = InterpolatedUnivariateSpline(
2946
+ np.linspace(0, 1, len(chain2_pos[:, pos])),
2947
+ chain2_pos[:, pos],
2264
2948
  k=1,
2265
2949
  )
2266
- inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
2950
+ inter2_pos.append(chain2_interp(np.linspace(0, 1, sampling)))
2267
2951
 
2268
2952
  return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
2269
2953
 
@@ -2279,46 +2963,66 @@ class lineageTree(lineageTreeLoaders):
2279
2963
  w: int = 0,
2280
2964
  centered_band: bool = True,
2281
2965
  cost_mat_p: bool = False,
2282
- ) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
2283
- """
2284
- Calculate DTW distance between two cell cycles
2285
-
2286
- Args:
2287
- nodes1 (int): node to compare distance
2288
- nodes2 (int): node to compare distance
2289
- threshold: set a maximum number of points a track can have
2290
- regist (boolean): Rotate and translate trajectories
2291
- start_d (int): start delay
2292
- back_d (int): end delay
2293
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2294
- w (int): window size
2295
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2296
- cost_mat_p (boolean): True if print the not normalized cost matrix
2297
-
2298
- Returns:
2299
- (float) DTW distance
2300
- (tuple of tuples) Aligment path
2301
- (matrix) Cost matrix
2302
- (list of lists) pos_cycle1: rotated and translated trajectories positions
2303
- (list of lists) pos_cycle2: rotated and translated trajectories positions
2966
+ ) -> (
2967
+ tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
2968
+ | tuple[float, tuple]
2969
+ ):
2304
2970
  """
2305
- nodes1_cycle = self.get_cycle(nodes1)
2306
- nodes2_cycle = self.get_cycle(nodes2)
2307
-
2308
- interp_cycle1, interp_cycle2 = self.__interpolate(
2309
- nodes1_cycle, nodes2_cycle, threshold
2971
+ Calculate DTW distance between two chains
2972
+
2973
+ Parameters
2974
+ ----------
2975
+ nodes1 : int
2976
+ node to compare distance
2977
+ nodes2 : int
2978
+ node to compare distance
2979
+ threshold : int, default=1000
2980
+ set a maximum number of points a chain can have
2981
+ regist : bool, default=True
2982
+ Rotate and translate trajectories
2983
+ start_d : int, default=0
2984
+ start delay
2985
+ back_d : int, default=0
2986
+ end delay
2987
+ fast : bool, default=False
2988
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
2989
+ w : int, default=0
2990
+ window size
2991
+ centered_band : bool, default=True
2992
+ when running the fast algorithm, `True` if the windown is centered
2993
+ cost_mat_p : bool, default=False
2994
+ True if print the not normalized cost matrix
2995
+
2996
+ Returns
2997
+ -------
2998
+ float
2999
+ DTW distance
3000
+ tuple of tuples
3001
+ Aligment path
3002
+ matrix
3003
+ Cost matrix
3004
+ list of lists
3005
+ rotated and translated trajectories positions
3006
+ list of lists
3007
+ rotated and translated trajectories positions
3008
+ """
3009
+ nodes1_chain = self.get_chain_of_node(nodes1)
3010
+ nodes2_chain = self.get_chain_of_node(nodes2)
3011
+
3012
+ interp_chain1, interp_chain2 = self.__interpolate(
3013
+ nodes1_chain, nodes2_chain, threshold
2310
3014
  )
2311
3015
 
2312
- pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
2313
- pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
3016
+ pos_chain1 = np.array([self.pos[c_id] for c_id in nodes1_chain])
3017
+ pos_chain2 = np.array([self.pos[c_id] for c_id in nodes2_chain])
2314
3018
 
2315
3019
  if regist:
2316
3020
  R, t = self.__rigid_transform_3D(
2317
- np.transpose(interp_cycle1), np.transpose(interp_cycle2)
3021
+ np.transpose(interp_chain1), np.transpose(interp_chain2)
2318
3022
  )
2319
- pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
3023
+ pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)
2320
3024
 
2321
- dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
3025
+ dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")
2322
3026
 
2323
3027
  path, cost_mat, final_cost = self.__dp(
2324
3028
  dist_mat,
@@ -2331,7 +3035,7 @@ class lineageTree(lineageTreeLoaders):
2331
3035
  cost = final_cost / len(path)
2332
3036
 
2333
3037
  if cost_mat_p:
2334
- return cost, path, cost_mat, pos_cycle1, pos_cycle2
3038
+ return cost, path, cost_mat, pos_chain1, pos_chain2
2335
3039
  else:
2336
3040
  return cost, path
2337
3041
 
@@ -2346,24 +3050,39 @@ class lineageTree(lineageTreeLoaders):
2346
3050
  fast: bool = False,
2347
3051
  w: int = 0,
2348
3052
  centered_band: bool = True,
2349
- ) -> (float, plt.figure):
2350
- """
2351
- Plot DTW cost matrix between two cell cycles in heatmap format
2352
-
2353
- Args:
2354
- nodes1 (int): node to compare distance
2355
- nodes2 (int): node to compare distance
2356
- start_d (int): start delay
2357
- back_d (int): end delay
2358
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2359
- w (int): window size
2360
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2361
-
2362
- Returns:
2363
- (float) DTW distance
2364
- (figure) Heatmap of cost matrix with opitimal path
2365
- """
2366
- cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
3053
+ ) -> tuple[float, plt.Figure]:
3054
+ """
3055
+ Plot DTW cost matrix between two chains in heatmap format
3056
+
3057
+ Parameters
3058
+ ----------
3059
+ nodes1 : int
3060
+ node to compare distance
3061
+ nodes2 : int
3062
+ node to compare distance
3063
+ threshold : int, default=1000
3064
+ set a maximum number of points a chain can have
3065
+ regist : bool, default=True
3066
+ Rotate and translate trajectories
3067
+ start_d : int, default=0
3068
+ start delay
3069
+ back_d : int, default=0
3070
+ end delay
3071
+ fast : bool, default=False
3072
+ if `True`, the algorithm will use a faster version but might not find the optimal alignment
3073
+ w : int, default=0
3074
+ window size
3075
+ centered_band : bool, default=True
3076
+ when running the fast algorithm, `True` if the windown is centered
3077
+
3078
+ Returns
3079
+ -------
3080
+ float
3081
+ DTW distance
3082
+ plt.Figure
3083
+ Heatmap of cost matrix with opitimal path
3084
+ """
3085
+ cost, path, cost_mat, pos_chain1, pos_chain2 = self.calculate_dtw(
2367
3086
  nodes1,
2368
3087
  nodes2,
2369
3088
  threshold,
@@ -2385,32 +3104,32 @@ class lineageTree(lineageTreeLoaders):
2385
3104
  ax.set_title("Heatmap of DTW Cost Matrix")
2386
3105
  ax.set_xlabel("Tree 1")
2387
3106
  ax.set_ylabel("tree 2")
2388
- x_path, y_path = zip(*path)
3107
+ x_path, y_path = zip(*path, strict=True)
2389
3108
  ax.plot(y_path, x_path, color="black")
2390
3109
 
2391
3110
  return cost, fig
2392
3111
 
2393
3112
  @staticmethod
2394
3113
  def __plot_2d(
2395
- pos_cycle1,
2396
- pos_cycle2,
2397
- nodes1,
2398
- nodes2,
2399
- ax,
2400
- x_idx,
2401
- y_idx,
2402
- x_label,
2403
- y_label,
2404
- ):
3114
+ pos_chain1: np.ndarray,
3115
+ pos_chain2: np.ndarray,
3116
+ nodes1: list[int],
3117
+ nodes2: list[int],
3118
+ ax: plt.Axes,
3119
+ x_idx: list[int],
3120
+ y_idx: list[int],
3121
+ x_label: str,
3122
+ y_label: str,
3123
+ ) -> None:
2405
3124
  ax.plot(
2406
- pos_cycle1[:, x_idx],
2407
- pos_cycle1[:, y_idx],
3125
+ pos_chain1[:, x_idx],
3126
+ pos_chain1[:, y_idx],
2408
3127
  "-",
2409
3128
  label=f"root = {nodes1}",
2410
3129
  )
2411
3130
  ax.plot(
2412
- pos_cycle2[:, x_idx],
2413
- pos_cycle2[:, y_idx],
3131
+ pos_chain2[:, x_idx],
3132
+ pos_chain2[:, y_idx],
2414
3133
  "-",
2415
3134
  label=f"root = {nodes2}",
2416
3135
  )
@@ -2428,40 +3147,55 @@ class lineageTree(lineageTreeLoaders):
2428
3147
  fast: bool = False,
2429
3148
  w: int = 0,
2430
3149
  centered_band: bool = True,
2431
- projection: str = None,
3150
+ projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
2432
3151
  alig: bool = False,
2433
- ) -> (float, plt.figure):
2434
- """
2435
- Plots DTW trajectories aligment between two cell cycles in 2D or 3D
2436
-
2437
- Args:
2438
- nodes1 (int): node to compare distance
2439
- nodes2 (int): node to compare distance
2440
- threshold (int): set a maximum number of points a track can have
2441
- regist (boolean): Rotate and translate trajectories
2442
- start_d (int): start delay
2443
- back_d (int): end delay
2444
- w (int): window size
2445
- fast (boolean): True if the user wants to run the fast algorithm with window restrains
2446
- centered_band (boolean): if running the fast algorithm, True if the windown is centered
2447
- projection (string): specify which 2D to plot ->
2448
- '3d' : for the 3d visualization
2449
- 'xy' or None (default) : 2D projection of axis x and y
2450
- 'xz' : 2D projection of axis x and z
2451
- 'yz' : 2D projection of axis y and z
2452
- 'pca' : PCA projection
2453
- alig (boolean): True to show alignment on plot
2454
-
2455
- Returns:
2456
- (float) DTW distance
2457
- (figue) Trajectories Plot
3152
+ ) -> tuple[float, plt.Figure]:
3153
+ """
3154
+ Plots DTW trajectories aligment between two chains in 2D or 3D
3155
+
3156
+ Parameters
3157
+ ----------
3158
+ nodes1 : int
3159
+ node to compare distance
3160
+ nodes2 : int
3161
+ node to compare distance
3162
+ threshold : int, default=1000
3163
+ set a maximum number of points a chain can have
3164
+ regist : bool, default=True
3165
+ Rotate and translate trajectories
3166
+ start_d : int, default=0
3167
+ start delay
3168
+ back_d : int, default=0
3169
+ end delay
3170
+ w : int, default=0
3171
+ window size
3172
+ fast : bool, default=False
3173
+ True if the user wants to run the fast algorithm with window restrains
3174
+ centered_band : bool, default=True
3175
+ if running the fast algorithm, True if the windown is centered
3176
+ projection : {"3d", "xy", "xz", "yz", "pca"}, optional
3177
+ specify which 2D to plot ->
3178
+ "3d" : for the 3d visualization
3179
+ "xy" or None (default) : 2D projection of axis x and y
3180
+ "xz" : 2D projection of axis x and z
3181
+ "yz" : 2D projection of axis y and z
3182
+ "pca" : PCA projection
3183
+ alig : bool
3184
+ True to show alignment on plot
3185
+
3186
+ Returns
3187
+ -------
3188
+ float
3189
+ DTW distance
3190
+ figure
3191
+ Trajectories Plot
2458
3192
  """
2459
3193
  (
2460
3194
  distance,
2461
3195
  alignment,
2462
3196
  cost_mat,
2463
- pos_cycle1,
2464
- pos_cycle2,
3197
+ pos_chain1,
3198
+ pos_chain2,
2465
3199
  ) = self.calculate_dtw(
2466
3200
  nodes1,
2467
3201
  nodes2,
@@ -2484,16 +3218,16 @@ class lineageTree(lineageTreeLoaders):
2484
3218
 
2485
3219
  if projection == "3d":
2486
3220
  ax.plot(
2487
- pos_cycle1[:, 0],
2488
- pos_cycle1[:, 1],
2489
- pos_cycle1[:, 2],
3221
+ pos_chain1[:, 0],
3222
+ pos_chain1[:, 1],
3223
+ pos_chain1[:, 2],
2490
3224
  "-",
2491
3225
  label=f"root = {nodes1}",
2492
3226
  )
2493
3227
  ax.plot(
2494
- pos_cycle2[:, 0],
2495
- pos_cycle2[:, 1],
2496
- pos_cycle2[:, 2],
3228
+ pos_chain2[:, 0],
3229
+ pos_chain2[:, 1],
3230
+ pos_chain2[:, 2],
2497
3231
  "-",
2498
3232
  label=f"root = {nodes2}",
2499
3233
  )
@@ -2503,8 +3237,8 @@ class lineageTree(lineageTreeLoaders):
2503
3237
  else:
2504
3238
  if projection == "xy" or projection == "yx" or projection is None:
2505
3239
  self.__plot_2d(
2506
- pos_cycle1,
2507
- pos_cycle2,
3240
+ pos_chain1,
3241
+ pos_chain2,
2508
3242
  nodes1,
2509
3243
  nodes2,
2510
3244
  ax,
@@ -2515,8 +3249,8 @@ class lineageTree(lineageTreeLoaders):
2515
3249
  )
2516
3250
  elif projection == "xz" or projection == "zx":
2517
3251
  self.__plot_2d(
2518
- pos_cycle1,
2519
- pos_cycle2,
3252
+ pos_chain1,
3253
+ pos_chain2,
2520
3254
  nodes1,
2521
3255
  nodes2,
2522
3256
  ax,
@@ -2527,8 +3261,8 @@ class lineageTree(lineageTreeLoaders):
2527
3261
  )
2528
3262
  elif projection == "yz" or projection == "zy":
2529
3263
  self.__plot_2d(
2530
- pos_cycle1,
2531
- pos_cycle2,
3264
+ pos_chain1,
3265
+ pos_chain2,
2532
3266
  nodes1,
2533
3267
  nodes2,
2534
3268
  ax,
@@ -2542,24 +3276,25 @@ class lineageTree(lineageTreeLoaders):
2542
3276
  from sklearn.decomposition import PCA
2543
3277
  except ImportError:
2544
3278
  Warning(
2545
- "scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
3279
+ "scikit-learn is not installed, the PCA orientation cannot be used."
3280
+ "You can install scikit-learn with pip install"
2546
3281
  )
2547
3282
 
2548
3283
  # Apply PCA
2549
3284
  pca = PCA(n_components=2)
2550
- pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
2551
- pos_cycle1_2d = pca.transform(pos_cycle1)
2552
- pos_cycle2_2d = pca.transform(pos_cycle2)
3285
+ pca.fit(np.vstack([pos_chain1, pos_chain2]))
3286
+ pos_chain1_2d = pca.transform(pos_chain1)
3287
+ pos_chain2_2d = pca.transform(pos_chain2)
2553
3288
 
2554
3289
  ax.plot(
2555
- pos_cycle1_2d[:, 0],
2556
- pos_cycle1_2d[:, 1],
3290
+ pos_chain1_2d[:, 0],
3291
+ pos_chain1_2d[:, 1],
2557
3292
  "-",
2558
3293
  label=f"root = {nodes1}",
2559
3294
  )
2560
3295
  ax.plot(
2561
- pos_cycle2_2d[:, 0],
2562
- pos_cycle2_2d[:, 1],
3296
+ pos_chain2_2d[:, 0],
3297
+ pos_chain2_2d[:, 1],
2563
3298
  "-",
2564
3299
  label=f"root = {nodes2}",
2565
3300
  )
@@ -2588,7 +3323,7 @@ class lineageTree(lineageTreeLoaders):
2588
3323
  'pca' : PCA projection"""
2589
3324
  )
2590
3325
 
2591
- connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
3326
+ connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]
2592
3327
 
2593
3328
  for connection in connections:
2594
3329
  xyz1 = connection[0]
@@ -2616,108 +3351,174 @@ class lineageTree(lineageTreeLoaders):
2616
3351
 
2617
3352
  return distance, fig
2618
3353
 
2619
- def first_labelling(self):
2620
- self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
2621
-
2622
3354
  def __init__(
2623
3355
  self,
2624
- file_format: str = None,
2625
- tb: int = None,
2626
- te: int = None,
2627
- z_mult: float = 1.0,
2628
- file_type: str = "",
2629
- delim: str = ",",
2630
- eigen: bool = False,
2631
- shape: tuple = None,
2632
- raw_size: tuple = None,
2633
- reorder: bool = False,
2634
- xml_attributes: tuple = None,
2635
- name: str = None,
2636
- time_resolution: Union[int, None] = None,
3356
+ *,
3357
+ successor: dict[int, Sequence] | None = None,
3358
+ predecessor: dict[int, int | Sequence] | None = None,
3359
+ time: dict[int, int] | None = None,
3360
+ starting_time: int | None = None,
3361
+ pos: dict[int, Iterable] | None = None,
3362
+ name: str | None = None,
3363
+ root_leaf_value: Sequence | None = None,
3364
+ **kwargs,
2637
3365
  ):
2638
- """
2639
- TODO: complete the doc
2640
- Main library to build tree graph representation of lineage tree data
2641
- It can read TGMM, ASTEC, SVF, MaMuT and TrackMate outputs.
2642
-
2643
- Args:
2644
- file_format (str): either - path format to TGMM xmls
2645
- - path to the MaMuT xml
2646
- - path to the binary file
2647
- tb (int, optional):first time point (necessary for TGMM xmls only)
2648
- te (int, optional): last time point (necessary for TGMM xmls only)
2649
- z_mult (float, optional):z aspect ratio if necessary (usually only for TGMM xmls)
2650
- file_type (str, optional):type of input file. Accepts:
2651
- 'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
2652
- default is 'binary'
2653
- delim (str, optional): _description_. Defaults to ",".
2654
- eigen (bool, optional): _description_. Defaults to False.
2655
- shape (tuple, optional): _description_. Defaults to None.
2656
- raw_size (tuple, optional): _description_. Defaults to None.
2657
- reorder (bool, optional): _description_. Defaults to False.
2658
- xml_attributes (tuple, optional): _description_. Defaults to None.
2659
- name (str, optional): The name of the dataset. Defaults to None.
2660
- time_resolution (Union[int, None], optional): Time resolution in mins (If time resolution is smaller than one minute input the time in ms). Defaults to None.
2661
- """
3366
+ """Create a lineageTree object from minimal information, without reading from a file.
3367
+ Either `successor` or `predecessor` should be specified.
3368
+
3369
+ Parameters
3370
+ ----------
3371
+ successor : dict mapping int to Iterable
3372
+ Dictionary assigning nodes to their successors.
3373
+ predecessor : dict mapping int to int or Iterable
3374
+ Dictionary assigning nodes to their predecessors.
3375
+ time : dict mapping int to int, optional
3376
+ Dictionary assigning nodes to the time point they were recorded to.
3377
+ Defaults to None, in which case all times are set to `starting_time`.
3378
+ starting_time : int, optional
3379
+ Starting time of the lineage tree. Defaults to 0.
3380
+ pos : dict mapping int to Iterable, optional
3381
+ Dictionary assigning nodes to their positions. Defaults to None.
3382
+ name : str, optional
3383
+ Name of the lineage tree. Defaults to None.
3384
+ root_leaf_value : Iterable, optional
3385
+ Iterable of values of roots' predecessors and leaves' successors in the successor and predecessor dictionaries.
3386
+ Defaults are `[None, (), [], set()]`.
3387
+ **kwargs:
3388
+ Supported keyword arguments are dictionaries assigning nodes to any custom property.
3389
+ The property must be specified for every node, and named differently from lineageTree's own attributes.
3390
+ """
3391
+ self.__version__ = importlib.metadata.version("LineageTree")
3392
+ self.name = str(name) if name is not None else None
3393
+ if successor is not None and predecessor is not None:
3394
+ raise ValueError(
3395
+ "You cannot have both successors and predecessors."
3396
+ )
2662
3397
 
2663
- self.name = name
2664
- self.time_nodes = {}
2665
- self.time_edges = {}
2666
- self.max_id = -1
2667
- self.next_id = []
2668
- self.nodes = set()
2669
- self.successor = {}
2670
- self.predecessor = {}
2671
- self.pos = {}
2672
- self.time_id = {}
2673
- self.time = {}
2674
- if time_resolution is not None:
2675
- self._time_resolution = time_resolution
2676
- self.kdtrees = {}
2677
- self.spatial_density = {}
2678
- if file_type and file_format:
2679
- if xml_attributes is None:
2680
- self.xml_attributes = []
2681
- else:
2682
- self.xml_attributes = xml_attributes
2683
- file_type = file_type.lower()
2684
- if file_type == "tgmm":
2685
- self.read_tgmm_xml(file_format, tb, te, z_mult)
2686
- self.t_b = tb
2687
- self.t_e = te
2688
- elif file_type == "mamut" or file_type == "trackmate":
2689
- self.read_from_mamut_xml(file_format)
2690
- elif file_type == "celegans":
2691
- self.read_from_txt_for_celegans(file_format)
2692
- elif file_type == "celegans_cao":
2693
- self.read_from_txt_for_celegans_CAO(
2694
- file_format,
2695
- reorder=reorder,
2696
- shape=shape,
2697
- raw_size=raw_size,
3398
+ if root_leaf_value is None:
3399
+ root_leaf_value = [None, (), [], set()]
3400
+ elif not isinstance(root_leaf_value, Iterable):
3401
+ raise TypeError(
3402
+ f"root_leaf_value is of type {type(root_leaf_value)}, expected Iterable."
3403
+ )
3404
+ elif len(root_leaf_value) < 1:
3405
+ raise ValueError(
3406
+ "root_leaf_value should have at least one element."
3407
+ )
3408
+ self._successor = {}
3409
+ self._predecessor = {}
3410
+ if successor is not None:
3411
+ for pred, succs in successor.items():
3412
+ if succs in root_leaf_value:
3413
+ self._successor[pred] = ()
3414
+ else:
3415
+ if not isinstance(succs, Iterable):
3416
+ raise TypeError(
3417
+ f"Successors should be Iterable, got {type(succs)}."
3418
+ )
3419
+ if len(succs) == 0:
3420
+ raise ValueError(
3421
+ f"{succs} was not declared as a leaf but was found as a successor.\n"
3422
+ "Please lift the ambiguity."
3423
+ )
3424
+ self._successor[pred] = tuple(succs)
3425
+ for succ in succs:
3426
+ if succ in self._predecessor:
3427
+ raise ValueError(
3428
+ "Node can have at most one predecessor."
3429
+ )
3430
+ self._predecessor[succ] = (pred,)
3431
+ elif predecessor is not None:
3432
+ for succ, pred in predecessor.items():
3433
+ if pred in root_leaf_value:
3434
+ self._predecessor[succ] = ()
3435
+ else:
3436
+ if isinstance(pred, Sequence):
3437
+ if len(pred) == 0:
3438
+ raise ValueError(
3439
+ f"{pred} was not declared as a leaf but was found as a successor.\n"
3440
+ "Please lift the ambiguity."
3441
+ )
3442
+ if 1 < len(pred):
3443
+ raise ValueError(
3444
+ "Node can have at most one predecessor."
3445
+ )
3446
+ pred = pred[0]
3447
+ self._predecessor[succ] = (pred,)
3448
+ self._successor.setdefault(pred, ())
3449
+ self._successor[pred] += (succ,)
3450
+ for root in set(self._successor).difference(self._predecessor):
3451
+ self._predecessor[root] = ()
3452
+ for leaf in set(self._predecessor).difference(self._successor):
3453
+ self._successor[leaf] = ()
3454
+
3455
+ if self.__check_for_cycles():
3456
+ raise ValueError(
3457
+ "Cycles were found in the tree, there should not be any."
3458
+ )
3459
+
3460
+ if pos is None:
3461
+ self.pos = {}
3462
+ else:
3463
+ if self.nodes.difference(pos) != set():
3464
+ raise ValueError("Please provide the position of all nodes.")
3465
+ self.pos = {
3466
+ node: np.array(position) for node, position in pos.items()
3467
+ }
3468
+
3469
+ if time is None:
3470
+ if starting_time is None:
3471
+ starting_time = 0
3472
+ if not isinstance(starting_time, int):
3473
+ warnings.warn(
3474
+ f"Attribute `starting_time` was a `{type(starting_time)}`, has been casted as an `int`.",
3475
+ stacklevel=2,
2698
3476
  )
2699
- elif file_type == "mastodon":
2700
- if isinstance(file_format, list) and len(file_format) == 2:
2701
- self.read_from_mastodon_csv(file_format)
3477
+ self._time = {node: starting_time for node in self.roots}
3478
+ queue = list(self.roots)
3479
+ for node in queue:
3480
+ for succ in self._successor[node]:
3481
+ self._time[succ] = self._time[node] + 1
3482
+ queue.append(succ)
3483
+ else:
3484
+ if starting_time is not None:
3485
+ warnings.warn(
3486
+ "Both `time` and `starting_time` were provided, `starting_time` was ignored.",
3487
+ stacklevel=2,
3488
+ )
3489
+ self._time = {n: int(time[n]) for n in self.nodes}
3490
+ if self._time != time:
3491
+ if len(self._time) != len(time):
3492
+ warnings.warn(
3493
+ "The provided `time` dictionary had keys that were not nodes. "
3494
+ "They have been removed",
3495
+ stacklevel=2,
3496
+ )
2702
3497
  else:
2703
- if isinstance(file_format, list):
2704
- file_format = file_format[0]
2705
- self.read_from_mastodon(file_format, name)
2706
- elif file_type == "astec":
2707
- self.read_from_ASTEC(file_format, eigen)
2708
- elif file_type == "csv":
2709
- self.read_from_csv(file_format, z_mult, link=1, delim=delim)
2710
- elif file_type == "bao":
2711
- self.read_C_elegans_bao(file_format)
2712
- elif file_format and file_format.endswith(".lT"):
2713
- with open(file_format, "br") as f:
2714
- tmp = pkl.load(f)
2715
- f.close()
2716
- self.__dict__.update(tmp.__dict__)
2717
- elif file_format is not None:
2718
- self.read_from_binary(file_format)
2719
- if self.name is None:
2720
- try:
2721
- self.name = Path(file_format).stem
2722
- except TypeError:
2723
- self.name = Path(file_format[0]).stem
3498
+ warnings.warn(
3499
+ "The provided `time` dictionary had values that were not `int`. "
3500
+ "These values have been truncated and converted to `int`",
3501
+ stacklevel=2,
3502
+ )
3503
+ if self.nodes.symmetric_difference(self._time) != set():
3504
+ raise ValueError(
3505
+ "Please provide the time of all nodes and only existing nodes."
3506
+ )
3507
+ if not all(
3508
+ self._time[node] < self._time[s]
3509
+ for node, succ in self._successor.items()
3510
+ for s in succ
3511
+ ):
3512
+ raise ValueError(
3513
+ "Provided times are not strictly increasing. Setting times to default."
3514
+ )
3515
+ # custom properties
3516
+ for name, d in kwargs.items():
3517
+ if name in self.__dict__:
3518
+ warnings.warn(
3519
+ f"Attribute name {name} is reserved.", stacklevel=2
3520
+ )
3521
+ continue
3522
+ setattr(self, name, d)
3523
+ if not hasattr(self, "_comparisons"):
3524
+ self._comparisons = {}