LineageTree 1.4.3__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,170 @@
1
+ import os
2
+ import pickle as pkl
3
+ from functools import partial
4
+
5
+ try:
6
+ from edist import uted
7
+ except ImportError:
8
+ warnings.warn(
9
+ "No edist installed therefore you will not be able to compute the tree edit distance."
10
+ )
11
+ from LineageTree import lineageTree
12
+
13
+ from .tree_styles import tree_style
14
+
15
+
16
+ class lineageTreeManager:
17
+ def __init__(self):
18
+ self.lineagetrees = {}
19
+ # self.classification = {"Wt": {}, "Ptb": {}}
20
+ self.lineageTree_counter = 0
21
+ self.registered = {}
22
+
23
+ def __next__(self):
24
+ self.lineageTree_counter += 1
25
+ return self.lineageTree_counter - 1
26
+
27
+ def add(
28
+ self, other_tree: lineageTree, name: str = "", classification: str = ""
29
+ ):
30
+ """Function that adds a new lineagetree object to the class.
31
+ Can be added either by .add or by using the + operator. If a name is
32
+ specified it will also add it as this specific name, otherwise it will
33
+ use the already existing name of the lineagetree.
34
+
35
+ Args:
36
+ other_tree (lineageTree): Thelineagetree to be added.
37
+ name (str, optional): Then name of. Defaults to "".
38
+
39
+ """
40
+ if isinstance(other_tree, lineageTree):
41
+ for tree in self.lineagetrees.values():
42
+ if tree == other_tree:
43
+ return False
44
+ if name:
45
+ self.lineagetrees[name] = other_tree
46
+ else:
47
+ if hasattr(other_tree, "name"):
48
+ name = other_tree.name
49
+ self.lineagetrees[name] = other_tree
50
+ else:
51
+ name = f"Lineagetree {next(self)}"
52
+ self.lineagetrees[name] = other_tree
53
+ self.lineagetrees[name].name = name
54
+ # try:
55
+ # name = other_tree.name
56
+ # self.lineagetrees[name] = other_tree
57
+ # except:
58
+ # self.lineagetrees[
59
+ # f"Lineagetree {next(self)}"
60
+ # ] = other_tree
61
+ # if classification in ("Wt", "Ptb"):
62
+ # self.classification[type] = {name: other_tree}
63
+
64
+ def __add__(self, other):
65
+ self.add(other)
66
+
67
+ # def classify_existing(self, key, classification: str):
68
+ # if classification in ("Wt", "Ptb"):
69
+ # self.classification[classification] = {key: self.lineagetrees[key]}
70
+ # else:
71
+ # return False
72
+
73
+ def write(self, fname: str):
74
+ """Saves the manager
75
+
76
+ Args:
77
+ fname (str): The path and name of the file that is to be saved.
78
+ """
79
+ if os.path.splitext(fname)[-1] != ".ltM":
80
+ fname = os.path.extsep.join((fname, "ltM"))
81
+ with open(fname, "bw") as f:
82
+ pkl.dump(self, f)
83
+ f.close()
84
+
85
+ def remove_embryo(self, key):
86
+ """Removes the embryo from the manager.
87
+
88
+ Args:
89
+ key (str): The name of the lineagetree to be removed
90
+
91
+ Raises:
92
+ Exception: If there is not such a lineagetree
93
+ """
94
+ self.lineagetrees.pop(key, None)
95
+
96
+ @classmethod
97
+ def load(cls, fname: str):
98
+ """
99
+ Loading a lineage tree Manager from a ".ltm" file.
100
+
101
+ Args:
102
+ fname (str): path to and name of the file to read
103
+
104
+ Returns:
105
+ (lineageTree): loaded file
106
+ """
107
+ with open(fname, "br") as f:
108
+ ltm = pkl.load(f)
109
+ f.close()
110
+ return ltm
111
+
112
+ def cross_lineage_edit_distance(
113
+ self,
114
+ n1: int,
115
+ embryo_1,
116
+ end_time1: int,
117
+ n2: int,
118
+ embryo_2,
119
+ end_time2: int,
120
+ style="fragmented",
121
+ node_lengths: tuple = (1, 5, 7),
122
+ registration=None,
123
+ ):
124
+ """Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
125
+ by two nodes `n1` from lineagetree1 and `n2` lineagetree2. The topology of the trees
126
+ are compared and the matching cost is given by the function delta (see edist doc for
127
+ more information).The distance is normed by the function norm that takes the two list
128
+ of nodes spawned by the trees `n1` and `n2`.
129
+
130
+ Args:
131
+ n1 (int): Node of the first Lineagetree
132
+ embryo_1 (str): The key/name of the first Lineagetree
133
+ end_time1 (int): End time of first Lineagetree
134
+ n2 (int): The key/name of the first Lineagetree
135
+ embryo_2 (str): Node of the second Lineagetree
136
+ end_time2 (int): End time of second lineagetree
137
+ registration (_type_, optional): _description_. Defaults to None.
138
+ """
139
+
140
+ tree = tree_style[style].value
141
+ tree1 = tree(
142
+ lT=self.lineagetrees[embryo_1],
143
+ node_length=node_lengths,
144
+ end_time=end_time1,
145
+ root=n1,
146
+ )
147
+ tree2 = tree(
148
+ lT=self.lineagetrees[embryo_2],
149
+ node_length=node_lengths,
150
+ end_time=end_time2,
151
+ root=n2,
152
+ )
153
+ delta = tree1.delta
154
+ _, times1 = tree1.tree
155
+ _, times2 = tree2.tree
156
+ nodes1, adj1, corres1 = tree1.edist
157
+ nodes2, adj2, corres2 = tree2.edist
158
+ if len(nodes1) == len(nodes2) == 0:
159
+ return 0
160
+
161
+ delta_tmp = partial(
162
+ delta,
163
+ corres1=corres1,
164
+ times1=times1,
165
+ corres2=corres2,
166
+ times2=times2,
167
+ )
168
+ return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
169
+ tree1.get_norm(), tree2.get_norm()
170
+ )
@@ -0,0 +1,305 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+
4
+ import numpy as np
5
+
6
+ from LineageTree import lineageTree
7
+
8
+
9
+ class abstract_trees(ABC):
10
+ def __init__(
11
+ self, lT: lineageTree, root: int, node_length: int, end_time: int
12
+ ):
13
+ self.lT = lT
14
+ self.root = root
15
+ self.node_length = node_length
16
+ self.end_time = end_time
17
+ self.tree = self.get_tree()
18
+ self.edist = self._edist_format(self.tree[0])
19
+
20
+ @abstractmethod
21
+ def get_tree(self):
22
+ """
23
+ Get a tree version of the tree spawned by the node `r`
24
+
25
+ Args:
26
+ r (int): root of the tree to spawn
27
+ end_time (int): the last time point to consider
28
+ time_resolution (float): the time between two consecutive time points
29
+
30
+ Returns:
31
+ (dict) {m (int): [d1 (int), d2 (int)]}: an adjacency dictionnary
32
+ where the ids are the ids of the cells in the original tree
33
+ at their first time point (except for the cell `r` if it was
34
+ not the first time point).
35
+ (dict) {m (int): duration (float)}: life time duration of the cell `m`
36
+ """
37
+
38
+ @abstractmethod
39
+ def delta(self, x, y, corres1, corres2, times1, times2):
40
+ """The distance of two nodes inside a tree. Behaves like a staticmethod.
41
+ The corres1/2 and time1/2 should always be provided and will be handled accordingly by the specific
42
+ delta of each tree style.
43
+
44
+ Args:
45
+ x (int): The first node to compare, takes the names provided by the edist.
46
+ y (int): The second node to compare, takes the names provided by the edist
47
+ corres1 (dict): Correspondance between node1 and its name in the real tree.
48
+ corres2 (dict): Correspondance between node2 and its name in the real tree.
49
+ times1 (dict): The dictionary of the branch lengths of the tree that n1 is spawned from.
50
+ times2 (dict): The dictionary of the branch lengths of the tree that n2 is spawned from.
51
+
52
+ Returns:
53
+ (int|float): The diatance between these 2 nodes.
54
+ """
55
+ if x is None and y is None:
56
+ return 0
57
+ if x is None:
58
+ return times2[corres2[y]]
59
+ if y is None:
60
+ return times1[corres1[x]]
61
+ len_x = times1[corres1[x]]
62
+ len_y = times2[corres2[y]]
63
+ return np.abs(len_x - len_y)
64
+
65
+ @abstractmethod
66
+ def get_norm(self):
67
+ """
68
+ Returns the valid value for normalizing the edit distance.
69
+ Returns:
70
+ (int|float): The number of nodes of each tree according to each style.
71
+ """
72
+
73
+ def _edist_format(self, adj_dict: dict):
74
+ """Formating the custom tree style to the format needed by edist.
75
+ SHOULD NOT BE CHANGED.
76
+
77
+ Args:
78
+ adj_dict (dict): _description_
79
+
80
+ Returns:
81
+ _type_: _description_
82
+ """
83
+ inv_adj = {vi: k for k, v in adj_dict.items() for vi in v}
84
+ roots = set(adj_dict).difference(inv_adj)
85
+ nid2list = {}
86
+ list2nid = {}
87
+ nodes = []
88
+ adj_list = []
89
+ curr_id = 0
90
+ for r in roots:
91
+ to_do = [r]
92
+ while to_do:
93
+ curr = to_do.pop(0)
94
+ nid2list[curr] = curr_id
95
+ list2nid[curr_id] = curr
96
+ nodes.append(curr_id)
97
+ to_do = adj_dict.get(curr, []) + to_do
98
+ curr_id += 1
99
+ adj_list = [
100
+ [nid2list[d] for d in adj_dict.get(list2nid[_id], [])]
101
+ for _id in nodes
102
+ ]
103
+ return nodes, adj_list, list2nid
104
+
105
+
106
+ class mini_tree(abstract_trees):
107
+ """Each branch is converted to a node of length 1, it is useful for comparing synchronous developing cells, extremely fast.
108
+ Mainly used for testing.
109
+ """
110
+
111
+ def __init__(self, **kwargs):
112
+ super().__init__(**kwargs)
113
+
114
+ def get_tree(self):
115
+ if self.end_time is None:
116
+ self.end_time = self.lT.t_e
117
+ out_dict = {}
118
+ self.times = {}
119
+ to_do = [self.root]
120
+ while to_do:
121
+ current = to_do.pop()
122
+ cycle = np.array(self.lT.get_successors(current))
123
+ cycle_times = np.array([self.lT.time[c] for c in cycle])
124
+ cycle = cycle[cycle_times <= self.end_time]
125
+ if cycle.size:
126
+ _next = self.lT[cycle[-1]]
127
+ if len(_next) > 1:
128
+ out_dict[current] = _next
129
+ to_do.extend(_next)
130
+ else:
131
+ out_dict[current] = []
132
+ self.length = len(out_dict)
133
+ return out_dict, None
134
+
135
+ def get_norm(self):
136
+ return len(
137
+ self.lT.get_all_branches_of_node(self.root, end_time=self.end_time)
138
+ )
139
+
140
+ def _edist_format(self, adj_dict: dict):
141
+ return super()._edist_format(adj_dict)
142
+
143
+ def delta(self, x, y, corres1, corres2, times1, times2):
144
+ if x is None and y is None:
145
+ return 0
146
+ if x is None:
147
+ return 1
148
+ if y is None:
149
+ return 1
150
+ return 0
151
+
152
+
153
+ class simple_tree(abstract_trees):
154
+ """Each branch is converted to one node with length the same as the life cycle of the cell.
155
+ This method is fast, but imprecise, especialy for small trees (recommended height of the trees should be 100 at least).
156
+ Use with CAUTION.
157
+ """
158
+
159
+ def __init__(self, **kwargs):
160
+ super().__init__(**kwargs)
161
+
162
+ def get_tree(self):
163
+ if self.end_time is None:
164
+ self.end_time = self.lT.t_e
165
+ out_dict = {}
166
+ self.times = {}
167
+ to_do = [self.root]
168
+ while to_do:
169
+ current = to_do.pop()
170
+ cycle = np.array(self.lT.get_successors(current))
171
+ cycle_times = np.array([self.lT.time[c] for c in cycle])
172
+ cycle = cycle[cycle_times <= self.end_time]
173
+ if cycle.size:
174
+ _next = self.lT[cycle[-1]]
175
+ if 1 < len(_next) and self.lT.time[cycle[-1]] < self.end_time:
176
+ out_dict[current] = _next
177
+ to_do.extend(_next)
178
+ else:
179
+ out_dict[current] = []
180
+ self.times[current] = len(
181
+ cycle
182
+ ) # * time_resolution will be fixed when working on registered trees.
183
+ return out_dict, self.times
184
+
185
+ def delta(self, x, y, corres1, corres2, times1, times2):
186
+ return super().delta(x, y, corres1, corres2, times1, times2)
187
+
188
+ def get_norm(self):
189
+ return len(
190
+ self.lT.get_sub_tree(self.root, end_time=self.end_time)
191
+ )
192
+
193
+
194
+ class fragmented_tree(abstract_trees):
195
+ """Similar idea to simple tree, but tries to correct its flaws.
196
+ Instead of having branches with length == life cycle of cell,nodes of specific length are added on the
197
+ edges of the branch, providing both accuratr results and speed.
198
+ It's the recommended method for calculating edit distances on developing embryos.
199
+ """
200
+
201
+ def __init__(self, **kwargs):
202
+ super().__init__(**kwargs)
203
+
204
+ def get_tree(self):
205
+ if self.end_time is None:
206
+ self.end_time = self.lT.t_e
207
+ self.out_dict = {}
208
+ self.times = {}
209
+ to_do = [self.root]
210
+ if not isinstance(self.node_length, list):
211
+ self.node_length = list(self.node_length)
212
+ while to_do:
213
+ current = to_do.pop()
214
+ cycle = np.array(
215
+ self.lT.get_successors(current, end_time=self.end_time)
216
+ )
217
+ if cycle.size > 0:
218
+ cumul_sum_of_nodes = np.cumsum(self.node_length) * 2 + 1
219
+ max_number_fragments = len(
220
+ cumul_sum_of_nodes[cumul_sum_of_nodes < len(cycle)]
221
+ )
222
+ if max_number_fragments > 0:
223
+ current_node_lengths = self.node_length[
224
+ :max_number_fragments
225
+ ].copy()
226
+ length_middle_node = (
227
+ len(cycle) - sum(current_node_lengths) * 2
228
+ )
229
+ times_tmp = (
230
+ current_node_lengths
231
+ + [length_middle_node]
232
+ + current_node_lengths[::-1]
233
+ )
234
+ pos_all_nodes = np.concatenate(
235
+ [[0], np.cumsum(times_tmp[:-1])]
236
+ )
237
+ track = cycle[pos_all_nodes]
238
+ self.out_dict.update(
239
+ {k: [v] for k, v in zip(track[:-1], track[1:])}
240
+ )
241
+ self.times.update(zip(track, times_tmp))
242
+ else:
243
+ for i, cell in enumerate(cycle[:-1]):
244
+ self.out_dict[cell] = [cycle[i + 1]]
245
+ self.times[cell] = 1
246
+ current = cycle[-1]
247
+ _next = self.lT[current]
248
+ self.times[current] = 1
249
+ if _next and self.lT.time[_next[0]] <= self.end_time:
250
+ to_do.extend(_next)
251
+ self.out_dict[current] = _next
252
+ else:
253
+ self.out_dict[current] = []
254
+
255
+ return self.out_dict, self.times
256
+
257
+ def get_norm(self):
258
+ return len(
259
+ self.lT.get_sub_tree(self.root, end_time=self.end_time)
260
+ )
261
+
262
+ def delta(self, x, y, corres1, corres2, times1, times2):
263
+ return super().delta(x, y, corres1, corres2, times1, times2)
264
+
265
+
266
+ class full_tree(abstract_trees):
267
+ """No approximations the whole tree is used here. Perfect accuracy, but heavy on ram and speed.
268
+ Not recommended to use on napari.
269
+
270
+ """
271
+
272
+ def __init__(self, **kwargs):
273
+ super().__init__(**kwargs)
274
+
275
+ def get_tree(self) -> dict:
276
+ self.out_dict = {}
277
+ self.times = {}
278
+ to_do = [self.root]
279
+ while to_do:
280
+ current = to_do.pop()
281
+ _next = self.lT.successor.get(current, [])
282
+ if _next and self.lT.time[_next[0]] <= self.end_time:
283
+ self.out_dict[current] = _next
284
+ to_do.extend(_next)
285
+ else:
286
+ self.out_dict[current] = []
287
+ self.times[current] = 1
288
+ return self.out_dict, self.times
289
+
290
+ def get_norm(self):
291
+ return len(self.lT.get_sub_tree(self.root, end_time=self.end_time))
292
+
293
+ def delta(self, x, y, corres1, corres2, times1, times2):
294
+ return super().delta(x, y, corres1, corres2, times1, times2)
295
+
296
+
297
+ class tree_style(Enum):
298
+ mini = mini_tree
299
+ simple = simple_tree
300
+ fragmented = fragmented_tree
301
+ full = full_tree
302
+
303
+ @classmethod
304
+ def list_names(self):
305
+ return [style.name for style in self]
LineageTree/utils.py ADDED
@@ -0,0 +1,211 @@
1
+ import csv
2
+ import random
3
+ import warnings
4
+
5
+ import networkx as nx
6
+
7
+ from LineageTree import lineageTree
8
+
9
+ try:
10
+ import motile
11
+ except ImportError:
12
+ warnings.warn(
13
+ "No motile installed therefore you will not be able to produce links with motile."
14
+ )
15
+
16
+
17
+ def hierarchy_pos(
18
+ G,
19
+ a,
20
+ root=None,
21
+ width=2000.0,
22
+ vert_gap=0.5,
23
+ vert_loc=0,
24
+ xcenter=0,
25
+ ):
26
+ """
27
+ From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
28
+ Licensed under Creative Commons Attribution-Share Alike
29
+
30
+ If the graph is a tree this will return the positions to plot this in a
31
+ hierarchical layout.
32
+
33
+
34
+ #The graph represents the lifetimes of cells, so there is no new point for each timepoint.
35
+ #Each lifetime is represented by length.
36
+
37
+ G: the graph (must be a tree)
38
+
39
+ root: the root node of current branch
40
+ - if the tree is directed and this is not given,
41
+
42
+ root: the root node of current branch
43
+ - if the tree is directed and this is not given,
44
+ the root will be found and used
45
+ - if the tree is directed and this is given, then
46
+ - if the tree is directed and this is given, then
47
+ the positions will be just for the descendants of this node.
48
+ - if the tree is undirected and not given,
49
+ then a random choice will be used.
50
+
51
+
52
+ width: horizontal space allocated for this branch - avoids overlap with other branches
53
+
54
+
55
+ vert_gap: gap between levels of hierarchy
56
+
57
+
58
+ vert_loc: vertical location of root
59
+
60
+
61
+ xcenter: horizontal location of root
62
+ """
63
+ if not nx.is_tree(G):
64
+ raise TypeError(
65
+ "cannot use hierarchy_pos on a graph that is not a tree"
66
+ )
67
+
68
+ if root is None:
69
+ if isinstance(G, nx.DiGraph):
70
+ root = next(
71
+ iter(nx.topological_sort(G))
72
+ ) # allows back compatibility with nx version 1.11
73
+ else:
74
+ root = random.choice(list(G.nodes))
75
+
76
+ def lengths(cell):
77
+ succ = a.successor.get(cell, [])
78
+ if len(succ) < 2:
79
+ if list(G.neighbors(cell)) == []:
80
+ return 0
81
+ if list(G.neighbors(cell))[0] in a.get_cycle(cell):
82
+ return (
83
+ len(a.get_successors(cell))
84
+ - len(a.get_successors(list(G.neighbors(cell))[0]))
85
+ - 1
86
+ )
87
+ return len(a.get_successors(cell))
88
+ else:
89
+ return 0.7
90
+
91
+ def _hierarchy_pos(
92
+ G,
93
+ root,
94
+ width=2.0,
95
+ a=a,
96
+ vert_gap=0.5,
97
+ vert_loc=0,
98
+ xcenter=0.5,
99
+ pos=None,
100
+ parent=None,
101
+ ):
102
+ """
103
+ see hierarchy_pos docstring for most arguments
104
+
105
+ pos: a dict saying where all nodes go if they have been assigned
106
+ parent: parent of this branch. - only affects it if non-directed
107
+
108
+ """
109
+ if pos is None:
110
+ pos = {root: (xcenter, vert_loc)}
111
+ elif not a.predecessor.get(a.get_predecessors(root)[0]):
112
+ vert_loc = vert_loc - len(a.get_predecessors(root))
113
+ pos[root] = (xcenter, vert_loc)
114
+ else:
115
+ pos[root] = (xcenter, vert_loc)
116
+ children = list(G.neighbors(root))
117
+
118
+ if not isinstance(G, nx.DiGraph) and parent is not None:
119
+ children.remove(parent)
120
+ if len(children) != 0:
121
+ dx = width / len(children)
122
+ nextx = xcenter - width / 2 - dx / 2
123
+ for child in children:
124
+ nextx += dx
125
+ pos = _hierarchy_pos(
126
+ G,
127
+ child,
128
+ width=dx,
129
+ vert_gap=lengths(child),
130
+ vert_loc=vert_loc - vert_gap,
131
+ xcenter=nextx,
132
+ pos=pos,
133
+ a=a,
134
+ parent=root,
135
+ )
136
+ return pos
137
+
138
+ return _hierarchy_pos(G, root, width, a, vert_gap, vert_loc, xcenter)
139
+
140
+
141
+ def to_motile(
142
+ lT: lineageTree, crop: int = None, max_dist=200, max_skip_frames=1
143
+ ):
144
+ fmt = nx.DiGraph()
145
+ if not crop:
146
+ crop = lT.t_e
147
+ # time_nodes = [
148
+ for time in range(crop):
149
+ # time_nodes += lT.time_nodes[time]
150
+ # print(time_nodes)
151
+ for time_node in lT.time_nodes[time]:
152
+ fmt.add_node(
153
+ time_node,
154
+ t=lT.time[time_node],
155
+ pos=lT.pos[time_node],
156
+ score=1,
157
+ )
158
+ # for suc in lT.successor:
159
+ # fmt.add_edge(time_node, suc, **{"score":0})
160
+
161
+ motile.add_cand_edges(fmt, max_dist, max_skip_frames=max_skip_frames)
162
+
163
+ return fmt
164
+
165
+
166
+ def write_csv_from_lT_to_lineaja(
167
+ lT, path_to, start: int = 0, finish: int = 300
168
+ ):
169
+ csv_dict = {}
170
+ for time in range(start, finish):
171
+ for node in lT.time_nodes[time]:
172
+ csv_dict[node] = {"pos": lT.pos[node], "t": time}
173
+ with open(path_to, "w", newline="\n") as file:
174
+ fieldnames = [
175
+ "time",
176
+ "positions_x",
177
+ "positions_y",
178
+ "positions_z",
179
+ "id",
180
+ ]
181
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
182
+ writer.writeheader()
183
+ for node in csv_dict:
184
+ writer.writerow(
185
+ {
186
+ "time": csv_dict[node]["t"],
187
+ "positions_z": csv_dict[node]["pos"][0],
188
+ "positions_y": csv_dict[node]["pos"][1],
189
+ "positions_x": csv_dict[node]["pos"][2],
190
+ "id": node,
191
+ }
192
+ )
193
+
194
+
195
+ def postions_of_nx(lt, graphs):
196
+ """Calculates the positions of the Lineagetree to be plotted.
197
+
198
+ Args:
199
+ graphs (nx.Digraph): Graphs produced by export_nx_simple_graph
200
+
201
+ Returns:
202
+ pos (list): The positions of the nodes of the graphs for plotting
203
+ """
204
+ pos = {}
205
+ for i in range(len(graphs)):
206
+ pos[i] = hierarchy_pos(
207
+ graphs[i],
208
+ lt,
209
+ root=[n for n, d in graphs[i].in_degree() if d == 0][0],
210
+ )
211
+ return pos