repare 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of repare might be problematic. Click here for more details.

repare/pedigree.py CHANGED
@@ -4,6 +4,7 @@ from collections import defaultdict, deque
4
4
 
5
5
  import matplotlib.pyplot as plt
6
6
  import networkx as nx
7
+ from matplotlib.colors import to_rgba
7
8
 
8
9
 
9
10
  class Pedigree:
@@ -12,12 +13,12 @@ class Pedigree:
12
13
  """
13
14
 
14
15
  def __init__(self) -> None:
15
- self.num_placeholders = 0
16
+ self.num_placeholders: int = 0
16
17
  self.node_to_data: dict[str, dict[str, str | bool | float]] = dict()
17
- self.node_to_father = defaultdict(str)
18
- self.node_to_mother = defaultdict(str)
19
- self.node_to_children = defaultdict(set)
20
- self.node_to_siblings = defaultdict(set)
18
+ self.node_to_father: dict[str, str] = dict()
19
+ self.node_to_mother: dict[str, str] = dict()
20
+ self.node_to_children: dict[str, set[str]] = dict()
21
+ self.node_to_siblings: dict[str, set[str]] = dict()
21
22
 
22
23
  def __deepcopy__(self, memo: dict) -> "Pedigree":
23
24
  """
@@ -34,10 +35,10 @@ class Pedigree:
34
35
 
35
36
  new_pedigree.node_to_father = self.node_to_father.copy()
36
37
  new_pedigree.node_to_mother = self.node_to_mother.copy()
37
- new_pedigree.node_to_children = defaultdict(set)
38
+ new_pedigree.node_to_children = dict()
38
39
  for k, v in self.node_to_children.items():
39
40
  new_pedigree.node_to_children[k] = v.copy()
40
- new_pedigree.node_to_siblings = defaultdict(set)
41
+ new_pedigree.node_to_siblings = dict()
41
42
  for k, v in self.node_to_siblings.items():
42
43
  new_pedigree.node_to_siblings[k] = v.copy()
43
44
  return new_pedigree
@@ -46,23 +47,22 @@ class Pedigree:
46
47
  """
47
48
  Gets pedigree topological sort of the Pedigree. See https://doi.org/10.1089/cmb.2011.0254.
48
49
  """
49
- leaf_nodes: list[str] = sorted([node for node in self.node_to_data if not self.node_to_children[node]])
50
+ leaf_nodes: list[str] = sorted([node for node in self.node_to_data if not self.get_children(node)])
50
51
  result: list[str] = []
51
52
 
52
- # DFS
53
- def visit(node: str) -> None:
53
+ def dfs(node: str) -> None:
54
54
  # No father == no mother
55
- if not self.node_to_father[node]:
56
- assert not self.node_to_mother[node]
55
+ if not self.get_father(node):
56
+ assert not self.get_mother(node)
57
57
  result.append(node)
58
58
  else:
59
59
  # Visit father first
60
- visit(self.node_to_father[node])
61
- visit(self.node_to_mother[node])
60
+ dfs(self.get_father(node))
61
+ dfs(self.get_mother(node))
62
62
  result.append(node)
63
63
 
64
64
  for node in leaf_nodes:
65
- visit(node)
65
+ dfs(node)
66
66
  # Break between each node's path
67
67
  result.append("")
68
68
 
@@ -72,15 +72,56 @@ class Pedigree:
72
72
  if node.isnumeric():
73
73
  if node not in placeholder_to_idx:
74
74
  placeholder_to_idx[node] = len(placeholder_to_idx)
75
- mt_haplogroup = self.node_to_data[node]["mt_haplogroup"]
76
- if self.node_to_data[node]["sex"] == "M":
77
- y_haplogroup = self.node_to_data[node]["y_haplogroup"]
75
+
76
+ node_data = self.get_data(node)
77
+ mt_haplogroup = node_data["mt_haplogroup"]
78
+ if node_data["sex"] == "M":
79
+ y_haplogroup = node_data["y_haplogroup"]
78
80
  # Unique identifier for placeholder
79
81
  result[i] = f"M {placeholder_to_idx[node]} {mt_haplogroup} {y_haplogroup}"
80
82
  else:
81
83
  result[i] = f"F {placeholder_to_idx[node]} {mt_haplogroup}"
82
84
  return tuple(result)
83
85
 
