LineageTree 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,10 @@
2
2
  # This file is subject to the terms and conditions defined in
3
3
  # file 'LICENCE', which is part of this source code package.
4
4
  # Author: Leo Guignard (leo.guignard...@AT@...gmail.com)
5
- import csv
6
5
  import os
7
6
  import pickle as pkl
8
7
  import struct
9
8
  import warnings
10
- import xml.etree.ElementTree as ET
11
9
  from collections.abc import Iterable
12
10
  from functools import partial
13
11
  from itertools import combinations
@@ -15,6 +13,7 @@ from numbers import Number
15
13
  from pathlib import Path
16
14
  from typing import TextIO, Union
17
15
 
16
+ from .loaders import lineageTreeLoaders
18
17
  from .tree_styles import tree_style
19
18
 
20
19
  try:
@@ -24,16 +23,18 @@ except ImportError:
24
23
  "No edist installed therefore you will not be able to compute the tree edit distance."
25
24
  )
26
25
  import matplotlib.pyplot as plt
27
- import networkx as nx
28
26
  import numpy as np
29
27
  from scipy.interpolate import InterpolatedUnivariateSpline
30
28
  from scipy.spatial import Delaunay, distance
31
29
  from scipy.spatial import cKDTree as KDTree
32
30
 
33
- from .utils import hierarchy_pos, postions_of_nx
31
+ from .utils import (
32
+ create_links_and_cycles,
33
+ hierarchical_pos,
34
+ )
34
35
 
35
36
 
36
- class lineageTree:
37
+ class lineageTree(lineageTreeLoaders):
37
38
  def __eq__(self, other):
38
39
  if isinstance(other, lineageTree):
39
40
  return other.successor == self.successor
@@ -381,7 +382,18 @@ class lineageTree:
381
382
  @property
382
383
  def labels(self):
383
384
  if not hasattr(self, "_labels"):
384
- self._labels = {i: "Unlabeled" for i in self.roots}
385
+ if hasattr(self, "cell_name"):
386
+ self._labels = {
387
+ i: self.cell_name.get(i, "Unlabeled") for i in self.roots
388
+ }
389
+ else:
390
+ self._labels = {
391
+ i: "Unlabeled"
392
+ for i in self.roots
393
+ for l in self.find_leaves(i)
394
+ if abs(self.time[l] - self.time[i])
395
+ >= abs(self.t_e - self.t_b) / 4
396
+ }
385
397
  return self._labels
386
398
 
387
399
  def _write_header_am(self, f: TextIO, nb_points: int, length: int):
@@ -732,61 +744,6 @@ class lineageTree:
732
744
  )
733
745
  dwg.save()
734
746
 
