LineageTree 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,10 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
- import csv
6
5
  import os
7
6
  import pickle as pkl
8
7
  import struct
9
8
  import warnings
10
- import xml.etree.ElementTree as ET
11
9
  from collections.abc import Iterable
12
10
  from functools import partial
13
11
  from itertools import combinations
@@ -15,6 +13,7 @@ from numbers import Number
15
13
  from pathlib import Path
16
14
  from typing import TextIO, Union
17
15
 
16
+ from .loaders import lineageTreeLoaders
18
17
  from .tree_styles import tree_style
19
18
 
20
19
  try:
@@ -24,16 +23,18 @@ except ImportError:
24
23
  "No edist installed therefore you will not be able to compute the tree edit distance."
25
24
  )
26
25
  import matplotlib.pyplot as plt
27
- import networkx as nx
28
26
  import numpy as np
29
27
  from scipy.interpolate import InterpolatedUnivariateSpline
30
28
  from scipy.spatial import Delaunay, distance
31
29
  from scipy.spatial import cKDTree as KDTree
32
30
 
33
- from .utils import hierarchy_pos, postions_of_nx
31
+ from .utils import (
32
+ create_links_and_cycles,
33
+ hierarchical_pos,
34
+ )
34
35
 
35
36
 
36
- class lineageTree:
37
+ class lineageTree(lineageTreeLoaders):
37
38
  def __eq__(self, other):
38
39
  if isinstance(other, lineageTree):
39
40
  return other.successor == self.successor
@@ -381,7 +382,18 @@ class lineageTree:
381
382
  @property
382
383
  def labels(self):
383
384
  if not hasattr(self, "_labels"):
384
- self._labels = {i: "Unlabeled" for i in self.roots}
385
+ if hasattr(self, "cell_name"):
386
+ self._labels = {
387
+ i: self.cell_name.get(i, "Unlabeled") for i in self.roots
388
+ }
389
+ else:
390
+ self._labels = {
391
+ i: "Unlabeled"
392
+ for i in self.roots
393
+ for l in self.find_leaves(i)
394
+ if abs(self.time[l] - self.time[i])
395
+ >= abs(self.t_e - self.t_b) / 4
396
+ }
385
397
  return self._labels
386
398
 
387
399
  def _write_header_am(self, f: TextIO, nb_points: int, length: int):
@@ -732,61 +744,6 @@ class lineageTree:
732
744
  )
733
745
  dwg.save()
734
746
 
