LineageTree 1.8.0__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import pickle as pkl
3
5
  import warnings
6
+ from collections.abc import Callable, Generator, Iterable
4
7
  from functools import partial
8
+ from typing import TYPE_CHECKING, Literal
5
9
 
10
+ import matplotlib.colors as mcolors
6
11
  import numpy as np
12
+ from matplotlib import colormaps
13
+
14
+ from .tree_approximation import tree_style
7
15
 
8
16
  try:
9
17
  from edist import uted
@@ -12,51 +20,104 @@ except ImportError:
12
20
  "No edist installed therefore you will not be able to compute the tree edit distance.",
13
21
  stacklevel=2,
14
22
  )
23
+ import matplotlib.pyplot as plt
24
+ from edist import uted
25
+
15
26
  from LineageTree import lineageTree
27
+ from LineageTree.tree_approximation import TreeApproximationTemplate
28
+
29
+ from .utils import convert_style_to_number
16
30
 
17
- from .tree_styles import tree_style
31
+ if TYPE_CHECKING:
32
+ from edist.alignment import Alignment
18
33
 
19
34
 
20
35
  class lineageTreeManager:
21
- def __init__(self):
22
- self.lineagetrees = {}
23
- self.lineageTree_counter = 0
24
- self.registered = {}
36
+ norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
25
37
 
26
- def __next__(self):
38
+ def __init__(self, lineagetree_list: Iterable[lineageTree] = ()):
39
+ """Creates a lineageTreeManager
40
+ :TODO: write the docstring
41
+
42
+ Parameters
43
+ ----------
44
+ lineagetree_list: Iterable of lineageTree
45
+ List of lineage trees to be in the lineageTreeManager
46
+ """
47
+ self.lineagetrees: dict[str, lineageTree] = {}
48
+ self.lineageTree_counter: int = 0
49
+ self._comparisons: dict = {}
50
+ for lT in lineagetree_list:
51
+ self.add(lT)
52
+
53
+ def __next__(self) -> int:
27
54
  self.lineageTree_counter += 1
28
55
  return self.lineageTree_counter - 1
29
56
 
57
+ def __len__(self) -> int:
58
+ """Returns how many lineagetrees are in the manager.
59
+
60
+ Returns
61
+ -------
62
+ int
63
+ The number of trees inside the manager
64
+ """
65
+ return len(self.lineagetrees)
66
+
67
+ def __iter__(self) -> Generator[tuple[str, lineageTree]]:
68
+ yield from self.lineagetrees.items()
69
+
70
+ def __getitem__(self, key: str) -> lineageTree:
71
+ if key in self.lineagetrees:
72
+ return self.lineagetrees[key]
73
+ else:
74
+ raise KeyError(f"'{key}' not found in the manager")
75
+
30
76
  @property
31
- def gcd(self):
32
- if len(self.lineagetrees) >= 1:
77
+ def gcd(self) -> int:
78
+ """Calculates the greatesτ common divisor between all lineagetree resolutions in the manager.
79
+
80
+ Returns
81
+ -------
82
+ int
83
+ The overall greatest common divisor.
84
+ """
85
+ if len(self) > 1:
33
86
  all_time_res = [
34
87
  embryo._time_resolution
35
88
  for embryo in self.lineagetrees.values()
36
89
  ]
37
90
  return np.gcd.reduce(all_time_res)
91
+ elif len(self):
92
+ return 1
93
+ else:
94
+ raise ValueError(
95
+ "You cannot calculate the greatest common divisor of time resolutions with an empty manager."
96
+ )
38
97
 
39
- def add(
40
- self, other_tree: lineageTree, name: str = "", classification: str = ""
41
- ):
98
+ def add(self, other_tree: lineageTree, name: str = ""):
42
99
  """Function that adds a new lineagetree object to the class.
43
100
  Can be added either by .add or by using the + operator. If a name is
44
101
  specified it will also add it as this specific name, otherwise it will
45
102
  use the already existing name of the lineagetree.
46
103
 
47
- Args:
48
- other_tree (lineageTree): Thelineagetree to be added.
49
- name (str, optional): Then name of. Defaults to "".
50
-
104
+ Parameters
105
+ ----------
106
+ other_tree : LineageTree
107
+ Thelineagetree to be added.
108
+ name : str, default=""
109
+ Then name of the lineagetree to be added, defaults to ''.
110
+ (Usually lineageTrees have the name of the path they are read from,
111
+ so this is going to be the name most of the times.)
51
112
  """
