LineageTree 1.8.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import pickle as pkl
3
5
  import warnings
6
+ from collections.abc import Callable
4
7
  from functools import partial
8
+ from typing import TYPE_CHECKING, Literal
5
9
 
10
+ import matplotlib.colors as mcolors
6
11
  import numpy as np
12
+ from matplotlib import colormaps
13
+
14
+ from .tree_approximation import tree_style
7
15
 
8
16
  try:
9
17
  from edist import uted
@@ -12,51 +20,97 @@ except ImportError:
12
20
  "No edist installed therefore you will not be able to compute the tree edit distance.",
13
21
  stacklevel=2,
14
22
  )
23
+ import matplotlib.pyplot as plt
24
+ from edist import uted
25
+
15
26
  from LineageTree import lineageTree
27
+ from LineageTree.tree_approximation import TreeApproximationTemplate
28
+
29
+ from .utils import convert_style_to_number
16
30
 
17
- from .tree_styles import tree_style
31
+ if TYPE_CHECKING:
32
+ from edist.alignment import Alignment
18
33
 
19
34
 
20
35
  class lineageTreeManager:
36
+ norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
37
+
21
38
  def __init__(self):
22
39
  self.lineagetrees = {}
23
40
  self.lineageTree_counter = 0
24
41
  self.registered = {}
42
+ self._comparisons = {}
25
43
 
26
44
  def __next__(self):
27
45
  self.lineageTree_counter += 1
28
46
  return self.lineageTree_counter - 1
29
47
 
48
+ def __len__(self):
49
+ """Returns how many lineagetrees are in the manager.
50
+
51
+ Returns
52
+ -------
53
+ int
54
+ The number of trees inside the manager
55
+ """
56
+ return len(self.lineagetrees)
57
+
58
+ def __iter__(
59
+ self,
60
+ ):
61
+ yield from self.lineagetrees.items()
62
+
63
+ def __getitem__(self, key):
64
+ if key in self.lineagetrees:
65
+ return self.lineagetrees[key]
66
+ else:
67
+ raise KeyError(f"'{key}' not found in the manager")
68
+
30
69
  @property
31
- def gcd(self):
32
- if len(self.lineagetrees) >= 1:
70
+ def gcd(self) -> int:
71
+ """Calculates the greatesτ common divisor between all lineagetree resolutions in the manager.
72
+
73
+ Returns
74
+ -------
75
+ int
76
+ The overall greatest common divisor.
77
+ """
78
+ if len(self) > 1:
33
79
  all_time_res = [
34
80
  embryo._time_resolution
35
81
  for embryo in self.lineagetrees.values()
36
82
  ]
37
83
  return np.gcd.reduce(all_time_res)
84
+ elif len(self):
85
+ return 1
86
+ else:
87
+ raise ValueError(
88
+ "You cannot calculate the greatest common divisor of time resolutions with an empty manager."
89
+ )
38
90
 
39
- def add(
40
- self, other_tree: lineageTree, name: str = "", classification: str = ""
41
- ):
91
+ def add(self, other_tree: lineageTree, name: str = ""):
42
92
  """Function that adds a new lineagetree object to the class.
43
93
  Can be added either by .add or by using the + operator. If a name is
44
94
  specified it will also add it as this specific name, otherwise it will
45
95
  use the already existing name of the lineagetree.
46
96
 
47
- Args:
48
- other_tree (lineageTree): Thelineagetree to be added.
49
- name (str, optional): Then name of. Defaults to "".
50
-
97
+ Parameters
98
+ ----------
99
+ other_tree : LineageTree
100
+ Thelineagetree to be added.
101
+ name : str, default=""
102
+ Then name of the lineagetree to be added, defaults to ''.
103
+ (Usually lineageTrees have the name of the path they are read from,
104
+ so this is going to be the name most of the times.)
51
105
  """
52
- if isinstance(other_tree, lineageTree) and other_tree.time_resolution:
106
+ if isinstance(other_tree, lineageTree):
53
107
  for tree in self.lineagetrees.values():
54
108
  if tree == other_tree:
55
109
  return False
56
110
  if name:
57
111
  self.lineagetrees[name] = other_tree
58
112
  else:
59
- if hasattr(other_tree, "name"):
113
+ if other_tree.name:
60
114
  name = other_tree.name