735
- def to_treex(
736
- self,
737
- sampling: int = 1,
738
- start: int = 0,
739
- finish: int = 10000,
740
- many: bool = True,
741
- ):
742
- """
743
- TODO: finish the doc
744
- Convert the lineage tree into a treex file.
745
-
746
- start/finish refer to first index in the new array times_to_consider
747
-
748
- """
749
- from warnings import warn
750
-
751
- from treex.tree import Tree
752
-
753
- if finish - start <= 0:
754
- warn("Will return None, because start = finish", stacklevel=2)
755
- return None
756
- id_to_tree = {_id: Tree() for _id in self.nodes}
757
- times_to_consider = sorted(
758
- [t for t, n in self.time_nodes.items() if len(n) > 0]
759
- )
760
- times_to_consider = times_to_consider[start:finish:sampling]
761
- start_time = times_to_consider[0]
762
- for t in times_to_consider:
763
- for id_mother in self.time_nodes[t]:
764
- ids_daughters = self[id_mother]
765
- new_ids_daughters = ids_daughters.copy()
766
- for _ in range(sampling - 1):
767
- tmp = []
768
- for d in new_ids_daughters:
769
- tmp.extend(self.successor.get(d, [d]))
770
- new_ids_daughters = tmp
771
- for (
772
- daugther
773
- ) in (
774
- new_ids_daughters
775
- ): ## For each daughter in the list of daughters
776
- id_to_tree[id_mother].add_subtree(
777
- id_to_tree[daugther]
778
- ) ## Add the Treex daughter as a subtree of the Treex mother
779
- roots = [id_to_tree[_id] for _id in set(self.time_nodes[start_time])]
780
- for root, ids in zip(roots, set(self.time_nodes[start_time])):
781
- root.add_attribute_to_id("ID", ids)
782
- if not many:
783
- reroot = Tree()
784
- for root in roots:
785
- reroot.add_subtree(root)
786
- return reroot
787
- else:
788
- return roots
789
-
790
747
  def to_tlp(
791
748
  self,
792
749
  fname: str,
@@ -993,726 +950,6 @@ class lineageTree:
993
950
  f.write(")")
994
951
  f.close()
995
952
 
996
- def read_from_csv(
997
- self, file_path: str, z_mult: float, link: int = 1, delim: str = ","
998
- ):
999
- """
1000
- TODO: write doc
1001
- """
1002
-
1003
- def convert_for_csv(v):
1004
- if v.isdigit():
1005
- return int(v)
1006
- else:
1007
- return float(v)
1008
-
1009
- with open(file_path) as f:
1010
- lines = f.readlines()
1011
- f.close()
1012
- self.time_nodes = {}
1013
- self.time_edges = {}
1014
- unique_id = 0
1015
- self.nodes = set()
1016
- self.successor = {}
1017
- self.predecessor = {}
1018
- self.pos = {}
1019
- self.time_id = {}
1020
- self.time = {}
1021
- self.lin = {}
1022
- self.C_lin = {}
1023
- if not link:
1024
- self.displacement = {}
1025
- lines_to_int = []
1026
- corres = {}
1027
- for line in lines:
1028
- lines_to_int += [
1029
- [convert_for_csv(v.strip()) for v in line.split(delim)]
1030
- ]
1031
- lines_to_int = np.array(lines_to_int)
1032
- if link == 2:
1033
- lines_to_int = lines_to_int[np.argsort(lines_to_int[:, 0])]
1034
- else:
1035
- lines_to_int = lines_to_int[np.argsort(lines_to_int[:, 1])]
1036
- for line in lines_to_int:
1037
- if link == 1:
1038
- id_, t, z, y, x, pred, lin_id = line
1039
- elif link == 2:
1040
- t, z, y, x, id_, pred, lin_id = line
1041
- else:
1042
- id_, t, z, y, x, dz, dy, dx = line
1043
- pred = None
1044
- lin_id = None
1045
- t = int(t)
1046
- pos = np.array([x, y, z])
1047
- C = unique_id
1048
- corres[id_] = C
1049
- pos[-1] = pos[-1] * z_mult
1050
- if pred in corres:
1051
- M = corres[pred]
1052
- self.predecessor[C] = [M]
1053
- self.successor.setdefault(M, []).append(C)
1054
- self.time_edges.setdefault(t, set()).add((M, C))
1055
- self.lin.setdefault(lin_id, []).append(C)
1056
- self.C_lin[C] = lin_id
1057
- self.pos[C] = pos
1058
- self.nodes.add(C)
1059
- self.time_nodes.setdefault(t, set()).add(C)
1060
- # self.time_id[(t, cell_id)] = C
1061
- self.time[C] = t
1062
- if not link:
1063
- self.displacement[C] = np.array([dx, dy, dz * z_mult])
1064
- unique_id += 1
1065
- self.max_id = unique_id - 1
1066
- self.t_b = min(self.time_nodes)
1067
- self.t_e = max(self.time_nodes)
1068
-
1069
- def read_from_ASTEC(self, file_path: str, eigen: bool = False):
1070
- """
1071
- Read an `xml` or `pkl` file produced by the ASTEC algorithm.
1072
-
1073
- Args:
1074
- file_path (str): path to an output generated by ASTEC
1075
- eigen (bool): whether or not to read the eigen values, default False
1076
- """
1077
- self._astec_keydictionary = {
1078
- "cell_lineage": [
1079
- "lineage_tree",
1080
- "lin_tree",
1081
- "Lineage tree",
1082
- "cell_lineage",
1083
- ],
1084
- "cell_h_min": ["cell_h_min", "h_mins_information"],
1085
- "cell_volume": [
1086
- "cell_volume",
1087
- "volumes_information",
1088
- "volumes information",
1089
- "vol",
1090
- ],
1091
- "cell_surface": ["cell_surface", "cell surface"],
1092
- "cell_compactness": [
1093
- "cell_compactness",
1094
- "Cell Compactness",
1095
- "compacity",
1096
- "cell_sphericity",
1097
- ],
1098
- "cell_sigma": ["cell_sigma", "sigmas_information", "sigmas"],
1099
- "cell_labels_in_time": [
1100
- "cell_labels_in_time",
1101
- "Cells labels in time",
1102
- "time_labels",
1103
- ],
1104
- "cell_barycenter": [
1105
- "cell_barycenter",
1106
- "Barycenters",
1107
- "barycenters",
1108
- ],
1109
- "cell_fate": ["cell_fate", "Fate"],
1110
- "cell_fate_2": ["cell_fate_2", "Fate2"],
1111
- "cell_fate_3": ["cell_fate_3", "Fate3"],
1112
- "cell_fate_4": ["cell_fate_4", "Fate4"],
1113
- "all_cells": [
1114
- "all_cells",
1115
- "All Cells",
1116
- "All_Cells",
1117
- "all cells",
1118
- "tot_cells",
1119
- ],
1120
- "cell_principal_values": [
1121
- "cell_principal_values",
1122
- "Principal values",
1123
- ],
1124
- "cell_name": ["cell_name", "Names", "names", "cell_names"],
1125
- "cell_contact_surface": [
1126
- "cell_contact_surface",
1127
- "cell_cell_contact_information",
1128
- ],
1129
- "cell_history": [
1130
- "cell_history",
1131
- "Cells history",
1132
- "cell_life",
1133
- "life",
1134
- ],
1135
- "cell_principal_vectors": [
1136
- "cell_principal_vectors",
1137
- "Principal vectors",
1138
- ],
1139
- "cell_naming_score": ["cell_naming_score", "Scores", "scores"],
1140
- "problematic_cells": ["problematic_cells"],
1141
- "unknown_key": ["unknown_key"],
1142
- }
1143
-
1144
- if os.path.splitext(file_path)[-1] == ".xml":
1145
- tmp_data = self._read_from_ASTEC_xml(file_path)
1146
- else:
1147
- tmp_data = self._read_from_ASTEC_pkl(file_path, eigen)
1148
-
1149
- # make sure these are all named liked they are in tmp_data (or change dictionary above)
1150
- self.name = {}
1151
- if "cell_volume" in tmp_data:
1152
- self.volume = {}
1153
- if "cell_fate" in tmp_data:
1154
- self.fates = {}
1155
- if "cell_barycenter" in tmp_data:
1156
- self.pos = {}
1157
- self.lT2pkl = {}
1158
- self.pkl2lT = {}
1159
- self.contact = {}
1160
- self.prob_cells = set()
1161
- self.image_label = {}
1162
-
1163
- lt = tmp_data["cell_lineage"]
1164
-
1165
- if "cell_contact_surface" in tmp_data:
1166
- do_surf = True
1167
- surfaces = tmp_data["cell_contact_surface"]
1168
- else:
1169
- do_surf = False
1170
-
1171
- inv = {vi: [c] for c, v in lt.items() for vi in v}
1172
- nodes = set(lt).union(inv)
1173
-
1174
- unique_id = 0
1175
-
1176
- for n in nodes:
1177
- t = n // 10**4
1178
- self.image_label[unique_id] = n % 10**4
1179
- self.lT2pkl[unique_id] = n
1180
- self.pkl2lT[n] = unique_id
1181
- self.time_nodes.setdefault(t, set()).add(unique_id)
1182
- self.nodes.add(unique_id)
1183
- self.time[unique_id] = t
1184
- if "cell_volume" in tmp_data:
1185
- self.volume[unique_id] = tmp_data["cell_volume"].get(n, 0.0)
1186
- if "cell_fate" in tmp_data:
1187
- self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
1188
- if "cell_barycenter" in tmp_data:
1189
- self.pos[unique_id] = tmp_data["cell_barycenter"].get(
1190
- n, np.zeros(3)
1191
- )
1192
-
1193
- unique_id += 1
1194
- if do_surf:
1195
- for c in nodes:
1196
- if c in surfaces and c in self.pkl2lT:
1197
- self.contact[self.pkl2lT[c]] = {
1198
- self.pkl2lT.get(n, -1): s
1199
- for n, s in surfaces[c].items()
1200
- if n % 10**4 == 1 or n in self.pkl2lT
1201
- }
1202
-
1203
- for n, new_id in self.pkl2lT.items():
1204
- if n in inv:
1205
- self.predecessor[new_id] = [self.pkl2lT[ni] for ni in inv[n]]
1206
- if n in lt:
1207
- self.successor[new_id] = [
1208
- self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
1209
- ]
1210
-
1211
- for ni in self.successor[new_id]:
1212
- self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
1213
-
1214
- self.t_b = min(self.time_nodes)
1215
- self.t_e = max(self.time_nodes)
1216
- self.max_id = unique_id
1217
-
1218
- # do this in the end of the process, skip lineage tree and whatever is stored already
1219
- discard = {
1220
- "cell_volume",
1221
- "cell_fate",
1222
- "cell_barycenter",
1223
- "cell_contact_surface",
1224
- "cell_lineage",
1225
- "all_cells",
1226
- "cell_history",
1227
- "problematic_cells",
1228
- "cell_labels_in_time",
1229
- }
1230
- self.specific_properties = []
1231
- for prop_name, prop_values in tmp_data.items():
1232
- if not (prop_name in discard or hasattr(self, prop_name)):
1233
- if isinstance(prop_values, dict):
1234
- dictionary = {
1235
- self.pkl2lT.get(k, -1): v
1236
- for k, v in prop_values.items()
1237
- }
1238
- # is it a regular dictionary or a dictionary with dictionaries inside?
1239
- for key, value in dictionary.items():
1240
- if isinstance(value, dict):
1241
- # rename all ids from old to new
1242
- dictionary[key] = {
1243
- self.pkl2lT.get(k, -1): v
1244
- for k, v in value.items()
1245
- }
1246
- self.__dict__[prop_name] = dictionary
1247
- self.specific_properties.append(prop_name)
1248
- # is any of this necessary? Or does it mean it anyways does not contain
1249
- # information about the id and a simple else: is enough?
1250
- elif (
1251
- isinstance(prop_values, (list, set, np.ndarray))
1252
- and prop_name not in []
1253
- ):
1254
- self.__dict__[prop_name] = prop_values
1255
- self.specific_properties.append(prop_name)
1256
-
1257
- # what else could it be?
1258
-
1259
- # add a list of all available properties
1260
-
1261
- def _read_from_ASTEC_xml(self, file_path: str):
1262
- def _set_dictionary_value(root):
1263
- if len(root) == 0:
1264
- if root.text is None:
1265
- return None
1266
- else:
1267
- return eval(root.text)
1268
- else:
1269
- dictionary = {}
1270
- for child in root:
1271
- key = child.tag
1272
- if child.tag == "cell":
1273
- key = int(child.attrib["cell-id"])
1274
- dictionary[key] = _set_dictionary_value(child)
1275
- return dictionary
1276
-
1277
- tree = ET.parse(file_path)
1278
- root = tree.getroot()
1279
- dictionary = {}
1280
-
1281
- for k, _v in self._astec_keydictionary.items():
1282
- if root.tag == k:
1283
- dictionary[str(root.tag)] = _set_dictionary_value(root)
1284
- break
1285
- else:
1286
- for child in root:
1287
- value = _set_dictionary_value(child)
1288
- if value is not None:
1289
- dictionary[str(child.tag)] = value
1290
- return dictionary
1291
-
1292
- def _read_from_ASTEC_pkl(self, file_path: str, eigen: bool = False):
1293
- with open(file_path, "rb") as f:
1294
- tmp_data = pkl.load(f, encoding="latin1")
1295
- f.close()
1296
- new_ref = {}
1297
- for k, v in self._astec_keydictionary.items():
1298
- for key in v:
1299
- new_ref[key] = k
1300
- new_dict = {}
1301
-
1302
- for k, v in tmp_data.items():
1303
- if k in new_ref:
1304
- new_dict[new_ref[k]] = v
1305
- else:
1306
- new_dict[k] = v
1307
- return new_dict
1308
-
1309
- def read_from_txt_for_celegans(self, file: str):
1310
- """
1311
- Read a C. elegans lineage tree
1312
-
1313
- Args:
1314
- file (str): Path to the file to read
1315
- """
1316
- implicit_l_t = {
1317
- "AB": "P0",
1318
- "P1": "P0",
1319
- "EMS": "P1",
1320
- "P2": "P1",
1321
- "MS": "EMS",
1322
- "E": "EMS",
1323
- "C": "P2",
1324
- "P3": "P2",
1325
- "D": "P3",
1326
- "P4": "P3",
1327
- "Z2": "P4",
1328
- "Z3": "P4",
1329
- }
1330
- with open(file) as f:
1331
- raw = f.readlines()[1:]
1332
- f.close()
1333
- self.name = {}
1334
-
1335
- unique_id = 0
1336
- for line in raw:
1337
- t = int(line.split("\t")[0])
1338
- self.name[unique_id] = line.split("\t")[1]
1339
- position = np.array(line.split("\t")[2:5], dtype=float)
1340
- self.time_nodes.setdefault(t, set()).add(unique_id)
1341
- self.nodes.add(unique_id)
1342
- self.pos[unique_id] = position
1343
- self.time[unique_id] = t
1344
- unique_id += 1
1345
-
1346
- self.t_b = min(self.time_nodes)
1347
- self.t_e = max(self.time_nodes)
1348
-
1349
- for t, cells in self.time_nodes.items():
1350
- if t != self.t_b:
1351
- prev_cells = self.time_nodes[t - 1]
1352
- name_to_id = {self.name[c]: c for c in prev_cells}
1353
- for c in cells:
1354
- if self.name[c] in name_to_id:
1355
- p = name_to_id[self.name[c]]
1356
- elif self.name[c][:-1] in name_to_id:
1357
- p = name_to_id[self.name[c][:-1]]
1358
- elif implicit_l_t.get(self.name[c]) in name_to_id:
1359
- p = name_to_id[implicit_l_t.get(self.name[c])]
1360
- else:
1361
- print(
1362
- "error, cell %s has no predecessors" % self.name[c]
1363
- )
1364
- p = None
1365
- self.predecessor.setdefault(c, []).append(p)
1366
- self.successor.setdefault(p, []).append(c)
1367
- self.time_edges.setdefault(t - 1, set()).add((p, c))
1368
- self.max_id = unique_id
1369
-
1370
- def read_from_txt_for_celegans_CAO(
1371
- self,
1372
- file: str,
1373
- reorder: bool = False,
1374
- raw_size: float = None,
1375
- shape: float = None,
1376
- ):
1377
- """
1378
- Read a C. elegans lineage tree from Cao et al.
1379
-
1380
- Args:
1381
- file (str): Path to the file to read
1382
- """
1383
-
1384
- implicit_l_t = {
1385
- "AB": "P0",
1386
- "P1": "P0",
1387
- "EMS": "P1",
1388
- "P2": "P1",
1389
- "MS": "EMS",
1390
- "E": "EMS",
1391
- "C": "P2",
1392
- "P3": "P2",
1393
- "D": "P3",
1394
- "P4": "P3",
1395
- "Z2": "P4",
1396
- "Z3": "P4",
1397
- }
1398
-
1399
- def split_line(line):
1400
- return (
1401
- line.split()[0],
1402
- eval(line.split()[1]),
1403
- eval(line.split()[2]),
1404
- eval(line.split()[3]),
1405
- eval(line.split()[4]),
1406
- )
1407
-
1408
- with open(file) as f:
1409
- raw = f.readlines()[1:]
1410
- f.close()
1411
- self.name = {}
1412
-
1413
- unique_id = 0
1414
- for name, t, z, x, y in map(split_line, raw):
1415
- self.name[unique_id] = name
1416
- position = np.array([x, y, z], dtype=np.float)
1417
- self.time_nodes.setdefault(t, set()).add(unique_id)
1418
- self.nodes.add(unique_id)
1419
- if reorder:
1420
-
1421
- def flip(x):
1422
- return np.array([x[0], x[1], raw_size[2] - x[2]])
1423
-
1424
- def adjust(x):
1425
- return (shape / raw_size * flip(x))[[1, 0, 2]]
1426
-
1427
- self.pos[unique_id] = adjust(position)
1428
- else:
1429
- self.pos[unique_id] = position
1430
- self.time[unique_id] = t
1431
- unique_id += 1
1432
-
1433
- self.t_b = min(self.time_nodes)
1434
- self.t_e = max(self.time_nodes)
1435
-
1436
- for t, cells in self.time_nodes.items():
1437
- if t != self.t_b:
1438
- prev_cells = self.time_nodes[t - 1]
1439
- name_to_id = {self.name[c]: c for c in prev_cells}
1440
- for c in cells:
1441
- if self.name[c] in name_to_id:
1442
- p = name_to_id[self.name[c]]
1443
- elif self.name[c][:-1] in name_to_id:
1444
- p = name_to_id[self.name[c][:-1]]
1445
- elif implicit_l_t.get(self.name[c]) in name_to_id:
1446
- p = name_to_id[implicit_l_t.get(self.name[c])]
1447
- else:
1448
- print(
1449
- "error, cell %s has no predecessors" % self.name[c]
1450
- )
1451
- p = None
1452
- self.predecessor.setdefault(c, []).append(p)
1453
- self.successor.setdefault(p, []).append(c)
1454
- self.time_edges.setdefault(t - 1, set()).add((p, c))
1455
- self.max_id = unique_id
1456
-
1457
- def read_tgmm_xml(
1458
- self, file_format: str, tb: int, te: int, z_mult: float = 1.0
1459
- ):
1460
- """Reads a lineage tree from TGMM xml output.
1461
-
1462
- Args:
1463
- file_format (str): path to the xmls location.
1464
- it should be written as follow:
1465
- path/to/xml/standard_name_t{t:06d}.xml where (as an example)
1466
- {t:06d} means a series of 6 digits representing the time and
1467
- if the time values is smaller that 6 digits, the missing
1468
- digits are filed with 0s
1469
- tb (int): first time point to read
1470
- te (int): last time point to read
1471
- z_mult (float): aspect ratio
1472
- """
1473
- self.time_nodes = {}
1474
- self.time_edges = {}
1475
- unique_id = 0
1476
- self.nodes = set()
1477
- self.successor = {}
1478
- self.predecessor = {}
1479
- self.pos = {}
1480
- self.time_id = {}
1481
- self.time = {}
1482
- self.mother_not_found = []
1483
- self.ind_cells = {}
1484
- self.svIdx = {}
1485
- self.lin = {}
1486
- self.C_lin = {}
1487
- self.coeffs = {}
1488
- self.intensity = {}
1489
- self.W = {}
1490
- for t in range(tb, te + 1):
1491
- print(t, end=" ")
1492
- if t % 10 == 0:
1493
- print()
1494
- tree = ET.parse(file_format.format(t=t))
1495
- root = tree.getroot()
1496
- self.time_nodes[t] = set()
1497
- self.time_edges[t] = set()
1498
- for it in root:
1499
- if (
1500
- "-1.#IND" not in it.attrib["m"]
1501
- and "nan" not in it.attrib["m"]
1502
- ):
1503
- M_id, pos, cell_id, svIdx, lin_id = (
1504
- int(it.attrib["parent"]),
1505
- [
1506
- float(v)
1507
- for v in it.attrib["m"].split(" ")
1508
- if v != ""
1509
- ],
1510
- int(it.attrib["id"]),
1511
- [
1512
- int(v)
1513
- for v in it.attrib["svIdx"].split(" ")
1514
- if v != ""
1515
- ],
1516
- int(it.attrib["lineage"]),
1517
- )
1518
- try:
1519
- alpha, W, nu, alphaPrior = (
1520
- float(it.attrib["alpha"]),
1521
- [
1522
- float(v)
1523
- for v in it.attrib["W"].split(" ")
1524
- if v != ""
1525
- ],
1526
- float(it.attrib["nu"]),
1527
- float(it.attrib["alphaPrior"]),
1528
- )
1529
- pos = np.array(pos)
1530
- C = unique_id
1531
- pos[-1] = pos[-1] * z_mult
1532
- if (t - 1, M_id) in self.time_id:
1533
- M = self.time_id[(t - 1, M_id)]
1534
- self.successor.setdefault(M, []).append(C)
1535
- self.predecessor.setdefault(C, []).append(M)
1536
- self.time_edges[t].add((M, C))
1537
- else:
1538
- if M_id != -1:
1539
- self.mother_not_found.append(C)
1540
- self.pos[C] = pos
1541
- self.nodes.add(C)
1542
- self.time_nodes[t].add(C)
1543
- self.time_id[(t, cell_id)] = C
1544
- self.time[C] = t
1545
- self.svIdx[C] = svIdx
1546
- self.lin.setdefault(lin_id, []).append(C)
1547
- self.C_lin[C] = lin_id
1548
- self.intensity[C] = max(alpha - alphaPrior, 0)
1549
- tmp = list(np.array(W) * nu)
1550
- self.W[C] = np.array(W).reshape(3, 3)
1551
- self.coeffs[C] = (
1552
- tmp[:3] + tmp[4:6] + tmp[8:9] + list(pos)
1553
- )
1554
- unique_id += 1
1555
- except Exception:
1556
- pass
1557
- else:
1558
- if t in self.ind_cells:
1559
- self.ind_cells[t] += 1
1560
- else:
1561
- self.ind_cells[t] = 1
1562
- self.max_id = unique_id - 1
1563
-
1564
- def read_from_mastodon(self, path: str, name: str):
1565
- """
1566
- TODO: write doc
1567
- """
1568
- from mastodon_reader import MastodonReader
1569
-
1570
- mr = MastodonReader(path)
1571
- spots, links = mr.read_tables()
1572
-
1573
- self.node_name = {}
1574
-
1575
- for c in spots.iloc:
1576
- unique_id = c.name
1577
- x, y, z = c.x, c.y, c.z
1578
- t = c.t
1579
- n = c[name] if name is not None else ""
1580
- self.time_nodes.setdefault(t, set()).add(unique_id)
1581
- self.nodes.add(unique_id)
1582
- self.time[unique_id] = t
1583
- self.node_name[unique_id] = n
1584
- self.pos[unique_id] = np.array([x, y, z])
1585
-
1586
- for e in links.iloc:
1587
- source = e.source_idx
1588
- target = e.target_idx
1589
- self.predecessor.setdefault(target, []).append(source)
1590
- self.successor.setdefault(source, []).append(target)
1591
- self.time_edges.setdefault(self.time[source], set()).add(
1592
- (source, target)
1593
- )
1594
- self.t_b = min(self.time_nodes.keys())
1595
- self.t_e = max(self.time_nodes.keys())
1596
-
1597
- def read_from_mastodon_csv(self, path: str):
1598
- """
1599
- TODO: Write doc
1600
- """
1601
- spots = []
1602
- links = []
1603
- self.node_name = {}
1604
-
1605
- with open(path[0], encoding="utf-8", errors="ignore") as file:
1606
- csvreader = csv.reader(file)
1607
- for row in csvreader:
1608
- spots.append(row)
1609
- spots = spots[3:]
1610
-
1611
- with open(path[1], encoding="utf-8", errors="ignore") as file:
1612
- csvreader = csv.reader(file)
1613
- for row in csvreader:
1614
- links.append(row)
1615
- links = links[3:]
1616
-
1617
- for spot in spots:
1618
- unique_id = int(spot[1])
1619
- x, y, z = spot[5:8]
1620
- t = int(spot[4])
1621
- self.time_nodes.setdefault(t, set()).add(unique_id)
1622
- self.nodes.add(unique_id)
1623
- self.time[unique_id] = t
1624
- self.node_name[unique_id] = spot[1]
1625
- self.pos[unique_id] = np.array([x, y, z], dtype=float)
1626
-
1627
- for link in links:
1628
- source = int(float(link[4]))
1629
- target = int(float(link[5]))
1630
- self.predecessor.setdefault(target, []).append(source)
1631
- self.successor.setdefault(source, []).append(target)
1632
- self.time_edges.setdefault(self.time[source], set()).add(
1633
- (source, target)
1634
- )
1635
- self.t_b = min(self.time_nodes.keys())
1636
- self.t_e = max(self.time_nodes.keys())
1637
-
1638
- def read_from_mamut_xml(self, path: str):
1639
- """Read a lineage tree from a MaMuT xml.
1640
-
1641
- Args:
1642
- path (str): path to the MaMut xml
1643
- """
1644
- tree = ET.parse(path)
1645
- for elem in tree.getroot():
1646
- if elem.tag == "Model":
1647
- Model = elem
1648
- FeatureDeclarations, AllSpots, AllTracks, FilteredTracks = list(Model)
1649
-
1650
- for attr in self.xml_attributes:
1651
- self.__dict__[attr] = {}
1652
- self.time_nodes = {}
1653
- self.time_edges = {}
1654
- self.nodes = set()
1655
- self.pos = {}
1656
- self.time = {}
1657
- self.node_name = {}
1658
- for frame in AllSpots:
1659
- t = int(frame.attrib["frame"])
1660
- self.time_nodes[t] = set()
1661
- for cell in frame:
1662
- cell_id, n, x, y, z = (
1663
- int(cell.attrib["ID"]),
1664
- cell.attrib["name"],
1665
- float(cell.attrib["POSITION_X"]),
1666
- float(cell.attrib["POSITION_Y"]),
1667
- float(cell.attrib["POSITION_Z"]),
1668
- )
1669
- self.time_nodes[t].add(cell_id)
1670
- self.nodes.add(cell_id)
1671
- self.pos[cell_id] = np.array([x, y, z])
1672
- self.time[cell_id] = t
1673
- self.node_name[cell_id] = n
1674
- if "TISSUE_NAME" in cell.attrib:
1675
- if not hasattr(self, "fate"):
1676
- self.fate = {}
1677
- self.fate[cell_id] = cell.attrib["TISSUE_NAME"]
1678
- if "TISSUE_TYPE" in cell.attrib:
1679
- if not hasattr(self, "fate_nb"):
1680
- self.fate_nb = {}
1681
- self.fate_nb[cell_id] = eval(cell.attrib["TISSUE_TYPE"])
1682
- for attr in cell.attrib:
1683
- if attr in self.xml_attributes:
1684
- self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
1685
-
1686
- tracks = {}
1687
- self.successor = {}
1688
- self.predecessor = {}
1689
- self.track_name = {}
1690
- for track in AllTracks:
1691
- if "TRACK_DURATION" in track.attrib:
1692
- t_id, _ = (
1693
- int(track.attrib["TRACK_ID"]),
1694
- float(track.attrib["TRACK_DURATION"]),
1695
- )
1696
- else:
1697
- t_id = int(track.attrib["TRACK_ID"])
1698
- t_name = track.attrib["name"]
1699
- tracks[t_id] = []
1700
- for edge in track:
1701
- s, t = (
1702
- int(edge.attrib["SPOT_SOURCE_ID"]),
1703
- int(edge.attrib["SPOT_TARGET_ID"]),
1704
- )
1705
- if s in self.nodes and t in self.nodes:
1706
- if self.time[s] > self.time[t]:
1707
- s, t = t, s
1708
- self.successor.setdefault(s, []).append(t)
1709
- self.predecessor.setdefault(t, []).append(s)
1710
- self.track_name[s] = t_name
1711
- self.track_name[t] = t_name
1712
- tracks[t_id].append((s, t))
1713
- self.t_b = min(self.time_nodes.keys())
1714
- self.t_e = max(self.time_nodes.keys())
1715
-
1716
953
  def to_binary(self, fname: str, starting_points: list = None):
1717
954
  """Writes the lineage tree (a forest) as a binary structure
1718
955
  (assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
@@ -2151,12 +1388,12 @@ class lineageTree:
2151
1388
  if not end_time:
2152
1389
  end_time = self.t_e
2153
1390
  branches = [self.get_successors(node)]
2154
- to_do = self[branches[0][-1]].copy()
1391
+ to_do = list(self[branches[0][-1]])
2155
1392
  while to_do:
2156
1393
  current = to_do.pop()
2157
- track = self.get_cycle(current, end_time=end_time)
1394
+ track = self.get_successors(current, end_time=end_time)
2158
1395
  branches += [track]
2159
- to_do.extend(self[track[-1]])
1396
+ to_do += self[track[-1]]
2160
1397
  return branches
2161
1398
 
2162
1399
  def get_all_tracks(self, force_recompute: bool = False) -> list:
@@ -2196,6 +1433,29 @@ class lineageTree:
2196
1433
  to_do.extend(self[track[-1]])
2197
1434
  return tracks
2198
1435
 
1436
+ def find_leaves(self, roots: Union[int, set, list, tuple]):
1437
+ """Finds the leaves of a tree spawned by one or more nodes.
1438
+
1439
+ Args:
1440
+ roots (Union[int,set,list,tuple]): The roots of the trees.
1441
+
1442
+ Returns:
1443
+ set: The leaves of one or more trees.
1444
+ """
1445
+ if not isinstance(roots, Iterable):
1446
+ to_do = [roots]
1447
+ elif isinstance(roots, Iterable):
1448
+ to_do = list(roots)
1449
+ leaves = set()
1450
+ while to_do:
1451
+ curr = to_do.pop()
1452
+ succ = self.successor.get(curr)
1453
+ if succ is not None:
1454
+ leaves.add(curr)
1455
+ else:
1456
+ to_do += succ
1457
+ return leaves
1458
+
2199
1459
  def get_sub_tree(
2200
1460
  self,
2201
1461
  x: Union[int, Iterable],
@@ -2314,9 +1574,7 @@ class lineageTree:
2314
1574
  list: A list that contains the array of eigenvalues and eigenvectors.
2315
1575
  """
2316
1576
  if time is None:
2317
- time = np.argmax(
2318
- [len(self.time_nodes[t]) for t in range(int(self.t_e))]
2319
- )
1577
+ time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
2320
1578
  pos = np.array([self.pos[node] for node in self.time_nodes[time]])
2321
1579
  pos = pos - np.mean(pos, axis=0)
2322
1580
  cov = np.cov(np.array(pos).T)
@@ -2510,56 +1768,100 @@ class lineageTree:
2510
1768
  tree1.get_norm(), tree2.get_norm()
2511
1769
  )