86
+ def get_data(self, node: str) -> dict[str, str | bool | float]:
87
+ """
88
+ Returns the data for the given node.
89
+ """
90
+ assert node in self.node_to_data
91
+ return self.node_to_data[node]
92
+
93
+ def get_father(self, node: str) -> str:
94
+ """
95
+ Returns the father of the given node.
96
+ If the node has no father, returns "".
97
+ """
98
+ assert node in self.node_to_data
99
+ return self.node_to_father.get(node, "")
100
+
101
+ def get_mother(self, node: str) -> str:
102
+ """
103
+ Returns the mother of the given node.
104
+ If the node has no mother, returns "".
105
+ """
106
+ assert node in self.node_to_data
107
+ return self.node_to_mother.get(node, "")
108
+
109
+ def get_children(self, node: str) -> set[str]:
110
+ """
111
+ Returns the children of the given node.
112
+ If the node has no children, returns an empty set.
113
+ """
114
+ assert node in self.node_to_data
115
+ return self.node_to_children.get(node, set())
116
+
117
+ def get_siblings(self, node: str) -> set[str]:
118
+ """
119
+ Returns the siblings of the given node.
120
+ If the node has no siblings, returns an empty set.
121
+ """
122
+ assert node in self.node_to_data
123
+ return self.node_to_siblings.get(node, set())
124
+
84
125
  def add_node(
85
126
  self,
86
127
  node_id: str,
@@ -94,7 +135,7 @@ class Pedigree:
94
135
  """
95
136
  Add a node to the pedigree. If haplogroup is unknown, set argument to empty string ("").
96
137
  """
97
- self.node_to_data[node_id]: dict[str, str | bool | float] = dict()
138
+ self.node_to_data[node_id] = dict()
98
139
  self.node_to_data[node_id]["sex"] = sex
99
140
  if y_haplogroup and sex == "F":
100
141
  raise ValueError("Only males can have y_haplogroup values.")
@@ -113,20 +154,33 @@ class Pedigree:
113
154
  assert node1 != node2
114
155
  assert node1 in self.node_to_data and node2 in self.node_to_data
115
156
 
116
- if self.node_to_data[node1]["sex"] == "M":
157
+ # Remove child from original parent
158
+ if self.get_data(node1)["sex"] == "M":
159
+ if self.get_father(node2):
160
+ self.node_to_children[self.get_father(node2)].remove(node2)
117
161
  self.node_to_father[node2] = node1
118
162
  else:
163
+ if self.get_mother(node2):
164
+ self.node_to_children[self.get_mother(node2)].remove(node2)
119
165
  self.node_to_mother[node2] = node1
166
+
167
+ for sibling in self.get_siblings(node2):
168
+ self.node_to_siblings[sibling].remove(node2)
169
+ if node2 in self.node_to_siblings:
170
+ del self.node_to_siblings[node2]
171
+
172
+ if node1 not in self.node_to_children:
173
+ self.node_to_children[node1] = set()
120
174
  self.node_to_children[node1].add(node2)
121
175
 
122
176
  # Add any sibling relations that are created
123
- node2_parents = [self.node_to_father[node2], self.node_to_mother[node2]]
177
+ node2_parents = [self.get_father(node2), self.get_mother(node2)]
124
178
  # Make sure Node 2 has 2 known parents
125
179
  if node2_parents[0] and node2_parents[1]:
126
- for node1_child in self.node_to_children[node1]:
180
+ for node1_child in self.get_children(node1):
127
181
  if (
128
182
  node1_child != node2
129
- and [self.node_to_father[node1_child], self.node_to_mother[node1_child]] == node2_parents
183
+ and [self.get_father(node1_child), self.get_mother(node1_child)] == node2_parents
130
184
  ):
131
185
  self.add_sibling_relation(node1_child, node2)
132
186
 
@@ -138,15 +192,19 @@ class Pedigree:
138
192
  assert node1 != node2
139
193
  assert node1 in self.node_to_data and node2 in self.node_to_data
140
194
 
195
+ if node1 not in self.node_to_siblings:
196
+ self.node_to_siblings[node1] = set()
197
+ if node2 not in self.node_to_siblings:
198
+ self.node_to_siblings[node2] = set()
141
199
  self.node_to_siblings[node1].add(node2)
142
200
  self.node_to_siblings[node2].add(node1)
143
201
 
144
202
  # Update siblings of siblings
145
- for node1_sibling in self.node_to_siblings[node1]:
203
+ for node1_sibling in self.get_siblings(node1):
146
204
  if node1_sibling != node2:
147
205
  self.node_to_siblings[node1_sibling].add(node2)
148
206
  self.node_to_siblings[node2].add(node1_sibling)
149
- for node2_sibling in self.node_to_siblings[node2]:
207
+ for node2_sibling in self.get_siblings(node2):
150
208
  if node1 != node2_sibling:
151
209
  self.node_to_siblings[node1].add(node2_sibling)
152
210
  self.node_to_siblings[node2_sibling].add(node1)
@@ -168,37 +226,46 @@ class Pedigree:
168
226
  name_to_discard = node2 if name_to_keep == node1 else node1
169
227
 
170
228
  # Update relations for relatives of the discarded node
171
- name_to_discard_father = self.node_to_father[name_to_discard]
172
- name_to_discard_mother = self.node_to_mother[name_to_discard]
229
+ name_to_discard_father = self.get_father(name_to_discard)
230
+ name_to_discard_mother = self.get_mother(name_to_discard)
173
231
  if name_to_discard_father:
232
+ assert self.get_children(name_to_discard_father)
174
233
  self.node_to_children[name_to_discard_father].remove(name_to_discard)
175
234
  if name_to_discard_mother:
235
+ assert self.get_children(name_to_discard_mother)
176
236
  self.node_to_children[name_to_discard_mother].remove(name_to_discard)
177
237
 
178
- for child in self.node_to_children[name_to_discard]:
238
+ name_to_discard_children = set()
239
+ for child in self.get_children(name_to_discard):
179
240
  # Merging a parent and child - we will see this when there is inbreeding
241
+ # Note: can’t merge a parent and a child if the child has siblings,
242
+ # because then the child becomes the parent of its siblings (this is
243
+ # handled by check_invalid_parent_child_merge)
180
244
  if name_to_keep == child:
181
- if self.node_to_data[name_to_keep]["sex"] == "M":
245
+ if self.get_data(name_to_keep)["sex"] == "M":
182
246
  del self.node_to_father[name_to_keep]
183
247
  else:
184
248
  del self.node_to_mother[name_to_keep]
185
249
  else:
186
- # This also handles having the correct name of the merged parents from last loop iteration
187
- self.add_parent_relation(name_to_keep, child)
250
+ name_to_discard_children.add(child)
251
+
252
+ for child in name_to_discard_children:
253
+ # This step also handles having the correct name of the merged parents from last loop iteration
254
+ self.add_parent_relation(name_to_keep, child)
188
255
 
189
256
  # Remove all occurrences of name_to_discard in its sibling's sibling sets first
190
- # so that self.add_sibling_relation() does not add it back in
191
- for sibling in self.node_to_siblings[name_to_discard]:
257
+ # so that add_sibling_relation does not add it back in.
258
+ for sibling in self.get_siblings(name_to_discard):
192
259
  self.node_to_siblings[sibling].remove(name_to_discard)
193
- for sibling in self.node_to_siblings[name_to_discard]:
260
+ for sibling in self.get_siblings(name_to_discard):
194
261
  if sibling != name_to_keep:
195
262
  self.add_sibling_relation(sibling, name_to_keep)
196
263
 
197
264
  # Recursively merge parent relations of Node 1 and Node 2
198
- father1 = self.node_to_father[name_to_keep]
199
- father2 = self.node_to_father[name_to_discard]
200
- mother1 = self.node_to_mother[name_to_keep]
201
- mother2 = self.node_to_mother[name_to_discard]
265
+ father1 = self.get_father(name_to_keep)
266
+ father2 = self.get_father(name_to_discard)
267
+ mother1 = self.get_mother(name_to_keep)
268
+ mother2 = self.get_mother(name_to_discard)
202
269
  if father1 and father2:
203
270
  pair_queue.append((father1, father2))
204
271
  elif father2:
@@ -225,69 +292,163 @@ class Pedigree:
225
292
  elif node2 == name_to_discard:
226
293
  pair_queue[idx] = (node1, name_to_keep)
227
294
 
228
- del self.node_to_data[name_to_discard]
229
- del self.node_to_father[name_to_discard]
230
- del self.node_to_mother[name_to_discard]
231
- del self.node_to_children[name_to_discard]
232
- del self.node_to_siblings[name_to_discard]
295
+ for data_dict in [
296
+ self.node_to_data,
297
+ self.node_to_father,
298
+ self.node_to_mother,
299
+ self.node_to_children,
300
+ self.node_to_siblings,
301
+ ]:
302
+ if name_to_discard in data_dict:
303
+ del data_dict[name_to_discard]
233
304
 
234
- def check_cycles_if_merged(self, node1: str, node2: str) -> bool:
305
+ def check_valid_merge(self, node1: str, node2: str) -> bool:
235
306
  """
236
- Returns True if merging Node 1 and Node 2 (and their ancestors) would result in a cycle.
307
+ Returns True if merging Node 1 and Node 2 (and their ancestors) is a valid operation.
237
308
  """
309
+ assert node1 in self.node_to_data and node2 in self.node_to_data
238
310
  # Get sets of nodes that would be merged if Node 1 and Node 2 were merged (i.e. ancestors of Node 1 and Node 2)
239
311
  # Note that we get sets and not just pairs because of potential inbreeding,
240
312
  # for example if one node is both a parent and grandparent of another node
241
313
  merge_sets: list[set[str]] = []
242
- merge_queue = deque([(node1, node2)])
314
+ merge_queue: deque[set[str]] = deque([set([node1, node2])])
315
+ included_nodes: set[str] = set()
316
+
243
317
  while merge_queue:
244
- curr_node1, curr_node2 = merge_queue.popleft()
318
+ curr_nodes = merge_queue.popleft()
319
+ # If all nodes are the same, skip
320
+ if len(set(curr_nodes)) == 1:
321
+ continue
322
+
245
323
  # Update merge sets
246
- if curr_node1 != curr_node2:
247
- updated = False
248
- for merge_set in merge_sets:
249
- if curr_node1 in merge_set or curr_node2 in merge_set:
250
- merge_set.update([curr_node1, curr_node2])
251
- updated = True
252
- break
253
- if not updated:
254
- merge_sets.append({curr_node1, curr_node2})
255
- # Add parents to the queue
256
- curr_father1 = self.node_to_father[curr_node1]
257
- curr_father2 = self.node_to_father[curr_node2]
258
- curr_mother1 = self.node_to_mother[curr_node1]
259
- curr_mother2 = self.node_to_mother[curr_node2]
260
- if curr_father1 and curr_father2:
261
- merge_queue.append((curr_father1, curr_father2))
262
- if curr_mother1 and curr_mother2:
263
- merge_queue.append((curr_mother1, curr_mother2))
324
+ updated = False
325
+ for merge_set in merge_sets:
326
+ if merge_set == curr_nodes:
327
+ updated = True
328
+ break
329
+
330
+ if any(node in merge_set for node in curr_nodes):
331
+ merge_set.update(curr_nodes)
332
+ # Include all merged nodes in the current set of nodes
333
+ curr_nodes = merge_set
334
+ updated = True
335
+ break
336
+ if not updated:
337
+ merge_sets.append(set(curr_nodes))
338
+ included_nodes.update(curr_nodes)
339
+
340
+ # Add parents to the queue
341
+ curr_fathers = set([self.get_father(node) for node in curr_nodes if self.get_father(node)])
342
+ curr_mothers = set([self.get_mother(node) for node in curr_nodes if self.get_mother(node)])
343
+ if len(curr_fathers) > 1 and not curr_fathers.issubset(included_nodes):
344
+ merge_queue.append(curr_fathers)
345
+ if len(curr_mothers) > 1 and not curr_mothers.issubset(included_nodes):
346
+ merge_queue.append(curr_mothers)
347
+
348
+ if self.check_named_node_merge(merge_sets):
349
+ return False
350
+
351
+ if self.check_invalid_parent_child_merge(merge_sets):
352
+ return False
353
+
354
+ if self.check_invalid_sibling_child_merge(merge_sets):
355
+ return False
356
+
357
+ if self.check_cycle_merge(merge_sets):
358
+ return False
359
+ return True
360
+
361
+ def check_named_node_merge(self, merge_sets: list[set[str]]) -> bool:
362
+ """
363
+ Returns True if merging Node 1 and Node 2 (and their ancestors) would result in
364
+ at least one named node being lost.
365
+ """
366
+ for nodes_to_merge in merge_sets:
367
+ named_nodes = [node for node in nodes_to_merge if not node.isnumeric()]
368
+ if len(named_nodes) > 1:
369
+ return True
370
+ return False
371
+
372
+ def check_invalid_parent_child_merge(self, merge_sets: list[set[str]]) -> bool:
373
+ """
374
+ Returns True if merging Node 1 and Node 2 (and their ancestors) would result in an invalid parent-child merge.
375
+ We can't merge a parent and a child where the child has siblings, because then the child becomes the parent
376
+ of its siblings which makes the pedigree internally inconsistent. merge_sets is a list of node sets that would
377
+ be merged if Node 1 and Node 2 were merged.
378
+ """
379
+ for nodes_to_merge in merge_sets:
380
+ for curr_node1 in nodes_to_merge:
381
+ for curr_node2 in nodes_to_merge:
382
+ if curr_node1 == curr_node2:
383
+ continue
384
+
385
+ # If Node 1 is a parent of Node 2, and Node 2 has siblings, don't merge
386
+ if self.get_father(curr_node2) == curr_node1 or self.get_mother(curr_node2) == curr_node1:
387
+ if self.get_siblings(curr_node2):
388
+ return True
389
+ # If Node 2 is a parent of Node 1, and Node 1 has siblings, don't merge
390
+ if self.get_father(curr_node1) == curr_node2 or self.get_mother(curr_node1) == curr_node2:
391
+ if self.get_siblings(curr_node1):
392
+ return True
393
+ return False
394
+
395
+ def check_invalid_sibling_child_merge(self, merge_sets: list[set[str]]) -> bool:
396
+ """
397
+ Returns True if merging Node 1 and Node 2 (and their ancestors) would result in an invalid sibling-child merge.
398
+ We can't merge a sibling and a child of a common node, because then the common node becomes parent and sibling
399
+ of the merged node which makes the pedigree internally inconsistent. merge_sets is a list of node sets that
400
+ would be merged if Node 1 and Node 2 were merged.
401
+ """
402
+ for nodes_to_merge in merge_sets:
403
+ for curr_node1 in nodes_to_merge:
404
+ for curr_node2 in nodes_to_merge:
405
+ if curr_node1 == curr_node2:
406
+ continue
407
+
408
+ # A third node is both a parent of Node 1 and a sibling of Node 2
409
+ if self.get_father(curr_node1) in self.get_siblings(curr_node2) or self.get_mother(
410
+ curr_node1
411
+ ) in self.get_siblings(curr_node2):
412
+ return True
413
+ # A third node is both a parent of Node 2 and a sibling of Node 1
414
+ if self.get_father(curr_node2) in self.get_siblings(curr_node1) or self.get_mother(
415
+ curr_node2
416
+ ) in self.get_siblings(curr_node1):
417
+ return True
418
+ return False
419
+
420
+ def check_cycle_merge(self, merge_sets: list[set[str]]) -> bool:
421
+ """
422
+ Returns True if merging Node 1 and Node 2 (and their ancestors) would result in a cycle.
423
+ merge_sets is a list of node sets that would be merged if Node 1 and Node 2 were merged.
424
+ """
264
425
 
265
426
  # DFS cycle detection
266
- def dfs(node):
267
- merged_nodes: set[str] | None = None
427
+ def dfs(node) -> bool:
428
+ nodes_to_merge: set[str] | None = None
268
429
  for merge_set in merge_sets:
269
430
  if node in merge_set:
270
- merged_nodes = merge_set
431
+ nodes_to_merge = merge_set
271
432
  break
272
433
 
273
- if merged_nodes:
434
+ if nodes_to_merge:
274
435
  if node in in_progress:
275
436
  return True
276
437
  if node in explored:
277
438
  return False
278
- in_progress.update(merged_nodes)
279
- for child in [child for node in merged_nodes for child in self.node_to_children[node]]:
439
+ in_progress.update(nodes_to_merge)
440
+ for child in [child for node in nodes_to_merge for child in self.get_children(node)]:
280
441
  if dfs(child):
281
442
  return True
282
- in_progress.difference_update(merged_nodes)
283
- explored.update(merged_nodes)
443
+ in_progress.difference_update(nodes_to_merge)
444
+ explored.update(nodes_to_merge)
284
445
  else:
285
446
  if node in in_progress:
286
447
  return True
287
448
  if node in explored:
288
449
  return False
289
450
  in_progress.add(node)
290
- for child in self.node_to_children[node]:
451
+ for child in self.get_children(node):
291
452
  if dfs(child):
292
453
  return True
293
454
  in_progress.remove(node)
@@ -310,8 +471,8 @@ class Pedigree:
310
471
  """
311
472
  assert node in self.node_to_data
312
473
 
313
- father = self.node_to_father[node]
314
- mother = self.node_to_mother[node]
474
+ father = self.get_father(node)
475
+ mother = self.get_mother(node)
315
476
 
316
477
  if not father:
317
478
  father_id = str(self.num_placeholders)
@@ -326,7 +487,7 @@ class Pedigree:
326
487
  )
327
488
 
328
489
  self.add_parent_relation(father_id, node)
329
- for sibling in self.node_to_siblings[node]:
490
+ for sibling in self.get_siblings(node):
330
491
  self.add_parent_relation(father_id, sibling)
331
492
  self.num_placeholders += 1
332
493
 
@@ -343,7 +504,7 @@ class Pedigree:
343
504
  )
344
505
 
345
506
  self.add_parent_relation(mother_id, node)
346
- for sibling in self.node_to_siblings[node]:
507
+ for sibling in self.get_siblings(node):
347
508
  self.add_parent_relation(mother_id, sibling)
348
509
  self.num_placeholders += 1
349
510
 
@@ -352,56 +513,96 @@ class Pedigree:
352
513
  Update haplogroups of placeholder nodes.
353
514
  """
354
515
  for node in self.node_to_data:
355
- y_haplogroup: str = self.node_to_data[node]["y_haplogroup"]
516
+ y_haplogroup: str = self.get_data(node)["y_haplogroup"]
356
517
  y_lineage: deque[str] = deque(
357
- [self.node_to_father[node]]
358
- + [child for child in self.node_to_children[node] if self.node_to_data[child]["sex"] == "M"]
518
+ [self.get_father(node)]
519
+ + [child for child in self.get_children(node) if self.get_data(child)["sex"] == "M"]
359
520
  )
360
521
 
361
522
  while y_lineage:
362
523
  curr_node = y_lineage.popleft()
363
524
  if (
364
525
  not curr_node
365
- or "*" not in self.node_to_data[curr_node]["y_haplogroup"]
366
- or self.node_to_data[curr_node]["y_haplogroup"].rstrip("*") == y_haplogroup.rstrip("*")
526
+ or "*" not in self.get_data(curr_node)["y_haplogroup"]
527
+ or self.get_data(curr_node)["y_haplogroup"].rstrip("*") == y_haplogroup.rstrip("*")
367
528
  ):
368
529
  continue
369
530
  # Overwrite/extend Y haplogroup if it contains a "*" and is a strict subset of the "leaf" haplogroup
370
- if y_haplogroup.startswith(self.node_to_data[curr_node]["y_haplogroup"].rstrip("*")):
531
+ if y_haplogroup.startswith(self.get_data(curr_node)["y_haplogroup"].rstrip("*")):
371
532
  self.node_to_data[curr_node]["y_haplogroup"] = (
372
533
  y_haplogroup if y_haplogroup[-1] == "*" else y_haplogroup + "*"
373
534
  )
374
- y_lineage.append(self.node_to_father[curr_node])
375
- for curr_node_child in self.node_to_children[curr_node]:
535
+ y_lineage.append(self.get_father(curr_node))
536
+ for curr_node_child in self.get_children(curr_node):
376
537
  # Only males have Y chromosome
377
- if self.node_to_data[curr_node_child]["sex"] == "M":
538
+ if self.get_data(curr_node_child)["sex"] == "M":
378
539
  y_lineage.append(curr_node_child)
379
540
 
380
- mt_haplogroup: str = self.node_to_data[node]["mt_haplogroup"]
381
- mt_lineage: deque[str] = deque([self.node_to_mother[node]])
541
+ mt_haplogroup: str = self.get_data(node)["mt_haplogroup"]
542
+ mt_lineage: deque[str] = deque([self.get_mother(node)])
382
543
  # Only females pass on mitochondrial DNA to children
383
- if self.node_to_data[node]["sex"] == "F":
384
- mt_lineage.extend(self.node_to_children[node])
544
+ if self.get_data(node)["sex"] == "F":
545
+ mt_lineage.extend(self.get_children(node))
385
546
 
386
547
  while mt_lineage:
387
548
  curr_node = mt_lineage.popleft()
388
549
  if (
389
550
  not curr_node
390
- or "*" not in self.node_to_data[curr_node]["mt_haplogroup"]
391
- or self.node_to_data[curr_node]["mt_haplogroup"].rstrip("*") == mt_haplogroup.rstrip("*")
551
+ or "*" not in self.get_data(curr_node)["mt_haplogroup"]
552
+ or self.get_data(curr_node)["mt_haplogroup"].rstrip("*") == mt_haplogroup.rstrip("*")
392
553
  ):
393
554
  continue
394
555
  # Overwrite/extend mitochondrial haplogroup if it contains a "*"
395
556
  # and is a strict subset of the "leaf" haplogroup
396
- if mt_haplogroup.startswith(self.node_to_data[curr_node]["mt_haplogroup"].rstrip("*")):
557
+ if mt_haplogroup.startswith(self.get_data(curr_node)["mt_haplogroup"].rstrip("*")):
397
558
  self.node_to_data[curr_node]["mt_haplogroup"] = (
398
559
  mt_haplogroup if mt_haplogroup[-1] == "*" else mt_haplogroup + "*"
399
560
  )
400
- mt_lineage.append(self.node_to_mother[curr_node])
401
- if self.node_to_data[curr_node]["sex"] == "F":
402
- mt_lineage.extend(self.node_to_children[curr_node])
561
+ mt_lineage.append(self.get_mother(curr_node))
562
+ if self.get_data(curr_node)["sex"] == "F":
563
+ mt_lineage.extend(self.get_children(curr_node))
403
564
 
404
- def validate_members(self, members: set) -> bool:
565
+ def validate_structure(self) -> bool:
566
+ """
567
+ Validates pedigree structure and consistency of internal data.
568
+ """
569
+ for child, father in self.node_to_father.items():
570
+ if child not in self.node_to_children[father]:
571
+ return False
572
+ if child == father:
573
+ return False
574
+
575
+ for child, mother in self.node_to_mother.items():
576
+ if child not in self.node_to_children[mother]:
577
+ return False
578
+ if child == mother:
579
+ return False
580
+
581
+ for parent, children in self.node_to_children.items():
582
+ for child in children:
583
+ if self.get_data(parent)["sex"] == "M":
584
+ if parent != self.node_to_father[child]:
585
+ return False
586
+ else:
587
+ if parent != self.node_to_mother[child]:
588
+ return False
589
+ if parent == child:
590
+ return False
591
+
592
+ for node, siblings in self.node_to_siblings.items():
593
+ for sibling in siblings:
594
+ if node not in self.node_to_siblings[sibling]:
595
+ return False
596
+ if (
597
+ self.node_to_father[node] != self.node_to_father[sibling]
598
+ or self.node_to_mother[node] != self.node_to_mother[sibling]
599
+ ):
600
+ return False
601
+ if node == sibling:
602
+ return False
603
+ return True
604
+
605
+ def validate_members(self, members: set[str]) -> bool:
405
606
  """
406
607
  Validates this tree based on the member nodes it should contain.
407
608
  """
@@ -427,15 +628,11 @@ class Pedigree:
427
628
  return haplogroup1.startswith(haplogroup2.rstrip("*"))
428
629
 
429
630
  for parent, child in self.get_parent_child_pairs():
430
- if self.node_to_data[parent]["sex"] == "F":
431
- if not haplogroups_agree(
432
- self.node_to_data[parent]["mt_haplogroup"], self.node_to_data[child]["mt_haplogroup"]
433
- ):
631
+ if self.get_data(parent)["sex"] == "F":
632
+ if not haplogroups_agree(self.get_data(parent)["mt_haplogroup"], self.get_data(child)["mt_haplogroup"]):
434
633
  return False
435
- elif self.node_to_data[parent]["sex"] == "M" and self.node_to_data[child]["sex"] == "M":
436
- if not haplogroups_agree(
437
- self.node_to_data[parent]["y_haplogroup"], self.node_to_data[child]["y_haplogroup"]
438
- ):
634
+ elif self.get_data(parent)["sex"] == "M" and self.get_data(child)["sex"] == "M":
635
+ if not haplogroups_agree(self.get_data(parent)["y_haplogroup"], self.get_data(child)["y_haplogroup"]):
439
636
  return False
440
637
  return True
441
638
 
@@ -444,7 +641,7 @@ class Pedigree:
444
641
  Validates that nodes that cannot have children do not have children.
445
642
  """
446
643
  for node in self.get_non_placeholder_nodes():
447
- if len(self.node_to_children[node]) > 0 and not self.node_to_data[node]["can_have_children"]:
644
+ if len(self.get_children(node)) > 0 and not self.get_data(node)["can_have_children"]:
448
645
  return False
449
646
  return True
450
647
 
@@ -454,9 +651,9 @@ class Pedigree:
454
651
  """
455
652
  related_pairs = self.get_related_pairs()
456
653
  for node in self.get_non_placeholder_nodes():
457
- if not self.node_to_data[node]["can_be_inbred"]:
458
- father = self.node_to_father[node]
459
- mother = self.node_to_mother[node]
654
+ if not self.get_data(node)["can_be_inbred"]:
655
+ father = self.get_father(node)
656
+ mother = self.get_mother(node)
460
657
  if (father, mother) in related_pairs or (mother, father) in related_pairs:
461
658
  return False
462
659
  return True
@@ -465,11 +662,10 @@ class Pedigree:
465
662
  """
466
663
  Validates that nodes do not postdate their descendants.
467
664
  """
468
- leaf_nodes: list[str] = [node for node in self.node_to_data if not self.node_to_children[node]]
665
+ leaf_nodes: list[str] = [node for node in self.node_to_data if not self.get_children(node)]
469
666
 
470
- # DFS
471
- def visit(node: str, curr_years_before_present: float) -> None:
472
- years_before_present = self.node_to_data[node]["years_before_present"]
667
+ def dfs(node: str, curr_years_before_present: float) -> bool:
668
+ years_before_present = self.get_data(node)["years_before_present"]
473
669
  if not math.isnan(years_before_present):
474
670
  # Node postdates its descendants
475
671
  if years_before_present < curr_years_before_present:
@@ -477,16 +673,16 @@ class Pedigree:
477
673
  else:
478
674
  curr_years_before_present = years_before_present
479
675
 
480
- if self.node_to_father[node]:
481
- assert self.node_to_mother[node]
482
- if not visit(self.node_to_father[node], curr_years_before_present):
676
+ if self.get_father(node):
677
+ assert self.get_mother(node)
678
+ if not dfs(self.get_father(node), curr_years_before_present):
483
679
  return False
484
- if not visit(self.node_to_mother[node], curr_years_before_present):
680
+ if not dfs(self.get_mother(node), curr_years_before_present):
485
681
  return False
486
682
  return True
487
683
 
488
684
  for node in leaf_nodes:
489
- if not visit(node, float("-inf")):
685
+ if not dfs(node, float("-inf")):
490
686
  return False
491
687
  return True
492
688
 
@@ -508,7 +704,7 @@ class Pedigree:
508
704
  pair_to_constraints: defaultdict[tuple[str, str], list[tuple[str, ...]]],
509
705
  pair_to_relations_so_far: defaultdict[tuple[str, str], list[tuple[str, str, bool]]],
510
706
  check_half_siblings: bool,
511
- ) -> tuple[int, list[tuple[str, str, str]]]:
707
+ ) -> tuple[int, list[tuple[str, str, str, str]]]:
512
708
  """
513
709
  Validates this tree based on the input relation data.
514
710
  If check_half_siblings is False, don't check for extraneous half-sibling relations
@@ -537,7 +733,9 @@ class Pedigree:
537
733
  pair_to_constraints_seen_entries[(node1, node2)].add(idx)
538
734
  break
539
735
 
540
- def validate_relation(node1: str, node2: str, relation: str, strike_log: list[str, str, str, str]) -> None:
736
+ def validate_relation(
737
+ node1: str, node2: str, relation: str, strike_log: list[tuple[str, str, str, str]]
738
+ ) -> None:
541
739
  relation_to_degree = {
542
740
  "parent-child": "1",
543
741
  "child-parent": "1",
@@ -575,7 +773,7 @@ class Pedigree:
575
773
  remove_relation_from_input_data(node1, node2, relation)
576
774
  remove_relation_from_input_data(node2, node1, flipped_relations[relation])
577
775
 
578
- strike_log: list[str, str, str, str] = [] # (node1, node2, +/- relation degree, constraints)
776
+ strike_log: list[tuple[str, str, str, str]] = [] # (node1, node2, +/- relation degree, constraints)
579
777
  # Check that relations in the pedigree are present in the input data
580
778
  for parent, child in self.get_parent_child_pairs(include_placeholders=False):
581
779
  validate_relation(parent, child, "parent-child", strike_log)
@@ -622,7 +820,7 @@ class Pedigree:
622
820
  if not self.is_relation_in_pedigree(node1, node2, constraints.split(";")):
623
821
  strike_log.append((node1, node2, f"-{degree}", constraints))
624
822
  else:
625
- pedigree_shared_relations: defaultdict(int) = self.get_relations_between_nodes(
823
+ pedigree_shared_relations: defaultdict[str, int] = self.get_relations_between_nodes(
626
824
  node1, node2, include_maternal_paternal=True
627
825
  )
628
826
  for degree, constraints, _ in degrees_constraints:
@@ -654,7 +852,7 @@ class Pedigree:
654
852
  node_pair_strike_balances[(node2, node1)] -= 1
655
853
  return strike_count, strike_log
656
854
 
657
- def count_third_degree_inconcistencies(
855
+ def count_third_degree_inconsistencies(
658
856
  self, pair_to_constraints: defaultdict[tuple[str, str], list[tuple[str, ...]]]
659
857
  ) -> int:
660
858
  """