61
115
  self.lineagetrees[name] = other_tree
62
116
  else:
@@ -74,8 +128,10 @@ class lineageTreeManager:
74
128
  def write(self, fname: str):
75
129
  """Saves the manager
76
130
 
77
- Args:
78
- fname (str): The path and name of the file that is to be saved.
131
+ Parameters
132
+ ----------
133
+ fname : str
134
+ The path and name of the file that is to be saved.
79
135
  """
80
136
  if os.path.splitext(fname)[-1] != ".ltM":
81
137
  fname = os.path.extsep.join((fname, "ltM"))
@@ -86,83 +142,119 @@ class lineageTreeManager:
86
142
  def remove_embryo(self, key):
87
143
  """Removes the embryo from the manager.
88
144
 
89
- Args:
90
- key (str): The name of the lineagetree to be removed
145
+ Parameters
146
+ ----------
147
+ key : str
148
+ The name of the lineagetree to be removed
91
149
 
92
- Raises:
93
- Exception: If there is not such a lineagetree
150
+ Raises
151
+ ------
152
+ IndexError
153
+ If there is no such lineagetree
94
154
  """
95
155
  self.lineagetrees.pop(key, None)
96
156
 
97
157
  @classmethod
98
- def load(cls, fname: str):
99
- """
100
- Loading a lineage tree Manager from a ".ltm" file.
158
+ def load(cls, fname: str) -> lineageTreeManager:
159
+ """Loading a lineage tree Manager from a ".ltm" file.
101
160
 
102
- Args:
103
- fname (str): path to and name of the file to read
161
+ Parameters
162
+ ----------
163
+ fname : str
164
+ path to and name of the file to read
104
165
 
105
- Returns:
106
- (lineageTree): loaded file
166
+ Returns
167
+ -------
168
+ lineageTreeManager
169
+ loaded file
107
170
  """
108
171
  with open(fname, "br") as f:
109
172
  ltm = pkl.load(f)
110
173
  f.close()
111
174
  return ltm
112
175
 
113
- def cross_lineage_edit_distance(
176
+ def __cross_lineage_edit_backtrace(
114
177
  self,
115
178
  n1: int,
116
- embryo_1,
117
- end_time1: int,
179
+ embryo_1: str,
118
180
  n2: int,
119
- embryo_2,
120
- end_time2: int,
121
- style="simple",
181
+ embryo_2: str,
182
+ end_time1: int | None = None,
183
+ end_time2: int | None = None,
184
+ style: (
185
+ Literal["simple", "normalized_simple", "full", "downsampled"]
186
+ | type[TreeApproximationTemplate]
187
+ ) = "simple",
188
+ norm: Literal["max", "sum", None] = "max",
122
189
  downsample: int = 2,
123
- registration=None, # will be added as a later feature
124
- ):
190
+ ) -> dict[
191
+ str,
192
+ Alignment
193
+ | tuple[TreeApproximationTemplate, TreeApproximationTemplate],
194
+ ]:
125
195
  """Compute the unordered tree edit distance from Zhang 1996 between the trees spawned
126
196
  by two nodes `n1` from lineagetree1 and `n2` lineagetree2. The topology of the trees
127
197
  are compared and the matching cost is given by the function delta (see edist doc for
128
198
  more information).The distance is normed by the function norm that takes the two list
129
199
  of nodes spawned by the trees `n1` and `n2`.
130
200
 