2512
1770
 
2513
- def to_simple_networkx(
2514
- self, node: Union[int, list, set, tuple] = None, start_time: int = 0
1771
+ def draw_tree_graph(
1772
+ self,
1773
+ hier,
1774
+ lnks_tms,
1775
+ selected_cells=None,
1776
+ color="magenta",
1777
+ ax=None,
1778
+ figure=None,
1779
+ **kwargs,
2515
1780
  ):
2516
- """
2517
- Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
2518
- they can be used for visualization or other tasks),
2519
- so only the start and the end of a branch are calculated, all cells in between are not taken into account.
1781
+ if selected_cells is None:
1782
+ selected_cells = []
1783
+ if ax is None:
1784
+ figure, ax = plt.subplots()
1785
+ else:
1786
+ ax.clear()
1787
+ if not isinstance(selected_cells, set):
1788
+ selected_cells = set(selected_cells)
1789
+ hier_unselected = {
1790
+ k: v for k, v in hier.items() if k not in selected_cells
1791
+ }
1792
+ hier_selected = {k: v for k, v in hier.items() if k in selected_cells}
1793
+ unselected = np.array(tuple(hier_unselected.values()))
1794
+ x = []
1795
+ y = []
1796
+ if hier_unselected:
1797
+ for pred, succs in lnks_tms["links"].items():
1798
+ if pred not in selected_cells:
1799
+ for succ in succs:
1800
+ x.extend((hier[succ][0], hier[pred][0], None))
1801
+ y.extend((hier[succ][1], hier[pred][1], None))
1802
+ ax.plot(x, y, c="black", linewidth=0.3, zorder=0.5, **kwargs)
1803
+ ax.scatter(
1804
+ *unselected.T,
1805
+ s=0.1,
1806
+ c="black",
1807
+ zorder=1,
1808
+ **kwargs,
1809
+ )
1810
+ if selected_cells:
1811
+ selected = np.array(tuple(hier_selected.values()))
1812
+ x = []
1813
+ y = []
1814
+ for pred, succs in lnks_tms["links"].items():
1815
+ if pred in selected_cells:
1816
+ for succ in succs:
1817
+ x.extend((hier[succ][0], hier[pred][0], None))
1818
+ y.extend((hier[succ][1], hier[pred][1], None))
1819
+ ax.plot(x, y, c=color, linewidth=0.3, zorder=0.4, **kwargs)
1820
+ ax.scatter(
1821
+ selected.T[0],
1822
+ selected.T[1],
1823
+ s=0.1,
1824
+ c=color,
1825
+ zorder=0.9,
1826
+ **kwargs,
1827
+ )
1828
+ ax.get_yaxis().set_visible(False)
1829
+ ax.get_xaxis().set_visible(False)
1830
+ return figure, ax
1831
+
1832
+ def to_simple_graph(self, node=None, start_time: int = None):
1833
+ """Generates a dictionary of graphs where the keys are the index of the graph and
1834
+ the values are the graphs themselves which are produced by create_links_and _cycles
1835
+
2520
1836
  Args:
2521
- start_time (int): From which timepoints are the graphs to be calculated.
2522
- For example if start_time is 10, then all trees that begin
2523
- on tp 10 or before are calculated.
2524
- returns:
2525
- G : list(nx.Digraph(),...)
2526
- pos : list(dict(id:position))
2527
- """
1837
+ node (_type_, optional): The id of the node/nodes to produce the simple graphs. Defaults to None.
1838
+ start_time (int, optional): Important only if there are no nodes it will produce the graph of every
1839
+ root that starts before or at start time. Defaults to None.
2528
1840
 
1841
+ Returns:
1842
+ (dict): The keys are just index values 0-n and the values are the graphs produced.
1843
+ """
1844
+ if start_time is None:
1845
+ start_time = self.t_b
2529
1846
  if node is None:
2530
1847
  mothers = [
2531
1848
  root for root in self.roots if self.time[root] <= start_time
2532
1849
  ]
2533
1850
  else:
2534
1851
  mothers = node if isinstance(node, (list, set)) else [node]
2535
- graph = {}
2536
- all_nodes = {}
2537
- all_edges = {}
2538
- for mom in mothers:
2539
- edges = set()
2540
- nodes = set()
2541
- for branch in self.get_all_branches_of_node(mom):
2542
- nodes.update((branch[0], branch[-1]))
2543
- if len(branch) > 1:
2544
- edges.add((branch[0], branch[-1]))
2545
- for suc in self[branch[-1]]:
2546
- edges.add((branch[-1], suc))
2547
- all_edges[mom] = edges
2548
- all_nodes[mom] = nodes
2549
- for i, mother in enumerate(mothers):
2550
- graph[i] = nx.DiGraph()
2551
- graph[i].add_nodes_from(all_nodes[mother])
2552
- graph[i].add_edges_from(all_edges[mother])
2553
-
2554
- return graph
1852
+ return {
1853
+ i: create_links_and_cycles(self, mother)
1854
+ for i, mother in enumerate(mothers)
1855
+ }
2555
1856
 
2556
1857
  def plot_all_lineages(
2557
1858
  self,
2558
- starting_point: int = 0,
1859
+ nodes: list = None,
1860
+ last_time_point_to_consider: int = None,
2559
1861
  nrows=2,
2560
1862
  figsize=(10, 15),
2561
- dpi=70,
2562
- fontsize=22,
1863
+ dpi=100,
1864
+ fontsize=15,
2563
1865
  figure=None,
2564
1866
  axes=None,
2565
1867
  **kwargs,
@@ -2567,78 +1869,98 @@ class lineageTree:
2567
1869
  """Plots all lineages.
2568
1870
 
2569
1871
  Args:
2570
- starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
2571
- For example if start_time is 10, then all trees that begin
2572
- on tp 10 or before are calculated. Defaults to None.
1872
+ last_time_point_to_consider (int, optional): Which timepoints and upwards are the graphs to be plotted.
1873
+ For example if start_time is 10, then all trees that begin
1874
+ on tp 10 or before are calculated. Defaults to None, where
1875
+ it will plot all the roots that exist on self.t_b.
2573
1876
  nrows (int): How many rows of plots should be printed.
2574
- kwargs: args accepted by networkx
1877
+ kwargs: args accepted by matplotlib
2575
1878
  """
2576
1879
 
2577
1880
  nrows = int(nrows)
1881
+ if last_time_point_to_consider is None:
1882
+ last_time_point_to_consider = self.t_b
2578
1883
  if nrows < 1 or not nrows:
2579
1884
  nrows = 1
2580
1885
  raise Warning("Number of rows has to be at least 1")
2581
-
2582
- graphs = self.to_simple_networkx(start_time=starting_point)
1886
+ if nodes:
1887
+ graphs = {
1888
+ i: self.to_simple_graph(node) for i, node in enumerate(nodes)
1889
+ }
1890
+ else:
1891
+ graphs = self.to_simple_graph(
1892
+ start_time=last_time_point_to_consider
1893
+ )
1894
+ pos = {
1895
+ i: hierarchical_pos(
1896
+ g, g["root"], ycenter=-int(self.time[g["root"]])
1897
+ )
1898
+ for i, g in graphs.items()
1899
+ }
2583
1900
  ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2584
- pos = postions_of_nx(self, graphs)
2585
1901
  figure, axes = plt.subplots(
2586
1902
  figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2587
1903
  )
2588
1904
  flat_axes = axes.flatten()
2589
1905
  ax2root = {}
2590
- for i, graph in enumerate(graphs.values()):
2591
- nx.draw_networkx(
2592
- graph,
2593
- pos[i],
2594
- with_labels=False,
2595
- arrows=False,
2596
- **kwargs,
2597
- ax=flat_axes[i],
1906
+ min_width, min_height = float("inf"), float("inf")
1907
+ for ax in flat_axes:
1908
+ bbox = ax.get_window_extent().transformed(
1909
+ figure.dpi_scale_trans.inverted()
2598
1910
  )
2599
- root = [n for n, d in graph.in_degree() if d == 0][0]
1911
+ min_width = min(min_width, bbox.width)
1912
+ min_height = min(min_height, bbox.height)
1913
+
1914
+ adjusted_fontsize = fontsize * min(min_width, min_height) / 5
1915
+ for i, graph in graphs.items():
1916
+ self.draw_tree_graph(
1917
+ hier=pos[i], lnks_tms=graph, ax=flat_axes[i], **kwargs
1918
+ )
1919
+ root = graph["root"]
1920
+ ax2root[flat_axes[i]] = root
2600
1921
  label = self.labels.get(root, "Unlabeled")
2601
1922
  xlim = flat_axes[i].get_xlim()
2602
1923
  ylim = flat_axes[i].get_ylim()
2603
- x_pos = (xlim[1]) / 10
2604
- y_pos = ylim[0] + 15
2605
- ax2root[flat_axes[i]] = root
1924
+ x_pos = (xlim[0] + xlim[1]) / 2
1925
+ y_pos = ylim[1] * 0.8
2606
1926
  flat_axes[i].text(
2607
1927
  x_pos,
2608
1928
  y_pos,
2609
1929
  label,
2610
- fontsize=fontsize,
1930
+ fontsize=adjusted_fontsize,
2611
1931
  color="black",
2612
1932
  ha="center",
2613
1933
  va="center",
2614
1934
  bbox={
2615
1935
  "facecolor": "white",
1936
+ "alpha": 0.5,
2616
1937
  "edgecolor": "green",
2617
- "boxstyle": "round",
2618
1938
  },
2619
1939
  )
2620
1940
  [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
2621
1941
  return figure, axes, ax2root
2622
1942
 
2623
- def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
1943
+ def plot_node(self, node, figsize=(4, 7), dpi=150, vert_gap=2, **kwargs):
2624
1944
  """Plots the subtree spawn by a node.
2625
1945
 
2626
1946
  Args:
2627
1947
  node (int): The id of the node that is going to be plotted.
2628
- kwargs: args accepted by networkx
1948
+ kwargs: args accepted by matplotlib
2629
1949
  """
2630
- graph = self.to_simple_networkx(node)
1950
+ graph = self.to_simple_graph(node)
2631
1951
  if len(graph) > 1:
2632
1952
  raise Warning("Please enter only one node")
2633
- graph = graph[list(graph)[0]]
2634
- figure, ax = plt.subplots(nrows=1, ncols=1)
2635
- nx.draw_networkx(
2636
- graph,
2637
- hierarchy_pos(graph, self, node),
2638
- with_labels=False,
2639
- arrows=False,
1953
+ graph = graph[0]
1954
+ figure, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
1955
+ self.draw_tree_graph(
1956
+ hier=hierarchical_pos(
1957
+ graph,
1958
+ graph["root"],
1959
+ vert_gap=vert_gap,
1960
+ ycenter=-int(self.time[node]),
1961
+ ),
1962
+ lnks_tms=graph,
2640
1963
  ax=ax,
2641
- **kwargs,
2642
1964
  )
2643
1965
  return figure, ax
2644
1966