LineageTree 1.4.3__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,35 +2,244 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
-
6
5
  import csv
7
6
  import os
8
7
  import pickle as pkl
9
8
  import struct
9
+ import warnings
10
10
  import xml.etree.ElementTree as ET
11
+ from collections.abc import Iterable
11
12
  from functools import partial
12
13
  from itertools import combinations
13
14
  from numbers import Number
14
- from typing import TextIO
15
-
15
+ from pathlib import Path
16
+ from typing import TextIO, Union
17
+
18
+ from .tree_styles import tree_style
19
+
20
+ try:
21
+ from edist import uted
22
+ except ImportError:
23
+ warnings.warn(
24
+ "No edist installed therefore you will not be able to compute the tree edit distance."
25
+ )
26
+ import matplotlib.pyplot as plt
27
+ import networkx as nx
16
28
  import numpy as np
17
- from scipy.spatial import Delaunay
29
+ from scipy.interpolate import InterpolatedUnivariateSpline
30
+ from scipy.spatial import Delaunay, distance
18
31
  from scipy.spatial import cKDTree as KDTree
19
32
 
33
+ from .utils import hierarchy_pos, postions_of_nx
34
+
20
35
 
21
36
  class lineageTree:
37
+ def __eq__(self, other):
38
+ if isinstance(other, lineageTree):
39
+ return other.successor == self.successor
40
+ return False
41
+
22
42
  def get_next_id(self):
23
43
  """Computes the next authorized id.
24
44
 
25
45
  Returns:
26
46
  int: next authorized id
27
47
  """
48
+ if self.max_id == -1 and self.nodes:
49
+ self.max_id = max(self.nodes)
28
50
  if self.next_id == []:
29
51
  self.max_id += 1
30
52
  return self.max_id
31
53
  else:
32
54
  return self.next_id.pop()
33
55
 
56
+ def complete_lineage(self, nodes: Union[int, set] = None):
57
+ """Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
58
+ for tree edit distance algorithms.
59
+
60
+ Args:
61
+ nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
62
+ """
63
+ if nodes is None:
64
+ nodes = set(self.roots)
65
+ elif isinstance(nodes, int):
66
+ nodes = {nodes}
67
+ for node in nodes:
68
+ sub = set(self.get_sub_tree(node))
69
+ specific_leaves = sub.intersection(self.leaves)
70
+ for leaf in specific_leaves:
71
+ self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
72
+
73
+ ###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
74
+ def add_branch(
75
+ self,
76
+ pred: int,
77
+ length: int,
78
+ move_timepoints: bool = True,
79
+ pos: Union[callable, None] = None,
80
+ reverse: bool = False,
81
+ ):
82
+ """Adds a branch of specific length to a node either as a successor or as a predecessor.
83
+ If it is placed on top of a tree all the nodes will move timepoints #length down.
84
+
85
+ Args:
86
+ pred (int): Id of the successor (predecessor if reverse is False)
87
+ length (int): The length of the new branch.
88
+ pos (np.ndarray, optional): The new position of the branch. Defaults to None.
89
+ move_timepoints (bool): Moves the ti Important only if reverse= True
90
+ reverese (bool): If reverse will add a successor branch instead of a predecessor branch
91
+ Returns:
92
+ (int): Id of the first node of the sublineage.
93
+ """
94
+ if length == 0:
95
+ return pred
96
+ if self.predecessor.get(pred) and not reverse:
97
+ raise Warning("Cannot add 2 predecessors to a node")
98
+ time = self.time[pred]
99
+ original = pred
100
+ if not reverse:
101
+ if move_timepoints:
102
+ nodes_to_move = set(self.get_sub_tree(pred))
103
+ new_times = {
104
+ node: self.time[node] + length for node in nodes_to_move
105
+ }
106
+ for node in nodes_to_move:
107
+ old_time = self.time[node]
108
+ self.time_nodes[old_time].remove(node)
109
+ self.time_nodes.setdefault(old_time + length, set()).add(
110
+ node
111
+ )
112
+ self.time.update(new_times)
113
+ for t in range(length - 1, -1, -1):
114
+ _next = self.add_node(
115
+ time + t,
116
+ succ=pred,
117
+ pos=self.pos[original],
118
+ reverse=True,
119
+ )
120
+ pred = _next
121
+ else:
122
+ for t in range(length):
123
+ _next = self.add_node(
124
+ time - t,
125
+ succ=pred,
126
+ pos=self.pos[original],
127
+ reverse=True,
128
+ )
129
+ pred = _next
130
+ else:
131
+ for t in range(length):
132
+ _next = self.add_node(
133
+ time + t, succ=pred, pos=self.pos[original], reverse=False
134
+ )
135
+ pred = _next
136
+ self.labels[pred] = "New branch"
137
+ if self.time[pred] == self.t_b:
138
+ self.roots.add(pred)
139
+ self.labels[pred] = "New branch"
140
+ if original in self.roots and reverse is True:
141
+ self.roots.add(pred)
142
+ self.labels[pred] = "New branch"
143
+ self.roots.remove(original)
144
+ self.labels.pop(original, -1)
145
+ self.t_e = max(self.time_nodes)
146
+ return pred
147
+
148
+ def cut_tree(self, root):
149
+ """It transforms a lineage that has at least 2 divisions into 2 independent lineages,
150
+ that spawn from the time point of the first node. (splits a tree into 2)
151
+
152
+ Args:
153
+ root (int): The id of the node, which will be cut.
154
+
155
+ Returns:
156
+ int: The id of the new tree
157
+ """
158
+ cycle = self.get_successors(root)
159
+ last_cell = cycle[-1]
160
+ if last_cell in self.successor:
161
+ new_lT = self.successor[last_cell].pop()
162
+ self.predecessor.pop(new_lT)
163
+ label_of_root = self.labels.get(cycle[0], cycle[0])
164
+ self.labels[cycle[0]] = f"L-Split {label_of_root}"
165
+ new_tr = self.add_branch(
166
+ new_lT, len(cycle) + 1, move_timepoints=False
167
+ )
168
+ self.roots.add(new_tr)
169
+ self.labels[new_tr] = f"R-Split {label_of_root}"
170
+ return new_tr
171
+ else:
172
+ raise Warning("No division of the branch")
173
+
174
+ def fuse_lineage_tree(
175
+ self,
176
+ l1_root: int,
177
+ l2_root: int,
178
+ length_l1: int = 0,
179
+ length_l2: int = 0,
180
+ length: int = 1,
181
+ ):
182
+ """Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
183
+ first node and the node of the resulting lineage can also be longer.
184
+
185
+ Args:
186
+ l1_root (int): Id of the first root
187
+ l2_root (int): Id of the second root
188
+ length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
189
+ length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
190
+ length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
191
+
192
+ Returns:
193
+ int: The id of the root of the new lineage.
194
+ """
195
+ if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
196
+ raise ValueError("Please select 2 roots.")
197
+ if self.time[l1_root] != self.time[l2_root]:
198
+ warnings.warn(
199
+ "Using lineagetrees that do not exist in the same timepoint. The operation will continue"
200
+ )
201
+ new_root1 = self.add_branch(l1_root, length_l1)
202
+ new_root2 = self.add_branch(l2_root, length_l2)
203
+ next_root1 = self[new_root1][0]
204
+ self.remove_nodes(new_root1)
205
+ self.successor[new_root2].append(next_root1)
206
+ self.predecessor[next_root1] = [new_root2]
207
+ new_branch = self.add_branch(new_root2, length)
208
+ self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
209
+ return new_branch
210
+
211
+ def copy_lineage(self, root):
212
+ """
213
+ Copies the structure of a tree and makes a new with new nodes.
214
+ Warning does not take into account the predecessor of the root node.
215
+
216
+ Args:
217
+ root (int): The root of the tree to be copied
218
+
219
+ Returns:
220
+ int: The root of the new tree.
221
+ """
222
+ new_nodes = {
223
+ old_node: self.get_next_id()
224
+ for old_node in self.get_sub_tree(root)
225
+ }
226
+ self.nodes.update(new_nodes.values())
227
+ for old_node, new_node in new_nodes.items():
228
+ self.time[new_node] = self.time[old_node]
229
+ succ = self.successor.get(old_node)
230
+ if succ:
231
+ self.successor[new_node] = [new_nodes[n] for n in succ]
232
+ pred = self.predecessor.get(old_node)
233
+ if pred:
234
+ self.predecessor[new_node] = [new_nodes[n] for n in pred]
235
+ self.pos[new_node] = self.pos[old_node] + 0.5
236
+ self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
237
+ new_root = new_nodes[root]
238
+ self.labels[new_root] = f"Copy of {root}"
239
+ if self.time[new_root] == 0:
240
+ self.roots.add(new_root)
241
+ return new_root
242
+
34
243
  def add_node(
35
244
  self,
36
245
  t: int = None,
@@ -55,109 +264,126 @@ class lineageTree:
55
264
  int: id of the new node.
56
265
  """
57
266
  C_next = self.get_next_id() if nid is None else nid
58
- self.time_nodes.setdefault(t, []).append(C_next)
267
+ self.time_nodes.setdefault(t, set()).add(C_next)
59
268
  if succ is not None and not reverse:
60
269
  self.successor.setdefault(succ, []).append(C_next)
61
270
  self.predecessor.setdefault(C_next, []).append(succ)
62
- self.edges.add((succ, C_next))
63
271
  elif succ is not None:
64
272
  self.predecessor.setdefault(succ, []).append(C_next)
65
273
  self.successor.setdefault(C_next, []).append(succ)
66
- self.edges.add((C_next, succ))
67
274
  self.nodes.add(C_next)
68
275
  self.pos[C_next] = pos
69
- self.progeny[C_next] = 0
70
276
  self.time[C_next] = t
71
277
  return C_next
72
278
 
73
- def remove_track(self, track: list):
74
- self.nodes.difference_update(track)
75
- times = {self.time[n] for n in track}
76
- for t in times:
77
- self.time_nodes[t] = list(
78
- set(self.time_nodes[t]).difference(track)
79
- )
80
- for i, c in enumerate(track):
81
- self.pos.pop(c)
82
- if i != 0:
83
- self.predecessor.pop(c)
84
- if i < len(track) - 1:
85
- self.successor.pop(c)
86
- self.time.pop(c)
87
-
88
- def remove_node(self, c: int) -> tuple:
89
- """Removes a node and update the lineageTree accordingly
279
+ def remove_nodes(self, group: Union[int, set, list]):
280
+ """Removes a group of nodes from the LineageTree
90
281
 
91
282
  Args:
92
- c (int): id of the node to remove
93
- """
94
- self.nodes.remove(c)
95
- self.time_nodes[self.time[c]].remove(c)
96
- # self.time_nodes.pop(c, 0)
97
- pos = self.pos.pop(c, 0)
98
- e_to_remove = [e for e in self.edges if c in e]
99
- for e in e_to_remove:
100
- self.edges.remove(e)
101
- if c in self.roots:
102
- self.roots.remove(c)
103
- succ = self.successor.pop(c, [])
104
- s_to_remove = [s for s, ci in self.successor.items() if c in ci]
105
- for s in s_to_remove:
106
- self.successor[s].remove(c)
107
-
108
- pred = self.predecessor.pop(c, [])
109
- p_to_remove = [s for s, ci in self.predecessor.items() if ci == c]
110
- for s in p_to_remove:
111
- self.predecessor[s].remove(c)
112
-
113
- self.time.pop(c, 0)
114
- self.spatial_density.pop(c, 0)
115
-
116
- self.next_id.append(c)
117
- return e_to_remove, succ, s_to_remove, pred, p_to_remove, pos
118
-
119
- def fuse_nodes(self, c1: int, c2: int):
120
- """Fuses together two nodes that belong to the same time point
121
- and update the lineageTree accordingly.
283
+ group (set|list|int): One or more nodes that are to be removed.
284
+ """
285
+ if isinstance(group, int):
286
+ group = {group}
287
+ if isinstance(group, list):
288
+ group = set(group)
289
+ group = group.intersection(self.nodes)
290
+ self.nodes.difference_update(group)
291
+ times = {self.time.pop(n) for n in group}
292
+ for t in times:
293
+ self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
294
+ for node in group:
295
+ self.pos.pop(node)
296
+ if self.predecessor.get(node):
297
+ pred = self.predecessor[node][0]
298
+ siblings = self.successor.pop(pred, [])
299
+ if len(siblings) == 2:
300
+ siblings.remove(node)
301
+ self.successor[pred] = siblings
302
+ self.predecessor.pop(node, [])
303
+ for succ in self.successor.get(node, []):
304
+ self.predecessor.pop(succ, [])
305
+ self.successor.pop(node, [])
306
+ self.labels.pop(node, 0)
307
+ if node in self.roots:
308
+ self.roots.remove(node)
309
+
310
+ def modify_branch(self, node, new_length):
311
+ """Changes the length of a branch, so it adds or removes nodes
312
+ to make the correct length of the cycle.
122
313
 
123
314
  Args:
124
- c1 (int): id of the first node to fuse
125
- c2 (int): id of the second node to fuse
315
+ node (int): Any node of the branch to be modified/
316
+ new_length (int): The new length of the tree.
126
317
  """
127
- (
128
- e_to_remove,
129
- succ,
130
- s_to_remove,
131
- pred,
132
- p_to_remove,
133
- c2_pos,
134
- ) = self.remove_node(c2)
135
- for e in e_to_remove:
136
- new_e = [c1] + [other_c for other_c in e if e != c2]
137
- self.edges.add(new_e)
138
-
139
- self.successor.setdefault(c1, []).extend(succ)
140
- self.predecessor.setdefault(c1, []).extend(pred)
141
-
142
- for s in s_to_remove:
143
- self.successor[s].append(c1)
144
-
145
- for p in p_to_remove:
146
- self.predecessor[p].append(c1)
147
-
148
- self.pos[c1] = np.mean([self.pos[c1], c2_pos], axis=0)
149
- self.progeny[c1] += 1
318
+ if new_length <= 1:
319
+ warnings.warn("New length should be more than 1")
320
+ return None
321
+ cycle = self.get_cycle(node)
322
+ length = len(cycle)
323
+ successors = self.successor.get(cycle[-1])
324
+ if length == 1 and new_length != 1:
325
+ pred = self.predecessor.pop(node, None)
326
+ new_node = self.add_branch(
327
+ node, length=new_length, move_timepoints=True, reverse=False
328
+ )
329
+ if pred:
330
+ self.successor[pred[0]].remove(node)
331
+ self.successor[pred[0]].append(new_node)
332
+ elif self.leaves.intersection(cycle) and new_length < length:
333
+ self.remove_nodes(cycle[new_length:])
334
+ elif new_length < length:
335
+ to_remove = length - new_length
336
+ last_cell = cycle[new_length - 1]
337
+ subtree = self.get_sub_tree(cycle[-1])[1:]
338
+ self.remove_nodes(cycle[new_length:])
339
+ self.successor[last_cell] = successors
340
+ if successors:
341
+ for succ in successors:
342
+ self.predecessor[succ] = [last_cell]
343
+ for node in subtree:
344
+ if node not in cycle[new_length - 1 :]:
345
+ old_time = self.time[node]
346
+ self.time[node] = old_time - to_remove
347
+ self.time_nodes[old_time].remove(node)
348
+ self.time_nodes.setdefault(
349
+ old_time - to_remove, set()
350
+ ).add(node)
351
+ elif length < new_length:
352
+ to_add = new_length - length
353
+ last_cell = cycle[-1]
354
+ self.successor.pop(cycle[-2])
355
+ self.predecessor.pop(last_cell)
356
+ succ = self.add_branch(
357
+ last_cell, length=to_add, move_timepoints=True, reverse=False
358
+ )
359
+ self.predecessor[succ] = [cycle[-2]]
360
+ self.successor[cycle[-2]] = [succ]
361
+ self.time[last_cell] = (
362
+ self.time[self.predecessor[last_cell][0]] + 1
363
+ )
364
+ else:
365
+ return None
150
366
 
151
367
  @property
152
368
  def roots(self):
153
369
  if not hasattr(self, "_roots"):
154
- self._roots = set(self.successor).difference(self.predecessor)
370
+ self._roots = set(self.nodes).difference(self.predecessor)
155
371
  return self._roots
156
372
 
373
+ @property
374
+ def edges(self):
375
+ return {(k, vi) for k, v in self.successor.items() for vi in v}
376
+
157
377
  @property
158
378
  def leaves(self):
159
379
  return set(self.predecessor).difference(self.successor)
160
380
 
381
+ @property
382
+ def labels(self):
383
+ if not hasattr(self, "_labels"):
384
+ self._labels = {i: "Unlabeled" for i in self.roots}
385
+ return self._labels
386
+
161
387
  def _write_header_am(self, f: TextIO, nb_points: int, length: int):
162
388
  """Header for Amira .am files"""
163
389
  f.write("# AmiraMesh 3D ASCII 2.0\n")
@@ -470,7 +696,7 @@ class lineageTree:
470
696
  stroke=svgwrite.rgb(0, 0, 0),
471
697
  )