52
- if isinstance(other_tree, lineageTree) and other_tree.time_resolution:
113
+ if isinstance(other_tree, lineageTree):
53
114
  for tree in self.lineagetrees.values():
54
115
  if tree == other_tree:
55
116
  return False
56
117
  if name:
57
118
  self.lineagetrees[name] = other_tree
58
119
  else:
59
- if hasattr(other_tree, "name"):
120
+ if other_tree.name:
60
121
  name = other_tree.name
61
122
  self.lineagetrees[name] = other_tree
62
123
  else:
@@ -68,101 +129,146 @@ class lineageTreeManager:
68
129
  "Please add a LineageTree object or add time resolution to the LineageTree added."
69
130
  )
70
131
 
71
- def __add__(self, other):
132
+ def __add__(self, other: lineageTree):
72
133
  self.add(other)
73
134
 
74
135
  def write(self, fname: str):
75
136
  """Saves the manager
76
137
 
77
- Args:
78
- fname (str): The path and name of the file that is to be saved.
138
+ Parameters
139
+ ----------
140
+ fname : str
141
+ The path and name of the file that is to be saved.
79
142
  """
80
- if os.path.splitext(fname)[-1] != ".ltM":
81
- fname = os.path.extsep.join((fname, "ltM"))
143
+ if os.path.splitext(fname)[-1].upper() != ".LTM":
144
+ fname = os.path.extsep.join((fname, "lTM"))
145
+ for _, lT in self:
146
+ if hasattr(lT, "_protected_predecessor"):
147
+ del lT._protected_predecessor
148
+ if hasattr(lT, "_protected_successor"):
149
+ del lT._protected_successor
150
+ if hasattr(lT, "_protected_time"):
151
+ del lT._protected_time
82
152
  with open(fname, "bw") as f:
83
153
  pkl.dump(self, f)
84
154
  f.close()
85
155
 
86
- def remove_embryo(self, key):
156
+ def remove_embryo(self, key: str):
87
157
  """Removes the embryo from the manager.
88
158
 
89
- Args:
90
- key (str): The name of the lineagetree to be removed
159
+ Parameters
160
+ ----------
161
+ key : str
162
+ The name of the lineagetree to be removed
91
163
 
92
- Raises:
93
- Exception: If there is not such a lineagetree
164
+ Raises
165
+ ------
166
+ IndexError
167
+ If there is no such lineagetree
94
168
  """
95
169
  self.lineagetrees.pop(key, None)
96
170
 
97
171
  @classmethod
98
- def load(cls, fname: str):
99
- """
100
- Loading a lineage tree Manager from a ".ltm" file.
172
+ def load(cls, fname: str) -> lineageTreeManager:
173
+ """Loading a lineage tree Manager from a ".ltm" file.
101
174
 
102
- Args:
103
- fname (str): path to and name of the file to read
175
+ Parameters
176
+ ----------
177
+ fname : str
178
+ path to and name of the file to read
104
179
 
105
- Returns:
106
- (lineageTree): loaded file
180
+ Returns
181
+ -------
182
+ lineageTreeManager
183
+ loaded file
107
184
  """
108
185
  with open(fname, "br") as f:
109
186
  ltm = pkl.load(f)
110
187
  f.close()
111
188
  return ltm
112
189
 
113
- def cross_lineage_edit_distance(
190
+ def __cross_lineage_edit_backtrace(
114
191
  self,
115
192
  n1: int,
116
- embryo_1,
117
- end_time1: int,
193
+ embryo_1: str,
118
194
  n2: int,
119
- embryo_2,
120
- end_time2: int,
121
- style="simple",
195
+ embryo_2: str,
196
+ end_time1: int | None = None,
197
+ end_time2: int | None = None,
198
+ style: (
199
+ Literal["simple", "normalized_simple", "full", "downsampled"]
200
+ | type[TreeApproximationTemplate]
201
+ ) = "simple",
202
+ norm: Literal["max", "sum", None] = "max",
122
203
  downsample: int = 2,
123
- registration=None, # will be added as a later feature
124
- ):
204
+ ) -> dict[
205
+ str,
206
+ Alignment
207
+ | tuple[TreeApproximationTemplate, TreeApproximationTemplate],
208
+ ]:
125
209
  """Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