131
- Args:
132
- n1 (int): Node of the first Lineagetree
133
- embryo_1 (str): The key/name of the first Lineagetree
134
- end_time1 (int): End time of first Lineagetree
135
- n2 (int): The key/name of the first Lineagetree
136
- embryo_2 (str): Node of the second Lineagetree
137
- end_time2 (int): End time of second lineagetree
138
- registration (_type_, optional): _description_. Defaults to None.
139
- """
201
+ Parameters
202
+ ----------
203
+ n1 : int
204
+ Node of the first Lineagetree
205
+ embryo_1 : str
206
+ The key/name of the first Lineagetree
207
+ n2 : int
208
+ The key/name of the first Lineagetree
209
+ embryo_2 : str
210
+ Node of the second Lineagetree
211
+ end_time1 : int, optional
212
+ The final time point the comparison algorithm will take into account for the first dataset.
213
+ If None or not provided all nodes will be taken into account.
214
+ end_time2 : int, optional
215
+ The final time point the comparison algorithm will take into account for the second dataset.
216
+ If None or not provided all nodes will be taken into account.
217
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
218
+ The approximation used to calculate the tree.
219
+ norm : {"max","sum", "None"}, default="max"
220
+ The normalization method used (Not important for this function)
221
+ downsample : int, default==2
222
+ The downsample factor for the downsampled tree approximation.
223
+ Used only when `style="downsampled"`.
140
224
 
141
- tree = tree_style[style].value
142
- lcm = (
143
- self.lineagetrees[embryo_1]._time_resolution
144
- * self.lineagetrees[embryo_2]._time_resolution
145
- ) / self.gcd
146
- if style == "downsampled":
147
- if downsample % (lcm / 10) != 0:
148
- raise Exception(
149
- f"Use a valid downsampling rate (multiple of {lcm/10})"
150
- )
151
- time_res = [
152
- downsample / self.lineagetrees[embryo_2].time_resolution,
153
- downsample / self.lineagetrees[embryo_1].time_resolution,
154
- ]
155
- elif style == "full":
156
- time_res = [
157
- lcm / 10 / self.lineagetrees[embryo_2].time_resolution,
158
- lcm / 10 / self.lineagetrees[embryo_1].time_resolution,
159
- ]
225
+ Returns
226
+ -------
227
+ dict mapping str to Alignment or tuple of [TreeApproximationTemplate, TreeApproximationTemplate]
228
+ - 'alignment'
229
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
230
+ - 'trees'
231
+ A list of the two trees that have been mapped to each other.
232
+ """
233
+ if (
234
+ self[embryo_1].time_resolution <= 0
235
+ or self[embryo_2].time_resolution <= 0
236
+ ):
237
+ raise Warning("Resolution cannot be <=0 ")
238
+ parameters = (
239
+ (end_time1, end_time2),
240
+ convert_style_to_number(style, downsample),
241
+ )
242
+ n1_embryo, n2_embryo = sorted(
243
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
244
+ )
245
+ self._comparisons.setdefault(parameters, {})
246
+ if isinstance(style, str):
247
+ tree = tree_style[style].value
248
+ elif issubclass(style, TreeApproximationTemplate):
249
+ tree = style
160
250
  else:
161
- time_res = [
162
- self.lineagetrees[embryo_1]._time_resolution,
163
- self.lineagetrees[embryo_2]._time_resolution,
164
- ]
165
- time_res = [i / self.gcd for i in time_res]
251
+ raise Warning("Use a valid approximation.")
252
+ time_res = tree.handle_resolutions(
253
+ time_resolution1=self.lineagetrees[embryo_1]._time_resolution,
254
+ time_resolution2=self.lineagetrees[embryo_2]._time_resolution,
255
+ gcd=self.gcd,
256
+ downsample=downsample,
257
+ )
166
258
  tree1 = tree(
167
259
  lT=self.lineagetrees[embryo_1],
168
260
  downsample=downsample,
@@ -184,8 +276,11 @@ class lineageTreeManager:
184
276
  nodes1, adj1, corres1 = tree1.edist
185
277
  nodes2, adj2, corres2 = tree2.edist
186
278
  if len(nodes1) == len(nodes2) == 0:
187
- return 0
188
-
279
+ self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
280
+ "alignment": (),
281
+ "trees": (),
282
+ }
283
+ return self._comparisons[parameters][(n1_embryo, n2_embryo)]
189
284
  delta_tmp = partial(
190
285
  delta,
191
286
  corres1=corres1,
@@ -193,6 +288,590 @@ class lineageTreeManager:
193
288
  corres2=corres2,
194
289
  times2=times2,
195
290
  )