@@ -727,92 +925,88 @@ class Pedigree:
727
925
 
728
926
  for relation in relations_list:
729
927
  if relation == "parent-child":
730
- if node2 in self.node_to_children[node1]:
928
+ if node2 in self.get_children(node1):
731
929
  return True
732
930
  if relation == "child-parent":
733
- if node1 in self.node_to_children[node2]:
931
+ if node1 in self.get_children(node2):
734
932
  return True
735
933
  if relation == "siblings":
736
- if node2 in self.node_to_siblings[node1]:
737
- assert node1 in self.node_to_siblings[node2]
934
+ if node2 in self.get_siblings(node1):
935
+ assert node1 in self.get_siblings(node2)
738
936
  return True
739
937
 
740
938
  if relation == "aunt/uncle-nephew/niece":
741
- for sibling in self.node_to_siblings[node1]:
742
- if node2 in self.node_to_children[sibling]:
939
+ for sibling in self.get_siblings(node1):
940
+ if node2 in self.get_children(sibling):
743
941
  return True
744
942
  if relation == "nephew/niece-aunt/uncle":
745
- for sibling in self.node_to_siblings[node2]:
746
- if node1 in self.node_to_children[sibling]:
943
+ for sibling in self.get_siblings(node2):
944
+ if node1 in self.get_children(sibling):
747
945
  return True
