LineageTree 1.4.4__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,35 +2,244 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
-
6
5
  import csv
7
6
  import os
8
7
  import pickle as pkl
9
8
  import struct
9
+ import warnings
10
10
  import xml.etree.ElementTree as ET
11
+ from collections.abc import Iterable
11
12
  from functools import partial
12
13
  from itertools import combinations
13
14
  from numbers import Number
14
- from typing import TextIO
15
-
15
+ from pathlib import Path
16
+ from typing import TextIO, Union
17
+
18
+ from .tree_styles import tree_style
19
+
20
+ try:
21
+ from edist import uted
22
+ except ImportError:
23
+ warnings.warn(
24
+ "No edist installed therefore you will not be able to compute the tree edit distance."
25
+ )
26
+ import matplotlib.pyplot as plt
27
+ import networkx as nx
16
28
  import numpy as np
17
- from scipy.spatial import Delaunay
29
+ from scipy.interpolate import InterpolatedUnivariateSpline
30
+ from scipy.spatial import Delaunay, distance
18
31
  from scipy.spatial import cKDTree as KDTree
19
32
 
33
+ from .utils import hierarchy_pos, postions_of_nx
34
+
20
35
 
21
36
  class lineageTree:
37
+ def __eq__(self, other):
38
+ if isinstance(other, lineageTree):
39
+ return other.successor == self.successor
40
+ return False
41
+
22
42
  def get_next_id(self):
23
43
  """Computes the next authorized id.
24
44
 
25
45
  Returns:
26
46
  int: next authorized id
27
47
  """
48
+ if self.max_id == -1 and self.nodes:
49
+ self.max_id = max(self.nodes)
28
50
  if self.next_id == []:
29
51
  self.max_id += 1
30
52
  return self.max_id
31
53
  else:
32
54
  return self.next_id.pop()
33
55
 
56
+ def complete_lineage(self, nodes: Union[int, set] = None):
57
+ """Makes all leaf branches longer so that they reach the last timepoint( self.t_e), useful
58
+ for tree edit distance algorithms.
59
+
60
+ Args:
61
+ nodes (int,set), optional): Which trees should be "completed", if None it will complete the whole dataset. Defaults to None.
62
+ """
63
+ if nodes is None:
64
+ nodes = set(self.roots)
65
+ elif isinstance(nodes, int):
66
+ nodes = {nodes}
67
+ for node in nodes:
68
+ sub = set(self.get_sub_tree(node))
69
+ specific_leaves = sub.intersection(self.leaves)
70
+ for leaf in specific_leaves:
71
+ self.add_branch(leaf, self.t_e - self.time[leaf], reverse=True)
72
+
73
+ ###TODO pos can be callable and stay motionless (copy the position of the succ node, use something like optical flow)
74
+ def add_branch(
75
+ self,
76
+ pred: int,
77
+ length: int,
78
+ move_timepoints: bool = True,
79
+ pos: Union[callable, None] = None,
80
+ reverse: bool = False,
81
+ ):
82
+ """Adds a branch of specific length to a node either as a successor or as a predecessor.
83
+ If it is placed on top of a tree all the nodes will move timepoints #length down.
84
+
85
+ Args:
86
+ pred (int): Id of the successor (predecessor if reverse is False)
87
+ length (int): The length of the new branch.
88
+ pos (np.ndarray, optional): The new position of the branch. Defaults to None.
89
+ move_timepoints (bool): Moves the ti Important only if reverse= True
90
+ reverese (bool): If reverse will add a successor branch instead of a predecessor branch
91
+ Returns:
92
+ (int): Id of the first node of the sublineage.
93
+ """
94
+ if length == 0:
95
+ return pred
96
+ if self.predecessor.get(pred) and not reverse:
97
+ raise Warning("Cannot add 2 predecessors to a node")
98
+ time = self.time[pred]
99
+ original = pred
100
+ if not reverse:
101
+ if move_timepoints:
102
+ nodes_to_move = set(self.get_sub_tree(pred))
103
+ new_times = {
104
+ node: self.time[node] + length for node in nodes_to_move
105
+ }
106
+ for node in nodes_to_move:
107
+ old_time = self.time[node]
108
+ self.time_nodes[old_time].remove(node)
109
+ self.time_nodes.setdefault(old_time + length, set()).add(
110
+ node
111
+ )
112
+ self.time.update(new_times)
113
+ for t in range(length - 1, -1, -1):
114
+ _next = self.add_node(
115
+ time + t,
116
+ succ=pred,
117
+ pos=self.pos[original],
118
+ reverse=True,
119
+ )
120
+ pred = _next
121
+ else:
122
+ for t in range(length):
123
+ _next = self.add_node(
124
+ time - t,
125
+ succ=pred,
126
+ pos=self.pos[original],
127
+ reverse=True,
128
+ )
129
+ pred = _next
130
+ else:
131
+ for t in range(length):
132
+ _next = self.add_node(
133
+ time + t, succ=pred, pos=self.pos[original], reverse=False
134
+ )
135
+ pred = _next
136
+ self.labels[pred] = "New branch"
137
+ if self.time[pred] == self.t_b:
138
+ self.roots.add(pred)
139
+ self.labels[pred] = "New branch"
140
+ if original in self.roots and reverse is True:
141
+ self.roots.add(pred)
142
+ self.labels[pred] = "New branch"
143
+ self.roots.remove(original)
144
+ self.labels.pop(original, -1)
145
+ self.t_e = max(self.time_nodes)
146
+ return pred
147
+
148
+ def cut_tree(self, root):
149
+ """It transforms a lineage that has at least 2 divisions into 2 independent lineages,
150
+ that spawn from the time point of the first node. (splits a tree into 2)
151
+
152
+ Args:
153
+ root (int): The id of the node, which will be cut.
154
+
155
+ Returns:
156
+ int: The id of the new tree
157
+ """
158
+ cycle = self.get_successors(root)
159
+ last_cell = cycle[-1]
160
+ if last_cell in self.successor:
161
+ new_lT = self.successor[last_cell].pop()
162
+ self.predecessor.pop(new_lT)
163
+ label_of_root = self.labels.get(cycle[0], cycle[0])
164
+ self.labels[cycle[0]] = f"L-Split {label_of_root}"
165
+ new_tr = self.add_branch(
166
+ new_lT, len(cycle) + 1, move_timepoints=False
167
+ )
168
+ self.roots.add(new_tr)
169
+ self.labels[new_tr] = f"R-Split {label_of_root}"
170
+ return new_tr
171
+ else:
172
+ raise Warning("No division of the branch")
173
+
174
+ def fuse_lineage_tree(
175
+ self,
176
+ l1_root: int,
177
+ l2_root: int,
178
+ length_l1: int = 0,
179
+ length_l2: int = 0,
180
+ length: int = 1,
181
+ ):
182
+ """Fuses 2 lineages from the lineagetree object. The 2 lineages that are to be fused can have a longer
183
+ first node and the node of the resulting lineage can also be longer.
184
+
185
+ Args:
186
+ l1_root (int): Id of the first root
187
+ l2_root (int): Id of the second root
188
+ length_l1 (int, optional): The length of the branch that will be added on top of the first lineage. Defaults to 0, which means only one node will be added.
189
+ length_l2 (int, optional): The length of the branch that will be added on top of the second lineage. Defaults to 0, which means only one node will be added.
190
+ length (int, optional): The length of the branch that will be added on top of the resulting lineage. Defaults to 1.
191
+
192
+ Returns:
193
+ int: The id of the root of the new lineage.
194
+ """
195
+ if self.predecessor.get(l1_root) or self.predecessor.get(l2_root):
196
+ raise ValueError("Please select 2 roots.")
197
+ if self.time[l1_root] != self.time[l2_root]:
198
+ warnings.warn(
199
+ "Using lineagetrees that do not exist in the same timepoint. The operation will continue"
200
+ )
201
+ new_root1 = self.add_branch(l1_root, length_l1)
202
+ new_root2 = self.add_branch(l2_root, length_l2)
203
+ next_root1 = self[new_root1][0]
204
+ self.remove_nodes(new_root1)
205
+ self.successor[new_root2].append(next_root1)
206
+ self.predecessor[next_root1] = [new_root2]
207
+ new_branch = self.add_branch(new_root2, length)
208
+ self.labels[new_branch] = f"Fusion of {new_root1} and {new_root2}"
209
+ return new_branch
210
+
211
+ def copy_lineage(self, root):
212
+ """
213
+ Copies the structure of a tree and makes a new with new nodes.
214
+ Warning does not take into account the predecessor of the root node.
215
+
216
+ Args:
217
+ root (int): The root of the tree to be copied
218
+
219
+ Returns:
220
+ int: The root of the new tree.
221
+ """
222
+ new_nodes = {
223
+ old_node: self.get_next_id()
224
+ for old_node in self.get_sub_tree(root)
225
+ }
226
+ self.nodes.update(new_nodes.values())
227
+ for old_node, new_node in new_nodes.items():
228
+ self.time[new_node] = self.time[old_node]
229
+ succ = self.successor.get(old_node)
230
+ if succ:
231
+ self.successor[new_node] = [new_nodes[n] for n in succ]
232
+ pred = self.predecessor.get(old_node)
233
+ if pred:
234
+ self.predecessor[new_node] = [new_nodes[n] for n in pred]
235
+ self.pos[new_node] = self.pos[old_node] + 0.5
236
+ self.time_nodes[self.time[old_node]].add(new_nodes[old_node])
237
+ new_root = new_nodes[root]
238
+ self.labels[new_root] = f"Copy of {root}"
239
+ if self.time[new_root] == 0:
240
+ self.roots.add(new_root)
241
+ return new_root
242
+
34
243
  def add_node(
35
244
  self,
36
245
  t: int = None,
@@ -55,109 +264,126 @@ class lineageTree:
55
264
  int: id of the new node.
56
265
  """
57
266
  C_next = self.get_next_id() if nid is None else nid
58
- self.time_nodes.setdefault(t, []).append(C_next)
267
+ self.time_nodes.setdefault(t, set()).add(C_next)
59
268
  if succ is not None and not reverse:
60
269
  self.successor.setdefault(succ, []).append(C_next)
61
270
  self.predecessor.setdefault(C_next, []).append(succ)
62
- self.edges.add((succ, C_next))
63
271
  elif succ is not None:
64
272
  self.predecessor.setdefault(succ, []).append(C_next)
65
273
  self.successor.setdefault(C_next, []).append(succ)
66
- self.edges.add((C_next, succ))
67
274
  self.nodes.add(C_next)
68
275
  self.pos[C_next] = pos
69
- self.progeny[C_next] = 0
70
276
  self.time[C_next] = t
71
277
  return C_next
72
278
 
73
- def remove_track(self, track: list):
74
- self.nodes.difference_update(track)
75
- times = {self.time[n] for n in track}
76
- for t in times:
77
- self.time_nodes[t] = list(
78
- set(self.time_nodes[t]).difference(track)
79
- )
80
- for i, c in enumerate(track):
81
- self.pos.pop(c)
82
- if i != 0:
83
- self.predecessor.pop(c)
84
- if i < len(track) - 1:
85
- self.successor.pop(c)
86
- self.time.pop(c)
87
-
88
- def remove_node(self, c: int) -> tuple:
89
- """Removes a node and update the lineageTree accordingly
279
+ def remove_nodes(self, group: Union[int, set, list]):
280
+ """Removes a group of nodes from the LineageTree
90
281
 
91
282
  Args:
92
- c (int): id of the node to remove
93
- """
94
- self.nodes.remove(c)
95
- self.time_nodes[self.time[c]].remove(c)
96
- # self.time_nodes.pop(c, 0)
97
- pos = self.pos.pop(c, 0)
98
- e_to_remove = [e for e in self.edges if c in e]
99
- for e in e_to_remove:
100
- self.edges.remove(e)
101
- if c in self.roots:
102
- self.roots.remove(c)
103
- succ = self.successor.pop(c, [])
104
- s_to_remove = [s for s, ci in self.successor.items() if c in ci]
105
- for s in s_to_remove:
106
- self.successor[s].remove(c)
107
-
108
- pred = self.predecessor.pop(c, [])
109
- p_to_remove = [s for s, ci in self.predecessor.items() if ci == c]
110
- for s in p_to_remove:
111
- self.predecessor[s].remove(c)
112
-
113
- self.time.pop(c, 0)
114
- self.spatial_density.pop(c, 0)
115
-
116
- self.next_id.append(c)
117
- return e_to_remove, succ, s_to_remove, pred, p_to_remove, pos
118
-
119
- def fuse_nodes(self, c1: int, c2: int):
120
- """Fuses together two nodes that belong to the same time point
121
- and update the lineageTree accordingly.
283
+ group (set|list|int): One or more nodes that are to be removed.
284
+ """
285
+ if isinstance(group, int):
286
+ group = {group}
287
+ if isinstance(group, list):
288
+ group = set(group)
289
+ group = group.intersection(self.nodes)
290
+ self.nodes.difference_update(group)
291
+ times = {self.time.pop(n) for n in group}
292
+ for t in times:
293
+ self.time_nodes[t] = set(self.time_nodes[t]).difference(group)
294
+ for node in group:
295
+ self.pos.pop(node)
296
+ if self.predecessor.get(node):
297
+ pred = self.predecessor[node][0]
298
+ siblings = self.successor.pop(pred, [])
299
+ if len(siblings) == 2:
300
+ siblings.remove(node)
301
+ self.successor[pred] = siblings
302
+ self.predecessor.pop(node, [])
303
+ for succ in self.successor.get(node, []):
304
+ self.predecessor.pop(succ, [])
305
+ self.successor.pop(node, [])
306
+ self.labels.pop(node, 0)
307
+ if node in self.roots:
308
+ self.roots.remove(node)
309
+
310
+ def modify_branch(self, node, new_length):
311
+ """Changes the length of a branch, so it adds or removes nodes
312
+ to make the correct length of the cycle.
122
313
 
123
314
  Args:
124
- c1 (int): id of the first node to fuse
125
- c2 (int): id of the second node to fuse
315
+ node (int): Any node of the branch to be modified/
316
+ new_length (int): The new length of the tree.
126
317
  """
127
- (
128
- e_to_remove,
129
- succ,
130
- s_to_remove,
131
- pred,
132
- p_to_remove,
133
- c2_pos,
134
- ) = self.remove_node(c2)
135
- for e in e_to_remove:
136
- new_e = [c1] + [other_c for other_c in e if e != c2]
137
- self.edges.add(new_e)
138
-
139
- self.successor.setdefault(c1, []).extend(succ)
140
- self.predecessor.setdefault(c1, []).extend(pred)
141
-
142
- for s in s_to_remove:
143
- self.successor[s].append(c1)
144
-
145
- for p in p_to_remove:
146
- self.predecessor[p].append(c1)
147
-
148
- self.pos[c1] = np.mean([self.pos[c1], c2_pos], axis=0)
149
- self.progeny[c1] += 1
318
+ if new_length <= 1:
319
+ warnings.warn("New length should be more than 1")
320
+ return None
321
+ cycle = self.get_cycle(node)
322
+ length = len(cycle)
323
+ successors = self.successor.get(cycle[-1])
324
+ if length == 1 and new_length != 1:
325
+ pred = self.predecessor.pop(node, None)
326
+ new_node = self.add_branch(
327
+ node, length=new_length, move_timepoints=True, reverse=False
328
+ )
329
+ if pred:
330
+ self.successor[pred[0]].remove(node)
331
+ self.successor[pred[0]].append(new_node)
332
+ elif self.leaves.intersection(cycle) and new_length < length:
333
+ self.remove_nodes(cycle[new_length:])
334
+ elif new_length < length:
335
+ to_remove = length - new_length
336
+ last_cell = cycle[new_length - 1]
337
+ subtree = self.get_sub_tree(cycle[-1])[1:]
338
+ self.remove_nodes(cycle[new_length:])
339
+ self.successor[last_cell] = successors
340
+ if successors:
341
+ for succ in successors:
342
+ self.predecessor[succ] = [last_cell]
343
+ for node in subtree:
344
+ if node not in cycle[new_length - 1 :]:
345
+ old_time = self.time[node]
346
+ self.time[node] = old_time - to_remove
347
+ self.time_nodes[old_time].remove(node)
348
+ self.time_nodes.setdefault(
349
+ old_time - to_remove, set()
350
+ ).add(node)
351
+ elif length < new_length:
352
+ to_add = new_length - length
353
+ last_cell = cycle[-1]
354
+ self.successor.pop(cycle[-2])
355
+ self.predecessor.pop(last_cell)
356
+ succ = self.add_branch(
357
+ last_cell, length=to_add, move_timepoints=True, reverse=False
358
+ )
359
+ self.predecessor[succ] = [cycle[-2]]
360
+ self.successor[cycle[-2]] = [succ]
361
+ self.time[last_cell] = (
362
+ self.time[self.predecessor[last_cell][0]] + 1
363
+ )
364
+ else:
365
+ return None
150
366
 
151
367
  @property
152
368
  def roots(self):
153
369
  if not hasattr(self, "_roots"):
154
- self._roots = set(self.successor).difference(self.predecessor)
370
+ self._roots = set(self.nodes).difference(self.predecessor)
155
371
  return self._roots
156
372
 
373
+ @property
374
+ def edges(self):
375
+ return {(k, vi) for k, v in self.successor.items() for vi in v}
376
+
157
377
  @property
158
378
  def leaves(self):
159
379
  return set(self.predecessor).difference(self.successor)
160
380
 
381
+ @property
382
+ def labels(self):
383
+ if not hasattr(self, "_labels"):
384
+ self._labels = {i: "Unlabeled" for i in self.roots}
385
+ return self._labels
386
+
161
387
  def _write_header_am(self, f: TextIO, nb_points: int, length: int):
162
388
  """Header for Amira .am files"""
163
389
  f.write("# AmiraMesh 3D ASCII 2.0\n")
@@ -470,7 +696,7 @@ class lineageTree:
470
696
  stroke=svgwrite.rgb(0, 0, 0),
471
697
  )