126
210
  by two nodes `n1` from lineagetree1 and `n2` lineagetree2. The topology of the trees
127
211
  are compared and the matching cost is given by the function delta (see edist doc for
128
212
  more information).The distance is normed by the function norm that takes the two list
129
213
  of nodes spawned by the trees `n1` and `n2`.
130
214
 
131
- Args:
132
- n1 (int): Node of the first Lineagetree
133
- embryo_1 (str): The key/name of the first Lineagetree
134
- end_time1 (int): End time of first Lineagetree
135
- n2 (int): The key/name of the first Lineagetree
136
- embryo_2 (str): Node of the second Lineagetree
137
- end_time2 (int): End time of second lineagetree
138
- registration (_type_, optional): _description_. Defaults to None.
139
- """
215
+ Parameters
216
+ ----------
217
+ n1 : int
218
+ Node of the first Lineagetree
219
+ embryo_1 : str
220
+ The key/name of the first Lineagetree
221
+ n2 : int
222
+ The key/name of the first Lineagetree
223
+ embryo_2 : str
224
+ Node of the second Lineagetree
225
+ end_time1 : int, optional
226
+ The final time point the comparison algorithm will take into account for the first dataset.
227
+ If None or not provided all nodes will be taken into account.
228
+ end_time2 : int, optional
229
+ The final time point the comparison algorithm will take into account for the second dataset.
230
+ If None or not provided all nodes will be taken into account.
231
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
232
+ The approximation used to calculate the tree.
233
+ norm : {"max","sum", "None"}, default="max"
234
+ The normalization method used (Not important for this function)
235
+ downsample : int, default==2
236
+ The downsample factor for the downsampled tree approximation.
237
+ Used only when `style="downsampled"`.
140
238
 
141
- tree = tree_style[style].value
142
- lcm = (
143
- self.lineagetrees[embryo_1]._time_resolution
144
- * self.lineagetrees[embryo_2]._time_resolution
145
- ) / self.gcd
146
- if style == "downsampled":
147
- if downsample % (lcm / 10) != 0:
148
- raise Exception(
149
- f"Use a valid downsampling rate (multiple of {lcm/10})"
150
- )
151
- time_res = [
152
- downsample / self.lineagetrees[embryo_2].time_resolution,
153
- downsample / self.lineagetrees[embryo_1].time_resolution,
154
- ]
155
- elif style == "full":
156
- time_res = [
157
- lcm / 10 / self.lineagetrees[embryo_2].time_resolution,
158
- lcm / 10 / self.lineagetrees[embryo_1].time_resolution,
159
- ]
239
+ Returns
240
+ -------
241
+ dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
242
+ - 'alignment'
243
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
244
+ - 'trees'
245
+ A list of the two trees that have been mapped to each other.
246
+ """
247
+ if (
248
+ self[embryo_1].time_resolution <= 0
249
+ or self[embryo_2].time_resolution <= 0
250
+ ):
251
+ raise Warning("Resolution cannot be <=0 ")
252
+ parameters = (
253
+ (end_time1, end_time2),
254
+ convert_style_to_number(style, downsample),
255
+ )
256
+ n1_embryo, n2_embryo = sorted(
257
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
258
+ )
259
+ self._comparisons.setdefault(parameters, {})
260
+ if isinstance(style, str):
261
+ tree = tree_style[style].value
262
+ elif issubclass(style, TreeApproximationTemplate):
263
+ tree = style
160
264
  else:
161
- time_res = [
162
- self.lineagetrees[embryo_1]._time_resolution,
163
- self.lineagetrees[embryo_2]._time_resolution,
164
- ]
165
- time_res = [i / self.gcd for i in time_res]
265
+ raise Warning("Use a valid approximation.")
266
+ time_res = tree.handle_resolutions(
267
+ time_resolution1=self.lineagetrees[embryo_1]._time_resolution,
268
+ time_resolution2=self.lineagetrees[embryo_2]._time_resolution,
269
+ gcd=self.gcd,
270
+ downsample=downsample,
271
+ )
166
272
  tree1 = tree(
167
273
  lT=self.lineagetrees[embryo_1],
168
274
  downsample=downsample,
@@ -184,8 +290,11 @@ class lineageTreeManager:
184
290
  nodes1, adj1, corres1 = tree1.edist
185
291
  nodes2, adj2, corres2 = tree2.edist
186
292
  if len(nodes1) == len(nodes2) == 0:
187
- return 0
188
-
293
+ self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
294
+ "alignment": (),
295
+ "trees": (),
296
+ }
297
+ return self._comparisons[parameters][(n1_embryo, n2_embryo)]
189
298
  delta_tmp = partial(
190
299
  delta,
191
300
  corres1=corres1,
@@ -193,6 +302,585 @@ class lineageTreeManager:
193
302
  corres2=corres2,
194
303
  times2=times2,
195
304
  )