748
946
  if relation == "grandparent-grandchild":
749
- for child in self.node_to_children[node1]:
750
- if node2 in self.node_to_children[child]:
947
+ for child in self.get_children(node1):
948
+ if node2 in self.get_children(child):
751
949
  return True
752
950
  if relation == "grandchild-grandparent":
753
- for child in self.node_to_children[node2]:
754
- if node1 in self.node_to_children[child]:
951
+ for child in self.get_children(node2):
952
+ if node1 in self.get_children(child):
755
953
  return True
756
954
  if relation == "half-siblings":
757
- if self.node_to_father[node2]:
758
- if (
759
- node1 in self.node_to_children[self.node_to_father[node2]]
760
- and self.node_to_mother[node1] != self.node_to_mother[node2]
955
+ if self.get_father(node2):
956
+ if node1 in self.get_children(self.get_father(node2)) and self.get_mother(node1) != self.get_mother(
957
+ node2
761
958
  ):
762
959
  return True
763
- if self.node_to_mother[node2]:
764
- if (
765
- node1 in self.node_to_children[self.node_to_mother[node2]]
766
- and self.node_to_father[node1] != self.node_to_father[node2]
960
+ if self.get_mother(node2):
961
+ if node1 in self.get_children(self.get_mother(node2)) and self.get_father(node1) != self.get_father(
962
+ node2
767
963
  ):
768
964
  return True
769
965
 
770
966
  if relation == "maternal aunt/uncle-nephew/niece":
771
- for sibling in self.node_to_siblings[node1]:
772
- if self.node_to_data[sibling]["sex"] == "F" and node2 in self.node_to_children[sibling]:
967
+ for sibling in self.get_siblings(node1):
968
+ if self.get_data(sibling)["sex"] == "F" and node2 in self.get_children(sibling):
773
969
  return True
774
970
  if relation == "paternal aunt/uncle-nephew/niece":
775
- for sibling in self.node_to_siblings[node1]:
776
- if self.node_to_data[sibling]["sex"] == "M" and node2 in self.node_to_children[sibling]:
971
+ for sibling in self.get_siblings(node1):
972
+ if self.get_data(sibling)["sex"] == "M" and node2 in self.get_children(sibling):
777
973
  return True
778
974
  if relation == "maternal nephew/niece-aunt/uncle":
779
- for sibling in self.node_to_siblings[node2]:
780
- if self.node_to_data[sibling]["sex"] == "F" and node1 in self.node_to_children[sibling]:
975
+ for sibling in self.get_siblings(node2):
976
+ if self.get_data(sibling)["sex"] == "F" and node1 in self.get_children(sibling):
781
977
  return True
782
978
  if relation == "paternal nephew/niece-aunt/uncle":
783
- for sibling in self.node_to_siblings[node2]:
784
- if self.node_to_data[sibling]["sex"] == "M" and node1 in self.node_to_children[sibling]:
979
+ for sibling in self.get_siblings(node2):
980
+ if self.get_data(sibling)["sex"] == "M" and node1 in self.get_children(sibling):
785
981
  return True
786
982
 
787
983
  if relation == "maternal grandparent-grandchild":
788
- for child in self.node_to_children[node1]:
789
- if self.node_to_data[child]["sex"] == "F" and node2 in self.node_to_children[child]:
984
+ for child in self.get_children(node1):
985
+ if self.get_data(child)["sex"] == "F" and node2 in self.get_children(child):
790
986
  return True
791
987
  if relation == "paternal grandparent-grandchild":
792
- for child in self.node_to_children[node1]:
793
- if self.node_to_data[child]["sex"] == "M" and node2 in self.node_to_children[child]:
988
+ for child in self.get_children(node1):
989
+ if self.get_data(child)["sex"] == "M" and node2 in self.get_children(child):
794
990
  return True
795
991
  if relation == "maternal grandchild-grandparent":
796
- for child in self.node_to_children[node2]:
797
- if self.node_to_data[child]["sex"] == "F" and node1 in self.node_to_children[child]:
992
+ for child in self.get_children(node2):
993
+ if self.get_data(child)["sex"] == "F" and node1 in self.get_children(child):
798
994
  return True
799
995
  if relation == "paternal grandchild-grandparent":
800
- for child in self.node_to_children[node2]:
801
- if self.node_to_data[child]["sex"] == "M" and node1 in self.node_to_children[child]:
996
+ for child in self.get_children(node2):
997
+ if self.get_data(child)["sex"] == "M" and node1 in self.get_children(child):
802
998
  return True
803
999
 
804
1000
  if relation == "paternal half-siblings":
805
- if self.node_to_father[node2]:
806
- if (
807
- node1 in self.node_to_children[self.node_to_father[node2]]
808
- and self.node_to_mother[node1] != self.node_to_mother[node2]
1001
+ if self.get_father(node2):
1002
+ if node1 in self.get_children(self.get_father(node2)) and self.get_mother(node1) != self.get_mother(
1003
+ node2
809
1004
  ):
810
1005
  return True
811
1006
  if relation == "maternal half-siblings":
812
- if self.node_to_mother[node2]:
813
- if (
814
- node1 in self.node_to_children[self.node_to_mother[node2]]
815
- and self.node_to_father[node1] != self.node_to_father[node2]
1007
+ if self.get_mother(node2):
1008
+ if node1 in self.get_children(self.get_mother(node2)) and self.get_father(node1) != self.get_father(
1009
+ node2
816
1010
  ):
817
1011
  return True
818
1012
  return False
@@ -823,7 +1017,7 @@ class Pedigree:
823
1017
  """
824
1018
  Returns a dictionary of the *1st- and 2nd-degree* relations between node1 and node2.
825
1019
  """
826
- relations = defaultdict(int)
1020
+ relations: defaultdict[str, int] = defaultdict(int)
827
1021
  if self.is_relation_in_pedigree(node1, node2, ["parent-child"]):
828
1022
  relations["parent-child"] += 1
829
1023
  if self.is_relation_in_pedigree(node1, node2, ["child-parent"]):
@@ -896,7 +1090,7 @@ class Pedigree:
896
1090
  """
897
1091
  parent_child_pairs: list[tuple[str, str]] = []
898
1092
  for parent in self.node_to_children:
899
- for child in self.node_to_children[parent]:
1093
+ for child in self.get_children(parent):
900
1094
  if include_placeholders or (not parent.isnumeric() and not child.isnumeric()):
901
1095
  parent_child_pairs.append((parent, child))
902
1096
  return parent_child_pairs
@@ -908,7 +1102,7 @@ class Pedigree:
908
1102
  """
909
1103
  sibling_pairs: list[tuple[str, str]] = []
910
1104
  for sibling1 in self.node_to_siblings:
911
- for sibling2 in self.node_to_siblings[sibling1]:
1105
+ for sibling2 in self.get_siblings(sibling1):
912
1106
  if include_placeholders or (not sibling1.isnumeric() and not sibling2.isnumeric()):
913
1107
  # Don't add symmetric duplicates
914
1108
  if (sibling2, sibling1) not in sibling_pairs:
@@ -925,8 +1119,8 @@ class Pedigree:
925
1119
  """
926
1120
  aunt_uncle_nephew_niece_pairs: list[tuple[str, str]] = []
927
1121
  for parent, child in self.get_parent_child_pairs():
928
- for parent_sibling in self.node_to_siblings[parent]:
929
- if not shared_relative_sex or self.node_to_data[parent]["sex"] == shared_relative_sex:
1122
+ for parent_sibling in self.get_siblings(parent):
1123
+ if not shared_relative_sex or self.get_data(parent)["sex"] == shared_relative_sex:
930
1124
  if include_placeholders or (not parent_sibling.isnumeric() and not child.isnumeric()):
931
1125
  aunt_uncle_nephew_niece_pairs.append((parent_sibling, child))
932
1126
  return aunt_uncle_nephew_niece_pairs
@@ -940,8 +1134,8 @@ class Pedigree:
940
1134
  """
941
1135
  grandparent_grandchild_pairs: list[tuple[str, str]] = []
942
1136
  for parent, child in self.get_parent_child_pairs():
943
- for child_child in self.node_to_children[child]:
944
- if not shared_relative_sex or self.node_to_data[child]["sex"] == shared_relative_sex:
1137
+ for child_child in self.get_children(child):
1138
+ if not shared_relative_sex or self.get_data(child)["sex"] == shared_relative_sex:
945
1139
  if include_placeholders or (not parent.isnumeric() and not child_child.isnumeric()):
946
1140
  grandparent_grandchild_pairs.append((parent, child_child))
947
1141
  return grandparent_grandchild_pairs
@@ -954,9 +1148,9 @@ class Pedigree:
954
1148
  """
955
1149
  half_sibling_pairs: list[tuple[str, str]] = []
956
1150
  for parent, child in self.get_parent_child_pairs():
957
- for other_child in self.node_to_children[parent]:
958
- if child != other_child and other_child not in self.node_to_siblings[child]:
959
- if not shared_relative_sex or self.node_to_data[parent]["sex"] == shared_relative_sex:
1151
+ for other_child in self.get_children(parent):
1152
+ if child != other_child and other_child not in self.get_siblings(child):
1153
+ if not shared_relative_sex or self.get_data(parent)["sex"] == shared_relative_sex:
960
1154
  if include_placeholders or (not child.isnumeric() and not other_child.isnumeric()):
961
1155
  # Don't add symmetric duplicates
962
1156
  if (other_child, child) not in half_sibling_pairs:
@@ -969,12 +1163,12 @@ class Pedigree:
969
1163
  """
970
1164
  half_aunt_uncle_nephew_niece_pairs: list[tuple[str, str]] = []
971
1165
  for half_sibling1, half_sibling2 in self.get_half_sibling_pairs():
972
- for half_sibling1_child in self.node_to_children[half_sibling1]:
1166
+ for half_sibling1_child in self.get_children(half_sibling1):
973
1167
  if half_sibling1_child != half_sibling2:
974
1168
  if include_placeholders or (not half_sibling2.isnumeric() and not half_sibling1_child.isnumeric()):
975
1169
  half_aunt_uncle_nephew_niece_pairs.append((half_sibling2, half_sibling1_child))
976
1170
 
977
- for half_sibling2_child in self.node_to_children[half_sibling2]:
1171
+ for half_sibling2_child in self.get_children(half_sibling2):
978
1172
  if half_sibling2_child != half_sibling1:
979
1173
  if include_placeholders or (not half_sibling1.isnumeric() and not half_sibling2_child.isnumeric()):
980
1174
  half_aunt_uncle_nephew_niece_pairs.append((half_sibling1, half_sibling2_child))
@@ -986,7 +1180,7 @@ class Pedigree:
986
1180
  """
987
1181
  greatgrandparent_greatgrandchild_pairs: list[tuple[str, str]] = []
988
1182
  for grandparent, grandchild in self.get_grandparent_grandchild_pairs():
989
- for grandchild_child in self.node_to_children[grandchild]:
1183
+ for grandchild_child in self.get_children(grandchild):
990
1184
  if include_placeholders or (not grandparent.isnumeric() and not grandchild_child.isnumeric()):
991
1185
  greatgrandparent_greatgrandchild_pairs.append((grandparent, grandchild_child))
992
1186
  return greatgrandparent_greatgrandchild_pairs
@@ -999,7 +1193,7 @@ class Pedigree:
999
1193
  """
1000
1194
  grandaunt_granduncle_grandnephew_grandniece_pairs: list[tuple[str, str]] = []
1001
1195
  for aunt_uncle, nephew_niece in self.get_aunt_uncle_nephew_niece_pairs():
1002
- for nephew_niece_child in self.node_to_children[nephew_niece]:
1196
+ for nephew_niece_child in self.get_children(nephew_niece):
1003
1197
  if include_placeholders or (not aunt_uncle.isnumeric() and not nephew_niece_child.isnumeric()):
1004
1198
  grandaunt_granduncle_grandnephew_grandniece_pairs.append((aunt_uncle, nephew_niece_child))
1005
1199
  return grandaunt_granduncle_grandnephew_grandniece_pairs
@@ -1010,7 +1204,7 @@ class Pedigree:
1010
1204
  """
1011
1205
  cousin_pairs: list[tuple[str, str]] = []
1012
1206
  for aunt_uncle, child in self.get_aunt_uncle_nephew_niece_pairs():
1013
- for aunt_uncle_child in self.node_to_children[aunt_uncle]:
1207
+ for aunt_uncle_child in self.get_children(aunt_uncle):
1014
1208
  if include_placeholders or (not child.isnumeric() and not aunt_uncle_child.isnumeric()):
1015
1209
  # Don't add symmetric duplicates
1016
1210
  if aunt_uncle_child != child and (aunt_uncle_child, child) not in cousin_pairs:
@@ -1041,30 +1235,22 @@ class Pedigree:
1041
1235
  """
1042
1236
  return set([node for node in self.node_to_data if not node.isnumeric()])
1043
1237
 
1044
- def clean_up_relations(self) -> None:
1238
+ def clean_data(self) -> None:
1045
1239
  """
1046
1240
  Remove any empty entries in the relation dictionaries.
1047
1241
  Also remove unnecessary placeholder nodes to standardize topological sort output.
1048
1242
  """
1049
- for relation_dict in [self.node_to_father, self.node_to_mother, self.node_to_children, self.node_to_siblings]:
1050
- keys_to_remove = set()
1051
- for k, v in relation_dict.items():
1052
- if not v:
1053
- keys_to_remove.add(k)
1054
- for key_to_remove in keys_to_remove:
1055
- del relation_dict[key_to_remove]
1056
-
1057
1243
  placeholder_nodes_to_remove: set[str] = set()
1058
1244
  for node in self.node_to_data:
1059
- mother = self.node_to_mother[node]
1060
- father = self.node_to_father[node]
1245
+ mother = self.get_mother(node)
1246
+ father = self.get_father(node)
1061
1247
  if mother.isnumeric() and father.isnumeric():
1062
- if len(self.node_to_children[mother]) == 1 and len(self.node_to_children[father]) == 1:
1248
+ if len(self.get_children(mother)) == 1 and len(self.get_children(father)) == 1:
1063
1249
  if (
1064
- not self.node_to_mother[mother]
1065
- and not self.node_to_father[mother]
1066
- and not self.node_to_mother[father]
1067
- and not self.node_to_father[father]
1250
+ not self.get_mother(mother)
1251
+ and not self.get_father(mother)
1252
+ and not self.get_mother(father)
1253
+ and not self.get_father(father)
1068
1254
  ):
1069
1255
  placeholder_nodes_to_remove.add(mother)
1070
1256
  placeholder_nodes_to_remove.add(father)
@@ -1080,16 +1266,36 @@ class Pedigree:
1080
1266
  if node in data_dict:
1081
1267
  del data_dict[node]
1082
1268
 
1083
- for node in self.node_to_data:
1084
- assert node not in self.node_to_siblings[node] and node not in self.node_to_children[node]
1085
- if self.node_to_father[node] in placeholder_nodes_to_remove:
1086
- del self.node_to_father[node]
1087
- if self.node_to_mother[node] in placeholder_nodes_to_remove:
1088
- del self.node_to_mother[node]
1269
+ for node in self.node_to_data:
1270
+ assert node not in self.get_siblings(node) and node not in self.get_children(node)
1271
+ if self.get_father(node) in placeholder_nodes_to_remove:
1272
+ del self.node_to_father[node]
1273
+ if self.get_mother(node) in placeholder_nodes_to_remove:
1274
+ del self.node_to_mother[node]
1089
1275
 
1090
- def plot(self, path: str) -> None:
1276
+ for relation_dict in [self.node_to_father, self.node_to_mother, self.node_to_children, self.node_to_siblings]:
1277
+ keys_to_remove = set()
1278
+ for k, v in relation_dict.items():
1279
+ if not v:
1280
+ keys_to_remove.add(k)
1281
+ for key_to_remove in keys_to_remove:
1282
+ del relation_dict[key_to_remove]
1283
+
1284
+ def plot(
1285
+ self,
1286
+ path: str,
1287
+ mt_haplogroup_to_color: dict[str, str] | dict[str, tuple[float, float, float, float]] | None = None,
1288
+ nodes_to_remove: list[str] | None = None,
1289
+ edges_to_remove: list[tuple[str, str]] | None = None,
1290
+ dotted_edges_to_add: list[tuple[str, str]] | None = None,
1291
+ ) -> None:
1091
1292
  """
1092
- Plot the pedigree to the given path.
1293
+ Plot the pedigree to the given path. Optionally takes a custom mapping of mt_haplogroups to colors.
1294
+ Also optionally takes arguments to plot uncertain relations.
1295
+ nodes_to_remove is a list of nodes to remove from the plot.
1296
+ edges_to_remove is a list of parent-child edges to remove from the plot.
1297
+ dotted_edges_to_add is a list of node pairs to plot as dotted lines.
1298
+ These arguments can be used in conjunction to replace uncertain relations with dotted lines.
1093
1299
  """
1094
1300
  if not importlib.util.find_spec("pygraphviz"):
1095
1301
  raise ImportError("Plotting pedigree requires PyGraphviz (https://pygraphviz.github.io/).")
@@ -1099,54 +1305,76 @@ class Pedigree:
1099
1305
  for node in self.node_to_data:
1100
1306
  if node not in tree.nodes:
1101
1307
  tree.add_node(node)
1102
- male_named_nodes = [
1103
- node for node in self.node_to_data if self.node_to_data[node]["sex"] == "M" and not node.isnumeric()
1104
- ]
1105
- male_placeholder_nodes = [
1106
- node for node in self.node_to_data if self.node_to_data[node]["sex"] == "M" and node.isnumeric()
1107
- ]
1108
- female_named_nodes = [
1109
- node for node in self.node_to_data if self.node_to_data[node]["sex"] == "F" and not node.isnumeric()
1308
+
1309
+ # Replace relations with dotted edges
1310
+ if nodes_to_remove:
1311
+ tree.remove_nodes_from(nodes_to_remove)
1312
+ if edges_to_remove:
1313
+ tree.remove_edges_from(edges_to_remove)
1314
+ if dotted_edges_to_add:
1315
+ tree.add_edges_from(dotted_edges_to_add, style="dotted")
1316
+ parent_child_edges = [
1317
+ (u, v) for u, v, style in tree.edges.data("style", default="parent_child") if style == "parent_child"
1110
1318
  ]
1319
+ dotted_edges = [(u, v) for u, v, style in tree.edges.data("style", default="parent_child") if style == "dotted"]
1320
+
1321
+ male_named_nodes = [node for node in tree.nodes if self.get_data(node)["sex"] == "M" and not node.isnumeric()]
1322
+ male_placeholder_nodes = [node for node in tree.nodes if self.get_data(node)["sex"] == "M" and node.isnumeric()]
1323
+ female_named_nodes = [node for node in tree.nodes if self.get_data(node)["sex"] == "F" and not node.isnumeric()]
1111
1324
  female_placeholder_nodes = [
1112
- node for node in self.node_to_data if self.node_to_data[node]["sex"] == "F" and node.isnumeric()
1325
+ node for node in tree.nodes if self.get_data(node)["sex"] == "F" and node.isnumeric()
1113
1326
  ]
1114
1327
 
1115
- non_placeholder_labels = dict()
1328
+ node_labels = dict()
1116
1329
  for node in tree.nodes:
1117
- mt_haplogroup = self.node_to_data[node]["mt_haplogroup"][:3]
1118
- y_haplogroup = self.node_to_data[node]["y_haplogroup"][:3]
1330
+ mt_haplogroup = self.get_data(node)["mt_haplogroup"].replace("*", "")[:3]
1331
+ y_haplogroup = self.get_data(node)["y_haplogroup"].replace("*", "")[:3]
1119
1332
  if node.isnumeric():
1120
1333
  if y_haplogroup:
1121
- non_placeholder_labels[node] = f"{mt_haplogroup}\n{y_haplogroup}"
1334
+ node_labels[node] = f"MT: {mt_haplogroup}\nY: {y_haplogroup}"
1122
1335
  else:
1123
- non_placeholder_labels[node] = f"{mt_haplogroup}"
1336
+ node_labels[node] = f"MT: {mt_haplogroup}"
1124
1337
  else:
1125
1338
  if y_haplogroup:
1126
- non_placeholder_labels[node] = f"{node}\n{mt_haplogroup}\n{y_haplogroup}"
1339
+ node_labels[node] = f"{node}\nMT: {mt_haplogroup}\nY: {y_haplogroup}"
1127
1340
  else:
1128
- non_placeholder_labels[node] = f"{node}\n{mt_haplogroup}"
1341
+ node_labels[node] = f"{node}\nMT: {mt_haplogroup}"
1129
1342
 
1130
1343
  # Create colormap for MT haplogroups
1131
- cmap = plt.get_cmap("tab20")
1132
- mt_haplogroups = sorted(
1133
- set([self.node_to_data[node]["mt_haplogroup"] for node in self.node_to_data if not node.isnumeric()])
1134
- )
1135
- mt_haplogroup_to_color = {
1136
- haplogroup: cmap(i / len(mt_haplogroups)) for i, haplogroup in enumerate(mt_haplogroups)
1137
- }
1344
+ if not mt_haplogroup_to_color:
1345
+ cmap = plt.get_cmap("tab20")
1346
+ mt_haplogroups = sorted(
1347
+ set(
1348
+ [
1349
+ self.get_data(node)["mt_haplogroup"].replace("*", "")
1350
+ for node in self.node_to_data
1351
+ if not node.isnumeric()
1352
+ ]
1353
+ )
1354
+ )
1355
+ mt_haplogroup_to_color = {
1356
+ haplogroup: cmap(i / len(mt_haplogroups)) for i, haplogroup in enumerate(mt_haplogroups)
1357
+ }
1358
+
1359
+ # Specify alpha here instead of in nx.draw_networkx_nodes so node borders stay opaque
1360
+ face_alpha = 0.5
1138
1361
  male_named_node_colors = [
1139
- mt_haplogroup_to_color[self.node_to_data[node]["mt_haplogroup"]] for node in male_named_nodes
1362
+ to_rgba(mt_haplogroup_to_color[self.get_data(node)["mt_haplogroup"].replace("*", "")], face_alpha)
1363
+ for node in male_named_nodes
1140
1364
  ]
1141
1365
  female_named_node_colors = [
1142
- mt_haplogroup_to_color[self.node_to_data[node]["mt_haplogroup"]] for node in female_named_nodes
1366
+ to_rgba(mt_haplogroup_to_color[self.get_data(node)["mt_haplogroup"].replace("*", "")], face_alpha)
1367
+ for node in female_named_nodes
1143
1368
  ]
1369
+ male_placeholder_node_colors = [to_rgba("#e5e5e5", face_alpha) for node in male_placeholder_nodes]
1370
+ female_placeholder_node_colors = [to_rgba("#e5e5e5", face_alpha) for node in female_placeholder_nodes]
1144
1371
 
1145
1372
  plt.figure(figsize=(12, 4.8), dpi=1200)
1146
1373
  # Scale sizes based on pedigree node count
1147
- node_size = min(1000, 10000 / len(tree.nodes))
1374
+ node_size = min(1000, 9000 / len(tree.nodes))
1148
1375
  # Matplotlib doesn't allow font size less than 1
1149
- font_size = max(min(4.5, 150 / len(tree.nodes)), 1)
1376
+ font_size = max(math.sqrt(node_size) / 5, 1)
1377
+ line_width = math.sqrt(node_size) / 100
1150
1378
 
1151
1379
  pos = nx.nx_agraph.graphviz_layout(tree, prog="dot")
1152
1380
  nx.draw_networkx_nodes(
@@ -1157,7 +1385,7 @@ class Pedigree:
1157
1385
  node_size=node_size,
1158
1386
  node_color=male_named_node_colors,
1159
1387
  edgecolors="black",
1160
- linewidths=0.2,
1388
+ linewidths=line_width,
1161
1389
  )
1162
1390
  nx.draw_networkx_nodes(
1163
1391
  tree,
@@ -1167,7 +1395,7 @@ class Pedigree:
1167
1395
  node_size=node_size,
1168
1396
  node_color=female_named_node_colors,
1169
1397
  edgecolors="black",
1170
- linewidths=0.2,
1398
+ linewidths=line_width,
1171
1399
  )
1172
1400
  nx.draw_networkx_nodes(
1173
1401
  tree,
@@ -1175,9 +1403,9 @@ class Pedigree:
1175
1403
  nodelist=male_placeholder_nodes,
1176
1404
  node_shape="s",
1177
1405
  node_size=node_size,
1178
- node_color="#e5e5e5",
1406
+ node_color=male_placeholder_node_colors,
1179
1407
  edgecolors="black",
1180
- linewidths=0.2,
1408
+ linewidths=line_width,
1181
1409
  )
1182
1410
  nx.draw_networkx_nodes(
1183
1411
  tree,
@@ -1185,12 +1413,33 @@ class Pedigree:
1185
1413
  nodelist=female_placeholder_nodes,
1186
1414
  node_shape="o",
1187
1415
  node_size=node_size,
1188
- node_color="#e5e5e5",
1416
+ node_color=female_placeholder_node_colors,
1189
1417
  edgecolors="black",
1190
- linewidths=0.2,
1418
+ linewidths=line_width,
1419
+ )
1420
+ nx.draw_networkx_labels(tree, pos=pos, labels=node_labels, font_size=font_size)
1421
+ nx.draw_networkx_edges(
1422
+ tree,
1423
+ edgelist=parent_child_edges,
1424
+ pos=pos,
1425
+ node_shape="s",
1426
+ node_size=node_size,
1427
+ width=line_width,
1428
+ arrowsize=line_width * 30,
1429
+ edge_color="black",
1430
+ )
1431
+ # Setting arrows=False causes edges to overlap their associated nodes for some reason
1432
+ nx.draw_networkx_edges(
1433
+ tree,
1434
+ edgelist=dotted_edges,
1435
+ pos=pos,
1436
+ node_shape="s",
1437
+ node_size=node_size,
1438
+ width=line_width * 1.5,
1439
+ arrowstyle="-",
1440
+ style=(0, (3, 3)),
1441
+ edge_color="blue",
1191
1442
  )
1192
- nx.draw_networkx_labels(tree, pos=pos, labels=non_placeholder_labels, font_size=font_size)
1193
- nx.draw_networkx_edges(tree, pos=pos, node_shape="s", node_size=node_size, width=0.2, arrowsize=font_size * 1.2)
1194
1443
 
1195
1444
  plt.axis("off")
1196
1445
  plt.savefig(path, bbox_inches="tight")