735
- def to_treex(
736
- self,
737
- sampling: int = 1,
738
- start: int = 0,
739
- finish: int = 10000,
740
- many: bool = True,
741
- ):
742
- """
743
- TODO: finish the doc
744
- Convert the lineage tree into a treex file.
745
-
746
- start/finish refer to first index in the new array times_to_consider
747
-
748
- """
749
- from warnings import warn
750
-
751
- from treex.tree import Tree
752
-
753
- if finish - start <= 0:
754
- warn("Will return None, because start = finish", stacklevel=2)
755
- return None
756
- id_to_tree = {_id: Tree() for _id in self.nodes}
757
- times_to_consider = sorted(
758
- [t for t, n in self.time_nodes.items() if len(n) > 0]
759
- )
760
- times_to_consider = times_to_consider[start:finish:sampling]
761
- start_time = times_to_consider[0]
762
- for t in times_to_consider:
763
- for id_mother in self.time_nodes[t]:
764
- ids_daughters = self[id_mother]
765
- new_ids_daughters = ids_daughters.copy()
766
- for _ in range(sampling - 1):
767
- tmp = []
768
- for d in new_ids_daughters:
769
- tmp.extend(self.successor.get(d, [d]))
770
- new_ids_daughters = tmp
771
- for (
772
- daugther
773
- ) in (
774
- new_ids_daughters
775
- ): ## For each daughter in the list of daughters
776
- id_to_tree[id_mother].add_subtree(
777
- id_to_tree[daugther]
778
- ) ## Add the Treex daughter as a subtree of the Treex mother
779
- roots = [id_to_tree[_id] for _id in set(self.time_nodes[start_time])]
780
- for root, ids in zip(roots, set(self.time_nodes[start_time])):
781
- root.add_attribute_to_id("ID", ids)
782
- if not many:
783
- reroot = Tree()
784
- for root in roots:
785
- reroot.add_subtree(root)
786
- return reroot
787
- else:
788
- return roots
789
-
790
747
  def to_tlp(
791
748
  self,
792
749
  fname: str,
@@ -993,726 +950,6 @@ class lineageTree:
993
950
  f.write(")")
994
951
  f.close()
995
952
 
996
- def read_from_csv(
997
- self, file_path: str, z_mult: float, link: int = 1, delim: str = ","
998
- ):
999
- """
1000
- TODO: write doc
1001
- """
1002
-
1003
- def convert_for_csv(v):
1004
- if v.isdigit():
1005
- return int(v)
1006
- else:
1007
- return float(v)
1008
-
1009
- with open(file_path) as f:
1010
- lines = f.readlines()
1011
- f.close()
1012
- self.time_nodes = {}
1013
- self.time_edges = {}
1014
- unique_id = 0
1015
- self.nodes = set()
1016
- self.successor = {}
1017
- self.predecessor = {}
1018
- self.pos = {}
1019
- self.time_id = {}
1020
- self.time = {}
1021
- self.lin = {}
1022
- self.C_lin = {}
1023
- if not link:
1024
- self.displacement = {}
1025
- lines_to_int = []
1026
- corres = {}
1027
- for line in lines:
1028
- lines_to_int += [
1029
- [convert_for_csv(v.strip()) for v in line.split(delim)]
1030
- ]
1031
- lines_to_int = np.array(lines_to_int)
1032
- if link == 2:
1033
- lines_to_int = lines_to_int[np.argsort(lines_to_int[:, 0])]
1034
- else:
1035
- lines_to_int = lines_to_int[np.argsort(lines_to_int[:, 1])]
1036
- for line in lines_to_int:
1037
- if link == 1:
1038
- id_, t, z, y, x, pred, lin_id = line
1039
- elif link == 2:
1040
- t, z, y, x, id_, pred, lin_id = line
1041
- else:
1042
- id_, t, z, y, x, dz, dy, dx = line
1043
- pred = None
1044
- lin_id = None
1045
- t = int(t)
1046
- pos = np.array([x, y, z])
1047
- C = unique_id
1048
- corres[id_] = C
1049
- pos[-1] = pos[-1] * z_mult
1050
- if pred in corres:
1051
- M = corres[pred]
1052
- self.predecessor[C] = [M]
1053
- self.successor.setdefault(M, []).append(C)
1054
- self.time_edges.setdefault(t, set()).add((M, C))
1055
- self.lin.setdefault(lin_id, []).append(C)
1056
- self.C_lin[C] = lin_id
1057
- self.pos[C] = pos
1058
- self.nodes.add(C)
1059
- self.time_nodes.setdefault(t, set()).add(C)
1060
- # self.time_id[(t, cell_id)] = C
1061
- self.time[C] = t
1062
- if not link:
1063
- self.displacement[C] = np.array([dx, dy, dz * z_mult])
1064
- unique_id += 1
1065
- self.max_id = unique_id - 1
1066
- self.t_b = min(self.time_nodes)
1067
- self.t_e = max(self.time_nodes)
1068
-
1069
- def read_from_ASTEC(self, file_path: str, eigen: bool = False):
1070
- """
1071
- Read an `xml` or `pkl` file produced by the ASTEC algorithm.
1072
-
1073
- Args:
1074
- file_path (str): path to an output generated by ASTEC
1075
- eigen (bool): whether or not to read the eigen values, default False
1076
- """
1077
- self._astec_keydictionary = {
1078
- "cell_lineage": [
1079
- "lineage_tree",
1080
- "lin_tree",
1081
- "Lineage tree",
1082
- "cell_lineage",
1083
- ],
1084
- "cell_h_min": ["cell_h_min", "h_mins_information"],
1085
- "cell_volume": [
1086
- "cell_volume",
1087
- "volumes_information",
1088
- "volumes information",
1089
- "vol",
1090
- ],
1091
- "cell_surface": ["cell_surface", "cell surface"],
1092
- "cell_compactness": [
1093
- "cell_compactness",
1094
- "Cell Compactness",
1095
- "compacity",
1096
- "cell_sphericity",
1097
- ],
1098
- "cell_sigma": ["cell_sigma", "sigmas_information", "sigmas"],
1099
- "cell_labels_in_time": [
1100
- "cell_labels_in_time",
1101
- "Cells labels in time",
1102
- "time_labels",
1103
- ],
1104
- "cell_barycenter": [
1105
- "cell_barycenter",
1106
- "Barycenters",
1107
- "barycenters",
1108
- ],
1109
- "cell_fate": ["cell_fate", "Fate"],
1110
- "cell_fate_2": ["cell_fate_2", "Fate2"],
1111
- "cell_fate_3": ["cell_fate_3", "Fate3"],
1112
- "cell_fate_4": ["cell_fate_4", "Fate4"],
1113
- "all_cells": [
1114
- "all_cells",
1115
- "All Cells",
1116
- "All_Cells",
1117
- "all cells",
1118
- "tot_cells",
1119
- ],
1120
- "cell_principal_values": [
1121
- "cell_principal_values",
1122
- "Principal values",
1123
- ],
1124
- "cell_name": ["cell_name", "Names", "names", "cell_names"],
1125
- "cell_contact_surface": [
1126
- "cell_contact_surface",
1127
- "cell_cell_contact_information",
1128
- ],
1129
- "cell_history": [
1130
- "cell_history",
1131
- "Cells history",
1132
- "cell_life",
1133
- "life",
1134
- ],
1135
- "cell_principal_vectors": [
1136
- "cell_principal_vectors",
1137
- "Principal vectors",
1138
- ],
1139
- "cell_naming_score": ["cell_naming_score", "Scores", "scores"],
1140
- "problematic_cells": ["problematic_cells"],
1141
- "unknown_key": ["unknown_key"],
1142
- }
1143
-
1144
- if os.path.splitext(file_path)[-1] == ".xml":
1145
- tmp_data = self._read_from_ASTEC_xml(file_path)
1146
- else:
1147
- tmp_data = self._read_from_ASTEC_pkl(file_path, eigen)
1148
-
1149
- # make sure these are all named liked they are in tmp_data (or change dictionary above)
1150
- self.name = {}
1151
- if "cell_volume" in tmp_data:
1152
- self.volume = {}
1153
- if "cell_fate" in tmp_data:
1154
- self.fates = {}
1155
- if "cell_barycenter" in tmp_data:
1156
- self.pos = {}
1157
- self.lT2pkl = {}
1158
- self.pkl2lT = {}
1159
- self.contact = {}
1160
- self.prob_cells = set()
1161
- self.image_label = {}
1162
-
1163
- lt = tmp_data["cell_lineage"]
1164
-
1165
- if "cell_contact_surface" in tmp_data:
1166
- do_surf = True
1167
- surfaces = tmp_data["cell_contact_surface"]
1168
- else:
1169
- do_surf = False
1170
-
1171
- inv = {vi: [c] for c, v in lt.items() for vi in v}
1172
- nodes = set(lt).union(inv)
1173
-
1174
- unique_id = 0
1175
-
1176
- for n in nodes:
1177
- t = n // 10**4
1178
- self.image_label[unique_id] = n % 10**4
1179
- self.lT2pkl[unique_id] = n
1180
- self.pkl2lT[n] = unique_id
1181
- self.time_nodes.setdefault(t, set()).add(unique_id)
1182
- self.nodes.add(unique_id)
1183
- self.time[unique_id] = t
1184
- if "cell_volume" in tmp_data:
1185
- self.volume[unique_id] = tmp_data["cell_volume"].get(n, 0.0)
1186
- if "cell_fate" in tmp_data:
1187
- self.fates[unique_id] = tmp_data["cell_fate"].get(n, "")
1188
- if "cell_barycenter" in tmp_data:
1189
- self.pos[unique_id] = tmp_data["cell_barycenter"].get(
1190
- n, np.zeros(3)
1191
- )
1192
-
1193
- unique_id += 1
1194
- if do_surf:
1195
- for c in nodes:
1196
- if c in surfaces and c in self.pkl2lT:
1197
- self.contact[self.pkl2lT[c]] = {
1198
- self.pkl2lT.get(n, -1): s
1199
- for n, s in surfaces[c].items()
1200
- if n % 10**4 == 1 or n in self.pkl2lT
1201
- }
1202
-
1203
- for n, new_id in self.pkl2lT.items():
1204
- if n in inv:
1205
- self.predecessor[new_id] = [self.pkl2lT[ni] for ni in inv[n]]
1206
- if n in lt:
1207
- self.successor[new_id] = [
1208
- self.pkl2lT[ni] for ni in lt[n] if ni in self.pkl2lT
1209
- ]
1210
-
1211
- for ni in self.successor[new_id]:
1212
- self.time_edges.setdefault(t - 1, set()).add((new_id, ni))
1213
-
1214
- self.t_b = min(self.time_nodes)
1215
- self.t_e = max(self.time_nodes)
1216
- self.max_id = unique_id
1217
-
1218
- # do this in the end of the process, skip lineage tree and whatever is stored already
1219
- discard = {
1220
- "cell_volume",
1221
- "cell_fate",
1222
- "cell_barycenter",
1223
- "cell_contact_surface",
1224
- "cell_lineage",
1225
- "all_cells",
1226
- "cell_history",
1227
- "problematic_cells",
1228
- "cell_labels_in_time",
1229
- }
1230
- self.specific_properties = []
1231
- for prop_name, prop_values in tmp_data.items():
1232
- if not (prop_name in discard or hasattr(self, prop_name)):
1233
- if isinstance(prop_values, dict):
1234
- dictionary = {
1235
- self.pkl2lT.get(k, -1): v
1236
- for k, v in prop_values.items()
1237
- }
1238
- # is it a regular dictionary or a dictionary with dictionaries inside?
1239
- for key, value in dictionary.items():
1240
- if isinstance(value, dict):
1241
- # rename all ids from old to new
1242
- dictionary[key] = {
1243
- self.pkl2lT.get(k, -1): v
1244
- for k, v in value.items()
1245
- }
1246
- self.__dict__[prop_name] = dictionary
1247
- self.specific_properties.append(prop_name)
1248
- # is any of this necessary? Or does it mean it anyways does not contain
1249
- # information about the id and a simple else: is enough?
1250
- elif (
1251
- isinstance(prop_values, (list, set, np.ndarray))
1252
- and prop_name not in []
1253
- ):
1254
- self.__dict__[prop_name] = prop_values
1255
- self.specific_properties.append(prop_name)
1256
-
1257
- # what else could it be?
1258
-
1259
- # add a list of all available properties
1260
-
1261
- def _read_from_ASTEC_xml(self, file_path: str):
1262
- def _set_dictionary_value(root):
1263
- if len(root) == 0:
1264
- if root.text is None:
1265
- return None
1266
- else:
1267
- return eval(root.text)
1268
- else:
1269
- dictionary = {}
1270
- for child in root:
1271
- key = child.tag
1272
- if child.tag == "cell":
1273
- key = int(child.attrib["cell-id"])
1274
- dictionary[key] = _set_dictionary_value(child)
1275
- return dictionary
1276
-
1277
- tree = ET.parse(file_path)
1278
- root = tree.getroot()
1279
- dictionary = {}
1280
-
1281
- for k, _v in self._astec_keydictionary.items():
1282
- if root.tag == k:
1283
- dictionary[str(root.tag)] = _set_dictionary_value(root)
1284
- break
1285
- else:
1286
- for child in root:
1287
- value = _set_dictionary_value(child)
1288
- if value is not None:
1289
- dictionary[str(child.tag)] = value
1290
- return dictionary
1291
-
1292
- def _read_from_ASTEC_pkl(self, file_path: str, eigen: bool = False):
1293
- with open(file_path, "rb") as f:
1294
- tmp_data = pkl.load(f, encoding="latin1")
1295
- f.close()
1296
- new_ref = {}
1297
- for k, v in self._astec_keydictionary.items():
1298
- for key in v:
1299
- new_ref[key] = k
1300
- new_dict = {}
1301
-
1302
- for k, v in tmp_data.items():
1303
- if k in new_ref:
1304
- new_dict[new_ref[k]] = v
1305
- else:
1306
- new_dict[k] = v
1307
- return new_dict
1308
-
1309
- def read_from_txt_for_celegans(self, file: str):
1310
- """
1311
- Read a C. elegans lineage tree
1312
-
1313
- Args:
1314
- file (str): Path to the file to read
1315
- """
1316
- implicit_l_t = {
1317
- "AB": "P0",
1318
- "P1": "P0",
1319
- "EMS": "P1",
1320
- "P2": "P1",
1321
- "MS": "EMS",
1322
- "E": "EMS",
1323
- "C": "P2",
1324
- "P3": "P2",
1325
- "D": "P3",
1326
- "P4": "P3",
1327
- "Z2": "P4",
1328
- "Z3": "P4",
1329
- }
1330
- with open(file) as f:
1331
- raw = f.readlines()[1:]
1332
- f.close()
1333
- self.name = {}
1334
-
1335
- unique_id = 0
1336
- for line in raw:
1337
- t = int(line.split("\t")[0])
1338
- self.name[unique_id] = line.split("\t")[1]
1339
- position = np.array(line.split("\t")[2:5], dtype=float)
1340
- self.time_nodes.setdefault(t, set()).add(unique_id)
1341
- self.nodes.add(unique_id)
1342
- self.pos[unique_id] = position
1343
- self.time[unique_id] = t
1344
- unique_id += 1
1345
-
1346
- self.t_b = min(self.time_nodes)
1347
- self.t_e = max(self.time_nodes)
1348
-
1349
- for t, cells in self.time_nodes.items():
1350
- if t != self.t_b:
1351
- prev_cells = self.time_nodes[t - 1]
1352
- name_to_id = {self.name[c]: c for c in prev_cells}
1353
- for c in cells:
1354
- if self.name[c] in name_to_id:
1355
- p = name_to_id[self.name[c]]
1356
- elif self.name[c][:-1] in name_to_id:
1357
- p = name_to_id[self.name[c][:-1]]
1358
- elif implicit_l_t.get(self.name[c]) in name_to_id:
1359
- p = name_to_id[implicit_l_t.get(self.name[c])]
1360
- else:
1361
- print(
1362
- "error, cell %s has no predecessors" % self.name[c]
1363
- )
1364
- p = None
1365
- self.predecessor.setdefault(c, []).append(p)
1366
- self.successor.setdefault(p, []).append(c)
1367
- self.time_edges.setdefault(t - 1, set()).add((p, c))
1368
- self.max_id = unique_id
1369
-
1370
- def read_from_txt_for_celegans_CAO(
1371
- self,
1372
- file: str,
1373
- reorder: bool = False,
1374
- raw_size: float = None,
1375
- shape: float = None,
1376
- ):
1377
- """
1378
- Read a C. elegans lineage tree from Cao et al.
1379
-
1380
- Args:
1381
- file (str): Path to the file to read
1382
- """
1383
-
1384
- implicit_l_t = {
1385
- "AB": "P0",
1386
- "P1": "P0",
1387
- "EMS": "P1",
1388
- "P2": "P1",
1389
- "MS": "EMS",
1390
- "E": "EMS",
1391
- "C": "P2",
1392
- "P3": "P2",
1393
- "D": "P3",
1394
- "P4": "P3",
1395
- "Z2": "P4",
1396
- "Z3": "P4",
1397
- }
1398
-
1399
- def split_line(line):
1400
- return (
1401
- line.split()[0],
1402
- eval(line.split()[1]),
1403
- eval(line.split()[2]),
1404
- eval(line.split()[3]),
1405
- eval(line.split()[4]),
1406
- )
1407
-
1408
- with open(file) as f:
1409
- raw = f.readlines()[1:]
1410
- f.close()
1411
- self.name = {}
1412
-
1413
- unique_id = 0
1414
- for name, t, z, x, y in map(split_line, raw):
1415
- self.name[unique_id] = name
1416
- position = np.array([x, y, z], dtype=np.float)
1417
- self.time_nodes.setdefault(t, set()).add(unique_id)
1418
- self.nodes.add(unique_id)
1419
- if reorder:
1420
-
1421
- def flip(x):
1422
- return np.array([x[0], x[1], raw_size[2] - x[2]])
1423
-
1424
- def adjust(x):
1425
- return (shape / raw_size * flip(x))[[1, 0, 2]]
1426
-
1427
- self.pos[unique_id] = adjust(position)
1428
- else:
1429
- self.pos[unique_id] = position
1430
- self.time[unique_id] = t
1431
- unique_id += 1
1432
-
1433
- self.t_b = min(self.time_nodes)
1434
- self.t_e = max(self.time_nodes)
1435
-
1436
- for t, cells in self.time_nodes.items():
1437
- if t != self.t_b:
1438
- prev_cells = self.time_nodes[t - 1]
1439
- name_to_id = {self.name[c]: c for c in prev_cells}
1440
- for c in cells:
1441
- if self.name[c] in name_to_id:
1442
- p = name_to_id[self.name[c]]
1443
- elif self.name[c][:-1] in name_to_id:
1444
- p = name_to_id[self.name[c][:-1]]
1445
- elif implicit_l_t.get(self.name[c]) in name_to_id:
1446
- p = name_to_id[implicit_l_t.get(self.name[c])]
1447
- else:
1448
- print(
1449
- "error, cell %s has no predecessors" % self.name[c]
1450
- )
1451
- p = None
1452
- self.predecessor.setdefault(c, []).append(p)
1453
- self.successor.setdefault(p, []).append(c)
1454
- self.time_edges.setdefault(t - 1, set()).add((p, c))
1455
- self.max_id = unique_id
1456
-
1457
- def read_tgmm_xml(
1458
- self, file_format: str, tb: int, te: int, z_mult: float = 1.0
1459
- ):
1460
- """Reads a lineage tree from TGMM xml output.
1461
-
1462
- Args:
1463
- file_format (str): path to the xmls location.
1464
- it should be written as follow:
1465
- path/to/xml/standard_name_t{t:06d}.xml where (as an example)
1466
- {t:06d} means a series of 6 digits representing the time and
1467
- if the time values is smaller that 6 digits, the missing
1468
- digits are filed with 0s
1469
- tb (int): first time point to read
1470
- te (int): last time point to read
1471
- z_mult (float): aspect ratio
1472
- """
1473
- self.time_nodes = {}
1474
- self.time_edges = {}
1475
- unique_id = 0
1476
- self.nodes = set()
1477
- self.successor = {}
1478
- self.predecessor = {}
1479
- self.pos = {}
1480
- self.time_id = {}
1481
- self.time = {}
1482
- self.mother_not_found = []
1483
- self.ind_cells = {}
1484
- self.svIdx = {}
1485
- self.lin = {}
1486
- self.C_lin = {}
1487
- self.coeffs = {}
1488
- self.intensity = {}
1489
- self.W = {}
1490
- for t in range(tb, te + 1):
1491
- print(t, end=" ")
1492
- if t % 10 == 0:
1493
- print()
1494
- tree = ET.parse(file_format.format(t=t))
1495
- root = tree.getroot()
1496
- self.time_nodes[t] = set()
1497
- self.time_edges[t] = set()
1498
- for it in root:
1499
- if (
1500
- "-1.#IND" not in it.attrib["m"]
1501
- and "nan" not in it.attrib["m"]
1502
- ):
1503
- M_id, pos, cell_id, svIdx, lin_id = (
1504
- int(it.attrib["parent"]),
1505
- [
1506
- float(v)
1507
- for v in it.attrib["m"].split(" ")
1508
- if v != ""
1509
- ],
1510
- int(it.attrib["id"]),
1511
- [
1512
- int(v)
1513
- for v in it.attrib["svIdx"].split(" ")
1514
- if v != ""
1515
- ],
1516
- int(it.attrib["lineage"]),
1517
- )
1518
- try:
1519
- alpha, W, nu, alphaPrior = (
1520
- float(it.attrib["alpha"]),
1521
- [
1522
- float(v)
1523
- for v in it.attrib["W"].split(" ")
1524
- if v != ""
1525
- ],
1526
- float(it.attrib["nu"]),
1527
- float(it.attrib["alphaPrior"]),
1528
- )
1529
- pos = np.array(pos)
1530
- C = unique_id
1531
- pos[-1] = pos[-1] * z_mult
1532
- if (t - 1, M_id) in self.time_id:
1533
- M = self.time_id[(t - 1, M_id)]
1534
- self.successor.setdefault(M, []).append(C)
1535
- self.predecessor.setdefault(C, []).append(M)
1536
- self.time_edges[t].add((M, C))
1537
- else:
1538
- if M_id != -1:
1539
- self.mother_not_found.append(C)
1540
- self.pos[C] = pos
1541
- self.nodes.add(C)
1542
- self.time_nodes[t].add(C)
1543
- self.time_id[(t, cell_id)] = C
1544
- self.time[C] = t
1545
- self.svIdx[C] = svIdx
1546
- self.lin.setdefault(lin_id, []).append(C)
1547
- self.C_lin[C] = lin_id
1548
- self.intensity[C] = max(alpha - alphaPrior, 0)
1549
- tmp = list(np.array(W) * nu)
1550
- self.W[C] = np.array(W).reshape(3, 3)
1551
- self.coeffs[C] = (
1552
- tmp[:3] + tmp[4:6] + tmp[8:9] + list(pos)
1553
- )
1554
- unique_id += 1
1555
- except Exception:
1556
- pass
1557
- else:
1558
- if t in self.ind_cells:
1559
- self.ind_cells[t] += 1
1560
- else:
1561
- self.ind_cells[t] = 1
1562
- self.max_id = unique_id - 1
1563
-
1564
- def read_from_mastodon(self, path: str, name: str):
1565
- """
1566
- TODO: write doc
1567
- """
1568
- from mastodon_reader import MastodonReader
1569
-
1570
- mr = MastodonReader(path)
1571
- spots, links = mr.read_tables()
1572
-
1573
- self.node_name = {}
1574
-
1575
- for c in spots.iloc:
1576
- unique_id = c.name
1577
- x, y, z = c.x, c.y, c.z
1578
- t = c.t
1579
- n = c[name] if name is not None else ""
1580
- self.time_nodes.setdefault(t, set()).add(unique_id)
1581
- self.nodes.add(unique_id)
1582
- self.time[unique_id] = t
1583
- self.node_name[unique_id] = n
1584
- self.pos[unique_id] = np.array([x, y, z])
1585
-
1586
- for e in links.iloc:
1587
- source = e.source_idx
1588
- target = e.target_idx
1589
- self.predecessor.setdefault(target, []).append(source)
1590
- self.successor.setdefault(source, []).append(target)
1591
- self.time_edges.setdefault(self.time[source], set()).add(
1592
- (source, target)
1593
- )
1594
- self.t_b = min(self.time_nodes.keys())
1595
- self.t_e = max(self.time_nodes.keys())
1596
-
1597
- def read_from_mastodon_csv(self, path: str):
1598
- """
1599
- TODO: Write doc
1600
- """
1601
- spots = []
1602
- links = []
1603
- self.node_name = {}
1604
-
1605
- with open(path[0], encoding="utf-8", errors="ignore") as file:
1606
- csvreader = csv.reader(file)
1607
- for row in csvreader:
1608
- spots.append(row)
1609
- spots = spots[3:]
1610
-
1611
- with open(path[1], encoding="utf-8", errors="ignore") as file:
1612
- csvreader = csv.reader(file)
1613
- for row in csvreader:
1614
- links.append(row)
1615
- links = links[3:]
1616
-
1617
- for spot in spots:
1618
- unique_id = int(spot[1])
1619
- x, y, z = spot[5:8]
1620
- t = int(spot[4])
1621
- self.time_nodes.setdefault(t, set()).add(unique_id)
1622
- self.nodes.add(unique_id)
1623
- self.time[unique_id] = t
1624
- self.node_name[unique_id] = spot[1]
1625
- self.pos[unique_id] = np.array([x, y, z], dtype=float)
1626
-
1627
- for link in links:
1628
- source = int(float(link[4]))
1629
- target = int(float(link[5]))
1630
- self.predecessor.setdefault(target, []).append(source)
1631
- self.successor.setdefault(source, []).append(target)
1632
- self.time_edges.setdefault(self.time[source], set()).add(
1633
- (source, target)
1634
- )
1635
- self.t_b = min(self.time_nodes.keys())
1636
- self.t_e = max(self.time_nodes.keys())
1637
-
1638
- def read_from_mamut_xml(self, path: str):
1639
- """Read a lineage tree from a MaMuT xml.
1640
-
1641
- Args:
1642
- path (str): path to the MaMut xml
1643
- """
1644
- tree = ET.parse(path)
1645
- for elem in tree.getroot():
1646
- if elem.tag == "Model":
1647
- Model = elem
1648
- FeatureDeclarations, AllSpots, AllTracks, FilteredTracks = list(Model)
1649
-
1650
- for attr in self.xml_attributes:
1651
- self.__dict__[attr] = {}
1652
- self.time_nodes = {}
1653
- self.time_edges = {}
1654
- self.nodes = set()
1655
- self.pos = {}
1656
- self.time = {}
1657
- self.node_name = {}
1658
- for frame in AllSpots:
1659
- t = int(frame.attrib["frame"])
1660
- self.time_nodes[t] = set()
1661
- for cell in frame:
1662
- cell_id, n, x, y, z = (
1663
- int(cell.attrib["ID"]),
1664
- cell.attrib["name"],
1665
- float(cell.attrib["POSITION_X"]),
1666
- float(cell.attrib["POSITION_Y"]),
1667
- float(cell.attrib["POSITION_Z"]),
1668
- )
1669
- self.time_nodes[t].add(cell_id)
1670
- self.nodes.add(cell_id)
1671
- self.pos[cell_id] = np.array([x, y, z])
1672
- self.time[cell_id] = t
1673
- self.node_name[cell_id] = n
1674
- if "TISSUE_NAME" in cell.attrib:
1675
- if not hasattr(self, "fate"):
1676
- self.fate = {}
1677
- self.fate[cell_id] = cell.attrib["TISSUE_NAME"]
1678
- if "TISSUE_TYPE" in cell.attrib:
1679
- if not hasattr(self, "fate_nb"):
1680
- self.fate_nb = {}
1681
- self.fate_nb[cell_id] = eval(cell.attrib["TISSUE_TYPE"])
1682
- for attr in cell.attrib:
1683
- if attr in self.xml_attributes:
1684
- self.__dict__[attr][cell_id] = eval(cell.attrib[attr])
1685
-
1686
- tracks = {}
1687
- self.successor = {}
1688
- self.predecessor = {}
1689
- self.track_name = {}
1690
- for track in AllTracks:
1691
- if "TRACK_DURATION" in track.attrib:
1692
- t_id, _ = (
1693
- int(track.attrib["TRACK_ID"]),
1694
- float(track.attrib["TRACK_DURATION"]),
1695
- )
1696
- else:
1697
- t_id = int(track.attrib["TRACK_ID"])
1698
- t_name = track.attrib["name"]
1699
- tracks[t_id] = []
1700
- for edge in track:
1701
- s, t = (
1702
- int(edge.attrib["SPOT_SOURCE_ID"]),
1703
- int(edge.attrib["SPOT_TARGET_ID"]),
1704
- )
1705
- if s in self.nodes and t in self.nodes:
1706
- if self.time[s] > self.time[t]:
1707
- s, t = t, s
1708
- self.successor.setdefault(s, []).append(t)
1709
- self.predecessor.setdefault(t, []).append(s)
1710
- self.track_name[s] = t_name
1711
- self.track_name[t] = t_name
1712
- tracks[t_id].append((s, t))
1713
- self.t_b = min(self.time_nodes.keys())
1714
- self.t_e = max(self.time_nodes.keys())
1715
-
1716
953
  def to_binary(self, fname: str, starting_points: list = None):
1717
954
  """Writes the lineage tree (a forest) as a binary structure
1718
955
  (assuming it is a binary tree, it would not work for *n* ary tree with 2 < *n*).
@@ -2151,12 +1388,12 @@ class lineageTree:
2151
1388
  if not end_time:
2152
1389
  end_time = self.t_e
2153
1390
  branches = [self.get_successors(node)]
2154
- to_do = self[branches[0][-1]].copy()
1391
+ to_do = list(self[branches[0][-1]])
2155
1392
  while to_do:
2156
1393
  current = to_do.pop()
2157
- track = self.get_cycle(current, end_time=end_time)
1394
+ track = self.get_successors(current, end_time=end_time)
2158
1395
  branches += [track]
2159
- to_do.extend(self[track[-1]])
1396
+ to_do += self[track[-1]]
2160
1397
  return branches
2161
1398
 
2162
1399
  def get_all_tracks(self, force_recompute: bool = False) -> list:
@@ -2196,6 +1433,29 @@ class lineageTree:
2196
1433
  to_do.extend(self[track[-1]])
2197
1434
  return tracks
2198
1435
 
1436
+ def find_leaves(self, roots: Union[int, set, list, tuple]):
1437
+ """Finds the leaves of a tree spawned by one or more nodes.
1438
+
1439
+ Args:
1440
+ roots (Union[int,set,list,tuple]): The roots of the trees.
1441
+
1442
+ Returns:
1443
+ set: The leaves of one or more trees.
1444
+ """
1445
+ if not isinstance(roots, Iterable):
1446
+ to_do = [roots]
1447
+ elif isinstance(roots, Iterable):
1448
+ to_do = list(roots)
1449
+ leaves = set()
1450
+ while to_do:
1451
+ curr = to_do.pop()
1452
+ succ = self.successor.get(curr)
1453
+ if succ is not None:
1454
+ leaves.add(curr)
1455
+ else:
1456
+ to_do += succ
1457
+ return leaves
1458
+
2199
1459
  def get_sub_tree(
2200
1460
  self,
2201
1461
  x: Union[int, Iterable],
@@ -2314,9 +1574,7 @@ class lineageTree:
2314
1574
  list: A list that contains the array of eigenvalues and eigenvectors.
2315
1575
  """
2316
1576
  if time is None:
2317
- time = np.argmax(
2318
- [len(self.time_nodes[t]) for t in range(int(self.t_e))]
2319
- )
1577
+ time = max(self.time_nodes, key=lambda x: len(self.time_nodes[x]))
2320
1578
  pos = np.array([self.pos[node] for node in self.time_nodes[time]])
2321
1579
  pos = pos - np.mean(pos, axis=0)
2322
1580
  cov = np.cov(np.array(pos).T)
@@ -2510,56 +1768,99 @@ class lineageTree:
2510
1768
  tree1.get_norm(), tree2.get_norm()
2511
1769
  )