196
- return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
197
- tree1.get_norm(), tree2.get_norm()
305
+ btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
306
+
307
+ self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
308
+ "alignment": btrc,
309
+ "trees": (tree1, tree2),
310
+ }
311
+ return self._comparisons[parameters][(n1_embryo, n2_embryo)]
312
+
313
+ def __calculate_distance_of_sub_tree(
314
+ self,
315
+ node1: int,
316
+ lT1: lineageTree,
317
+ node2: int,
318
+ lT2: lineageTree,
319
+ alignment: Alignment,
320
+ corres1: dict,
321
+ corres2: dict,
322
+ delta_tmp: Callable,
323
+ norm: Callable,
324
+ norm1: int | float,
325
+ norm2: int | float,
326
+ ) -> float:
327
+ """Calculates the distance of the subtree of a node matched in a comparison.
328
+ DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
329
+
330
+ TODO ITS BOUND TO CHANGE
331
+
332
+ Parameters
333
+ ----------
334
+ node1 : int
335
+ The root of the first subtree
336
+ lT1 : lineageTree
337
+ The dataset the first lineage exists
338
+ node2 : int
339
+ The root of the first subtree
340
+ lT2 : lineageTree
341
+ The dataset the second lineage exists
342
+ alignment : Alignment
343
+ The alignment of the subtree
344
+ corres1 : dict
345
+ The correspndance dictionary of the first lineage
346
+ corres2 : dict
347
+ The correspondance dictionary of the second lineage
348
+ delta_tmp : Callable
349
+ The delta function for the comparisons
350
+ norm : Callable
351
+ How should the lineages be normalized
352
+ norm1 : int or float
353
+ The result of the normalization of the first tree
354
+ norm2 : int or float
355
+ The result of the normalization of the second tree
356
+
357
+ Returns
358
+ -------
359
+ float
360
+ The result of the comparison of the subtree
361
+ """
362
+ sub_tree_1 = set(lT1.get_subtree_nodes(node1))
363
+ sub_tree_2 = set(lT2.get_subtree_nodes(node2))
364
+ res = 0
365
+ for m in alignment:
366
+ if (
367
+ corres1.get(m._left, -1) in sub_tree_1
368
+ or corres2.get(m._right, -1) in sub_tree_2
369
+ ):
370
+ res += delta_tmp(
371
+ m._left if m._left != -1 else None,
372
+ m._right if m._right != -1 else None,
373
+ )
374
+ return res / norm([norm1, norm2])
375
+
376
+ def clear_comparisons(self):
377
+ self._comparisons.clear()
378
+
379
+ def cross_lineage_edit_distance(
380
+ self,
381
+ n1: int,
382
+ embryo_1: str,
383
+ n2: int,
384
+ embryo_2: str,
385
+ end_time1: int | None = None,
386
+ end_time2: int | None = None,
387
+ norm: Literal["max", "sum", None] = "max",
388
+ style: (
389
+ Literal["simple", "normalized_simple", "full", "downsampled"]
390
+ | type[TreeApproximationTemplate]
391
+ ) = "simple",
392
+ downsample: int = 2,
393
+ return_norms: bool = False,
394
+ ) -> float | tuple[float, tuple[float, float]]:
395
+ """
396
+ Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
397
+ by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
398
+ cost is given by the function delta (see edist doc for more information). There are
399
+ 5 styles available (tree approximations) and the user may add their own.
400
+
401
+ Parameters
402
+ ----------
403
+ n1 : int
404
+ id of the first node to compare
405
+ embryo_1 : str
406
+ the name of the first embryo to be used. (from lTm.lineagetrees.keys())
407
+ n2 : int
408
+ id of the second node to compare
409
+ embryo_2 : str
410
+ the name of the second embryo to be used. (from lTm.lineagetrees.keys())
411
+ end_time_1 : int, optional
412
+ the final time point the comparison algorithm will take into account for the first dataset.
413
+ If None or not provided all nodes will be taken into account.
414
+ end_time_2 : int, optional
415
+ the final time point the comparison algorithm will take into account for the second dataset.
416
+ If None or not provided all nodes will be taken into account.
417
+ norm : {"max", "sum"}, default="max"
418
+ The normalization method to use, defaults to 'max'.
419
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
420
+ Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
421
+ downsample : int, default=2
422
+ The downsample factor for the downsampled tree approximation.
423
+ Used only when `style="downsampled"`.
424
+ return_norms : bool
425
+ Decide if the norms will be returned explicitly (mainly used for the napari plugin)
426
+
427
+ Returns
428
+ -------
429
+ Alignment
430
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
431
+ tuple(tree,tree), optional
432
+ The two trees that have been mapped to each other.
433
+ Returned if `return_norms` is `True`
434
+ """
435
+
436
+ parameters = (
437
+ (end_time1, end_time2),
438
+ convert_style_to_number(style, downsample),
439
+ )
440
+ n1_embryo, n2_embryo = sorted(
441
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
198
442
  )