196
- return uted.uted(nodes1, adj1, nodes2, adj2, delta=delta_tmp) / max(
197
- tree1.get_norm(), tree2.get_norm()
291
+ btrc = uted.uted_backtrace(nodes1, adj1, nodes2, adj2, delta=delta_tmp)
292
+
293
+ self._comparisons[parameters][(n1_embryo, n2_embryo)] = {
294
+ "alignment": btrc,
295
+ "trees": (tree1, tree2),
296
+ }
297
+ return self._comparisons[parameters][(n1_embryo, n2_embryo)]
298
+
299
+ def __calculate_distance_of_sub_tree(
300
+ self,
301
+ node1: int,
302
+ lT1: lineageTree,
303
+ node2: int,
304
+ lT2: lineageTree,
305
+ alignment: Alignment,
306
+ corres1: dict,
307
+ corres2: dict,
308
+ delta_tmp: Callable,
309
+ norm: Callable,
310
+ norm1: int | float,
311
+ norm2: int | float,
312
+ ) -> float:
313
+ """Calculates the distance of the subtree of a node matched in a comparison.
314
+ DOES NOT CALCULATE THE DISTANCE FROM SCRATCH BUT USING THE ALIGNMENT.
315
+
316
+ TODO ITS BOUND TO CHANGE
317
+
318
+ Parameters
319
+ ----------
320
+ node1 : int
321
+ The root of the first subtree
322
+ lT1 : lineageTree
323
+ The dataset the first lineage exists
324
+ node2 : int
325
+ The root of the first subtree
326
+ lT2 : lineageTree
327
+ The dataset the second lineage exists
328
+ alignment : Alignment
329
+ The alignment of the subtree
330
+ corres1 : dict
331
+ The correspndance dictionary of the first lineage
332
+ corres2 : dict
333
+ The correspondance dictionary of the second lineage
334
+ delta_tmp : Callable
335
+ The delta function for the comparisons
336
+ norm : Callable
337
+ How should the lineages be normalized
338
+ norm1 : int or float
339
+ The result of the normalization of the first tree
340
+ norm2 : int or float
341
+ The result of the normalization of the second tree
342
+
343
+ Returns
344
+ -------
345
+ float
346
+ The result of the comparison of the subtree
347
+ """
348
+ sub_tree_1 = set(lT1.get_subtree_nodes(node1))
349
+ sub_tree_2 = set(lT2.get_subtree_nodes(node2))
350
+ res = 0
351
+ for m in alignment:
352
+ if (
353
+ corres1.get(m._left, -1) in sub_tree_1
354
+ or corres2.get(m._right, -1) in sub_tree_2
355
+ ):
356
+ res += delta_tmp(
357
+ m._left if m._left != -1 else None,
358
+ m._right if m._right != -1 else None,
359
+ )
360
+ return res / norm([norm1, norm2])
361
+
362
+ def clear_comparisons(self):
363
+ self._comparisons.clear()
364
+
365
+ def cross_lineage_edit_distance(
366
+ self,
367
+ n1: int,
368
+ embryo_1: str,
369
+ n2: int,
370
+ embryo_2: str,
371
+ end_time1: int | None = None,
372
+ end_time2: int | None = None,
373
+ norm: Literal["max", "sum", None] = "max",
374
+ style: (
375
+ Literal["simple", "normalized_simple", "full", "downsampled"]
376
+ | type[TreeApproximationTemplate]
377
+ ) = "simple",
378
+ downsample: int = 2,
379
+ return_norms: bool = False,
380
+ ) -> float | tuple[float, tuple[float, float]]:
381
+ """
382
+ Compute the unordered tree edit backtrace from Zhang 1996 between the trees spawned
383
+ by two nodes `n1` and `n2`. The topology of the trees are compared and the matching
384
+ cost is given by the function delta (see edist doc for more information). There are
385
+ 5 styles available (tree approximations) and the user may add their own.
386
+
387
+ Parameters
388
+ ----------
389
+ n1 : int
390
+ id of the first node to compare
391
+ embryo_1 : str
392
+ the name of the first embryo to be used. (from lTm.lineagetrees.keys())
393
+ n2 : int
394
+ id of the second node to compare
395
+ embryo_2 : str
396
+ the name of the second embryo to be used. (from lTm.lineagetrees.keys())
397
+ end_time_1 : int, optional
398
+ the final time point the comparison algorithm will take into account for the first dataset.
399
+ If None or not provided all nodes will be taken into account.
400
+ end_time_2 : int, optional
401
+ the final time point the comparison algorithm will take into account for the second dataset.
402
+ If None or not provided all nodes will be taken into account.
403
+ norm : {"max", "sum"}, default="max"
404
+ The normalization method to use, defaults to 'max'.
405
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
406
+ Which tree approximation is going to be used for the comparisons, defaults to 'simple'.
407
+ downsample : int, default=2
408
+ The downsample factor for the downsampled tree approximation.
409
+ Used only when `style="downsampled"`.
410
+ return_norms : bool
411
+ Decide if the norms will be returned explicitly (mainly used for the napari plugin)
412
+
413
+ Returns
414
+ -------
415
+ Alignment
416
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
417
+ --
418
+ ΟΡ
419
+ --
420
+
421
+ Alignment
422
+ The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
423
+ tuple(tree,tree)
424
+ The two trees that have been mapped to each other.
425
+ """
426
+
427
+ parameters = (
428
+ (end_time1, end_time2),
429
+ convert_style_to_number(style, downsample),
430
+ )
431
+ n1_embryo, n2_embryo = sorted(
432
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
433
+ )
434
+ self._comparisons.setdefault(parameters, {})
435
+ if self._comparisons[parameters].get((n1, n2)):
436
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
437
+ else:
438
+ tmp = self.__cross_lineage_edit_backtrace(
439
+ n1,
440
+ embryo_1,
441
+ n2,
442
+ embryo_2,
443
+ end_time1,
444
+ end_time2,
445
+ style,
446
+ norm,
447
+ downsample,
448
+ )
449
+ if len(self._comparisons) > 100:
450
+ warnings.warn(
451
+ "More than 100 comparisons are saved, use clear_comparisons() to delete them.",
452
+ stacklevel=2,
453
+ )
454
+ btrc = tmp["alignment"]
455
+ tree1, tree2 = tmp["trees"]
456
+ _, times1 = tree1.tree
457
+ _, times2 = tree2.tree
458
+ (
459
+ nodes1,
460
+ adj1,
461
+ corres1,
462
+ ) = tree1.edist
463
+ (
464
+ nodes2,
465
+ adj2,
466
+ corres2,
467
+ ) = tree2.edist
468
+ if len(nodes1) == len(nodes2) == 0:
469
+ self._comparisons[hash(frozenset(parameters))] = {
470
+ "alignment": (),
471
+ "trees": (),
472
+ }
473
+ return self._comparisons[hash(frozenset(parameters))]
474
+ delta_tmp = partial(
475
+ tree1.delta,
476
+ corres1=corres1,
477
+ corres2=corres2,
478
+ times1=times1,
479
+ times2=times2,
480
+ )
481
+ if norm not in self.norm_dict:
482
+ raise ValueError(
483
+ "Select a viable normalization method (max, sum, None)"
484
+ )
485
+ cost = btrc.cost(nodes1, nodes2, delta_tmp)
486
+ norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
487
+ if return_norms:
488
+ return cost, norm_values
489
+ return cost / self.norm_dict[norm](norm_values)
490
+
491
+ def plot_tree_distance_graphs(
492
+ self,
493
+ n1: int,
494
+ embryo_1: str,
495
+ n2: int,
496
+ embryo_2: str,
497
+ end_time1: int | None = None,
498
+ end_time2: int | None = None,
499
+ norm: Literal["max", "sum", None] = "max",
500
+ style: (
501
+ Literal["simple", "normalized_simple", "full", "downsampled"]
502
+ | type[TreeApproximationTemplate]
503
+ ) = "simple",
504
+ downsample: int = 2,
505
+ colormap: str = "cool",
506
+ default_color: str = "black",
507
+ size: float = 10,
508
+ lw: float = 0.3,
509
+ ax: np.ndarray | None = None,
510
+ ) -> tuple[plt.figure, plt.Axes]:
511
+ """
512
+ Plots the subtrees compared and colors them according to the quality of the matching of their subtree.
513
+
514
+ Parameters
515
+ ----------
516
+ n1 : int
517
+ id of the first node to compare
518
+ embryo_1 : str
519
+ the name of the first embryo
520
+ n2 : int
521
+ id of the second node to compare
522
+ embryo_2 : str
523
+ the name of the second embryo
524
+ end_time1 : int, optional
525
+ the final time point the comparison algorithm will take into account for the first dataset.
526
+ If None or not provided all nodes will be taken into account.
527
+ end_time2 : int, optional
528
+ the final time point the comparison algorithm will take into account for the second dataset.
529
+ If None or not provided all nodes will be taken into account.
530
+ norm : {"max", "sum"}, default="max"
531
+ The normalization method to use.
532
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
533
+ Which tree approximation is going to be used for the comparisons.
534
+ downsample : int, default=2
535
+ The downsample factor for the downsampled tree approximation.
536
+ Used only when `style="downsampled"`.
537
+ colormap : str, default="cool"
538
+ The colormap used for matched nodes, defaults to "cool"
539
+ default_color : str
540
+ The color of the unmatched nodes, defaults to "black"
541
+ size : float
542
+ The size of the nodes, defaults to 10
543
+ lw : float
544
+ The width of the edges, defaults to 0.3
545
+ ax : np.ndarray, optional
546
+ The axes used, if not provided another set of axes is produced, defaults to None
547
+
548
+ Returns
549
+ -------
550
+ plt.Figure
551
+ The matplotlib figure
552
+ plt.Axes
553
+ The matplotlib axes
554
+ """
555
+
556
+ parameters = (
557
+ (end_time1, end_time2),
558
+ convert_style_to_number(style, downsample),
559
+ )
560
+ n1_embryo, n2_embryo = sorted(
561
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
562
+ )
563
+ self._comparisons.setdefault(parameters, {})
564
+ if self._comparisons[parameters].get((n1, n2)):
565
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
566
+ else:
567
+ tmp = self.__cross_lineage_edit_backtrace(
568
+ n1,
569
+ embryo_1,
570
+ n2,
571
+ embryo_2,
572
+ end_time1,
573
+ end_time2,
574
+ style,
575
+ norm,
576
+ downsample,
577
+ )
578
+ btrc = tmp["alignment"]
579
+ tree1, tree2 = tmp["trees"]
580
+ _, times1 = tree1.tree
581
+ _, times2 = tree2.tree
582
+ (
583
+ *_,
584
+ corres1,
585
+ ) = tree1.edist
586
+ (
587
+ *_,
588
+ corres2,
589
+ ) = tree2.edist
590
+ delta_tmp = partial(
591
+ tree1.delta,
592
+ corres1=corres1,
593
+ corres2=corres2,
594
+ times1=times1,
595
+ times2=times2,
596
+ )
597
+ if norm not in self.norm_dict:
598
+ raise Warning(
599
+ "Select a viable normalization method (max, sum, None)"
600
+ )
601
+ matched_right = []
602
+ matched_left = []
603
+ colors1 = {}
604
+ colors2 = {}
605
+ if style not in ("full", "downsampled"):
606
+ for m in btrc:
607
+ if m._left != -1 and m._right != -1:
608
+ cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
609
+ if len(cyc1) > 1:
610
+ node_1, *_, l_node_1 = cyc1
611
+ matched_left.append(node_1)
612
+ matched_left.append(l_node_1)
613
+ elif len(cyc1) == 1:
614
+ node_1 = l_node_1 = cyc1.pop()
615
+ matched_left.append(node_1)
616
+
617
+ cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
618
+ if len(cyc2) > 1:
619
+ node_2, *_, l_node_2 = cyc2
620
+ matched_right.append(node_2)
621
+ matched_right.append(l_node_2)
622
+
623
+ elif len(cyc2) == 1:
624
+ node_2 = l_node_2 = cyc2.pop()
625
+ matched_right.append(node_2)
626
+
627
+ colors1[node_1] = self.__calculate_distance_of_sub_tree(
628
+ node_1,
629
+ tree1.lT,
630
+ node_2,
631
+ tree2.lT,
632
+ btrc,
633
+ corres1,
634
+ corres2,
635
+ delta_tmp,
636
+ self.norm_dict[norm],
637
+ tree1.get_norm(node_1),
638
+ tree2.get_norm(node_2),
639
+ )
640
+ colors2[node_2] = colors1[node_1]
641
+ colors1[l_node_1] = colors1[node_1]
642
+ colors2[l_node_2] = colors2[node_2]
643
+
644
+ else:
645
+ for m in btrc:
646
+ if m._left != -1 and m._right != -1:
647
+ node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0]
648
+ node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0]
649
+ if (
650
+ tree1.lT.get_chain_of_node(node_1)[0] == node_1
651
+ or tree2.lT.get_chain_of_node(node_2)[0] == node_2
652
+ and (node_1 not in colors1 or node_2 not in colors2)
653
+ ):
654
+ matched_left.append(node_1)
655
+ l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1]
656
+ matched_left.append(l_node_1)
657
+ matched_right.append(node_2)
658
+ l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1]
659
+ matched_right.append(l_node_2)
660
+ colors1[node_1] = (
661
+ self.__calculate_distance_of_sub_tree(
662
+ node_1,
663
+ tree1.lT,
664
+ node_2,
665
+ tree2.lT,
666
+ btrc,
667
+ corres1,
668
+ corres2,
669
+ delta_tmp,
670
+ self.norm_dict[norm],
671
+ tree1.get_norm(node_1),
672
+ tree2.get_norm(node_2),
673
+ )
674
+ )
675
+ colors2[node_2] = colors1[node_1]
676
+ colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = (
677
+ colors1[node_1]
678
+ )
679
+ colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = (
680
+ colors2[node_2]
681
+ )
682
+
683
+ if tree1.lT.get_chain_of_node(node_1)[-1] != node_1:
684
+ matched_left.append(
685
+ tree1.lT.get_chain_of_node(node_1)[-1]
686
+ )
687
+ if tree2.lT.get_chain_of_node(node_2)[-1] != node_2:
688
+ matched_right.append(
689
+ tree2.lT.get_chain_of_node(node_2)[-1]
690
+ )
691
+ if ax is None:
692
+ fig, ax = plt.subplots(nrows=1, ncols=2)
693
+ cmap = colormaps[colormap]
694
+ c_norm = mcolors.Normalize(0, 1)
695
+ colors1 = {c: cmap(c_norm(v)) for c, v in colors1.items()}
696
+ colors2 = {c: cmap(c_norm(v)) for c, v in colors2.items()}
697
+ tree1.lT.plot_subtree(
698
+ tree1.lT.get_ancestor_at_t(n1),
699
+ end_time=end_time1,
700
+ size=size,
701
+ color_of_nodes=colors1,
702
+ color_of_edges=colors1,
703
+ default_color=default_color,
704
+ lw=lw,
705
+ ax=ax[0],
706
+ )
707
+ tree2.lT.plot_subtree(
708
+ tree2.lT.get_ancestor_at_t(n2),
709
+ end_time=end_time2,
710
+ size=size,
711
+ color_of_nodes=colors2,
712
+ color_of_edges=colors2,
713
+ default_color=default_color,
714
+ lw=lw,
715
+ ax=ax[1],
198
716
  )