472
698
  )
473
- for si in self.successor.get(c_cycle[-1], []):
699
+ for si in self[c_cycle[-1]]:
474
700
  x3, y3 = positions[si]
475
701
  dwg.add(
476
702
  dwg.line(
@@ -483,7 +709,7 @@ class lineageTree:
483
709
  else:
484
710
  for c in treated_cells:
485
711
  x1, y1 = positions[c]
486
- for si in self.successor.get(c, []):
712
+ for si in self[c]:
487
713
  x2, y2 = positions[si]
488
714
  if draw_edges:
489
715
  dwg.add(
@@ -535,7 +761,7 @@ class lineageTree:
535
761
  start_time = times_to_consider[0]
536
762
  for t in times_to_consider:
537
763
  for id_mother in self.time_nodes[t]:
538
- ids_daughters = self.successor.get(id_mother, [])
764
+ ids_daughters = self[id_mother]
539
765
  new_ids_daughters = ids_daughters.copy()
540
766
  for _ in range(sampling - 1):
541
767
  tmp = []
@@ -659,12 +885,12 @@ class lineageTree:
659
885
  edges_to_use += list(s_edges)
660
886
  else:
661
887
  edges_to_use = []
888
+ nodes_to_use = set(nodes_to_use)
662
889
  if temporal:
663
- edges_to_use += [
664
- e
665
- for e in self.edges
666
- if e[0] in nodes_to_use and e[1] in nodes_to_use
667
- ]
890
+ for n in nodes_to_use:
891
+ for d in self.successor.get(n, []):
892
+ if d in nodes_to_use:
893
+ edges_to_use.append((n, d))
668
894
  if spatial:
669
895
  edges_to_use += [
670
896
  e for e in s_edges if t_min < self.time[e[0]] < t_max
@@ -787,7 +1013,6 @@ class lineageTree:
787
1013
  self.time_edges = {}
788
1014
  unique_id = 0
789
1015
  self.nodes = set()
790
- self.edges = set()
791
1016
  self.successor = {}
792
1017
  self.predecessor = {}
793
1018
  self.pos = {}
@@ -826,7 +1051,6 @@ class lineageTree:
826
1051
  M = corres[pred]
827
1052
  self.predecessor[C] = [M]
828
1053
  self.successor.setdefault(M, []).append(C)
829
- self.edges.add((M, C))
830
1054
  self.time_edges.setdefault(t, set()).add((M, C))
831
1055
  self.lin.setdefault(lin_id, []).append(C)
832
1056
  self.C_lin[C] = lin_id
@@ -924,7 +1148,12 @@ class lineageTree:
924
1148
 
925
1149
  # make sure these are all named liked they are in tmp_data (or change dictionary above)
926
1150
  self.name = {}
927
- self.volume = {}
1151
+ if "cell_volume" in tmp_data:
1152
+ self.volume = {}
1153
+ if "cell_fate" in tmp_data:
1154
+ self.fates = {}
1155
+ if "cell_barycenter" in tmp_data:
1156
+ self.pos = {}
928
1157
  self.lT2pkl = {}
929
1158
  self.pkl2lT = {}
930
1159
  self.contact = {}
@@ -955,8 +1184,11 @@ class lineageTree:
955
1184
  if "cell_volume" in tmp_data:
956
1185
  self.volume[unique_id] = tmp_data["cell_volume"].get(n, 0.0)
957
1186
  if "cell_fate" in tmp_data:
958
- self.fates = {}
959
1187
  self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
1188
+ if "cell_barycenter" in tmp_data:
1189
+ self.pos[unique_id] = tmp_data["cell_barycenter"].get(
1190
+ n, np.zeros(3)
1191
+ )
960
1192
 
961
1193
  unique_id += 1
962
1194
  if do_surf:
@@ -975,9 +1207,7 @@ class lineageTree:
975
1207
  self.successor[new_id] = [
976
1208
  self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
977
1209
  ]
978
- self.edges.update(
979
- [(new_id, ni) for ni in self.successor[new_id]]
980
- )
1210
+
981
1211
  for ni in self.successor[new_id]:
982
1212
  self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
983
1213
 
@@ -986,30 +1216,43 @@ class lineageTree:
986
1216
  self.max_id = unique_id
987
1217
 
988
1218
  # do this in the end of the process, skip lineage tree and whatever is stored already
1219
+ discard = {
1220
+ "cell_volume",
1221
+ "cell_fate",
1222
+ "cell_barycenter",
1223
+ "cell_contact_surface",
1224
+ "cell_lineage",
1225
+ "all_cells",
1226
+ "cell_history",
1227
+ "problematic_cells",
1228
+ "cell_labels_in_time",
1229
+ }
1230
+ self.specific_properties = []
989
1231
  for prop_name, prop_values in tmp_data.items():
990
- if hasattr(self, prop_name):
991
- continue
992
- else:
1232
+ if not (prop_name in discard or hasattr(self, prop_name)):
993
1233
  if isinstance(prop_values, dict):
994
1234
  dictionary = {
995
- self.pkl2lT[k]: v for k, v in prop_values.items()
1235
+ self.pkl2lT.get(k, -1): v
1236
+ for k, v in prop_values.items()
996
1237
  }
997
1238
  # is it a regular dictionary or a dictionary with dictionaries inside?
998
1239
  for key, value in dictionary.items():
999
1240
  if isinstance(value, dict):
1000
1241
  # rename all ids from old to new
1001
1242
  dictionary[key] = {
1002
- self.pkl2lT[k]: v for k, v in value
1243
+ self.pkl2lT.get(k, -1): v
1244
+ for k, v in value.items()
1003
1245
  }
1004
1246
  self.__dict__[prop_name] = dictionary
1247
+ self.specific_properties.append(prop_name)
1005
1248
  # is any of this necessary? Or does it mean it anyways does not contain
1006
1249
  # information about the id and a simple else: is enough?
1007
1250
  elif (
1008
- prop_values.isinstance(set)
1009
- or prop_values.isinstance(list)
1010
- or prop_values.isinstance(np.array)
1251
+ isinstance(prop_values, (list, set, np.ndarray))
1252
+ and prop_name not in []
1011
1253
  ):
1012
1254
  self.__dict__[prop_name] = prop_values
1255
+ self.specific_properties.append(prop_name)
1013
1256
 
1014
1257
  # what else could it be?
1015
1258
 
@@ -1121,7 +1364,6 @@ class lineageTree:
1121
1364
  p = None
1122
1365
  self.predecessor.setdefault(c, []).append(p)
1123
1366
  self.successor.setdefault(p, []).append(c)
1124
- self.edges.add((p, c))
1125
1367
  self.time_edges.setdefault(t - 1, set()).add((p, c))
1126
1368
  self.max_id = unique_id
1127
1369
 
@@ -1209,7 +1451,6 @@ class lineageTree:
1209
1451
  p = None
1210
1452
  self.predecessor.setdefault(c, []).append(p)
1211
1453
  self.successor.setdefault(p, []).append(c)
1212
- self.edges.add((p, c))
1213
1454
  self.time_edges.setdefault(t - 1, set()).add((p, c))
1214
1455
  self.max_id = unique_id
1215
1456
 
@@ -1233,7 +1474,6 @@ class lineageTree:
1233
1474
  self.time_edges = {}
1234
1475
  unique_id = 0
1235
1476
  self.nodes = set()
1236
- self.edges = set()
1237
1477
  self.successor = {}
1238
1478
  self.predecessor = {}
1239
1479
  self.pos = {}
@@ -1293,7 +1533,6 @@ class lineageTree:
1293
1533
  M = self.time_id[(t - 1, M_id)]
1294
1534
  self.successor.setdefault(M, []).append(C)
1295
1535
  self.predecessor.setdefault(C, []).append(M)
1296
- self.edges.add((M, C))
1297
1536
  self.time_edges[t].add((M, C))
1298
1537
  else:
1299
1538
  if M_id != -1:
@@ -1330,7 +1569,6 @@ class lineageTree:
1330
1569
 
1331
1570
  mr = MastodonReader(path)
1332
1571
  spots, links = mr.read_tables()
1333
- mr.read_tags(spots, links)
1334
1572
 
1335
1573
  self.node_name = {}
1336
1574
 
@@ -1350,7 +1588,6 @@ class lineageTree:
1350
1588
  target = e.target_idx
1351
1589
  self.predecessor.setdefault(target, []).append(source)
1352
1590
  self.successor.setdefault(source, []).append(target)
1353
- self.edges.add((source, target))
1354
1591
  self.time_edges.setdefault(self.time[source], set()).add(
1355
1592
  (source, target)
1356
1593
  )
@@ -1385,14 +1622,13 @@ class lineageTree:
1385
1622
  self.nodes.add(unique_id)
1386
1623
  self.time[unique_id] = t
1387
1624
  self.node_name[unique_id] = spot[1]
1388
- self.pos[unique_id] = np.array([x, y, z])
1625
+ self.pos[unique_id] = np.array([x, y, z], dtype=float)
1389
1626
 
1390
1627
  for link in links:
1391
1628
  source = int(float(link[4]))
1392
1629
  target = int(float(link[5]))
1393
1630
  self.predecessor.setdefault(target, []).append(source)
1394
1631
  self.successor.setdefault(source, []).append(target)
1395
- self.edges.add((source, target))
1396
1632
  self.time_edges.setdefault(self.time[source], set()).add(
1397
1633
  (source, target)
1398
1634
  )
@@ -1447,23 +1683,24 @@ class lineageTree:
1447
1683
  if attr in self.xml_attributes:
1448
1684
  self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
1449
1685
 
1450
- self.edges = set()
1451
1686
  tracks = {}
1452
1687
  self.successor = {}
1453
1688
  self.predecessor = {}
1454
1689
  self.track_name = {}
1455
1690
  for track in AllTracks:
1456
1691
  if "TRACK_DURATION" in track.attrib:
1457
- t_id, _ = int(track.attrib["TRACK_ID"]), float(
1458
- track.attrib["TRACK_DURATION"]
1692
+ t_id, _ = (
1693
+ int(track.attrib["TRACK_ID"]),
1694
+ float(track.attrib["TRACK_DURATION"]),
1459
1695
  )
1460
1696
  else:
1461
1697
  t_id = int(track.attrib["TRACK_ID"])
1462
1698
  t_name = track.attrib["name"]
1463
1699
  tracks[t_id] = []
1464
1700
  for edge in track:
1465
- s, t = int(edge.attrib["SPOT_SOURCE_ID"]), int(
1466
- edge.attrib["SPOT_TARGET_ID"]
1701
+ s, t = (
1702
+ int(edge.attrib["SPOT_SOURCE_ID"]),
1703
+ int(edge.attrib["SPOT_TARGET_ID"]),
1467
1704
  )
1468
1705
  if s in self.nodes and t in self.nodes:
1469
1706
  if self.time[s] > self.time[t]:
@@ -1473,7 +1710,6 @@ class lineageTree:
1473
1710
  self.track_name[s] = t_name
1474
1711
  self.track_name[t] = t_name
1475
1712
  tracks[t_id].append((s, t))
1476
- self.edges.add((s, t))
1477
1713
  self.t_b = min(self.time_nodes.keys())
1478
1714
  self.t_e = max(self.time_nodes.keys())
1479
1715
 
@@ -1511,7 +1747,7 @@ class lineageTree:
1511
1747
  curr_c = to_treat.pop()
1512
1748
  number_sequence.append(curr_c)
1513
1749
  pos_sequence += list(self.pos[curr_c])
1514
- if self.successor.get(curr_c, []) == []:
1750
+ if self[curr_c] == []:
1515
1751
  number_sequence.append(-1)
1516
1752
  elif len(self.successor[curr_c]) == 1:
1517
1753
  to_treat += self.successor[curr_c]
@@ -1673,7 +1909,6 @@ class lineageTree:
1673
1909
  self.time_edges = time_edges
1674
1910
  self.pos = pos
1675
1911
  self.nodes = set(nodes)
1676
- self.edges = set(edges)
1677
1912
  self.t_b = min(time_nodes.keys())
1678
1913
  self.t_e = max(time_nodes.keys())
1679
1914
  self.is_root = is_root
@@ -1693,7 +1928,7 @@ class lineageTree:
1693
1928
  f.close()
1694
1929
 
1695
1930
  @classmethod
1696
- def load(clf, fname: str):
1931
+ def load(clf, fname: str, rm_empty_lists=True):
1697
1932
  """
1698
1933
  Loading a lineage tree from a ".lT" file.
1699
1934
 
@@ -1706,6 +1941,18 @@ class lineageTree:
1706
1941
  with open(fname, "br") as f:
1707
1942
  lT = pkl.load(f)
1708
1943
  f.close()
1944
+ if rm_empty_lists:
1945
+ if [] in lT.successor.values():
1946
+ for node, succ in lT.successor.items():
1947
+ if succ == []:
1948
+ lT.successor.pop(node)
1949
+ if [] in lT.predecessor.values():
1950
+ for node, succ in lT.predecessor.items():
1951
+ if succ == []:
1952
+ lT.predecessor.pop(node)
1953
+ lT.t_e = max(lT.time_nodes)
1954
+ lT.t_b = min(lT.time_nodes)
1955
+ warnings.warn("Empty lists have been removed")
1709
1956
  return lT
1710
1957
 
1711
1958
  def get_idx3d(self, t: int) -> tuple:
@@ -1787,7 +2034,9 @@ class lineageTree:
1787
2034
 
1788
2035
  return self.Gabriel_graph[t]
1789
2036
 
1790
- def get_predecessors(self, x: int, depth: int = None) -> list:
2037
+ def get_predecessors(
2038
+ self, x: int, depth: int = None, start_time: int = None, end_time=None
2039
+ ) -> list:
1791
2040
  """Computes the predecessors of the node `x` up to
1792
2041
  `depth` predecessors or the begining of the life of `x`.
1793
2042
  The ordered list of ids is returned.
@@ -1798,20 +2047,34 @@ class lineageTree:
1798
2047
  Returns:
1799
2048
  [int, ]: list of ids, the last id is `x`
1800
2049
  """
1801
- cycle = [x]
2050
+ if not start_time:
2051
+ start_time = self.t_b
2052
+ if not end_time:
2053
+ end_time = self.t_e
2054
+ unconstrained_cycle = [x]
2055
+ cycle = [x] if start_time <= self.time[x] <= end_time else []
1802
2056
  acc = 0
1803
2057
  while (
1804
- len(
1805
- self.successor.get(self.predecessor.get(cycle[0], [-1])[0], [])
1806
- )
2058
+ len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
1807
2059
  == 1
1808
2060
  and acc != depth
2061
+ and start_time
2062
+ <= self.time.get(
2063
+ self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
2064
+ )
1809
2065
  ):
1810
- cycle.insert(0, self.predecessor[cycle[0]][0])
2066
+ unconstrained_cycle.insert(
2067
+ 0, self.predecessor[unconstrained_cycle[0]][0]
2068
+ )
1811
2069
  acc += 1
2070
+ if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
2071
+ cycle.insert(0, unconstrained_cycle[0])
2072
+
1812
2073
  return cycle
1813
2074
 
1814
- def get_successors(self, x: int, depth: int = None) -> list:
2075
+ def get_successors(
2076
+ self, x: int, depth: int = None, end_time: int = None
2077
+ ) -> list:
1815
2078
  """Computes the successors of the node `x` up to
1816
2079
  `depth` successors or the end of the life of `x`.
1817
2080
  The ordered list of ids is returned.
@@ -1822,11 +2085,18 @@ class lineageTree:
1822
2085
  Returns:
1823
2086
  [int, ]: list of ids, the first id is `x`
1824
2087
  """
2088
+ if end_time is None:
2089
+ end_time = self.t_e
1825
2090
  cycle = [x]
1826
2091
  acc = 0
1827
- while len(self.successor.get(cycle[-1], [])) == 1 and acc != depth:
2092
+ while (
2093
+ len(self[cycle[-1]]) == 1
2094
+ and acc != depth
2095
+ and self.time[cycle[-1]] < end_time
2096
+ ):
1828
2097
  cycle += self.successor[cycle[-1]]
1829
2098
  acc += 1
2099
+
1830
2100
  return cycle
1831
2101
 
1832
2102
  def get_cycle(
@@ -1835,12 +2105,14 @@ class lineageTree:
1835
2105
  depth: int = None,
1836
2106
  depth_pred: int = None,
1837
2107
  depth_succ: int = None,
2108
+ end_time: int = None,
1838
2109
  ) -> list:
1839
2110
  """Computes the predecessors and successors of the node `x` up to
1840
2111
  `depth_pred` predecessors plus `depth_succ` successors.
1841
2112
  If the value `depth` is provided and not None,
1842
2113
  `depth_pred` and `depth_succ` are overwriten by `depth`.
1843
2114
  The ordered list of ids is returned.
2115
+ If all `depth` are None, the full cycle is returned.
1844
2116
 
1845
2117
  Args:
1846
2118
  x (int): id of the node to compute
@@ -1850,11 +2122,13 @@ class lineageTree:
1850
2122
  Returns:
1851
2123
  [int, ]: list of ids
1852
2124
  """
2125
+ if end_time is None:
2126
+ end_time = self.t_e
1853
2127
  if depth is not None:
1854
2128
  depth_pred = depth_succ = depth
1855
- return self.get_predecessors(x, depth_pred)[:-1] + self.get_successors(
1856
- x, depth_succ
1857
- )
2129
+ return self.get_predecessors(x, depth_pred, end_time=end_time)[
2130
+ :-1
2131
+ ] + self.get_successors(x, depth_succ, end_time=end_time)
1858
2132
 
1859
2133
  @property
1860
2134
  def all_tracks(self):
@@ -1862,6 +2136,29 @@ class lineageTree:
1862
2136
  self._all_tracks = self.get_all_tracks()
1863
2137
  return self._all_tracks
1864
2138
 
2139
+ def get_all_branches_of_node(
2140
+ self, node: int, end_time: int = None
2141
+ ) -> list:
2142
+ """Computes all the tracks of the subtree spawn by a given node.
2143
+ Similar to get_all_tracks().
2144
+
2145
+ Args:
2146
+ node (int, optional): The node that we want to get its branches.
2147
+
2148
+ Returns:
2149
+ ([[int, ...], ...]): list of lists containing track cell ids
2150
+ """
2151
+ if not end_time:
2152
+ end_time = self.t_e
2153
+ branches = [self.get_successors(node)]
2154
+ to_do = self[branches[0][-1]].copy()
2155
+ while to_do:
2156
+ current = to_do.pop()
2157
+ track = self.get_cycle(current, end_time=end_time)
2158
+ branches += [track]
2159
+ to_do.extend(self[track[-1]])
2160
+ return branches
2161
+
1865
2162
  def get_all_tracks(self, force_recompute: bool = False) -> list:
1866
2163
  """Computes all the tracks of a given lineage tree,
1867
2164
  stores it in `self.all_tracks` and returns it.
@@ -1869,17 +2166,42 @@ class lineageTree:
1869
2166
  Returns:
1870
2167
  ([[int, ...], ...]): list of lists containing track cell ids
1871
2168
  """
1872
- if not hasattr(self, "_all_tracks"):
2169
+ if not hasattr(self, "_all_tracks") or force_recompute:
1873
2170
  self._all_tracks = []
1874
- to_do = set(self.nodes)
2171
+ to_do = list(self.roots)
1875
2172
  while len(to_do) != 0:
1876
2173
  current = to_do.pop()
1877
2174
  track = self.get_cycle(current)
1878
2175
  self._all_tracks += [track]
1879
- to_do -= set(track)
2176
+ to_do.extend(self[track[-1]])
1880
2177
  return self._all_tracks
1881
2178
 
1882
- def get_sub_tree(self, x: int, preorder: bool = False) -> list:
2179
+ def get_tracks(self, roots: list = None) -> list:
2180
+ """Computes the tracks given by the list of nodes `roots` and returns it.
2181
+
2182
+ Args:
2183
+ roots (list): list of ids of the roots to be computed
2184
+ Returns:
2185
+ ([[int, ...], ...]): list of lists containing track cell ids
2186
+ """
2187
+ if roots is None:
2188
+ return self.get_all_tracks(force_recompute=True)
2189
+ else:
2190
+ tracks = []
2191
+ to_do = list(roots)
2192
+ while len(to_do) != 0:
2193
+ current = to_do.pop()
2194
+ track = self.get_cycle(current)
2195
+ tracks.append(track)
2196
+ to_do.extend(self[track[-1]])
2197
+ return tracks
2198
+
2199
+ def get_sub_tree(
2200
+ self,
2201
+ x: Union[int, Iterable],
2202
+ end_time: Union[int, None] = None,
2203
+ preorder: bool = False,
2204
+ ) -> list:
1883
2205
  """Computes the list of cells from the subtree spawned by *x*
1884
2206
  The default output order is breadth first traversal.
1885
2207
  Unless preorder is `True` in that case the order is
@@ -1891,16 +2213,24 @@ class lineageTree:
1891
2213
  Returns:
1892
2214
  ([int, ...]): the ordered list of node ids
1893
2215
  """
1894
- to_do = [x]
2216
+ if not end_time:
2217
+ end_time = self.t_e
2218
+ if not isinstance(x, Iterable):
2219
+ to_do = [x]
2220
+ elif isinstance(x, Iterable):
2221
+ to_do = list(x)
1895
2222
  sub_tree = []
1896
- while len(to_do) > 0:
1897
- curr = to_do.pop(0)
2223
+ while to_do:
2224
+ curr = to_do.pop()
1898
2225
  succ = self.successor.get(curr, [])
2226
+ if succ and end_time < self.time.get(curr, end_time):
2227
+ succ = []
2228
+ continue
1899
2229
  if preorder:
1900
2230
  to_do = succ + to_do
1901
2231
  else:
1902
2232
  to_do += succ
1903
- sub_tree += [curr]
2233
+ sub_tree += [curr]
1904
2234
  return sub_tree
1905
2235
 
1906
2236
  def compute_spatial_density(
@@ -1971,6 +2301,70 @@ class lineageTree:
1971
2301
  )
1972
2302
  return self.th_edges
1973
2303
 
2304
+ def main_axes(self, time: int = None):
2305
+ """Finds the main axes for a timepoint.
2306
+ If none will select the timepoint with the highest amound of cells.
2307
+
2308
+ Args:
2309
+ time (int, optional): The timepoint to find the main axes.
2310
+ If None will find the timepoint
2311
+ with the largest number of cells.
2312
+
2313
+ Returns:
2314
+ list: A list that contains the array of eigenvalues and eigenvectors.
2315
+ """
2316
+ if time is None:
2317
+ time = np.argmax(
2318
+ [len(self.time_nodes[t]) for t in range(int(self.t_e))]
2319
+ )
2320
+ pos = np.array([self.pos[node] for node in self.time_nodes[time]])
2321
+ pos = pos - np.mean(pos, axis=0)
2322
+ cov = np.cov(np.array(pos).T)
2323
+ eig_val, eig_vec = np.linalg.eig(cov)
2324
+ srt = np.argsort(eig_val)[::-1]
2325
+ self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
2326
+ return eig_val[srt], eig_vec[:, srt]
2327
+
2328
+ def scale_embryo(self, scale=1000):
2329
+ """Scale the embryo using their eigenvalues.
2330
+
2331
+ Args:
2332
+ scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
2333
+
2334
+ Returns:
2335
+ float: The scale factor.
2336
+ """
2337
+ eig = self.main_axes()[0]
2338
+ return scale / (np.sqrt(eig[0]))
2339
+
2340
+ @staticmethod
2341
+ def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
2342
+ """Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
2343
+ Uses the Rodrigues rotation formula.
2344
+
2345
+ Args:
2346
+ vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
2347
+ vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
2348
+
2349
+ Returns:
2350
+ np.array: The rotation matrix.
2351
+ """
2352
+ vector1 = vector1 / np.linalg.norm(vector1)
2353
+ vector2 = vector2 / np.linalg.norm(vector2)
2354
+ if vector1 @ vector2 == 1:
2355
+ return np.eye(3)
2356
+ angle = np.arccos(vector1 @ vector2)
2357
+ axis = np.cross(vector1, vector2)
2358
+ axis = axis / np.linalg.norm(axis)
2359
+ K = np.array(
2360
+ [
2361
+ [0, -axis[2], axis[1]],
2362
+ [axis[2], 0, -axis[0]],
2363
+ [-axis[1], axis[0], 0],
2364
+ ]
2365
+ )
2366
+ return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
2367
+
1974
2368
  def get_ancestor_at_t(self, n: int, time: int = None):
1975
2369
  """
1976
2370
  Find the id of the ancestor of a give node `n`
@@ -1997,62 +2391,27 @@ class lineageTree:
1997
2391
  ancestor = self.predecessor.get(ancestor, [-1])[0]
1998
2392
  return ancestor
1999
2393
 
2000
- def get_simple_tree(self, r: int, time_resolution: int = 1) -> tuple:
2001
- """
2002
- Get a "simple" version of the tree spawned by the node `r`
2003
- This simple version is just one node per cell (as opposed to
2004
- one node per cell per time-point). The life time duration of
2005
- a cell `c` is stored in `self.cycle_time` and return by this
2006
- function
2394
+ def get_labelled_ancestor(self, node: int):
2395
+ """Finds the first labelled ancestor and returns its ID otherwise returns None
2007
2396
 
2008
2397
  Args:
2009
- r (int): root of the tree to spawn
2010
- time_resolution (float): the time between two consecutive time points
2398
+ node (int): The id of the node
2011
2399
 
2012
2400
  Returns:
2013
- (dict) {m (int): [d1 (int), d2 (int)]}: a adjacency dictionnary
2014
- where the ids are the ids of the cells in the original tree
2015
- at their first time point (except for the cell `r` if it was
2016
- not the first time point).
2017
- (dict) {m (int): duration (float)}: life time duration of the cell `m`
2018
- """
2019
- if not hasattr(self, "cycle_time"):
2020
- self.cycle_time = {}
2021
- out_dict = {}
2022
- to_do = [r]
2023
- while to_do:
2024
- current = to_do.pop()
2025
- cycle = self.get_successors(current)
2026
- _next = self.successor.get(cycle[-1], [])
2027
- if _next:
2028
- out_dict[current] = _next
2029
- to_do.extend(_next)
2030
- self.cycle_time[current] = len(cycle) * time_resolution
2031
- return out_dict, self.cycle_time
2032
-
2033
- @staticmethod
2034
- def __edist_format(adj_dict: dict):
2035
- inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
2036
- roots = set(adj_dict).difference(inv_adj)
2037
- nid2list = {}
2038
- list2nid = {}
2039
- nodes = []
2040
- adj_list = []
2041
- curr_id = 0
2042
- for r in roots:
2043
- to_do = [r]
2044
- while to_do:
2045
- curr = to_do.pop(0)
2046
- nid2list[curr] = curr_id
2047
- list2nid[curr_id] = curr
2048
- nodes.append(curr_id)
2049
- to_do = adj_dict.get(curr, []) + to_do
2050
- curr_id += 1
2051
- adj_list = [
2052
- [nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
2053
- for _id in nodes
2054
- ]
2055
- return nodes, adj_list, list2nid
2401
+ [None,int]: Returns the first ancestor found that has a label otherwise
2402
+ None.
2403
+ """
2404
+ if node not in self.nodes:
2405
+ return None
2406
+ ancestor = node
2407
+ while (
2408
+ self.t_b <= self.time.get(ancestor, self.t_b - 1)
2409
+ and ancestor != -1
2410
+ ):
2411
+ if ancestor in self.labels:
2412
+ return ancestor
2413
+ ancestor = self.predecessor.get(ancestor, [-1])[0]
2414
+ return
2056
2415
 
2057
2416
  def unordered_tree_edit_distances_at_time_t(
2058
2417
  self,
@@ -2060,9 +2419,10 @@ class lineageTree:
2060
2419
  delta: callable = None,
2061
2420
  norm: callable = None,
2062
2421
  recompute: bool = False,
2422
+ end_time: int = None,
2063
2423
  ) -> dict:
2064
2424
  """
2065
- Compute all the pairwise unordered tree edit distances from Zhang 1996 between the trees spawned at time `t`
2425
+ Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
2066
2426
 
2067
2427
  Args:
2068
2428
  t (int): time to look at
@@ -2071,6 +2431,8 @@ class lineageTree:
2071
2431
  of the tree spawned by `n1` and the number of nodes
2072
2432
  of the tree spawned by `n2` as arguments.
2073
2433
  recompute (bool): if True, forces to recompute the distances (default: False)
2434
+ end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
2435
+ will be taken into account.
2074
2436
 
2075
2437
  Returns:
2076
2438
  (dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
@@ -2084,14 +2446,20 @@ class lineageTree:
2084
2446
  for n1, n2 in combinations(roots, 2):
2085
2447
  key = tuple(sorted((n1, n2)))
2086
2448
  self.uted[t][key] = self.unordered_tree_edit_distance(
2087
- n1, n2, delta=delta, norm=norm
2449
+ n1, n2, end_time=end_time
2088
2450
  )
2089
2451
  return self.uted[t]
2090
2452
 
2091
2453
  def unordered_tree_edit_distance(
2092
- self, n1: int, n2: int, delta: callable = None, norm: callable = None
2454
+ self,
2455
+ n1: int,
2456
+ n2: int,
2457
+ end_time: int = None,
2458
+ style="fragmented",
2459
+ node_lengths: tuple = (1, 5, 7),
2093
2460
  ) -> float:
2094
2461
  """
2462
+ TODO: Add option for choosing which tree aproximation should be used (Full, simple, comp)
2095
2463
  Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
2096
2464
  by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
2097
2465
  cost is given by the function delta (see edist doc for more information).
@@ -2101,48 +2469,178 @@ class lineageTree:
2101
2469
  Args:
2102
2470
  n1 (int): id of the first node to compare
2103
2471
  n2 (int): id of the second node to compare
2104
- delta (callable): comparison function (see edist doc for more information)
2105
- norm (callable): norming function that takes the number of nodes
2106
- of the tree spawned by `n1` and the number of nodes
2107
- of the tree spawned by `n2` as arguments.
2472
+ tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
2473
+ Defaults to "fragmented".
2108
2474
 
2109
2475
  Returns:
2110
2476
  (float) The normed unordered tree edit distance
2111
2477
  """
2112
2478
 
2113
- from edist.uted import uted
2114
-
2115
- if delta is None or not callable(delta):
2479
+ tree = tree_style[style].value
2480
+ tree1 = tree(
2481
+ lT=self, node_length=node_lengths, end_time=end_time, root=n1
2482
+ )
2483
+ tree2 = tree(
2484
+ lT=self, node_length=node_lengths, end_time=end_time, root=n2
2485
+ )
2486
+ delta = tree1.delta
2487
+ _, times1 = tree1.tree
2488
+ _, times2 = tree2.tree
2489
+ (
2490
+ nodes1,
2491
+ adj1,
2492
+ corres1,
2493
+ ) = tree1.edist
2494
+ (
2495
+ nodes2,
2496
+ adj2,
2497
+ corres2,
2498
+ ) = tree2.edist
2499
+ if len(nodes1) == len(nodes2) == 0:
2500
+ return 0
2501
+ delta_tmp = partial(
2502
+ delta,
2503
+ corres1=corres1,
2504
+ corres2=corres2,
2505
+ times1=times1,
2506
+ times2=times2,
2507
+ )
2116
2508
 
2117
- def delta(x, y, corres1, corres2, times):
2118
- if x is None or y is None:
2119
- return 1
2120
- len_x = times[corres1[x]]
2121
- len_y = times[corres2[y]]
2122
- return np.abs(len_x - len_y) / (len_x + len_y)
2509
+ return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
2510
+ tree1.get_norm(), tree2.get_norm()
2511
+ )
2123
2512
 
2124
- if norm is None or not callable(norm):
2513
+ def to_simple_networkx(
2514
+ self, node: Union[int, list, set, tuple] = None, start_time: int = 0
2515
+ ):
2516
+ """
2517
+ Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
2518
+ they can be used for visualization or other tasks),
2519
+ so only the start and the end of a branch are calculated, all cells in between are not taken into account.
2520
+ Args:
2521
+ start_time (int): From which timepoints are the graphs to be calculated.
2522
+ For example if start_time is 10, then all trees that begin
2523
+ on tp 10 or before are calculated.
2524
+ returns:
2525
+ G : list(nx.Digraph(),...)
2526
+ pos : list(dict(id:position))
2527
+ """
2125
2528
 
2126
- def norm(x, y):
2127
- return max(len(x), len(y))
2529
+ if node is None:
2530
+ mothers = [
2531
+ root for root in self.roots if self.time[root] <= start_time
2532
+ ]
2533
+ else:
2534
+ mothers = node if isinstance(node, (list, set)) else [node]
2535
+ graph = {}
2536
+ all_nodes = {}
2537
+ all_edges = {}
2538
+ for mom in mothers:
2539
+ edges = set()
2540
+ nodes = set()
2541
+ for branch in self.get_all_branches_of_node(mom):
2542
+ nodes.update((branch[0], branch[-1]))
2543
+ if len(branch) > 1:
2544
+ edges.add((branch[0], branch[-1]))
2545
+ for suc in self[branch[-1]]:
2546
+ edges.add((branch[-1], suc))
2547
+ all_edges[mom] = edges
2548
+ all_nodes[mom] = nodes
2549
+ for i, mother in enumerate(mothers):
2550
+ graph[i] = nx.DiGraph()
2551
+ graph[i].add_nodes_from(all_nodes[mother])
2552
+ graph[i].add_edges_from(all_edges[mother])
2553
+
2554
+ return graph
2555
+
2556
+ def plot_all_lineages(
2557
+ self,
2558
+ starting_point: int = 0,
2559
+ nrows=2,
2560
+ figsize=(10, 15),
2561
+ dpi=70,
2562
+ fontsize=22,
2563
+ figure=None,
2564
+ axes=None,
2565
+ **kwargs,
2566
+ ):
2567
+ """Plots all lineages.
2128
2568
 
2129
- if norm is False:
2569
+ Args:
2570
+ starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
2571
+ For example if start_time is 10, then all trees that begin
2572
+ on tp 10 or before are calculated. Defaults to None.
2573
+ nrows (int): How many rows of plots should be printed.
2574
+ kwargs: args accepted by networkx
2575
+ """
2130
2576
 
2131
- def norm(*args):
2132
- return 1
2577
+ nrows = int(nrows)
2578
+ if nrows < 1 or not nrows:
2579
+ nrows = 1
2580
+ raise Warning("Number of rows has to be at least 1")
2133
2581
 
2134
- simple_tree_1, _ = self.get_simple_tree(n1)
2135
- simple_tree_2, _ = self.get_simple_tree(n2)
2136
- nodes1, adj1, corres1 = self.__edist_format(simple_tree_1)
2137
- nodes2, adj2, corres2 = self.__edist_format(simple_tree_2)
2138
- if len(nodes1) == len(nodes2) == 0:
2139
- return 0
2140
- delta_tmp = partial(
2141
- delta, corres1=corres1, corres2=corres2, times=self.cycle_time
2582
+ graphs = self.to_simple_networkx(start_time=starting_point)
2583
+ ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2584
+ pos = postions_of_nx(self, graphs)
2585
+ figure, axes = plt.subplots(
2586
+ figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2142
2587
  )
2143
- return uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / norm(
2144
- nodes1, nodes2
2588
+ flat_axes = axes.flatten()
2589
+ ax2root = {}
2590
+ for i, graph in enumerate(graphs.values()):
2591
+ nx.draw_networkx(
2592
+ graph,
2593
+ pos[i],
2594
+ with_labels=False,
2595
+ arrows=False,
2596
+ **kwargs,
2597
+ ax=flat_axes[i],
2598
+ )
2599
+ root = [n for n, d in graph.in_degree() if d == 0][0]
2600
+ label = self.labels.get(root, "Unlabeled")
2601
+ xlim = flat_axes[i].get_xlim()
2602
+ ylim = flat_axes[i].get_ylim()
2603
+ x_pos = (xlim[1]) / 10
2604
+ y_pos = ylim[0] + 15
2605
+ ax2root[flat_axes[i]] = root
2606
+ flat_axes[i].text(
2607
+ x_pos,
2608
+ y_pos,
2609
+ label,
2610
+ fontsize=fontsize,
2611
+ color="black",
2612
+ ha="center",
2613
+ va="center",
2614
+ bbox={
2615
+ "facecolor": "white",
2616
+ "edgecolor": "green",
2617
+ "boxstyle": "round",
2618
+ },
2619
+ )
2620
+ [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
2621
+ return figure, axes, ax2root
2622
+
2623
+ def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
2624
+ """Plots the subtree spawn by a node.
2625
+
2626
+ Args:
2627
+ node (int): The id of the node that is going to be plotted.
2628
+ kwargs: args accepted by networkx
2629
+ """
2630
+ graph = self.to_simple_networkx(node)
2631
+ if len(graph) > 1:
2632
+ raise Warning("Please enter only one node")
2633
+ graph = graph[list(graph)[0]]
2634
+ figure, ax = plt.subplots(nrows=1, ncols=1)
2635
+ nx.draw_networkx(
2636
+ graph,
2637
+ hierarchy_pos(graph, self, node),
2638
+ with_labels=False,
2639
+ arrows=False,
2640
+ ax=ax,
2641
+ **kwargs,
2145
2642
  )
2643
+ return figure, ax
2146
2644
 
2147
2645
  # def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
2148
2646
  # metric='euclidian', **kwargs):
@@ -2223,10 +2721,584 @@ class lineageTree:
2223
2721
  to_do.append(_next)
2224
2722
  elif self.time[_next] == t:
2225
2723
  final_nodes.append(_next)
2724
+ if not final_nodes:
2725
+ return list(r)
2226
2726
  return final_nodes
2227
2727
 
2728
+ @staticmethod
2729
+ def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
2730
+ """
2731
+ Calculate the line that centers the band w.
2732
+
2733
+ Args:
2734
+ dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2735
+
2736
+ Returns:
2737
+ (float) Slope
2738
+ (float) intercept of the line
2739
+ """
2740
+ i, j = dist_mat.shape
2741
+ x1 = max(0, i - j) / 2
2742
+ x2 = (i + min(i, j)) / 2
2743
+ y1 = max(0, j - i) / 2
2744
+ y2 = (j + min(i, j)) / 2
2745
+ slope = (y1 - y2) / (x1 - x2)
2746
+ intercept = y1 - slope * x1
2747
+ return slope, intercept
2748
+
2749
+ # Reference: https://github.com/kamperh/lecture_dtw_notebook/blob/main/dtw.ipynb
2750
+ def __dp(
2751
+ self,
2752
+ dist_mat: np.ndarray,
2753
+ start_d: int = 0,
2754
+ back_d: int = 0,
2755
+ fast: bool = False,
2756
+ w: int = 0,
2757
+ centered_band: bool = True,
2758
+ ) -> (((int, int), ...), np.ndarray):
2759
+ """
2760
+ Find DTW minimum cost between two series using dynamic programming.
2761
+
2762
+ Args:
2763
+ dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2764
+ start_d (int): start delay
2765
+ back_d (int): end delay
2766
+ w (int): window constrain
2767
+ slope (float): to calculate window - givem by the function __calculate_diag_line
2768
+ intercept (flost): to calculate window - givem by the function __calculate_diag_line
2769
+ use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
2770
+
2771
+ Returns:
2772
+ (tuple of tuples) Aligment path
2773
+ (matrix) Cost matrix
2774
+ """
2775
+ N, M = dist_mat.shape
2776
+ w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
2777
+
2778
+ if centered_band:
2779
+ slope, intercept = self.__calculate_diag_line(dist_mat)
2780
+ square_root = np.sqrt((slope**2) + 1)
2781
+
2782
+ # Initialize the cost matrix
2783
+ cost_mat = np.full((N + 1, M + 1), np.inf)
2784
+ cost_mat[0, 0] = 0
2785
+
2786
+ # Fill the cost matrix while keeping traceback information
2787
+ traceback_mat = np.zeros((N, M))
2788
+
2789
+ cost_mat[: start_d + 1, 0] = 0
2790
+ cost_mat[0, : start_d + 1] = 0
2791
+
2792
+ cost_mat[N - back_d :, M] = 0
2793
+ cost_mat[N, M - back_d :] = 0
2794
+
2795
+ for i in range(N):
2796
+ for j in range(M):
2797
+ if fast and not centered_band:
2798
+ condition = abs(i - j) <= w_limit
2799
+ elif fast:
2800
+ condition = (
2801
+ abs(slope * i - j + intercept) / square_root <= w_limit
2802
+ )
2803
+ else:
2804
+ condition = True
2805
+
2806
+ if condition:
2807
+ penalty = [
2808
+ cost_mat[i, j], # match (0)
2809
+ cost_mat[i, j + 1], # insertion (1)
2810
+ cost_mat[i + 1, j], # deletion (2)
2811
+ ]
2812
+ i_penalty = np.argmin(penalty)
2813
+ cost_mat[i + 1, j + 1] = (
2814
+ dist_mat[i, j] + penalty[i_penalty]
2815
+ )
2816
+ traceback_mat[i, j] = i_penalty
2817
+
2818
+ min_index1 = np.argmin(cost_mat[N - back_d :, M])
2819
+ min_index2 = np.argmin(cost_mat[N, M - back_d :])
2820
+
2821
+ if (
2822
+ cost_mat[N, M - back_d + min_index2]
2823
+ < cost_mat[N - back_d + min_index1, M]
2824
+ ):
2825
+ i = N - 1
2826
+ j = M - back_d + min_index2 - 1
2827
+ final_cost = cost_mat[i + 1, j + 1]
2828
+ else:
2829
+ i = N - back_d + min_index1 - 1
2830
+ j = M - 1
2831
+ final_cost = cost_mat[i + 1, j + 1]
2832
+
2833
+ path = [(i, j)]
2834
+
2835
+ while (
2836
+ start_d != 0
2837
+ and ((start_d < i and j > 0) or (i > 0 and start_d < j))
2838
+ ) or (start_d == 0 and (i > 0 or j > 0)):
2839
+ tb_type = traceback_mat[i, j]
2840
+ if tb_type == 0:
2841
+ # Match
2842
+ i -= 1
2843
+ j -= 1
2844
+ elif tb_type == 1:
2845
+ # Insertion
2846
+ i -= 1
2847
+ elif tb_type == 2:
2848
+ # Deletion
2849
+ j -= 1
2850
+
2851
+ path.append((i, j))
2852
+
2853
+ # Strip infinity edges from cost_mat before returning
2854
+ cost_mat = cost_mat[1:, 1:]
2855
+ return path[::-1], cost_mat, final_cost
2856
+
2857
+ # Reference: https://github.com/nghiaho12/rigid_transform_3D
2858
+ @staticmethod
2859
+ def __rigid_transform_3D(A, B):
2860
+ assert A.shape == B.shape
2861
+
2862
+ num_rows, num_cols = A.shape
2863
+ if num_rows != 3:
2864
+ raise Exception(
2865
+ f"matrix A is not 3xN, it is {num_rows}x{num_cols}"
2866
+ )
2867
+
2868
+ num_rows, num_cols = B.shape
2869
+ if num_rows != 3:
2870
+ raise Exception(
2871
+ f"matrix B is not 3xN, it is {num_rows}x{num_cols}"
2872
+ )
2873
+
2874
+ # find mean column wise
2875
+ centroid_A = np.mean(A, axis=1)
2876
+ centroid_B = np.mean(B, axis=1)
2877
+
2878
+ # ensure centroids are 3x1
2879
+ centroid_A = centroid_A.reshape(-1, 1)
2880
+ centroid_B = centroid_B.reshape(-1, 1)
2881
+
2882
+ # subtract mean
2883
+ Am = A - centroid_A
2884
+ Bm = B - centroid_B
2885
+
2886
+ H = Am @ np.transpose(Bm)
2887
+
2888
+ # find rotation
2889
+ U, S, Vt = np.linalg.svd(H)
2890
+ R = Vt.T @ U.T
2891
+
2892
+ # special reflection case
2893
+ if np.linalg.det(R) < 0:
2894
+ # print("det(R) < R, reflection detected!, correcting for it ...")
2895
+ Vt[2, :] *= -1
2896
+ R = Vt.T @ U.T
2897
+
2898
+ t = -R @ centroid_A + centroid_B
2899
+
2900
+ return R, t
2901
+
2902
+ def __interpolate(
2903
+ self, track1: list, track2: list, threshold: int
2904
+ ) -> (np.ndarray, np.ndarray):
2905
+ """
2906
+ Interpolate two series that have different lengths
2907
+
2908
+ Args:
2909
+ track1 (list): list of nodes of the first cell cycle to compare
2910
+ track2 (list): list of nodes of the second cell cycle to compare
2911
+ threshold (int): set a maximum number of points a track can have
2912
+
2913
+ Returns:
2914
+ (list of list) x, y, z postions for track1
2915
+ (list of list) x, y, z postions for track2
2916
+ """
2917
+ inter1_pos = []
2918
+ inter2_pos = []
2919
+
2920
+ track1_pos = np.array([self.pos[c_id] for c_id in track1])
2921
+ track2_pos = np.array([self.pos[c_id] for c_id in track2])
2922
+
2923
+ # Both tracks have the same length and size below the threshold - nothing is done
2924
+ if len(track1) == len(track2) and (
2925
+ len(track1) <= threshold or len(track2) <= threshold
2926
+ ):
2927
+ return track1_pos, track2_pos
2928
+ # Both tracks have the same length but one or more sizes are above the threshold
2929
+ elif len(track1) > threshold or len(track2) > threshold:
2930
+ sampling = threshold
2931
+ # Tracks have different lengths and the sizes are below the threshold
2932
+ else:
2933
+ sampling = max(len(track1), len(track2))
2934
+
2935
+ for pos in range(3):
2936
+ track1_interp = InterpolatedUnivariateSpline(
2937
+ np.linspace(0, 1, len(track1_pos[:, pos])),
2938
+ track1_pos[:, pos],
2939
+ k=1,
2940
+ )
2941
+ inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
2942
+
2943
+ track2_interp = InterpolatedUnivariateSpline(
2944
+ np.linspace(0, 1, len(track2_pos[:, pos])),
2945
+ track2_pos[:, pos],
2946
+ k=1,
2947
+ )
2948
+ inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
2949
+
2950
+ return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
2951
+
2952
+ def calculate_dtw(
2953
+ self,
2954
+ nodes1: int,
2955
+ nodes2: int,
2956
+ threshold: int = 1000,
2957
+ regist: bool = True,
2958
+ start_d: int = 0,
2959
+ back_d: int = 0,
2960
+ fast: bool = False,
2961
+ w: int = 0,
2962
+ centered_band: bool = True,
2963
+ cost_mat_p: bool = False,
2964
+ ) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
2965
+ """
2966
+ Calculate DTW distance between two cell cycles
2967
+
2968
+ Args:
2969
+ nodes1 (int): node to compare distance
2970
+ nodes2 (int): node to compare distance
2971
+ threshold: set a maximum number of points a track can have
2972
+ regist (boolean): Rotate and translate trajectories
2973
+ start_d (int): start delay
2974
+ back_d (int): end delay
2975
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
2976
+ w (int): window size
2977
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
2978
+ cost_mat_p (boolean): True if print the not normalized cost matrix
2979
+
2980
+ Returns:
2981
+ (float) DTW distance
2982
+ (tuple of tuples) Aligment path
2983
+ (matrix) Cost matrix
2984
+ (list of lists) pos_cycle1: rotated and translated trajectories positions
2985
+ (list of lists) pos_cycle2: rotated and translated trajectories positions
2986
+ """
2987
+ nodes1_cycle = self.get_cycle(nodes1)
2988
+ nodes2_cycle = self.get_cycle(nodes2)
2989
+
2990
+ interp_cycle1, interp_cycle2 = self.__interpolate(
2991
+ nodes1_cycle, nodes2_cycle, threshold
2992
+ )
2993
+
2994
+ pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
2995
+ pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
2996
+
2997
+ if regist:
2998
+ R, t = self.__rigid_transform_3D(
2999
+ np.transpose(interp_cycle1), np.transpose(interp_cycle2)
3000
+ )
3001
+ pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
3002
+
3003
+ dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
3004
+
3005
+ path, cost_mat, final_cost = self.__dp(
3006
+ dist_mat,
3007
+ start_d,
3008
+ back_d,
3009
+ w=w,
3010
+ fast=fast,
3011
+ centered_band=centered_band,
3012
+ )
3013
+ cost = final_cost / len(path)
3014
+
3015
+ if cost_mat_p:
3016
+ return cost, path, cost_mat, pos_cycle1, pos_cycle2
3017
+ else:
3018
+ return cost, path
3019
+
3020
+ def plot_dtw_heatmap(
3021
+ self,
3022
+ nodes1: int,
3023
+ nodes2: int,
3024
+ threshold: int = 1000,
3025
+ regist: bool = True,
3026
+ start_d: int = 0,
3027
+ back_d: int = 0,
3028
+ fast: bool = False,
3029
+ w: int = 0,
3030
+ centered_band: bool = True,
3031
+ ) -> (float, plt.figure):
3032
+ """
3033
+ Plot DTW cost matrix between two cell cycles in heatmap format
3034
+
3035
+ Args:
3036
+ nodes1 (int): node to compare distance
3037
+ nodes2 (int): node to compare distance
3038
+ start_d (int): start delay
3039
+ back_d (int): end delay
3040
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
3041
+ w (int): window size
3042
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
3043
+
3044
+ Returns:
3045
+ (float) DTW distance
3046
+ (figure) Heatmap of cost matrix with opitimal path
3047
+ """
3048
+ cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
3049
+ nodes1,
3050
+ nodes2,
3051
+ threshold,
3052
+ regist,
3053
+ start_d,
3054
+ back_d,
3055
+ fast,
3056
+ w,
3057
+ centered_band,
3058
+ cost_mat_p=True,
3059
+ )
3060
+
3061
+ fig = plt.figure(figsize=(8, 6))
3062
+ ax = fig.add_subplot(1, 1, 1)
3063
+ im = ax.imshow(
3064
+ cost_mat, cmap="viridis", origin="lower", interpolation="nearest"
3065
+ )
3066
+ plt.colorbar(im)
3067
+ ax.set_title("Heatmap of DTW Cost Matrix")
3068
+ ax.set_xlabel("Tree 1")
3069
+ ax.set_ylabel("tree 2")
3070
+ x_path, y_path = zip(*path)
3071
+ ax.plot(y_path, x_path, color="black")
3072
+
3073
+ return cost, fig
3074
+
3075
+ @staticmethod
3076
+ def __plot_2d(
3077
+ pos_cycle1,
3078
+ pos_cycle2,
3079
+ nodes1,
3080
+ nodes2,
3081
+ ax,
3082
+ x_idx,
3083
+ y_idx,
3084
+ x_label,
3085
+ y_label,
3086
+ ):
3087
+ ax.plot(
3088
+ pos_cycle1[:, x_idx],
3089
+ pos_cycle1[:, y_idx],
3090
+ "-",
3091
+ label=f"root = {nodes1}",
3092
+ )
3093
+ ax.plot(
3094
+ pos_cycle2[:, x_idx],
3095
+ pos_cycle2[:, y_idx],
3096
+ "-",
3097
+ label=f"root = {nodes2}",
3098
+ )
3099
+ ax.set_xlabel(x_label)
3100
+ ax.set_ylabel(y_label)
3101
+
3102
+ def plot_dtw_trajectory(
3103
+ self,
3104
+ nodes1: int,
3105
+ nodes2: int,
3106
+ threshold: int = 1000,
3107
+ regist: bool = True,
3108
+ start_d: int = 0,
3109
+ back_d: int = 0,
3110
+ fast: bool = False,
3111
+ w: int = 0,
3112
+ centered_band: bool = True,
3113
+ projection: str = None,
3114
+ alig: bool = False,
3115
+ ) -> (float, plt.figure):
3116
+ """
3117
+ Plots DTW trajectories aligment between two cell cycles in 2D or 3D
3118
+
3119
+ Args:
3120
+ nodes1 (int): node to compare distance
3121
+ nodes2 (int): node to compare distance
3122
+ threshold (int): set a maximum number of points a track can have
3123
+ regist (boolean): Rotate and translate trajectories
3124
+ start_d (int): start delay
3125
+ back_d (int): end delay
3126
+ w (int): window size
3127
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
3128
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
3129
+ projection (string): specify which 2D to plot ->
3130
+ '3d' : for the 3d visualization
3131
+ 'xy' or None (default) : 2D projection of axis x and y
3132
+ 'xz' : 2D projection of axis x and z
3133
+ 'yz' : 2D projection of axis y and z
3134
+ 'pca' : PCA projection
3135
+ alig (boolean): True to show alignment on plot
3136
+
3137
+ Returns:
3138
+ (float) DTW distance
3139
+ (figue) Trajectories Plot
3140
+ """
3141
+ (
3142
+ distance,
3143
+ alignment,
3144
+ cost_mat,
3145
+ pos_cycle1,
3146
+ pos_cycle2,
3147
+ ) = self.calculate_dtw(
3148
+ nodes1,
3149
+ nodes2,
3150
+ threshold,
3151
+ regist,
3152
+ start_d,
3153
+ back_d,
3154
+ fast,
3155
+ w,
3156
+ centered_band,
3157
+ cost_mat_p=True,
3158
+ )
3159
+
3160
+ fig = plt.figure(figsize=(10, 6))
3161
+
3162
+ if projection == "3d":
3163
+ ax = fig.add_subplot(1, 1, 1, projection="3d")
3164
+ else:
3165
+ ax = fig.add_subplot(1, 1, 1)
3166
+
3167
+ if projection == "3d":
3168
+ ax.plot(
3169
+ pos_cycle1[:, 0],
3170
+ pos_cycle1[:, 1],
3171
+ pos_cycle1[:, 2],
3172
+ "-",
3173
+ label=f"root = {nodes1}",
3174
+ )
3175
+ ax.plot(
3176
+ pos_cycle2[:, 0],
3177
+ pos_cycle2[:, 1],
3178
+ pos_cycle2[:, 2],
3179
+ "-",
3180
+ label=f"root = {nodes2}",
3181
+ )
3182
+ ax.set_ylabel("y position")
3183
+ ax.set_xlabel("x position")
3184
+ ax.set_zlabel("z position")
3185
+ else:
3186
+ if projection == "xy" or projection == "yx" or projection is None:
3187
+ self.__plot_2d(
3188
+ pos_cycle1,
3189
+ pos_cycle2,
3190
+ nodes1,
3191
+ nodes2,
3192
+ ax,
3193
+ 0,
3194
+ 1,
3195
+ "x position",
3196
+ "y position",
3197
+ )
3198
+ elif projection == "xz" or projection == "zx":
3199
+ self.__plot_2d(
3200
+ pos_cycle1,
3201
+ pos_cycle2,
3202
+ nodes1,
3203
+ nodes2,
3204
+ ax,
3205
+ 0,
3206
+ 2,
3207
+ "x position",
3208
+ "z position",
3209
+ )
3210
+ elif projection == "yz" or projection == "zy":
3211
+ self.__plot_2d(
3212
+ pos_cycle1,
3213
+ pos_cycle2,
3214
+ nodes1,
3215
+ nodes2,
3216
+ ax,
3217
+ 1,
3218
+ 2,
3219
+ "y position",
3220
+ "z position",
3221
+ )
3222
+ elif projection == "pca":
3223
+ try:
3224
+ from sklearn.decomposition import PCA
3225
+ except ImportError:
3226
+ Warning(
3227
+ "scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
3228
+ )
3229
+
3230
+ # Apply PCA
3231
+ pca = PCA(n_components=2)
3232
+ pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
3233
+ pos_cycle1_2d = pca.transform(pos_cycle1)
3234
+ pos_cycle2_2d = pca.transform(pos_cycle2)
3235
+
3236
+ ax.plot(
3237
+ pos_cycle1_2d[:, 0],
3238
+ pos_cycle1_2d[:, 1],
3239
+ "-",
3240
+ label=f"root = {nodes1}",
3241
+ )
3242
+ ax.plot(
3243
+ pos_cycle2_2d[:, 0],
3244
+ pos_cycle2_2d[:, 1],
3245
+ "-",
3246
+ label=f"root = {nodes2}",
3247
+ )
3248
+
3249
+ # Set axis labels
3250
+ axes = ["x", "y", "z"]
3251
+ x_label = axes[np.argmax(np.abs(pca.components_[0]))]
3252
+ y_label = axes[np.argmax(np.abs(pca.components_[1]))]
3253
+ x_percent = 100 * (
3254
+ np.max(np.abs(pca.components_[0]))
3255
+ / np.sum(np.abs(pca.components_[0]))
3256
+ )
3257
+ y_percent = 100 * (
3258
+ np.max(np.abs(pca.components_[1]))
3259
+ / np.sum(np.abs(pca.components_[1]))
3260
+ )
3261
+ ax.set_xlabel(f"{x_percent:.0f}% of {x_label} position")
3262
+ ax.set_ylabel(f"{y_percent:.0f}% of {y_label} position")
3263
+ else:
3264
+ raise ValueError(
3265
+ """Error: available projections are:
3266
+ '3d' : for the 3d visualization
3267
+ 'xy' or None (default) : 2D projection of axis x and y
3268
+ 'xz' : 2D projection of axis x and z
3269
+ 'yz' : 2D projection of axis y and z
3270
+ 'pca' : PCA projection"""
3271
+ )
3272
+
3273
+ connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
3274
+
3275
+ for connection in connections:
3276
+ xyz1 = connection[0]
3277
+ xyz2 = connection[1]
3278
+ x_pos = [xyz1[0], xyz2[0]]
3279
+ y_pos = [xyz1[1], xyz2[1]]
3280
+ z_pos = [xyz1[2], xyz2[2]]
3281
+
3282
+ if alig and projection != "pca":
3283
+ if projection == "3d":
3284
+ ax.plot(x_pos, y_pos, z_pos, "k--", color="grey")
3285
+ else:
3286
+ ax.plot(x_pos, y_pos, "k--", color="grey")
3287
+
3288
+ ax.set_aspect("equal")
3289
+ ax.legend()
3290
+ fig.tight_layout()
3291
+
3292
+ if alig and projection == "pca":
3293
+ warnings.warn(
3294
+ "Error: not possible to show alignment in PCA projection !",
3295
+ UserWarning,
3296
+ )
3297
+
3298
+ return distance, fig
3299
+
2228
3300
  def first_labelling(self):
2229
- self.labels = {i:"Enter_Label" for i in self.time_nodes[0]}
3301
+ self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
2230
3302
 
2231
3303
  def __init__(
2232
3304
  self,
@@ -2259,12 +3331,12 @@ class lineageTree:
2259
3331
  'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
2260
3332
  default is 'binary'
2261
3333
  """
3334
+ self.name = name
2262
3335
  self.time_nodes = {}
2263
3336
  self.time_edges = {}
2264
3337
  self.max_id = -1
2265
3338
  self.next_id = []
2266
3339
  self.nodes = set()
2267
- self.edges = set()
2268
3340
  self.successor = {}
2269
3341
  self.predecessor = {}
2270
3342
  self.pos = {}
@@ -2272,40 +3344,57 @@ class lineageTree:
2272
3344
  self.time = {}
2273
3345
  self.kdtrees = {}
2274
3346
  self.spatial_density = {}
2275
- self.progeny = {}
2276
- self.labels = {}
2277
- if xml_attributes is None:
2278
- self.xml_attributes = []
2279
- else:
2280
- self.xml_attributes = xml_attributes
2281
- file_type = file_type.lower()
2282
- if file_type == "tgmm":
2283
- self.read_tgmm_xml(file_format, tb, te, z_mult)
2284
- self.t_b = tb
2285
- self.t_e = te
2286
- elif file_type == "mamut" or file_type == "trackmate":
2287
- self.read_from_mamut_xml(file_format)
2288
- elif file_type == "celegans":
2289
- self.read_from_txt_for_celegans(file_format)
2290
- elif file_type == "celegans_cao":
2291
- self.read_from_txt_for_celegans_CAO(
2292
- file_format, reorder=reorder, shape=shape, raw_size=raw_size
2293
- )
2294
- elif file_type == "mastodon":
2295
- if isinstance(file_format, list) and len(file_format) == 2:
2296
- self.read_from_mastodon_csv(file_format)
3347
+ if file_type and file_format:
3348
+ if xml_attributes is None:
3349
+ self.xml_attributes = []
2297
3350
  else:
2298
- if isinstance(file_format, list):
2299
- file_format = file_format[0]
2300
- self.read_from_mastodon(file_format, name)
2301
- elif file_type == "astec":
2302
- self.read_from_ASTEC(file_format, eigen)
2303
- elif file_type == "csv":
2304
- self.read_from_csv(file_format, z_mult, link=1, delim=delim)
2305
- elif file_format and file_format.endswith(".lT"):
2306
- with open(file_format, "br") as f:
2307
- tmp = pkl.load(f)
2308
- f.close()
2309
- self.__dict__.update(tmp.__dict__)
2310
- elif file_format is not None:
2311
- self.read_from_binary(file_format)
3351
+ self.xml_attributes = xml_attributes
3352
+ file_type = file_type.lower()
3353
+ if file_type == "tgmm":
3354
+ self.read_tgmm_xml(file_format, tb, te, z_mult)
3355
+ self.t_b = tb
3356
+ self.t_e = te
3357
+ elif file_type == "mamut" or file_type == "trackmate":
3358
+ self.read_from_mamut_xml(file_format)
3359
+ elif file_type == "celegans":
3360
+ self.read_from_txt_for_celegans(file_format)
3361
+ elif file_type == "celegans_cao":
3362
+ self.read_from_txt_for_celegans_CAO(
3363
+ file_format,
3364
+ reorder=reorder,
3365
+ shape=shape,
3366
+ raw_size=raw_size,
3367
+ )
3368
+ elif file_type == "mastodon":
3369
+ if isinstance(file_format, list) and len(file_format) == 2:
3370
+ self.read_from_mastodon_csv(file_format)
3371
+ else:
3372
+ if isinstance(file_format, list):
3373
+ file_format = file_format[0]
3374
+ self.read_from_mastodon(file_format, name)
3375
+ elif file_type == "astec":
3376
+ self.read_from_ASTEC(file_format, eigen)
3377
+ elif file_type == "csv":
3378
+ self.read_from_csv(file_format, z_mult, link=1, delim=delim)
3379
+ elif file_format and file_format.endswith(".lT"):
3380
+ with open(file_format, "br") as f:
3381
+ tmp = pkl.load(f)
3382
+ f.close()
3383
+ self.__dict__.update(tmp.__dict__)
3384
+ elif file_format is not None:
3385
+ self.read_from_binary(file_format)
3386
+ if self.name is None:
3387
+ try:
3388
+ self.name = Path(file_format).stem
3389
+ except:
3390
+ self.name = Path(file_format[0]).stem
3391
+ if [] in self.successor.values():
3392
+ successors = list(self.successor.keys())
3393
+ for succ in successors:
3394
+ if self[succ] == []:
3395
+ self.successor.pop(succ)
3396
+ if [] in self.predecessor.values():
3397
+ predecessors = list(self.predecessor.keys())
3398
+ for succ in predecessors:
3399
+ if self[succ] == []:
3400
+ self.predecessor.pop(succ)