472
698
  )
473
- for si in self.successor.get(c_cycle[-1], []):
699
+ for si in self[c_cycle[-1]]:
474
700
  x3, y3 = positions[si]
475
701
  dwg.add(
476
702
  dwg.line(
@@ -483,7 +709,7 @@ class lineageTree:
483
709
  else:
484
710
  for c in treated_cells:
485
711
  x1, y1 = positions[c]
486
- for si in self.successor.get(c, []):
712
+ for si in self[c]:
487
713
  x2, y2 = positions[si]
488
714
  if draw_edges:
489
715
  dwg.add(
@@ -535,7 +761,7 @@ class lineageTree:
535
761
  start_time = times_to_consider[0]
536
762
  for t in times_to_consider:
537
763
  for id_mother in self.time_nodes[t]:
538
- ids_daughters = self.successor.get(id_mother, [])
764
+ ids_daughters = self[id_mother]
539
765
  new_ids_daughters = ids_daughters.copy()
540
766
  for _ in range(sampling - 1):
541
767
  tmp = []
@@ -659,12 +885,12 @@ class lineageTree:
659
885
  edges_to_use += list(s_edges)
660
886
  else:
661
887
  edges_to_use = []
888
+ nodes_to_use = set(nodes_to_use)
662
889
  if temporal:
663
- edges_to_use += [
664
- e
665
- for e in self.edges
666
- if e[0] in nodes_to_use and e[1] in nodes_to_use
667
- ]
890
+ for n in nodes_to_use:
891
+ for d in self.successor.get(n, []):
892
+ if d in nodes_to_use:
893
+ edges_to_use.append((n, d))
668
894
  if spatial:
669
895
  edges_to_use += [
670
896
  e for e in s_edges if t_min < self.time[e[0]] < t_max
@@ -787,7 +1013,6 @@ class lineageTree:
787
1013
  self.time_edges = {}
788
1014
  unique_id = 0
789
1015
  self.nodes = set()
790
- self.edges = set()
791
1016
  self.successor = {}
792
1017
  self.predecessor = {}
793
1018
  self.pos = {}
@@ -826,7 +1051,6 @@ class lineageTree:
826
1051
  M = corres[pred]
827
1052
  self.predecessor[C] = [M]
828
1053
  self.successor.setdefault(M, []).append(C)
829
- self.edges.add((M, C))
830
1054
  self.time_edges.setdefault(t, set()).add((M, C))
831
1055
  self.lin.setdefault(lin_id, []).append(C)
832
1056
  self.C_lin[C] = lin_id
@@ -962,8 +1186,9 @@ class lineageTree:
962
1186
  if "cell_fate" in tmp_data:
963
1187
  self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
964
1188
  if "cell_barycenter" in tmp_data:
965
- self.pos[unique_id] = tmp_data["cell_barycenter"].get(n, np.zeros(3))
966
-
1189
+ self.pos[unique_id] = tmp_data["cell_barycenter"].get(
1190
+ n, np.zeros(3)
1191
+ )
967
1192
 
968
1193
  unique_id += 1
969
1194
  if do_surf:
@@ -982,9 +1207,7 @@ class lineageTree:
982
1207
  self.successor[new_id] = [
983
1208
  self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
984
1209
  ]
985
- self.edges.update(
986
- [(new_id, ni) for ni in self.successor[new_id]]
987
- )
1210
+
988
1211
  for ni in self.successor[new_id]:
989
1212
  self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
990
1213
 
@@ -993,31 +1216,43 @@ class lineageTree:
993
1216
  self.max_id = unique_id
994
1217
 
995
1218
  # do this in the end of the process, skip lineage tree and whatever is stored already
996
- pre_treated_prop = [
1219
+ discard = {
997
1220
  "cell_volume",
998
1221
  "cell_fate",
999
1222
  "cell_barycenter",
1000
1223
  "cell_contact_surface",
1001
1224
  "cell_lineage",
1002
- ]
1225
+ "all_cells",
1226
+ "cell_history",
1227
+ "problematic_cells",
1228
+ "cell_labels_in_time",
1229
+ }
1230
+ self.specific_properties = []
1003
1231
  for prop_name, prop_values in tmp_data.items():
1004
- if not (prop_name in pre_treated_prop or hasattr(self, prop_name)):
1232
+ if not (prop_name in discard or hasattr(self, prop_name)):
1005
1233
  if isinstance(prop_values, dict):
1006
1234
  dictionary = {
1007
- self.pkl2lT.get(k, -1): v for k, v in prop_values.items()
1235
+ self.pkl2lT.get(k, -1): v
1236
+ for k, v in prop_values.items()
1008
1237
  }
1009
1238
  # is it a regular dictionary or a dictionary with dictionaries inside?
1010
1239
  for key, value in dictionary.items():
1011
1240
  if isinstance(value, dict):
1012
1241
  # rename all ids from old to new
1013
1242
  dictionary[key] = {
1014
- self.pkl2lT.get(k, -1): v for k, v in value.items()
1243
+ self.pkl2lT.get(k, -1): v
1244
+ for k, v in value.items()
1015
1245
  }
1016
1246
  self.__dict__[prop_name] = dictionary
1247
+ self.specific_properties.append(prop_name)
1017
1248
  # is any of this necessary? Or does it mean it anyways does not contain
1018
1249
  # information about the id and a simple else: is enough?
1019
- elif isinstance(prop_values, (list, set, np.ndarray)):
1250
+ elif (
1251
+ isinstance(prop_values, (list, set, np.ndarray))
1252
+ and prop_name not in []
1253
+ ):
1020
1254
  self.__dict__[prop_name] = prop_values
1255
+ self.specific_properties.append(prop_name)
1021
1256
 
1022
1257
  # what else could it be?
1023
1258
 
@@ -1129,7 +1364,6 @@ class lineageTree:
1129
1364
  p = None
1130
1365
  self.predecessor.setdefault(c, []).append(p)
1131
1366
  self.successor.setdefault(p, []).append(c)
1132
- self.edges.add((p, c))
1133
1367
  self.time_edges.setdefault(t - 1, set()).add((p, c))
1134
1368
  self.max_id = unique_id
1135
1369
 
@@ -1217,7 +1451,6 @@ class lineageTree:
1217
1451
  p = None
1218
1452
  self.predecessor.setdefault(c, []).append(p)
1219
1453
  self.successor.setdefault(p, []).append(c)
1220
- self.edges.add((p, c))
1221
1454
  self.time_edges.setdefault(t - 1, set()).add((p, c))
1222
1455
  self.max_id = unique_id
1223
1456
 
@@ -1241,7 +1474,6 @@ class lineageTree:
1241
1474
  self.time_edges = {}
1242
1475
  unique_id = 0
1243
1476
  self.nodes = set()
1244
- self.edges = set()
1245
1477
  self.successor = {}
1246
1478
  self.predecessor = {}
1247
1479
  self.pos = {}
@@ -1301,7 +1533,6 @@ class lineageTree:
1301
1533
  M = self.time_id[(t - 1, M_id)]
1302
1534
  self.successor.setdefault(M, []).append(C)
1303
1535
  self.predecessor.setdefault(C, []).append(M)
1304
- self.edges.add((M, C))
1305
1536
  self.time_edges[t].add((M, C))
1306
1537
  else:
1307
1538
  if M_id != -1:
@@ -1338,7 +1569,6 @@ class lineageTree:
1338
1569
 
1339
1570
  mr = MastodonReader(path)
1340
1571
  spots, links = mr.read_tables()
1341
- mr.read_tags(spots, links)
1342
1572
 
1343
1573
  self.node_name = {}
1344
1574
 
@@ -1358,7 +1588,6 @@ class lineageTree:
1358
1588
  target = e.target_idx
1359
1589
  self.predecessor.setdefault(target, []).append(source)
1360
1590
  self.successor.setdefault(source, []).append(target)
1361
- self.edges.add((source, target))
1362
1591
  self.time_edges.setdefault(self.time[source], set()).add(
1363
1592
  (source, target)
1364
1593
  )
@@ -1393,14 +1622,13 @@ class lineageTree:
1393
1622
  self.nodes.add(unique_id)
1394
1623
  self.time[unique_id] = t
1395
1624
  self.node_name[unique_id] = spot[1]
1396
- self.pos[unique_id] = np.array([x, y, z])
1625
+ self.pos[unique_id] = np.array([x, y, z], dtype=float)
1397
1626
 
1398
1627
  for link in links:
1399
1628
  source = int(float(link[4]))
1400
1629
  target = int(float(link[5]))
1401
1630
  self.predecessor.setdefault(target, []).append(source)
1402
1631
  self.successor.setdefault(source, []).append(target)
1403
- self.edges.add((source, target))
1404
1632
  self.time_edges.setdefault(self.time[source], set()).add(
1405
1633
  (source, target)
1406
1634
  )
@@ -1455,23 +1683,24 @@ class lineageTree:
1455
1683
  if attr in self.xml_attributes:
1456
1684
  self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
1457
1685
 
1458
- self.edges = set()
1459
1686
  tracks = {}
1460
1687
  self.successor = {}
1461
1688
  self.predecessor = {}
1462
1689
  self.track_name = {}
1463
1690
  for track in AllTracks:
1464
1691
  if "TRACK_DURATION" in track.attrib:
1465
- t_id, _ = int(track.attrib["TRACK_ID"]), float(
1466
- track.attrib["TRACK_DURATION"]
1692
+ t_id, _ = (
1693
+ int(track.attrib["TRACK_ID"]),
1694
+ float(track.attrib["TRACK_DURATION"]),
1467
1695
  )
1468
1696
  else:
1469
1697
  t_id = int(track.attrib["TRACK_ID"])
1470
1698
  t_name = track.attrib["name"]
1471
1699
  tracks[t_id] = []
1472
1700
  for edge in track:
1473
- s, t = int(edge.attrib["SPOT_SOURCE_ID"]), int(
1474
- edge.attrib["SPOT_TARGET_ID"]
1701
+ s, t = (
1702
+ int(edge.attrib["SPOT_SOURCE_ID"]),
1703
+ int(edge.attrib["SPOT_TARGET_ID"]),
1475
1704
  )
1476
1705
  if s in self.nodes and t in self.nodes:
1477
1706
  if self.time[s] > self.time[t]:
@@ -1481,7 +1710,6 @@ class lineageTree:
1481
1710
  self.track_name[s] = t_name
1482
1711
  self.track_name[t] = t_name
1483
1712
  tracks[t_id].append((s, t))
1484
- self.edges.add((s, t))
1485
1713
  self.t_b = min(self.time_nodes.keys())
1486
1714
  self.t_e = max(self.time_nodes.keys())
1487
1715
 
@@ -1519,7 +1747,7 @@ class lineageTree:
1519
1747
  curr_c = to_treat.pop()
1520
1748
  number_sequence.append(curr_c)
1521
1749
  pos_sequence += list(self.pos[curr_c])
1522
- if self.successor.get(curr_c, []) == []:
1750
+ if self[curr_c] == []:
1523
1751
  number_sequence.append(-1)
1524
1752
  elif len(self.successor[curr_c]) == 1:
1525
1753
  to_treat += self.successor[curr_c]
@@ -1681,7 +1909,6 @@ class lineageTree:
1681
1909
  self.time_edges = time_edges
1682
1910
  self.pos = pos
1683
1911
  self.nodes = set(nodes)
1684
- self.edges = set(edges)
1685
1912
  self.t_b = min(time_nodes.keys())
1686
1913
  self.t_e = max(time_nodes.keys())
1687
1914
  self.is_root = is_root
@@ -1701,7 +1928,7 @@ class lineageTree:
1701
1928
  f.close()
1702
1929
 
1703
1930
  @classmethod
1704
- def load(clf, fname: str):
1931
+ def load(clf, fname: str, rm_empty_lists=True):
1705
1932
  """
1706
1933
  Loading a lineage tree from a ".lT" file.
1707
1934
 
@@ -1714,6 +1941,18 @@ class lineageTree:
1714
1941
  with open(fname, "br") as f:
1715
1942
  lT = pkl.load(f)
1716
1943
  f.close()
1944
+ if rm_empty_lists:
1945
+ if [] in lT.successor.values():
1946
+ for node, succ in lT.successor.items():
1947
+ if succ == []:
1948
+ lT.successor.pop(node)
1949
+ if [] in lT.predecessor.values():
1950
+ for node, succ in lT.predecessor.items():
1951
+ if succ == []:
1952
+ lT.predecessor.pop(node)
1953
+ lT.t_e = max(lT.time_nodes)
1954
+ lT.t_b = min(lT.time_nodes)
1955
+ warnings.warn("Empty lists have been removed")
1717
1956
  return lT
1718
1957
 
1719
1958
  def get_idx3d(self, t: int) -> tuple:
@@ -1795,7 +2034,9 @@ class lineageTree:
1795
2034
 
1796
2035
  return self.Gabriel_graph[t]
1797
2036
 
1798
- def get_predecessors(self, x: int, depth: int = None) -> list:
2037
+ def get_predecessors(
2038
+ self, x: int, depth: int = None, start_time: int = None, end_time=None
2039
+ ) -> list:
1799
2040
  """Computes the predecessors of the node `x` up to
1800
2041
  `depth` predecessors or the begining of the life of `x`.
1801
2042
  The ordered list of ids is returned.
@@ -1806,20 +2047,34 @@ class lineageTree:
1806
2047
  Returns:
1807
2048
  [int, ]: list of ids, the last id is `x`
1808
2049
  """
1809
- cycle = [x]
2050
+ if not start_time:
2051
+ start_time = self.t_b
2052
+ if not end_time:
2053
+ end_time = self.t_e
2054
+ unconstrained_cycle = [x]
2055
+ cycle = [x] if start_time <= self.time[x] <= end_time else []
1810
2056
  acc = 0
1811
2057
  while (
1812
- len(
1813
- self.successor.get(self.predecessor.get(cycle[0], [-1])[0], [])
1814
- )
2058
+ len(self[self.predecessor.get(unconstrained_cycle[0], [-1])[0]])
1815
2059
  == 1
1816
2060
  and acc != depth
2061
+ and start_time
2062
+ <= self.time.get(
2063
+ self.predecessor.get(unconstrained_cycle[0], [-1])[0], -1
2064
+ )
1817
2065
  ):
1818
- cycle.insert(0, self.predecessor[cycle[0]][0])
2066
+ unconstrained_cycle.insert(
2067
+ 0, self.predecessor[unconstrained_cycle[0]][0]
2068
+ )
1819
2069
  acc += 1
2070
+ if start_time <= self.time[unconstrained_cycle[0]] <= end_time:
2071
+ cycle.insert(0, unconstrained_cycle[0])
2072
+
1820
2073
  return cycle
1821
2074
 
1822
- def get_successors(self, x: int, depth: int = None) -> list:
2075
+ def get_successors(
2076
+ self, x: int, depth: int = None, end_time: int = None
2077
+ ) -> list:
1823
2078
  """Computes the successors of the node `x` up to
1824
2079
  `depth` successors or the end of the life of `x`.
1825
2080
  The ordered list of ids is returned.
@@ -1830,11 +2085,18 @@ class lineageTree:
1830
2085
  Returns:
1831
2086
  [int, ]: list of ids, the first id is `x`
1832
2087
  """
2088
+ if end_time is None:
2089
+ end_time = self.t_e
1833
2090
  cycle = [x]
1834
2091
  acc = 0
1835
- while len(self.successor.get(cycle[-1], [])) == 1 and acc != depth:
2092
+ while (
2093
+ len(self[cycle[-1]]) == 1
2094
+ and acc != depth
2095
+ and self.time[cycle[-1]] < end_time
2096
+ ):
1836
2097
  cycle += self.successor[cycle[-1]]
1837
2098
  acc += 1
2099
+
1838
2100
  return cycle
1839
2101
 
1840
2102
  def get_cycle(
@@ -1843,12 +2105,14 @@ class lineageTree:
1843
2105
  depth: int = None,
1844
2106
  depth_pred: int = None,
1845
2107
  depth_succ: int = None,
2108
+ end_time: int = None,
1846
2109
  ) -> list:
1847
2110
  """Computes the predecessors and successors of the node `x` up to
1848
2111
  `depth_pred` predecessors plus `depth_succ` successors.
1849
2112
  If the value `depth` is provided and not None,
1850
2113
  `depth_pred` and `depth_succ` are overwriten by `depth`.
1851
2114
  The ordered list of ids is returned.
2115
+ If all `depth` are None, the full cycle is returned.
1852
2116
 
1853
2117
  Args:
1854
2118
  x (int): id of the node to compute
@@ -1858,11 +2122,13 @@ class lineageTree:
1858
2122
  Returns:
1859
2123
  [int, ]: list of ids
1860
2124
  """
2125
+ if end_time is None:
2126
+ end_time = self.t_e
1861
2127
  if depth is not None:
1862
2128
  depth_pred = depth_succ = depth
1863
- return self.get_predecessors(x, depth_pred)[:-1] + self.get_successors(
1864
- x, depth_succ
1865
- )
2129
+ return self.get_predecessors(x, depth_pred, end_time=end_time)[
2130
+ :-1
2131
+ ] + self.get_successors(x, depth_succ, end_time=end_time)
1866
2132
 
1867
2133
  @property
1868
2134
  def all_tracks(self):
@@ -1870,6 +2136,29 @@ class lineageTree:
1870
2136
  self._all_tracks = self.get_all_tracks()
1871
2137
  return self._all_tracks
1872
2138
 
2139
+ def get_all_branches_of_node(
2140
+ self, node: int, end_time: int = None
2141
+ ) -> list:
2142
+ """Computes all the tracks of the subtree spawn by a given node.
2143
+ Similar to get_all_tracks().
2144
+
2145
+ Args:
2146
+ node (int, optional): The node that we want to get its branches.
2147
+
2148
+ Returns:
2149
+ ([[int, ...], ...]): list of lists containing track cell ids
2150
+ """
2151
+ if not end_time:
2152
+ end_time = self.t_e
2153
+ branches = [self.get_successors(node)]
2154
+ to_do = self[branches[0][-1]].copy()
2155
+ while to_do:
2156
+ current = to_do.pop()
2157
+ track = self.get_cycle(current, end_time=end_time)
2158
+ branches += [track]
2159
+ to_do.extend(self[track[-1]])
2160
+ return branches
2161
+
1873
2162
  def get_all_tracks(self, force_recompute: bool = False) -> list:
1874
2163
  """Computes all the tracks of a given lineage tree,
1875
2164
  stores it in `self.all_tracks` and returns it.
@@ -1877,17 +2166,42 @@ class lineageTree:
1877
2166
  Returns:
1878
2167
  ([[int, ...], ...]): list of lists containing track cell ids
1879
2168
  """
1880
- if not hasattr(self, "_all_tracks"):
2169
+ if not hasattr(self, "_all_tracks") or force_recompute:
1881
2170
  self._all_tracks = []
1882
- to_do = set(self.nodes)
2171
+ to_do = list(self.roots)
1883
2172
  while len(to_do) != 0:
1884
2173
  current = to_do.pop()
1885
2174
  track = self.get_cycle(current)
1886
2175
  self._all_tracks += [track]
1887
- to_do -= set(track)
2176
+ to_do.extend(self[track[-1]])
1888
2177
  return self._all_tracks
1889
2178
 
1890
- def get_sub_tree(self, x: int, preorder: bool = False) -> list:
2179
+ def get_tracks(self, roots: list = None) -> list:
2180
+ """Computes the tracks given by the list of nodes `roots` and returns it.
2181
+
2182
+ Args:
2183
+ roots (list): list of ids of the roots to be computed
2184
+ Returns:
2185
+ ([[int, ...], ...]): list of lists containing track cell ids
2186
+ """
2187
+ if roots is None:
2188
+ return self.get_all_tracks(force_recompute=True)
2189
+ else:
2190
+ tracks = []
2191
+ to_do = list(roots)
2192
+ while len(to_do) != 0:
2193
+ current = to_do.pop()
2194
+ track = self.get_cycle(current)
2195
+ tracks.append(track)
2196
+ to_do.extend(self[track[-1]])
2197
+ return tracks
2198
+
2199
+ def get_sub_tree(
2200
+ self,
2201
+ x: Union[int, Iterable],
2202
+ end_time: Union[int, None] = None,
2203
+ preorder: bool = False,
2204
+ ) -> list:
1891
2205
  """Computes the list of cells from the subtree spawned by *x*
1892
2206
  The default output order is breadth first traversal.
1893
2207
  Unless preorder is `True` in that case the order is
@@ -1899,16 +2213,24 @@ class lineageTree:
1899
2213
  Returns:
1900
2214
  ([int, ...]): the ordered list of node ids
1901
2215
  """
1902
- to_do = [x]
2216
+ if not end_time:
2217
+ end_time = self.t_e
2218
+ if not isinstance(x, Iterable):
2219
+ to_do = [x]
2220
+ elif isinstance(x, Iterable):
2221
+ to_do = list(x)
1903
2222
  sub_tree = []
1904
- while len(to_do) > 0:
1905
- curr = to_do.pop(0)
2223
+ while to_do:
2224
+ curr = to_do.pop()
1906
2225
  succ = self.successor.get(curr, [])
2226
+ if succ and end_time < self.time.get(curr, end_time):
2227
+ succ = []
2228
+ continue
1907
2229
  if preorder:
1908
2230
  to_do = succ + to_do
1909
2231
  else:
1910
2232
  to_do += succ
1911
- sub_tree += [curr]
2233
+ sub_tree += [curr]
1912
2234
  return sub_tree
1913
2235
 
1914
2236
  def compute_spatial_density(
@@ -1979,6 +2301,70 @@ class lineageTree:
1979
2301
  )
1980
2302
  return self.th_edges
1981
2303
 
2304
+ def main_axes(self, time: int = None):
2305
+ """Finds the main axes for a timepoint.
2306
+ If none will select the timepoint with the highest amound of cells.
2307
+
2308
+ Args:
2309
+ time (int, optional): The timepoint to find the main axes.
2310
+ If None will find the timepoint
2311
+ with the largest number of cells.
2312
+
2313
+ Returns:
2314
+ list: A list that contains the array of eigenvalues and eigenvectors.
2315
+ """
2316
+ if time is None:
2317
+ time = np.argmax(
2318
+ [len(self.time_nodes[t]) for t in range(int(self.t_e))]
2319
+ )
2320
+ pos = np.array([self.pos[node] for node in self.time_nodes[time]])
2321
+ pos = pos - np.mean(pos, axis=0)
2322
+ cov = np.cov(np.array(pos).T)
2323
+ eig_val, eig_vec = np.linalg.eig(cov)
2324
+ srt = np.argsort(eig_val)[::-1]
2325
+ self.eig_val, self.eig_vec = eig_val[srt], eig_vec[:, srt]
2326
+ return eig_val[srt], eig_vec[:, srt]
2327
+
2328
+ def scale_embryo(self, scale=1000):
2329
+ """Scale the embryo using their eigenvalues.
2330
+
2331
+ Args:
2332
+ scale (int, optional): The resulting scale you want to achieve. Defaults to 1000.
2333
+
2334
+ Returns:
2335
+ float: The scale factor.
2336
+ """
2337
+ eig = self.main_axes()[0]
2338
+ return scale / (np.sqrt(eig[0]))
2339
+
2340
+ @staticmethod
2341
+ def __rodrigues_rotation_matrix(vector1, vector2=(0, 1, 0)):
2342
+ """Calculates the rodrigues matrix of a dataset. It should use vectors from the find_main_axes(eigenvectors) function of LineagTree.
2343
+ Uses the Rodrigues rotation formula.
2344
+
2345
+ Args:
2346
+ vector1 (list|np.array): The vector that should be rotated to be aligned to the second vector
2347
+ vector2 (list|np.array, optional): The second vector. Defaults to [1,0,0].
2348
+
2349
+ Returns:
2350
+ np.array: The rotation matrix.
2351
+ """
2352
+ vector1 = vector1 / np.linalg.norm(vector1)
2353
+ vector2 = vector2 / np.linalg.norm(vector2)
2354
+ if vector1 @ vector2 == 1:
2355
+ return np.eye(3)
2356
+ angle = np.arccos(vector1 @ vector2)
2357
+ axis = np.cross(vector1, vector2)
2358
+ axis = axis / np.linalg.norm(axis)
2359
+ K = np.array(
2360
+ [
2361
+ [0, -axis[2], axis[1]],
2362
+ [axis[2], 0, -axis[0]],
2363
+ [-axis[1], axis[0], 0],
2364
+ ]
2365
+ )
2366
+ return np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * K @ K
2367
+
1982
2368
  def get_ancestor_at_t(self, n: int, time: int = None):
1983
2369
  """
1984
2370
  Find the id of the ancestor of a give node `n`
@@ -2005,62 +2391,27 @@ class lineageTree:
2005
2391
  ancestor = self.predecessor.get(ancestor, [-1])[0]
2006
2392
  return ancestor
2007
2393
 
2008
- def get_simple_tree(self, r: int, time_resolution: int = 1) -> tuple:
2009
- """
2010
- Get a "simple" version of the tree spawned by the node `r`
2011
- This simple version is just one node per cell (as opposed to
2012
- one node per cell per time-point). The life time duration of
2013
- a cell `c` is stored in `self.cycle_time` and return by this
2014
- function
2394
+ def get_labelled_ancestor(self, node: int):
2395
+ """Finds the first labelled ancestor and returns its ID otherwise returns None
2015
2396
 
2016
2397
  Args:
2017
- r (int): root of the tree to spawn
2018
- time_resolution (float): the time between two consecutive time points
2398
+ node (int): The id of the node
2019
2399
 
2020
2400
  Returns:
2021
- (dict) {m (int): [d1 (int), d2 (int)]}: a adjacency dictionnary
2022
- where the ids are the ids of the cells in the original tree
2023
- at their first time point (except for the cell `r` if it was
2024
- not the first time point).
2025
- (dict) {m (int): duration (float)}: life time duration of the cell `m`
2026
- """
2027
- if not hasattr(self, "cycle_time"):
2028
- self.cycle_time = {}
2029
- out_dict = {}
2030
- to_do = [r]
2031
- while to_do:
2032
- current = to_do.pop()
2033
- cycle = self.get_successors(current)
2034
- _next = self.successor.get(cycle[-1], [])
2035
- if _next:
2036
- out_dict[current] = _next
2037
- to_do.extend(_next)
2038
- self.cycle_time[current] = len(cycle) * time_resolution
2039
- return out_dict, self.cycle_time
2040
-
2041
- @staticmethod
2042
- def __edist_format(adj_dict: dict):
2043
- inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
2044
- roots = set(adj_dict).difference(inv_adj)
2045
- nid2list = {}
2046
- list2nid = {}
2047
- nodes = []
2048
- adj_list = []
2049
- curr_id = 0
2050
- for r in roots:
2051
- to_do = [r]
2052
- while to_do:
2053
- curr = to_do.pop(0)
2054
- nid2list[curr] = curr_id
2055
- list2nid[curr_id] = curr
2056
- nodes.append(curr_id)
2057
- to_do = adj_dict.get(curr, []) + to_do
2058
- curr_id += 1
2059
- adj_list = [
2060
- [nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
2061
- for _id in nodes
2062
- ]
2063
- return nodes, adj_list, list2nid
2401
+ [None,int]: Returns the first ancestor found that has a label otherwise
2402
+ None.
2403
+ """
2404
+ if node not in self.nodes:
2405
+ return None
2406
+ ancestor = node
2407
+ while (
2408
+ self.t_b <= self.time.get(ancestor, self.t_b - 1)
2409
+ and ancestor != -1
2410
+ ):
2411
+ if ancestor in self.labels:
2412
+ return ancestor
2413
+ ancestor = self.predecessor.get(ancestor, [-1])[0]
2414
+ return
2064
2415
 
2065
2416
  def unordered_tree_edit_distances_at_time_t(
2066
2417
  self,
@@ -2068,6 +2419,7 @@ class lineageTree:
2068
2419
  delta: callable = None,
2069
2420
  norm: callable = None,
2070
2421
  recompute: bool = False,
2422
+ end_time: int = None,
2071
2423
  ) -> dict:
2072
2424
  """
2073
2425
  Compute all the pairwise unordered tree edit distances from Zhang 996 between the trees spawned at time `t`
@@ -2079,6 +2431,8 @@ class lineageTree:
2079
2431
  of the tree spawned by `n1` and the number of nodes
2080
2432
  of the tree spawned by `n2` as arguments.
2081
2433
  recompute (bool): if True, forces to recompute the distances (default: False)
2434
+ end_time (int): The final time point the comparison algorithm will take into account. If None all nodes
2435
+ will be taken into account.
2082
2436
 
2083
2437
  Returns:
2084
2438
  (dict) a dictionary that maps a pair of cell ids at time `t` to their unordered tree edit distance
@@ -2092,14 +2446,20 @@ class lineageTree:
2092
2446
  for n1, n2 in combinations(roots, 2):
2093
2447
  key = tuple(sorted((n1, n2)))
2094
2448
  self.uted[t][key] = self.unordered_tree_edit_distance(
2095
- n1, n2, delta=delta, norm=norm
2449
+ n1, n2, end_time=end_time
2096
2450
  )
2097
2451
  return self.uted[t]
2098
2452
 
2099
2453
  def unordered_tree_edit_distance(
2100
- self, n1: int, n2: int, delta: callable = None, norm: callable = None
2454
+ self,
2455
+ n1: int,
2456
+ n2: int,
2457
+ end_time: int = None,
2458
+ style="fragmented",
2459
+ node_lengths: tuple = (1, 5, 7),
2101
2460
  ) -> float:
2102
2461
  """
2462
+ TODO: Add option for choosing which tree aproximation should be used (Full, simple, comp)
2103
2463
  Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
2104
2464
  by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
2105
2465
  cost is given by the function delta (see edist doc for more information).
@@ -2109,48 +2469,178 @@ class lineageTree:
2109
2469
  Args:
2110
2470
  n1 (int): id of the first node to compare
2111
2471
  n2 (int): id of the second node to compare
2112
- delta (callable): comparison function (see edist doc for more information)
2113
- norm (callable): norming function that takes the number of nodes
2114
- of the tree spawned by `n1` and the number of nodes
2115
- of the tree spawned by `n2` as arguments.
2472
+ tree_style ("mini","simple","fragmented","full"): Which tree approximation is going to be used for the comparisons.
2473
+ Defaults to "fragmented".
2116
2474
 
2117
2475
  Returns:
2118
2476
  (float) The normed unordered tree edit distance
2119
2477
  """
2120
2478
 
2121
- from edist.uted import uted
2122
-
2123
- if delta is None or not callable(delta):
2479
+ tree = tree_style[style].value
2480
+ tree1 = tree(
2481
+ lT=self, node_length=node_lengths, end_time=end_time, root=n1
2482
+ )
2483
+ tree2 = tree(
2484
+ lT=self, node_length=node_lengths, end_time=end_time, root=n2
2485
+ )
2486
+ delta = tree1.delta
2487
+ _, times1 = tree1.tree
2488
+ _, times2 = tree2.tree
2489
+ (
2490
+ nodes1,
2491
+ adj1,
2492
+ corres1,
2493
+ ) = tree1.edist
2494
+ (
2495
+ nodes2,
2496
+ adj2,
2497
+ corres2,
2498
+ ) = tree2.edist
2499
+ if len(nodes1) == len(nodes2) == 0:
2500
+ return 0
2501
+ delta_tmp = partial(
2502
+ delta,
2503
+ corres1=corres1,
2504
+ corres2=corres2,
2505
+ times1=times1,
2506
+ times2=times2,
2507
+ )
2124
2508
 
2125
- def delta(x, y, corres1, corres2, times):
2126
- if x is None or y is None:
2127
- return 1
2128
- len_x = times[corres1[x]]
2129
- len_y = times[corres2[y]]
2130
- return np.abs(len_x - len_y) / (len_x + len_y)
2509
+ return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
2510
+ tree1.get_norm(), tree2.get_norm()
2511
+ )
2131
2512
 
2132
- if norm is None or not callable(norm):
2513
+ def to_simple_networkx(
2514
+ self, node: Union[int, list, set, tuple] = None, start_time: int = 0
2515
+ ):
2516
+ """
2517
+ Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
2518
+ they can be used for visualization or other tasks),
2519
+ so only the start and the end of a branch are calculated, all cells in between are not taken into account.
2520
+ Args:
2521
+ start_time (int): From which timepoints are the graphs to be calculated.
2522
+ For example if start_time is 10, then all trees that begin
2523
+ on tp 10 or before are calculated.
2524
+ returns:
2525
+ G : list(nx.Digraph(),...)
2526
+ pos : list(dict(id:position))
2527
+ """
2133
2528
 
2134
- def norm(x, y):
2135
- return max(len(x), len(y))
2529
+ if node is None:
2530
+ mothers = [
2531
+ root for root in self.roots if self.time[root] <= start_time
2532
+ ]
2533
+ else:
2534
+ mothers = node if isinstance(node, (list, set)) else [node]
2535
+ graph = {}
2536
+ all_nodes = {}
2537
+ all_edges = {}
2538
+ for mom in mothers:
2539
+ edges = set()
2540
+ nodes = set()
2541
+ for branch in self.get_all_branches_of_node(mom):
2542
+ nodes.update((branch[0], branch[-1]))
2543
+ if len(branch) > 1:
2544
+ edges.add((branch[0], branch[-1]))
2545
+ for suc in self[branch[-1]]:
2546
+ edges.add((branch[-1], suc))
2547
+ all_edges[mom] = edges
2548
+ all_nodes[mom] = nodes
2549
+ for i, mother in enumerate(mothers):
2550
+ graph[i] = nx.DiGraph()
2551
+ graph[i].add_nodes_from(all_nodes[mother])
2552
+ graph[i].add_edges_from(all_edges[mother])
2553
+
2554
+ return graph
2555
+
2556
+ def plot_all_lineages(
2557
+ self,
2558
+ starting_point: int = 0,
2559
+ nrows=2,
2560
+ figsize=(10, 15),
2561
+ dpi=70,
2562
+ fontsize=22,
2563
+ figure=None,
2564
+ axes=None,
2565
+ **kwargs,
2566
+ ):
2567
+ """Plots all lineages.
2136
2568
 
2137
- if norm is False:
2569
+ Args:
2570
+ starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
2571
+ For example if start_time is 10, then all trees that begin
2572
+ on tp 10 or before are calculated. Defaults to None.
2573
+ nrows (int): How many rows of plots should be printed.
2574
+ kwargs: args accepted by networkx
2575
+ """
2138
2576
 
2139
- def norm(*args):
2140
- return 1
2577
+ nrows = int(nrows)
2578
+ if nrows < 1 or not nrows:
2579
+ nrows = 1
2580
+ raise Warning("Number of rows has to be at least 1")
2141
2581
 
2142
- simple_tree_1, _ = self.get_simple_tree(n1)
2143
- simple_tree_2, _ = self.get_simple_tree(n2)
2144
- nodes1, adj1, corres1 = self.__edist_format(simple_tree_1)
2145
- nodes2, adj2, corres2 = self.__edist_format(simple_tree_2)
2146
- if len(nodes1) == len(nodes2) == 0:
2147
- return 0
2148
- delta_tmp = partial(
2149
- delta, corres1=corres1, corres2=corres2, times=self.cycle_time
2582
+ graphs = self.to_simple_networkx(start_time=starting_point)
2583
+ ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2584
+ pos = postions_of_nx(self, graphs)
2585
+ figure, axes = plt.subplots(
2586
+ figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2150
2587
  )
2151
- return uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / norm(
2152
- nodes1, nodes2
2588
+ flat_axes = axes.flatten()
2589
+ ax2root = {}
2590
+ for i, graph in enumerate(graphs.values()):
2591
+ nx.draw_networkx(
2592
+ graph,
2593
+ pos[i],
2594
+ with_labels=False,
2595
+ arrows=False,
2596
+ **kwargs,
2597
+ ax=flat_axes[i],
2598
+ )
2599
+ root = [n for n, d in graph.in_degree() if d == 0][0]
2600
+ label = self.labels.get(root, "Unlabeled")
2601
+ xlim = flat_axes[i].get_xlim()
2602
+ ylim = flat_axes[i].get_ylim()
2603
+ x_pos = (xlim[1]) / 10
2604
+ y_pos = ylim[0] + 15
2605
+ ax2root[flat_axes[i]] = root
2606
+ flat_axes[i].text(
2607
+ x_pos,
2608
+ y_pos,
2609
+ label,
2610
+ fontsize=fontsize,
2611
+ color="black",
2612
+ ha="center",
2613
+ va="center",
2614
+ bbox={
2615
+ "facecolor": "white",
2616
+ "edgecolor": "green",
2617
+ "boxstyle": "round",
2618
+ },
2619
+ )
2620
+ [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
2621
+ return figure, axes, ax2root
2622
+
2623
+ def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
2624
+ """Plots the subtree spawn by a node.
2625
+
2626
+ Args:
2627
+ node (int): The id of the node that is going to be plotted.
2628
+ kwargs: args accepted by networkx
2629
+ """
2630
+ graph = self.to_simple_networkx(node)
2631
+ if len(graph) > 1:
2632
+ raise Warning("Please enter only one node")
2633
+ graph = graph[list(graph)[0]]
2634
+ figure, ax = plt.subplots(nrows=1, ncols=1)
2635
+ nx.draw_networkx(
2636
+ graph,
2637
+ hierarchy_pos(graph, self, node),
2638
+ with_labels=False,
2639
+ arrows=False,
2640
+ ax=ax,
2641
+ **kwargs,
2153
2642
  )
2643
+ return figure, ax
2154
2644
 
2155
2645
  # def DTW(self, t1, t2, max_w=None, start_delay=None, end_delay=None,
2156
2646
  # metric='euclidian', **kwargs):
@@ -2231,11 +2721,584 @@ class lineageTree:
2231
2721
  to_do.append(_next)
2232
2722
  elif self.time[_next] == t:
2233
2723
  final_nodes.append(_next)
2234
- if not final_nodes: return list(r)
2724
+ if not final_nodes:
2725
+ return list(r)
2235
2726
  return final_nodes
2236
2727
 
2728
+ @staticmethod
2729
+ def __calculate_diag_line(dist_mat: np.ndarray) -> (float, float):
2730
+ """
2731
+ Calculate the line that centers the band w.
2732
+
2733
+ Args:
2734
+ dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2735
+
2736
+ Returns:
2737
+ (float) Slope
2738
+ (float) intercept of the line
2739
+ """
2740
+ i, j = dist_mat.shape
2741
+ x1 = max(0, i - j) / 2
2742
+ x2 = (i + min(i, j)) / 2
2743
+ y1 = max(0, j - i) / 2
2744
+ y2 = (j + min(i, j)) / 2
2745
+ slope = (y1 - y2) / (x1 - x2)
2746
+ intercept = y1 - slope * x1
2747
+ return slope, intercept
2748
+
2749
+ # Reference: https://github.com/kamperh/lecture_dtw_notebook/blob/main/dtw.ipynb
2750
+ def __dp(
2751
+ self,
2752
+ dist_mat: np.ndarray,
2753
+ start_d: int = 0,
2754
+ back_d: int = 0,
2755
+ fast: bool = False,
2756
+ w: int = 0,
2757
+ centered_band: bool = True,
2758
+ ) -> (((int, int), ...), np.ndarray):
2759
+ """
2760
+ Find DTW minimum cost between two series using dynamic programming.
2761
+
2762
+ Args:
2763
+ dist_mat (matrix): distance matrix obtained by the function calculate_dtw
2764
+ start_d (int): start delay
2765
+ back_d (int): end delay
2766
+ w (int): window constrain
2767
+ slope (float): to calculate window - givem by the function __calculate_diag_line
2768
+ intercept (flost): to calculate window - givem by the function __calculate_diag_line
2769
+ use_absolute (boolean): if the window constraing is calculate by the absolute difference between points (uncentered)
2770
+
2771
+ Returns:
2772
+ (tuple of tuples) Aligment path
2773
+ (matrix) Cost matrix
2774
+ """
2775
+ N, M = dist_mat.shape
2776
+ w_limit = max(w, abs(N - M)) # Calculate the Sakoe-Chiba band width
2777
+
2778
+ if centered_band:
2779
+ slope, intercept = self.__calculate_diag_line(dist_mat)
2780
+ square_root = np.sqrt((slope**2) + 1)
2781
+
2782
+ # Initialize the cost matrix
2783
+ cost_mat = np.full((N + 1, M + 1), np.inf)
2784
+ cost_mat[0, 0] = 0
2785
+
2786
+ # Fill the cost matrix while keeping traceback information
2787
+ traceback_mat = np.zeros((N, M))
2788
+
2789
+ cost_mat[: start_d + 1, 0] = 0
2790
+ cost_mat[0, : start_d + 1] = 0
2791
+
2792
+ cost_mat[N - back_d :, M] = 0
2793
+ cost_mat[N, M - back_d :] = 0
2794
+
2795
+ for i in range(N):
2796
+ for j in range(M):
2797
+ if fast and not centered_band:
2798
+ condition = abs(i - j) <= w_limit
2799
+ elif fast:
2800
+ condition = (
2801
+ abs(slope * i - j + intercept) / square_root <= w_limit
2802
+ )
2803
+ else:
2804
+ condition = True
2805
+
2806
+ if condition:
2807
+ penalty = [
2808
+ cost_mat[i, j], # match (0)
2809
+ cost_mat[i, j + 1], # insertion (1)
2810
+ cost_mat[i + 1, j], # deletion (2)
2811
+ ]
2812
+ i_penalty = np.argmin(penalty)
2813
+ cost_mat[i + 1, j + 1] = (
2814
+ dist_mat[i, j] + penalty[i_penalty]
2815
+ )
2816
+ traceback_mat[i, j] = i_penalty
2817
+
2818
+ min_index1 = np.argmin(cost_mat[N - back_d :, M])
2819
+ min_index2 = np.argmin(cost_mat[N, M - back_d :])
2820
+
2821
+ if (
2822
+ cost_mat[N, M - back_d + min_index2]
2823
+ < cost_mat[N - back_d + min_index1, M]
2824
+ ):
2825
+ i = N - 1
2826
+ j = M - back_d + min_index2 - 1
2827
+ final_cost = cost_mat[i + 1, j + 1]
2828
+ else:
2829
+ i = N - back_d + min_index1 - 1
2830
+ j = M - 1
2831
+ final_cost = cost_mat[i + 1, j + 1]
2832
+
2833
+ path = [(i, j)]
2834
+
2835
+ while (
2836
+ start_d != 0
2837
+ and ((start_d < i and j > 0) or (i > 0 and start_d < j))
2838
+ ) or (start_d == 0 and (i > 0 or j > 0)):
2839
+ tb_type = traceback_mat[i, j]
2840
+ if tb_type == 0:
2841
+ # Match
2842
+ i -= 1
2843
+ j -= 1
2844
+ elif tb_type == 1:
2845
+ # Insertion
2846
+ i -= 1
2847
+ elif tb_type == 2:
2848
+ # Deletion
2849
+ j -= 1
2850
+
2851
+ path.append((i, j))
2852
+
2853
+ # Strip infinity edges from cost_mat before returning
2854
+ cost_mat = cost_mat[1:, 1:]
2855
+ return path[::-1], cost_mat, final_cost
2856
+
2857
+ # Reference: https://github.com/nghiaho12/rigid_transform_3D
2858
+ @staticmethod
2859
+ def __rigid_transform_3D(A, B):
2860
+ assert A.shape == B.shape
2861
+
2862
+ num_rows, num_cols = A.shape
2863
+ if num_rows != 3:
2864
+ raise Exception(
2865
+ f"matrix A is not 3xN, it is {num_rows}x{num_cols}"
2866
+ )
2867
+
2868
+ num_rows, num_cols = B.shape
2869
+ if num_rows != 3:
2870
+ raise Exception(
2871
+ f"matrix B is not 3xN, it is {num_rows}x{num_cols}"
2872
+ )
2873
+
2874
+ # find mean column wise
2875
+ centroid_A = np.mean(A, axis=1)
2876
+ centroid_B = np.mean(B, axis=1)
2877
+
2878
+ # ensure centroids are 3x1
2879
+ centroid_A = centroid_A.reshape(-1, 1)
2880
+ centroid_B = centroid_B.reshape(-1, 1)
2881
+
2882
+ # subtract mean
2883
+ Am = A - centroid_A
2884
+ Bm = B - centroid_B
2885
+
2886
+ H = Am @ np.transpose(Bm)
2887
+
2888
+ # find rotation
2889
+ U, S, Vt = np.linalg.svd(H)
2890
+ R = Vt.T @ U.T
2891
+
2892
+ # special reflection case
2893
+ if np.linalg.det(R) < 0:
2894
+ # print("det(R) < R, reflection detected!, correcting for it ...")
2895
+ Vt[2, :] *= -1
2896
+ R = Vt.T @ U.T
2897
+
2898
+ t = -R @ centroid_A + centroid_B
2899
+
2900
+ return R, t
2901
+
2902
+ def __interpolate(
2903
+ self, track1: list, track2: list, threshold: int
2904
+ ) -> (np.ndarray, np.ndarray):
2905
+ """
2906
+ Interpolate two series that have different lengths
2907
+
2908
+ Args:
2909
+ track1 (list): list of nodes of the first cell cycle to compare
2910
+ track2 (list): list of nodes of the second cell cycle to compare
2911
+ threshold (int): set a maximum number of points a track can have
2912
+
2913
+ Returns:
2914
+ (list of list) x, y, z postions for track1
2915
+ (list of list) x, y, z postions for track2
2916
+ """
2917
+ inter1_pos = []
2918
+ inter2_pos = []
2919
+
2920
+ track1_pos = np.array([self.pos[c_id] for c_id in track1])
2921
+ track2_pos = np.array([self.pos[c_id] for c_id in track2])
2922
+
2923
+ # Both tracks have the same length and size below the threshold - nothing is done
2924
+ if len(track1) == len(track2) and (
2925
+ len(track1) <= threshold or len(track2) <= threshold
2926
+ ):
2927
+ return track1_pos, track2_pos
2928
+ # Both tracks have the same length but one or more sizes are above the threshold
2929
+ elif len(track1) > threshold or len(track2) > threshold:
2930
+ sampling = threshold
2931
+ # Tracks have different lengths and the sizes are below the threshold
2932
+ else:
2933
+ sampling = max(len(track1), len(track2))
2934
+
2935
+ for pos in range(3):
2936
+ track1_interp = InterpolatedUnivariateSpline(
2937
+ np.linspace(0, 1, len(track1_pos[:, pos])),
2938
+ track1_pos[:, pos],
2939
+ k=1,
2940
+ )
2941
+ inter1_pos.append(track1_interp(np.linspace(0, 1, sampling)))
2942
+
2943
+ track2_interp = InterpolatedUnivariateSpline(
2944
+ np.linspace(0, 1, len(track2_pos[:, pos])),
2945
+ track2_pos[:, pos],
2946
+ k=1,
2947
+ )
2948
+ inter2_pos.append(track2_interp(np.linspace(0, 1, sampling)))
2949
+
2950
+ return np.column_stack(inter1_pos), np.column_stack(inter2_pos)
2951
+
2952
+ def calculate_dtw(
2953
+ self,
2954
+ nodes1: int,
2955
+ nodes2: int,
2956
+ threshold: int = 1000,
2957
+ regist: bool = True,
2958
+ start_d: int = 0,
2959
+ back_d: int = 0,
2960
+ fast: bool = False,
2961
+ w: int = 0,
2962
+ centered_band: bool = True,
2963
+ cost_mat_p: bool = False,
2964
+ ) -> (float, tuple, np.ndarray, np.ndarray, np.ndarray):
2965
+ """
2966
+ Calculate DTW distance between two cell cycles
2967
+
2968
+ Args:
2969
+ nodes1 (int): node to compare distance
2970
+ nodes2 (int): node to compare distance
2971
+ threshold: set a maximum number of points a track can have
2972
+ regist (boolean): Rotate and translate trajectories
2973
+ start_d (int): start delay
2974
+ back_d (int): end delay
2975
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
2976
+ w (int): window size
2977
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
2978
+ cost_mat_p (boolean): True if print the not normalized cost matrix
2979
+
2980
+ Returns:
2981
+ (float) DTW distance
2982
+ (tuple of tuples) Aligment path
2983
+ (matrix) Cost matrix
2984
+ (list of lists) pos_cycle1: rotated and translated trajectories positions
2985
+ (list of lists) pos_cycle2: rotated and translated trajectories positions
2986
+ """
2987
+ nodes1_cycle = self.get_cycle(nodes1)
2988
+ nodes2_cycle = self.get_cycle(nodes2)
2989
+
2990
+ interp_cycle1, interp_cycle2 = self.__interpolate(
2991
+ nodes1_cycle, nodes2_cycle, threshold
2992
+ )
2993
+
2994
+ pos_cycle1 = np.array([self.pos[c_id] for c_id in nodes1_cycle])
2995
+ pos_cycle2 = np.array([self.pos[c_id] for c_id in nodes2_cycle])
2996
+
2997
+ if regist:
2998
+ R, t = self.__rigid_transform_3D(
2999
+ np.transpose(interp_cycle1), np.transpose(interp_cycle2)
3000
+ )
3001
+ pos_cycle1 = np.transpose(np.dot(R, pos_cycle1.T) + t)
3002
+
3003
+ dist_mat = distance.cdist(pos_cycle1, pos_cycle2, "euclidean")
3004
+
3005
+ path, cost_mat, final_cost = self.__dp(
3006
+ dist_mat,
3007
+ start_d,
3008
+ back_d,
3009
+ w=w,
3010
+ fast=fast,
3011
+ centered_band=centered_band,
3012
+ )
3013
+ cost = final_cost / len(path)
3014
+
3015
+ if cost_mat_p:
3016
+ return cost, path, cost_mat, pos_cycle1, pos_cycle2
3017
+ else:
3018
+ return cost, path
3019
+
3020
+ def plot_dtw_heatmap(
3021
+ self,
3022
+ nodes1: int,
3023
+ nodes2: int,
3024
+ threshold: int = 1000,
3025
+ regist: bool = True,
3026
+ start_d: int = 0,
3027
+ back_d: int = 0,
3028
+ fast: bool = False,
3029
+ w: int = 0,
3030
+ centered_band: bool = True,
3031
+ ) -> (float, plt.figure):
3032
+ """
3033
+ Plot DTW cost matrix between two cell cycles in heatmap format
3034
+
3035
+ Args:
3036
+ nodes1 (int): node to compare distance
3037
+ nodes2 (int): node to compare distance
3038
+ start_d (int): start delay
3039
+ back_d (int): end delay
3040
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
3041
+ w (int): window size
3042
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
3043
+
3044
+ Returns:
3045
+ (float) DTW distance
3046
+ (figure) Heatmap of cost matrix with opitimal path
3047
+ """
3048
+ cost, path, cost_mat, pos_cycle1, pos_cycle2 = self.calculate_dtw(
3049
+ nodes1,
3050
+ nodes2,
3051
+ threshold,
3052
+ regist,
3053
+ start_d,
3054
+ back_d,
3055
+ fast,
3056
+ w,
3057
+ centered_band,
3058
+ cost_mat_p=True,
3059
+ )
3060
+
3061
+ fig = plt.figure(figsize=(8, 6))
3062
+ ax = fig.add_subplot(1, 1, 1)
3063
+ im = ax.imshow(
3064
+ cost_mat, cmap="viridis", origin="lower", interpolation="nearest"
3065
+ )
3066
+ plt.colorbar(im)
3067
+ ax.set_title("Heatmap of DTW Cost Matrix")
3068
+ ax.set_xlabel("Tree 1")
3069
+ ax.set_ylabel("tree 2")
3070
+ x_path, y_path = zip(*path)
3071
+ ax.plot(y_path, x_path, color="black")
3072
+
3073
+ return cost, fig
3074
+
3075
+ @staticmethod
3076
+ def __plot_2d(
3077
+ pos_cycle1,
3078
+ pos_cycle2,
3079
+ nodes1,
3080
+ nodes2,
3081
+ ax,
3082
+ x_idx,
3083
+ y_idx,
3084
+ x_label,
3085
+ y_label,
3086
+ ):
3087
+ ax.plot(
3088
+ pos_cycle1[:, x_idx],
3089
+ pos_cycle1[:, y_idx],
3090
+ "-",
3091
+ label=f"root = {nodes1}",
3092
+ )
3093
+ ax.plot(
3094
+ pos_cycle2[:, x_idx],
3095
+ pos_cycle2[:, y_idx],
3096
+ "-",
3097
+ label=f"root = {nodes2}",
3098
+ )
3099
+ ax.set_xlabel(x_label)
3100
+ ax.set_ylabel(y_label)
3101
+
3102
+ def plot_dtw_trajectory(
3103
+ self,
3104
+ nodes1: int,
3105
+ nodes2: int,
3106
+ threshold: int = 1000,
3107
+ regist: bool = True,
3108
+ start_d: int = 0,
3109
+ back_d: int = 0,
3110
+ fast: bool = False,
3111
+ w: int = 0,
3112
+ centered_band: bool = True,
3113
+ projection: str = None,
3114
+ alig: bool = False,
3115
+ ) -> (float, plt.figure):
3116
+ """
3117
+ Plots DTW trajectories aligment between two cell cycles in 2D or 3D
3118
+
3119
+ Args:
3120
+ nodes1 (int): node to compare distance
3121
+ nodes2 (int): node to compare distance
3122
+ threshold (int): set a maximum number of points a track can have
3123
+ regist (boolean): Rotate and translate trajectories
3124
+ start_d (int): start delay
3125
+ back_d (int): end delay
3126
+ w (int): window size
3127
+ fast (boolean): True if the user wants to run the fast algorithm with window restrains
3128
+ centered_band (boolean): if running the fast algorithm, True if the windown is centered
3129
+ projection (string): specify which 2D to plot ->
3130
+ '3d' : for the 3d visualization
3131
+ 'xy' or None (default) : 2D projection of axis x and y
3132
+ 'xz' : 2D projection of axis x and z
3133
+ 'yz' : 2D projection of axis y and z
3134
+ 'pca' : PCA projection
3135
+ alig (boolean): True to show alignment on plot
3136
+
3137
+ Returns:
3138
+ (float) DTW distance
3139
+ (figue) Trajectories Plot
3140
+ """
3141
+ (
3142
+ distance,
3143
+ alignment,
3144
+ cost_mat,
3145
+ pos_cycle1,
3146
+ pos_cycle2,
3147
+ ) = self.calculate_dtw(
3148
+ nodes1,
3149
+ nodes2,
3150
+ threshold,
3151
+ regist,
3152
+ start_d,
3153
+ back_d,
3154
+ fast,
3155
+ w,
3156
+ centered_band,
3157
+ cost_mat_p=True,
3158
+ )
3159
+
3160
+ fig = plt.figure(figsize=(10, 6))
3161
+
3162
+ if projection == "3d":
3163
+ ax = fig.add_subplot(1, 1, 1, projection="3d")
3164
+ else:
3165
+ ax = fig.add_subplot(1, 1, 1)
3166
+
3167
+ if projection == "3d":
3168
+ ax.plot(
3169
+ pos_cycle1[:, 0],
3170
+ pos_cycle1[:, 1],
3171
+ pos_cycle1[:, 2],
3172
+ "-",
3173
+ label=f"root = {nodes1}",
3174
+ )
3175
+ ax.plot(
3176
+ pos_cycle2[:, 0],
3177
+ pos_cycle2[:, 1],
3178
+ pos_cycle2[:, 2],
3179
+ "-",
3180
+ label=f"root = {nodes2}",
3181
+ )
3182
+ ax.set_ylabel("y position")
3183
+ ax.set_xlabel("x position")
3184
+ ax.set_zlabel("z position")
3185
+ else:
3186
+ if projection == "xy" or projection == "yx" or projection is None:
3187
+ self.__plot_2d(
3188
+ pos_cycle1,
3189
+ pos_cycle2,
3190
+ nodes1,
3191
+ nodes2,
3192
+ ax,
3193
+ 0,
3194
+ 1,
3195
+ "x position",
3196
+ "y position",
3197
+ )
3198
+ elif projection == "xz" or projection == "zx":
3199
+ self.__plot_2d(
3200
+ pos_cycle1,
3201
+ pos_cycle2,
3202
+ nodes1,
3203
+ nodes2,
3204
+ ax,
3205
+ 0,
3206
+ 2,
3207
+ "x position",
3208
+ "z position",
3209
+ )
3210
+ elif projection == "yz" or projection == "zy":
3211
+ self.__plot_2d(
3212
+ pos_cycle1,
3213
+ pos_cycle2,
3214
+ nodes1,
3215
+ nodes2,
3216
+ ax,
3217
+ 1,
3218
+ 2,
3219
+ "y position",
3220
+ "z position",
3221
+ )
3222
+ elif projection == "pca":
3223
+ try:
3224
+ from sklearn.decomposition import PCA
3225
+ except ImportError:
3226
+ Warning(
3227
+ "scikit-learn is not installed, the PCA orientation cannot be used. You can install scikit-learn with pip install"
3228
+ )
3229
+
3230
+ # Apply PCA
3231
+ pca = PCA(n_components=2)
3232
+ pca.fit(np.vstack([pos_cycle1, pos_cycle2]))
3233
+ pos_cycle1_2d = pca.transform(pos_cycle1)
3234
+ pos_cycle2_2d = pca.transform(pos_cycle2)
3235
+
3236
+ ax.plot(
3237
+ pos_cycle1_2d[:, 0],
3238
+ pos_cycle1_2d[:, 1],
3239
+ "-",
3240
+ label=f"root = {nodes1}",
3241
+ )
3242
+ ax.plot(
3243
+ pos_cycle2_2d[:, 0],
3244
+ pos_cycle2_2d[:, 1],
3245
+ "-",
3246
+ label=f"root = {nodes2}",
3247
+ )
3248
+
3249
+ # Set axis labels
3250
+ axes = ["x", "y", "z"]
3251
+ x_label = axes[np.argmax(np.abs(pca.components_[0]))]
3252
+ y_label = axes[np.argmax(np.abs(pca.components_[1]))]
3253
+ x_percent = 100 * (
3254
+ np.max(np.abs(pca.components_[0]))
3255
+ / np.sum(np.abs(pca.components_[0]))
3256
+ )
3257
+ y_percent = 100 * (
3258
+ np.max(np.abs(pca.components_[1]))
3259
+ / np.sum(np.abs(pca.components_[1]))
3260
+ )
3261
+ ax.set_xlabel(f"{x_percent:.0f}% of {x_label} position")
3262
+ ax.set_ylabel(f"{y_percent:.0f}% of {y_label} position")
3263
+ else:
3264
+ raise ValueError(
3265
+ """Error: available projections are:
3266
+ '3d' : for the 3d visualization
3267
+ 'xy' or None (default) : 2D projection of axis x and y
3268
+ 'xz' : 2D projection of axis x and z
3269
+ 'yz' : 2D projection of axis y and z
3270
+ 'pca' : PCA projection"""
3271
+ )
3272
+
3273
+ connections = [[pos_cycle1[i], pos_cycle2[j]] for i, j in alignment]
3274
+
3275
+ for connection in connections:
3276
+ xyz1 = connection[0]
3277
+ xyz2 = connection[1]
3278
+ x_pos = [xyz1[0], xyz2[0]]
3279
+ y_pos = [xyz1[1], xyz2[1]]
3280
+ z_pos = [xyz1[2], xyz2[2]]
3281
+
3282
+ if alig and projection != "pca":
3283
+ if projection == "3d":
3284
+ ax.plot(x_pos, y_pos, z_pos, "k--", color="grey")
3285
+ else:
3286
+ ax.plot(x_pos, y_pos, "k--", color="grey")
3287
+
3288
+ ax.set_aspect("equal")
3289
+ ax.legend()
3290
+ fig.tight_layout()
3291
+
3292
+ if alig and projection == "pca":
3293
+ warnings.warn(
3294
+ "Error: not possible to show alignment in PCA projection !",
3295
+ UserWarning,
3296
+ )
3297
+
3298
+ return distance, fig
3299
+
2237
3300
  def first_labelling(self):
2238
- self.labels = {i:"Enter_Label" for i in self.time_nodes[0]}
3301
+ self.labels = {i: "Unlabeled" for i in self.time_nodes[0]}
2239
3302
 
2240
3303
  def __init__(
2241
3304
  self,
@@ -2268,12 +3331,12 @@ class lineageTree:
2268
3331
  'TGMM, 'ASTEC', MaMuT', 'TrackMate', 'csv', 'celegans', 'binary'
2269
3332
  default is 'binary'
2270
3333
  """
3334
+ self.name = name
2271
3335
  self.time_nodes = {}
2272
3336
  self.time_edges = {}
2273
3337
  self.max_id = -1
2274
3338
  self.next_id = []
2275
3339
  self.nodes = set()
2276
- self.edges = set()
2277
3340
  self.successor = {}
2278
3341
  self.predecessor = {}
2279
3342
  self.pos = {}
@@ -2281,40 +3344,57 @@ class lineageTree:
2281
3344
  self.time = {}
2282
3345
  self.kdtrees = {}
2283
3346
  self.spatial_density = {}
2284
- self.progeny = {}
2285
- self.labels = {}
2286
- if xml_attributes is None:
2287
- self.xml_attributes = []
2288
- else:
2289
- self.xml_attributes = xml_attributes
2290
- file_type = file_type.lower()
2291
- if file_type == "tgmm":
2292
- self.read_tgmm_xml(file_format, tb, te, z_mult)
2293
- self.t_b = tb
2294
- self.t_e = te
2295
- elif file_type == "mamut" or file_type == "trackmate":
2296
- self.read_from_mamut_xml(file_format)
2297
- elif file_type == "celegans":
2298
- self.read_from_txt_for_celegans(file_format)
2299
- elif file_type == "celegans_cao":
2300
- self.read_from_txt_for_celegans_CAO(
2301
- file_format, reorder=reorder, shape=shape, raw_size=raw_size
2302
- )
2303
- elif file_type == "mastodon":
2304
- if isinstance(file_format, list) and len(file_format) == 2:
2305
- self.read_from_mastodon_csv(file_format)
3347
+ if file_type and file_format:
3348
+ if xml_attributes is None:
3349
+ self.xml_attributes = []
2306
3350
  else:
2307
- if isinstance(file_format, list):
2308
- file_format = file_format[0]
2309
- self.read_from_mastodon(file_format, name)
2310
- elif file_type == "astec":
2311
- self.read_from_ASTEC(file_format, eigen)
2312
- elif file_type == "csv":
2313
- self.read_from_csv(file_format, z_mult, link=1, delim=delim)
2314
- elif file_format and file_format.endswith(".lT"):
2315
- with open(file_format, "br") as f:
2316
- tmp = pkl.load(f)
2317
- f.close()
2318
- self.__dict__.update(tmp.__dict__)
2319
- elif file_format is not None:
2320
- self.read_from_binary(file_format)
3351
+ self.xml_attributes = xml_attributes
3352
+ file_type = file_type.lower()
3353
+ if file_type == "tgmm":
3354
+ self.read_tgmm_xml(file_format, tb, te, z_mult)
3355
+ self.t_b = tb
3356
+ self.t_e = te
3357
+ elif file_type == "mamut" or file_type == "trackmate":
3358
+ self.read_from_mamut_xml(file_format)
3359
+ elif file_type == "celegans":
3360
+ self.read_from_txt_for_celegans(file_format)
3361
+ elif file_type == "celegans_cao":
3362
+ self.read_from_txt_for_celegans_CAO(
3363
+ file_format,
3364
+ reorder=reorder,
3365
+ shape=shape,
3366
+ raw_size=raw_size,
3367
+ )
3368
+ elif file_type == "mastodon":
3369
+ if isinstance(file_format, list) and len(file_format) == 2:
3370
+ self.read_from_mastodon_csv(file_format)
3371
+ else:
3372
+ if isinstance(file_format, list):
3373
+ file_format = file_format[0]
3374
+ self.read_from_mastodon(file_format, name)
3375
+ elif file_type == "astec":
3376
+ self.read_from_ASTEC(file_format, eigen)
3377
+ elif file_type == "csv":
3378
+ self.read_from_csv(file_format, z_mult, link=1, delim=delim)
3379
+ elif file_format and file_format.endswith(".lT"):
3380
+ with open(file_format, "br") as f:
3381
+ tmp = pkl.load(f)
3382
+ f.close()
3383
+ self.__dict__.update(tmp.__dict__)
3384
+ elif file_format is not None:
3385
+ self.read_from_binary(file_format)
3386
+ if self.name is None:
3387
+ try:
3388
+ self.name = Path(file_format).stem
3389
+ except:
3390
+ self.name = Path(file_format[0]).stem
3391
+ if [] in self.successor.values():
3392
+ successors = list(self.successor.keys())
3393
+ for succ in successors:
3394
+ if self[succ] == []:
3395
+ self.successor.pop(succ)
3396
+ if [] in self.predecessor.values():
3397
+ predecessors = list(self.predecessor.keys())
3398
+ for succ in predecessors:
3399
+ if self[succ] == []:
3400
+ self.predecessor.pop(succ)