443
+ self._comparisons.setdefault(parameters, {})
444
+ if self._comparisons[parameters].get((n1, n2)):
445
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
446
+ else:
447
+ tmp = self.__cross_lineage_edit_backtrace(
448
+ n1,
449
+ embryo_1,
450
+ n2,
451
+ embryo_2,
452
+ end_time1,
453
+ end_time2,
454
+ style,
455
+ norm,
456
+ downsample,
457
+ )
458
+ if len(self._comparisons) > 100:
459
+ warnings.warn(
460
+ "More than 100 comparisons are saved, use clear_comparisons() to delete them.",
461
+ stacklevel=2,
462
+ )
463
+ btrc = tmp["alignment"]
464
+ tree1, tree2 = tmp["trees"]
465
+ _, times1 = tree1.tree
466
+ _, times2 = tree2.tree
467
+ (
468
+ nodes1,
469
+ adj1,
470
+ corres1,
471
+ ) = tree1.edist
472
+ (
473
+ nodes2,
474
+ adj2,
475
+ corres2,
476
+ ) = tree2.edist
477
+ if len(nodes1) == len(nodes2) == 0:
478
+ self._comparisons[hash(frozenset(parameters))] = {
479
+ "alignment": (),
480
+ "trees": (),
481
+ }
482
+ return self._comparisons[hash(frozenset(parameters))]
483
+ delta_tmp = partial(
484
+ tree1.delta,
485
+ corres1=corres1,
486
+ corres2=corres2,
487
+ times1=times1,
488
+ times2=times2,
489
+ )
490
+ if norm not in self.norm_dict:
491
+ raise ValueError(
492
+ "Select a viable normalization method (max, sum, None)"
493
+ )
494
+ cost = btrc.cost(nodes1, nodes2, delta_tmp)
495
+ norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
496
+ if return_norms:
497
+ return cost, norm_values
498
+ return cost / self.norm_dict[norm](norm_values)
499
+
500
+ def plot_tree_distance_graphs(
501
+ self,
502
+ n1: int,
503
+ embryo_1: str,
504
+ n2: int,
505
+ embryo_2: str,
506
+ end_time1: int | None = None,
507
+ end_time2: int | None = None,
508
+ norm: Literal["max", "sum", None] = "max",
509
+ style: (
510
+ Literal["simple", "normalized_simple", "full", "downsampled"]
511
+ | type[TreeApproximationTemplate]
512
+ ) = "simple",
513
+ downsample: int = 2,
514
+ colormap: str = "cool",
515
+ default_color: str = "black",
516
+ size: float = 10,
517
+ lw: float = 0.3,
518
+ ax: np.ndarray | None = None,
519
+ ) -> tuple[plt.figure, plt.Axes]:
520
+ """
521
+ Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
522
+
523
+ Parameters
524
+ ----------
525
+ n1 : int
526
+ id of the first node to compare
527
+ embryo_1 : str
528
+ the name of the first embryo
529
+ n2 : int
530
+ id of the second node to compare
531
+ embryo_2 : str
532
+ the name of the second embryo
533
+ end_time1 : int, optional
534
+ the final time point the comparison algorithm will take into account for the first dataset.
535
+ If None or not provided all nodes will be taken into account.
536
+ end_time2 : int, optional
537
+ the final time point the comparison algorithm will take into account for the second dataset.
538
+ If None or not provided all nodes will be taken into account.
539
+ norm : {"max", "sum"}, default="max"
540
+ The normalization method to use.
541
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
542
+ Which tree approximation is going to be used for the comparisons.
543
+ downsample : int, default=2
544
+ The downsample factor for the downsampled tree approximation.
545
+ Used only when `style="downsampled"`.
546
+ colormap : str, default="cool"
547
+ The colormap used for matched nodes, defaults to "cool"
548
+ default_color : str
549
+ The color of the unmatched nodes, defaults to "black"
550
+ size : float
551
+ The size of the nodes, defaults to 10
552
+ lw : float
553
+ The width of the edges, defaults to 0.3
554
+ ax : np.ndarray, optional
555
+ The axes used, if not provided another set of axes is produced, defaults to None
556
+
557
+ Returns
558
+ -------
559
+ plt.Figure
560
+ The matplotlib figure
561
+ plt.Axes
562
+ The matplotlib axes
563
+ """
564
+
565
+ parameters = (
566
+ (end_time1, end_time2),
567
+ convert_style_to_number(style, downsample),
568
+ )
569
+ n1_embryo, n2_embryo = sorted(
570
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
571
+ )
572
+ self._comparisons.setdefault(parameters, {})
573
+ if self._comparisons[parameters].get((n1, n2)):
574
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
575
+ else:
576
+ tmp = self.__cross_lineage_edit_backtrace(
577
+ n1,
578
+ embryo_1,
579
+ n2,
580
+ embryo_2,
581
+ end_time1,
582
+ end_time2,
583
+ style,
584
+ norm,
585
+ downsample,
586
+ )
587
+ btrc = tmp["alignment"]
588
+ tree1, tree2 = tmp["trees"]
589
+ _, times1 = tree1.tree
590
+ _, times2 = tree2.tree
591
+ (
592
+ *_,
593
+ corres1,
594
+ ) = tree1.edist
595
+ (
596
+ *_,
597
+ corres2,
598
+ ) = tree2.edist
599
+ delta_tmp = partial(
600
+ tree1.delta,
601
+ corres1=corres1,
602
+ corres2=corres2,
603
+ times1=times1,
604
+ times2=times2,
605
+ )
606
+ if norm not in self.norm_dict:
607
+ raise Warning(
608
+ "Select a viable normalization method (max, sum, None)"
609
+ )
610
+ matched_right = []
611
+ matched_left = []
612
+ colors1 = {}
613
+ colors2 = {}
614
+ if style not in ("full", "downsampled"):
615
+ for m in btrc:
616
+ if m._left != -1 and m._right != -1:
617
+ cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
618
+ if len(cyc1) > 1:
619
+ node_1, *_, l_node_1 = cyc1
620
+ matched_left.append(node_1)
621
+ matched_left.append(l_node_1)
622
+ elif len(cyc1) == 1:
623
+ node_1 = l_node_1 = cyc1.pop()
624
+ matched_left.append(node_1)
625
+
626
+ cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
627
+ if len(cyc2) > 1:
628
+ node_2, *_, l_node_2 = cyc2
629
+ matched_right.append(node_2)
630
+ matched_right.append(l_node_2)
631
+
632
+ elif len(cyc2) == 1:
633
+ node_2 = l_node_2 = cyc2.pop()
634
+ matched_right.append(node_2)
635
+
636
+ colors1[node_1] = self.__calculate_distance_of_sub_tree(
637
+ node_1,
638
+ tree1.lT,
639
+ node_2,
640
+ tree2.lT,
641
+ btrc,
642
+ corres1,
643
+ corres2,
644
+ delta_tmp,
645
+ self.norm_dict[norm],
646
+ tree1.get_norm(node_1),
647
+ tree2.get_norm(node_2),
648
+ )
649
+ colors2[node_2] = colors1[node_1]
650
+ colors1[l_node_1] = colors1[node_1]
651
+ colors2[l_node_2] = colors2[node_2]
652
+
653
+ else:
654
+ for m in btrc:
655
+ if m._left != -1 and m._right != -1:
656
+ node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0]
657
+ node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0]
658
+ if (
659
+ tree1.lT.get_chain_of_node(node_1)[0] == node_1
660
+ or tree2.lT.get_chain_of_node(node_2)[0] == node_2
661
+ and (node_1 not in colors1 or node_2 not in colors2)
662
+ ):
663
+ matched_left.append(node_1)
664
+ l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1]
665
+ matched_left.append(l_node_1)
666
+ matched_right.append(node_2)
667
+ l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1]
668
+ matched_right.append(l_node_2)
669
+ colors1[node_1] = (
670
+ self.__calculate_distance_of_sub_tree(
671
+ node_1,
672
+ tree1.lT,
673
+ node_2,
674
+ tree2.lT,
675
+ btrc,
676
+ corres1,
677
+ corres2,
678
+ delta_tmp,
679
+ self.norm_dict[norm],
680
+ tree1.get_norm(node_1),
681
+ tree2.get_norm(node_2),
682
+ )
683
+ )
684
+ colors2[node_2] = colors1[node_1]
685
+ colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = (
686
+ colors1[node_1]
687
+ )
688
+ colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = (
689
+ colors2[node_2]
690
+ )
691
+
692
+ if tree1.lT.get_chain_of_node(node_1)[-1] != node_1:
693
+ matched_left.append(
694
+ tree1.lT.get_chain_of_node(node_1)[-1]
695
+ )
696
+ if tree2.lT.get_chain_of_node(node_2)[-1] != node_2:
697
+ matched_right.append(
698
+ tree2.lT.get_chain_of_node(node_2)[-1]
699
+ )
700
+ if ax is None:
701
+ fig, ax = plt.subplots(nrows=1, ncols=2)
702
+ cmap = colormaps[colormap]
703
+ c_norm = mcolors.Normalize(0, 1)
704
+ colors1 = {c: cmap(c_norm(v)) for c, v in colors1.items()}
705
+ colors2 = {c: cmap(c_norm(v)) for c, v in colors2.items()}
706
+ tree1.lT.plot_subtree(
707
+ tree1.lT.get_ancestor_at_t(n1),
708
+ end_time=end_time1,
709
+ size=size,
710
+ color_of_nodes=colors1,
711
+ color_of_edges=colors1,
712
+ default_color=default_color,
713
+ lw=lw,
714
+ ax=ax[0],
715
+ )
716
+ tree2.lT.plot_subtree(
717
+ tree2.lT.get_ancestor_at_t(n2),
718
+ end_time=end_time2,
719
+ size=size,
720
+ color_of_nodes=colors2,
721
+ color_of_edges=colors2,
722
+ default_color=default_color,
723
+ lw=lw,
724
+ ax=ax[1],
725
+ )
726
+ return ax[0].get_figure(), ax
727
+
728
+ def labelled_mappings(
729
+ self,
730
+ n1: int,
731
+ embryo_1: str,
732
+ n2: int,
733
+ embryo_2: str,
734
+ end_time1: int | None = None,
735
+ end_time2: int | None = None,
736
+ norm: Literal["max", "sum", None] = "max",
737
+ style: (
738
+ Literal["simple", "normalized_simple", "full", "downsampled"]
739
+ | type[TreeApproximationTemplate]
740
+ ) = "simple",
741
+ downsample: int = 2,
742
+ ) -> dict[str, list[str]]:
743
+ """
744
+ Returns the labels or IDs of all the nodes in the subtrees compared.
745
+
746
+ Parameters
747
+ ----------
748
+ n1 : int
749
+ id of the first node to compare
750
+ embryo_1 : str
751
+ the name of the first lineage
752
+ n2 : int
753
+ id of the second node to compare
754
+ embryo_2: str
755
+ the name of the second lineage
756
+ end_time1 : int, optional
757
+ the final time point the comparison algorithm will take into account for the first dataset.
758
+ If None or not provided all nodes will be taken into account.
759
+ end_time2 : int, optional
760
+ the final time point the comparison algorithm will take into account for the first dataset.
761
+ If None or not provided all nodes will be taken into account.
762
+ norm : {"max", "sum"}, default="max"
763
+ The normalization method to use.
764
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
765
+ Which tree approximation is going to be used for the comparisons.
766
+ downsample : int, default=2
767
+ The downsample factor for the downsampled tree approximation.
768
+ Used only when `style="downsampled"`.
769
+
770
+ Returns
771
+ -------
772
+ dict mapping str to lists of str
773
+ - 'matched' The labels of the matched nodes of the alignment.
774
+ - 'unmatched' The labels of the unmatched nodes of the alginment.
775
+ """
776
+
777
+ parameters = (
778
+ (end_time1, end_time2),
779
+ convert_style_to_number(style, downsample),
780
+ )
781
+ n1_embryo, n2_embryo = sorted(
782
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
783
+ )
784
+ self._comparisons.setdefault(parameters, {})
785
+ if self._comparisons[parameters].get((n1, n2)):
786
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
787
+ else:
788
+ tmp = self.__cross_lineage_edit_backtrace(
789
+ n1,
790
+ embryo_1,
791
+ n2,
792
+ embryo_2,
793
+ end_time1,
794
+ end_time2,
795
+ style,
796
+ norm,
797
+ downsample,
798
+ )
799
+ btrc = tmp["alignment"]
800
+ tree1, tree2 = tmp["trees"]
801
+ _, times1 = tree1.tree
802
+ _, times2 = tree2.tree
803
+ (
804
+ *_,
805
+ corres1,
806
+ ) = tree1.edist
807
+ (
808
+ *_,
809
+ corres2,
810
+ ) = tree2.edist
811
+ if norm not in self.norm_dict:
812
+ raise Warning(
813
+ "Select a viable normalization method (max, sum, None)"
814
+ )
815
+ matched = []
816
+ unmatched = []
817
+ if style not in ("full", "downsampled"):
818
+ for m in btrc:
819
+ if m._left != -1 and m._right != -1:
820
+ cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
821
+ if len(cyc1) > 1:
822
+ node_1, *_ = cyc1
823
+ elif len(cyc1) == 1:
824
+ node_1 = cyc1.pop()
825
+
826
+ cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
827
+ if len(cyc2) > 1:
828
+ node_2, *_ = cyc2
829
+
830
+ elif len(cyc2) == 1:
831
+ node_2 = cyc2.pop()
832
+
833
+ matched.append(
834
+ (
835
+ tree1.lT.labels.get(node_1, node_1),
836
+ tree2.lT.labels.get(node_2, node_2),
837
+ )
838
+ )
839
+ else:
840
+ if m._left != -1:
841
+ tmp_node = tree1.lT.get_chain_of_node(
842
+ corres1.get(m._left, "-")
843
+ )[0]
844
+ node_1 = (
845
+ tree1.lT.labels.get(tmp_node, tmp_node),
846
+ tree1.lT.name,
847
+ )
848
+ else:
849
+ tmp_node = tree2.lT.get_chain_of_node(
850
+ corres2.get(m._right, "-")
851
+ )[0]
852
+ node_1 = (
853
+ tree2.lT.labels.get(tmp_node, tmp_node),
854
+ tree2.lT.name,
855
+ )
856
+ unmatched.append(node_1)
857
+ else:
858
+ for m in btrc:
859
+ if m._left != -1 and m._right != -1:
860
+ node_1 = corres1[m._left]
861
+ node_2 = corres2[m._right]
862
+ matched.append(
863
+ (
864
+ tree1.lT.labels.get(node_1, node_1),
865
+ tree2.lT.labels.get(node_2, node_2),
866
+ )
867
+ )
868
+ else:
869
+ if m._left != -1:
870
+ tmp_node = tree1.lT.get_chain_of_node(
871
+ corres1.get(m._left, "-")
872
+ )[0]
873
+ node_1 = (
874
+ tree1.lT.labels.get(tmp_node, tmp_node),
875
+ tree1.lT.name,
876
+ )
877
+ else:
878
+ tmp_node = tree2.lT.get_chain_of_node(
879
+ corres2.get(m._right, "-")
880
+ )[0]
881
+ node_1 = (
882
+ tree2.lT.labels.get(tmp_node, tmp_node),
883
+ tree2.lT.name,
884
+ )
885
+ unmatched.append(node_1)
886
+ return {"matched": matched, "unmatched": unmatched}