2512
1770
 
2513
- def to_simple_networkx(
2514
- self, node: Union[int, list, set, tuple] = None, start_time: int = 0
1771
+ def draw_tree_graph(
1772
+ self,
1773
+ hier,
1774
+ lnks_tms,
1775
+ selected_cells=None,
1776
+ color="magenta",
1777
+ ax=None,
1778
+ **kwargs,
2515
1779
  ):
2516
- """
2517
- Creates a simple networkx tree graph (every branch is a cell lifetime). This function is to be used for producing nx.graph objects(
2518
- they can be used for visualization or other tasks),
2519
- so only the start and the end of a branch are calculated, all cells in between are not taken into account.
1780
+ if selected_cells is None:
1781
+ selected_cells = []
1782
+ if ax is None:
1783
+ _, ax = plt.subplots()
1784
+ else:
1785
+ ax.clear()
1786
+
1787
+ selected_cells = set(selected_cells)
1788
+ hier_unselected = {
1789
+ k: v for k, v in hier.items() if k not in selected_cells
1790
+ }
1791
+ hier_selected = {k: v for k, v in hier.items() if k in selected_cells}
1792
+ unselected = np.array(tuple(hier_unselected.values()))
1793
+ x = []
1794
+ y = []
1795
+ if hier_unselected:
1796
+ for pred, succs in lnks_tms["links"].items():
1797
+ if pred not in selected_cells:
1798
+ for succ in succs:
1799
+ x.extend((hier[succ][0], hier[pred][0], None))
1800
+ y.extend((hier[succ][1], hier[pred][1], None))
1801
+ ax.plot(x, y, c="black", linewidth=0.3, zorder=0.5, **kwargs)
1802
+ ax.scatter(
1803
+ *unselected.T,
1804
+ s=0.1,
1805
+ c="black",
1806
+ zorder=1,
1807
+ **kwargs,
1808
+ )
1809
+ if selected_cells:
1810
+ selected = np.array(tuple(hier_selected.values()))
1811
+ x = []
1812
+ y = []
1813
+ for pred, succs in lnks_tms["links"].items():
1814
+ if pred in selected_cells:
1815
+ for succ in succs:
1816
+ x.extend((hier[succ][0], hier[pred][0], None))
1817
+ y.extend((hier[succ][1], hier[pred][1], None))
1818
+ ax.plot(x, y, c=color, linewidth=0.3, zorder=0.4, **kwargs)
1819
+ ax.scatter(
1820
+ selected.T[0],
1821
+ selected.T[1],
1822
+ s=0.1,
1823
+ c=color,
1824
+ zorder=0.9,
1825
+ **kwargs,
1826
+ )
1827
+ ax.get_yaxis().set_visible(False)
1828
+ ax.get_xaxis().set_visible(False)
1829
+ return ax
1830
+
1831
+ def to_simple_graph(self, node=None, start_time: int = None):
1832
+ """Generates a dictionary of graphs where the keys are the index of the graph and
1833
+ the values are the graphs themselves which are produced by create_links_and _cycles
1834
+
2520
1835
  Args:
2521
- start_time (int): From which timepoints are the graphs to be calculated.
2522
- For example if start_time is 10, then all trees that begin
2523
- on tp 10 or before are calculated.
2524
- returns:
2525
- G : list(nx.Digraph(),...)
2526
- pos : list(dict(id:position))
2527
- """
1836
+ node (_type_, optional): The id of the node/nodes to produce the simple graphs. Defaults to None.
1837
+ start_time (int, optional): Important only if there are no nodes it will produce the graph of every
1838
+ root that starts before or at start time. Defaults to None.
2528
1839
 
1840
+ Returns:
1841
+ (dict): The keys are just index values 0-n and the values are the graphs produced.
1842
+ """
1843
+ if start_time is None:
1844
+ start_time = self.t_b
2529
1845
  if node is None:
2530
1846
  mothers = [
2531
1847
  root for root in self.roots if self.time[root] <= start_time
2532
1848
  ]
2533
1849
  else:
2534
1850
  mothers = node if isinstance(node, (list, set)) else [node]
2535
- graph = {}
2536
- all_nodes = {}
2537
- all_edges = {}
2538
- for mom in mothers:
2539
- edges = set()
2540
- nodes = set()
2541
- for branch in self.get_all_branches_of_node(mom):
2542
- nodes.update((branch[0], branch[-1]))
2543
- if len(branch) > 1:
2544
- edges.add((branch[0], branch[-1]))
2545
- for suc in self[branch[-1]]:
2546
- edges.add((branch[-1], suc))
2547
- all_edges[mom] = edges
2548
- all_nodes[mom] = nodes
2549
- for i, mother in enumerate(mothers):
2550
- graph[i] = nx.DiGraph()
2551
- graph[i].add_nodes_from(all_nodes[mother])
2552
- graph[i].add_edges_from(all_edges[mother])
2553
-
2554
- return graph
1851
+ return {
1852
+ i: create_links_and_cycles(self, mother)
1853
+ for i, mother in enumerate(mothers)
1854
+ }
2555
1855
 
2556
1856
  def plot_all_lineages(
2557
1857
  self,
2558
- starting_point: int = 0,
1858
+ nodes: list = None,
1859
+ last_time_point_to_consider: int = None,
2559
1860
  nrows=2,
2560
1861
  figsize=(10, 15),
2561
- dpi=70,
2562
- fontsize=22,
1862
+ dpi=100,
1863
+ fontsize=15,
2563
1864
  figure=None,
2564
1865
  axes=None,
2565
1866
  **kwargs,
@@ -2567,78 +1868,98 @@ class lineageTree:
2567
1868
  """Plots all lineages.
2568
1869
 
2569
1870
  Args:
2570
- starting_point (int, optional): Which timepoints and upwards are the graphs to be calculated.
2571
- For example if start_time is 10, then all trees that begin
2572
- on tp 10 or before are calculated. Defaults to None.
1871
+ last_time_point_to_consider (int, optional): Which timepoints and upwards are the graphs to be plotted.
1872
+ For example if start_time is 10, then all trees that begin
1873
+ on tp 10 or before are calculated. Defaults to None, where
1874
+ it will plot all the roots that exist on self.t_b.
2573
1875
  nrows (int): How many rows of plots should be printed.
2574
- kwargs: args accepted by networkx
1876
+ kwargs: args accepted by matplotlib
2575
1877
  """
2576
1878
 
2577
1879
  nrows = int(nrows)
1880
+ if last_time_point_to_consider is None:
1881
+ last_time_point_to_consider = self.t_b
2578
1882
  if nrows < 1 or not nrows:
2579
1883
  nrows = 1
2580
1884
  raise Warning("Number of rows has to be at least 1")
2581
-
2582
- graphs = self.to_simple_networkx(start_time=starting_point)
1885
+ if nodes:
1886
+ graphs = {
1887
+ i: self.to_simple_graph(node) for i, node in enumerate(nodes)
1888
+ }
1889
+ else:
1890
+ graphs = self.to_simple_graph(
1891
+ start_time=last_time_point_to_consider
1892
+ )
1893
+ pos = {
1894
+ i: hierarchical_pos(
1895
+ g, g["root"], ycenter=-int(self.time[g["root"]])
1896
+ )
1897
+ for i, g in graphs.items()
1898
+ }
2583
1899
  ncols = int(len(graphs) // nrows) + (+np.sign(len(graphs) % nrows))
2584
- pos = postions_of_nx(self, graphs)
2585
1900
  figure, axes = plt.subplots(
2586
1901
  figsize=figsize, nrows=nrows, ncols=ncols, dpi=dpi, sharey=True
2587
1902
  )
2588
1903
  flat_axes = axes.flatten()
2589
1904
  ax2root = {}
2590
- for i, graph in enumerate(graphs.values()):
2591
- nx.draw_networkx(
2592
- graph,
2593
- pos[i],
2594
- with_labels=False,
2595
- arrows=False,
2596
- **kwargs,
2597
- ax=flat_axes[i],
1905
+ min_width, min_height = float("inf"), float("inf")
1906
+ for ax in flat_axes:
1907
+ bbox = ax.get_window_extent().transformed(
1908
+ figure.dpi_scale_trans.inverted()
2598
1909
  )
2599
- root = [n for n, d in graph.in_degree() if d == 0][0]
1910
+ min_width = min(min_width, bbox.width)
1911
+ min_height = min(min_height, bbox.height)
1912
+
1913
+ adjusted_fontsize = fontsize * min(min_width, min_height) / 5
1914
+ for i, graph in graphs.items():
1915
+ self.draw_tree_graph(
1916
+ hier=pos[i], lnks_tms=graph, ax=flat_axes[i], **kwargs
1917
+ )
1918
+ root = graph["root"]
1919
+ ax2root[flat_axes[i]] = root
2600
1920
  label = self.labels.get(root, "Unlabeled")
2601
1921
  xlim = flat_axes[i].get_xlim()
2602
1922
  ylim = flat_axes[i].get_ylim()
2603
- x_pos = (xlim[1]) / 10
2604
- y_pos = ylim[0] + 15
2605
- ax2root[flat_axes[i]] = root
1923
+ x_pos = (xlim[0] + xlim[1]) / 2
1924
+ y_pos = ylim[1] * 0.8
2606
1925
  flat_axes[i].text(
2607
1926
  x_pos,
2608
1927
  y_pos,
2609
1928
  label,
2610
- fontsize=fontsize,
1929
+ fontsize=adjusted_fontsize,
2611
1930
  color="black",
2612
1931
  ha="center",
2613
1932
  va="center",
2614
1933
  bbox={
2615
1934
  "facecolor": "white",
1935
+ "alpha": 0.5,
2616
1936
  "edgecolor": "green",
2617
- "boxstyle": "round",
2618
1937
  },
2619
1938
  )
2620
1939
  [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
2621
1940
  return figure, axes, ax2root
2622
1941
 
2623
- def plot_node(self, node, figsize=(4, 7), dpi=150, **kwargs):
1942
+ def plot_node(self, node, figsize=(4, 7), dpi=150, vert_gap=2, **kwargs):
2624
1943
  """Plots the subtree spawn by a node.
2625
1944
 
2626
1945
  Args:
2627
1946
  node (int): The id of the node that is going to be plotted.
2628
- kwargs: args accepted by networkx
1947
+ kwargs: args accepted by matplotlib
2629
1948
  """
2630
- graph = self.to_simple_networkx(node)
1949
+ graph = self.to_simple_graph(node)
2631
1950
  if len(graph) > 1:
2632
1951
  raise Warning("Please enter only one node")
2633
- graph = graph[list(graph)[0]]
2634
- figure, ax = plt.subplots(nrows=1, ncols=1)
2635
- nx.draw_networkx(
2636
- graph,
2637
- hierarchy_pos(graph, self, node),
2638
- with_labels=False,
2639
- arrows=False,
1952
+ graph = graph[0]
1953
+ figure, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, dpi=dpi)
1954
+ self.draw_tree_graph(
1955
+ hier=hierarchical_pos(
1956
+ graph,
1957
+ graph["root"],
1958
+ vert_gap=vert_gap,
1959
+ ycenter=-int(self.time[node]),
1960
+ ),
1961
+ lnks_tms=graph,
2640
1962
  ax=ax,
2641
- **kwargs,
2642
1963
  )
2643
1964
  return figure, ax
2644
1965