717
+ return ax[0].get_figure(), ax
718
+
719
+ def labelled_mappings(
720
+ self,
721
+ n1: int,
722
+ embryo_1: str,
723
+ n2: int,
724
+ embryo_2: str,
725
+ end_time1: int | None = None,
726
+ end_time2: int | None = None,
727
+ norm: Literal["max", "sum", None] = "max",
728
+ style: (
729
+ Literal["simple", "normalized_simple", "full", "downsampled"]
730
+ | type[TreeApproximationTemplate]
731
+ ) = "simple",
732
+ downsample: int = 2,
733
+ ) -> dict[str, list[str]]:
734
+ """
735
+ Returns the labels or IDs of all the nodes in the subtrees compared.
736
+
737
+ Parameters
738
+ ----------
739
+ n1 : int
740
+ id of the first node to compare
741
+ embryo_1 : str
742
+ the name of the first lineage
743
+ n2 : int
744
+ id of the second node to compare
745
+ embryo_2: str
746
+ the name of the second lineage
747
+ end_time1 : int, optional
748
+ the final time point the comparison algorithm will take into account for the first dataset.
749
+ If None or not provided all nodes will be taken into account.
750
+ end_time2 : int, optional
751
+ the final time point the comparison algorithm will take into account for the first dataset.
752
+ If None or not provided all nodes will be taken into account.
753
+ norm : {"max", "sum"}, default="max"
754
+ The normalization method to use.
755
+ style : {"simple", "normalized_simple", "full", "downsampled"} or TreeApproximationTemplate subclass, default="simple"
756
+ Which tree approximation is going to be used for the comparisons.
757
+ downsample : int, default=2
758
+ The downsample factor for the downsampled tree approximation.
759
+ Used only when `style="downsampled"`.
760
+
761
+ Returns
762
+ -------
763
+ dict mapping str to lists of str
764
+ - 'matched' The labels of the matched nodes of the alignment.
765
+ - 'unmatched' The labels of the unmatched nodes of the alginment.
766
+ """
767
+
768
+ parameters = (
769
+ (end_time1, end_time2),
770
+ convert_style_to_number(style, downsample),
771
+ )
772
+ n1_embryo, n2_embryo = sorted(
773
+ ((n1, embryo_1), (n2, embryo_2)), key=lambda x: x[0]
774
+ )
775
+ self._comparisons.setdefault(parameters, {})
776
+ if self._comparisons[parameters].get((n1, n2)):
777
+ tmp = self._comparisons[parameters][(n1_embryo, n2_embryo)]
778
+ else:
779
+ tmp = self.__cross_lineage_edit_backtrace(
780
+ n1,
781
+ embryo_1,
782
+ n2,
783
+ embryo_2,
784
+ end_time1,
785
+ end_time2,
786
+ style,
787
+ norm,
788
+ downsample,
789
+ )
790
+ btrc = tmp["alignment"]
791
+ tree1, tree2 = tmp["trees"]
792
+ _, times1 = tree1.tree
793
+ _, times2 = tree2.tree
794
+ (
795
+ *_,
796
+ corres1,
797
+ ) = tree1.edist
798
+ (
799
+ *_,
800
+ corres2,
801
+ ) = tree2.edist
802
+ if norm not in self.norm_dict:
803
+ raise Warning(
804
+ "Select a viable normalization method (max, sum, None)"
805
+ )
806
+ matched = []
807
+ unmatched = []
808
+ if style not in ("full", "downsampled"):
809
+ for m in btrc:
810
+ if m._left != -1 and m._right != -1:
811
+ cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
812
+ if len(cyc1) > 1:
813
+ node_1, *_ = cyc1
814
+ elif len(cyc1) == 1:
815
+ node_1 = cyc1.pop()
816
+
817
+ cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
818
+ if len(cyc2) > 1:
819
+ node_2, *_ = cyc2
820
+
821
+ elif len(cyc2) == 1:
822
+ node_2 = cyc2.pop()
823
+
824
+ matched.append(
825
+ (
826
+ tree1.lT.labels.get(node_1, node_1),
827
+ tree2.lT.labels.get(node_2, node_2),
828
+ )
829
+ )
830
+ else:
831
+ if m._left != -1:
832
+ tmp_node = tree1.lT.get_chain_of_node(
833
+ corres1.get(m._left, "-")
834
+ )[0]
835
+ node_1 = (
836
+ tree1.lT.labels.get(tmp_node, tmp_node),
837
+ tree1.lT.name,
838
+ )
839
+ else:
840
+ tmp_node = tree2.lT.get_chain_of_node(
841
+ corres2.get(m._right, "-")
842
+ )[0]
843
+ node_1 = (
844
+ tree2.lT.labels.get(tmp_node, tmp_node),
845
+ tree2.lT.name,
846
+ )
847
+ unmatched.append(node_1)
848
+ else:
849
+ for m in btrc:
850
+ if m._left != -1 and m._right != -1:
851
+ node_1 = corres1[m._left]
852
+ node_2 = corres2[m._right]
853
+ matched.append(
854
+ (
855
+ tree1.lT.labels.get(node_1, node_1),
856
+ tree2.lT.labels.get(node_2, node_2),
857
+ )
858
+ )
859
+ else:
860
+ if m._left != -1:
861
+ tmp_node = tree1.lT.get_chain_of_node(
862
+ corres1.get(m._left, "-")
863
+ )[0]
864
+ node_1 = (
865
+ tree1.lT.labels.get(tmp_node, tmp_node),
866
+ tree1.lT.name,
867
+ )
868
+ else:
869
+ tmp_node = tree2.lT.get_chain_of_node(
870
+ corres2.get(m._right, "-")
871
+ )[0]
872
+ node_1 = (
873
+ tree2.lT.labels.get(tmp_node, tmp_node),
874
+ tree2.lT.name,
875
+ )
876
+ unmatched.append(node_1)
877
+ return {"matched": matched, "unmatched": unmatched}