repare 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of repare might be problematic. Click here for more details.

@@ -0,0 +1,874 @@
1
+ import copy
2
+ import logging
3
+ import os
4
+ import random
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import Any
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+ from repare.pedigree import Pedigree
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class PedigreeReconstructor:
18
+ """
19
+ Manages and builds up a collection of potential Pedigrees.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ relations_path: str,
25
+ nodes_path: str,
26
+ outputs_dir: str,
27
+ max_candidate_pedigrees: int = 1000,
28
+ epsilon: float = 0.2,
29
+ plot: bool = True,
30
+ write_alternate_pedigrees: bool = False,
31
+ random_seed: Any = 42,
32
+ ) -> None:
33
+ self._start_time = time.time()
34
+ self._validate_node_data(nodes_path)
35
+ self._process_node_data()
36
+ self._validate_relation_data(relations_path)
37
+ self._process_relation_data()
38
+
39
+ self._outputs_dir = outputs_dir
40
+ # Number of pedigrees to downsample to after each iteration of algorithm
41
+ self._max_candidate_pedigrees = max_candidate_pedigrees
42
+ # Parameter for epsilon-greedy sampling when pruning pedigrees
43
+ self._epsilon = epsilon
44
+ # Whether to plot the reconstructed pedigree(s)
45
+ self._plot = plot
46
+ # Whether to write corrected relations and plots of alternate final pedigrees
47
+ self._write_alternate_pedigrees = write_alternate_pedigrees
48
+ self._random_seed = random_seed
49
+ random.seed(self._random_seed)
50
+
51
+ # Maximum number of times to run the algorithm if no valid pedigree is found
52
+ self._MAX_RUNS = 10
53
+ self._candidate_pedigrees: list[Pedigree] = [self._get_initial_pedigree()]
54
+ self._pair_to_constraints: defaultdict[tuple[str, str], list[tuple[str, ...]]] = self._get_pair_to_constraints()
55
+ self._final_pedigree: Pedigree | None = None
56
+
57
+ def _validate_node_data(self, nodes_path: str) -> None:
58
+ """
59
+ Validate node data input.
60
+ """
61
+ self._node_data = pd.read_csv(nodes_path, dtype=str, comment="#", keep_default_na=False)
62
+ for mandatory_column in ["id", "sex", "y_haplogroup", "mt_haplogroup"]:
63
+ if mandatory_column not in self._node_data.columns:
64
+ raise ValueError(f'Column "{mandatory_column}" not found in input node data.')
65
+
66
+ for optional_column in ["can_have_children", "can_be_inbred", "years_before_present"]:
67
+ if optional_column not in self._node_data.columns:
68
+ self._node_data[optional_column] = ""
69
+
70
+ # Numeric IDs are used for placeholder nodes
71
+ if self._node_data["id"].str.isnumeric().any():
72
+ raise ValueError("Sample IDs cannot be completely numeric.")
73
+
74
+ if not self._node_data["sex"].isin(["M", "F"]).all():
75
+ raise ValueError('Node sex must be "M" or "F".')
76
+
77
+ for haplogroup_column in ["y_haplogroup", "mt_haplogroup"]:
78
+ for haplogroup in self._node_data[haplogroup_column]:
79
+ if "*" in haplogroup[:-1]:
80
+ raise ValueError(
81
+ "Expandable haplogroups should contain one trailing asterisk. "
82
+ "No other asterisks are allowed in haplogroups."
83
+ )
84
+
85
+ if not self._node_data["can_have_children"].isin(["True", "False", ""]).all():
86
+ raise ValueError('can_have_children value must be "True", "False", or empty.')
87
+ if not self._node_data["can_be_inbred"].isin(["True", "False", ""]).all():
88
+ raise ValueError('can_be_inbred value must be "True", "False", or empty.')
89
+ if not self._node_data["years_before_present"].apply(lambda x: x.isnumeric() or x == "").all():
90
+ raise ValueError("years_before_present value must be integer or empty.")
91
+
92
+ def _process_node_data(self) -> None:
93
+ """
94
+ Process node data input.
95
+ """
96
+ # Reorder node data columns and remove unnecessary columns
97
+ self._node_data = self._node_data[
98
+ ["id", "sex", "y_haplogroup", "mt_haplogroup", "can_have_children", "can_be_inbred", "years_before_present"]
99
+ ]
100
+ # Convert "can_have_children" and "can_be_inbred" columns to booleans
101
+ self._node_data["can_have_children"] = self._node_data["can_have_children"].map(
102
+ {"False": False, "True": True, "": True}
103
+ )
104
+ self._node_data["can_be_inbred"] = self._node_data["can_be_inbred"].map(
105
+ {"False": False, "True": True, "": True}
106
+ )
107
+ # Convert "years_before_present" column to floats
108
+ self._node_data["years_before_present"] = pd.to_numeric(
109
+ self._node_data["years_before_present"], errors="coerce"
110
+ )
111
+
112
+ def _validate_relation_data(self, relations_path: str) -> None:
113
+ """
114
+ Validate relation data input.
115
+ """
116
+ self._relation_data = pd.read_csv(relations_path, dtype=str, comment="#", keep_default_na=False)
117
+ for column_name in ["id1", "id2", "degree", "constraints"]:
118
+ if column_name not in self._relation_data.columns:
119
+ raise ValueError(f'Column "{column_name}" not found in input relation data.')
120
+
121
+ for optional_column in ["force_constraints"]:
122
+ if optional_column not in self._relation_data.columns:
123
+ self._relation_data[optional_column] = ""
124
+
125
+ excess_relation_nodes = set(self._relation_data["id1"]).union(set(self._relation_data["id2"])) - set(
126
+ self._node_data["id"]
127
+ )
128
+ if excess_relation_nodes:
129
+ raise ValueError(f"All node IDs in relation data must be present in node data: {excess_relation_nodes}.")
130
+
131
+ if not self._relation_data["degree"].isin(["1", "2", "3"]).all():
132
+ raise ValueError("Degree must be 1, 2, or 3.")
133
+ if not self._relation_data["force_constraints"].isin(["True", "False", ""]).all():
134
+ raise ValueError('can_have_children value must be "True", "False", or empty.')
135
+
136
+ self._relation_data["pair_degree"] = self._relation_data.apply(
137
+ lambda row: tuple(sorted([row["id1"], row["id2"], row["degree"]])), axis=1
138
+ )
139
+ grouped_relations = self._relation_data.groupby("pair_degree")
140
+ # Check for groups with multiple non-empty constraints, which can lead to issues when counting inconsistencies
141
+ invalid_groups = grouped_relations.filter(lambda group: (group["constraints"] != "").sum() > 1)
142
+ if not invalid_groups.empty:
143
+ raise ValueError("Node pairs cannot have multiple non-empty constraints of the same degree.")
144
+ self._relation_data.drop("pair_degree", axis=1, inplace=True)
145
+
146
+ def split_and_validate_constraints(constraints: str) -> None:
147
+ allowed_constraints = {
148
+ "parent-child",
149
+ "child-parent",
150
+ "siblings",
151
+ "maternal aunt/uncle-nephew/niece",
152
+ "maternal nephew/niece-aunt/uncle",
153
+ "paternal aunt/uncle-nephew/niece",
154
+ "paternal nephew/niece-aunt/uncle",
155
+ "maternal grandparent-grandchild",
156
+ "maternal grandchild-grandparent",
157
+ "paternal grandparent-grandchild",
158
+ "paternal grandchild-grandparent",
159
+ "maternal half-siblings",
160
+ "paternal half-siblings",
161
+ }
162
+ if constraints:
163
+ constraints_list = [c for c in constraints.split(";")]
164
+ if any(c not in allowed_constraints for c in constraints_list):
165
+ raise ValueError(
166
+ f"Invalid constraints found: {[c for c in constraints_list if c not in allowed_constraints]}."
167
+ )
168
+
169
+ self._relation_data["constraints"].apply(split_and_validate_constraints)
170
+
171
+ def _process_relation_data(self) -> None:
172
+ """
173
+ Process relation data input.
174
+ """
175
+ # Reorder relation data columns and remove unnecessary columns
176
+ self._relation_data = self._relation_data[["id1", "id2", "degree", "constraints", "force_constraints"]]
177
+ # Convert "force_constrains" column to booleans
178
+ self._relation_data["force_constraints"] = self._relation_data["force_constraints"].map(
179
+ {"False": False, "True": True, "": False}
180
+ )
181
+
182
+ def sort_nodes(row: pd.Series) -> pd.Series:
183
+ """
184
+ Ensure id1 and id2 are in a fixed (sorted) order and flip constraints as needed.
185
+ """
186
+ # Map constraints to their flipped value
187
+ flipped_constraints = {
188
+ "parent-child": "child-parent",
189
+ "child-parent": "parent-child",
190
+ "maternal aunt/uncle-nephew/niece": "maternal nephew/niece-aunt/uncle",
191
+ "paternal aunt/uncle-nephew/niece": "paternal nephew/niece-aunt/uncle",
192
+ "maternal nephew/niece-aunt/uncle": "maternal aunt/uncle-nephew/niece",
193
+ "paternal nephew/niece-aunt/uncle": "paternal aunt/uncle-nephew/niece",
194
+ "maternal grandparent-grandchild": "maternal grandchild-grandparent",
195
+ "paternal grandparent-grandchild": "paternal grandchild-grandparent",
196
+ "maternal grandchild-grandparent": "maternal grandparent-grandchild",
197
+ "paternal grandchild-grandparent": "paternal grandparent-grandchild",
198
+ "siblings": "siblings", # Symmetric
199
+ "maternal half-siblings": "maternal half-siblings", # Symmetric
200
+ "paternal half-siblings": "paternal half-siblings", # Symmetric
201
+ }
202
+ if row["id2"] < row["id1"]:
203
+ constraints = row["constraints"]
204
+ # Split constraints and map each to its flipped value
205
+ if constraints:
206
+ constraints_list = [c.strip() for c in constraints.split(";")]
207
+ flipped = [flipped_constraints[c] for c in constraints_list]
208
+ relation_flipped_constraints = ";".join(flipped)
209
+ else:
210
+ relation_flipped_constraints = ""
211
+ # Swap id1 and id2, and flip constraints
212
+ return pd.Series(
213
+ {
214
+ "id1": row["id2"],
215
+ "id2": row["id1"],
216
+ "degree": row["degree"],
217
+ "constraints": relation_flipped_constraints,
218
+ "force_constraints": row["force_constraints"],
219
+ }
220
+ )
221
+ else:
222
+ return row
223
+
224
+ self._relation_data = self._relation_data.apply(sort_nodes, axis=1)
225
+
226
+ # Note: We don't use maternal/paternal 3rd-degree relations because those are not well-defined
227
+ self._DEFAULT_CONSTRAINTS = {
228
+ "1": ("parent-child;child-parent;siblings"),
229
+ "2": (
230
+ "maternal aunt/uncle-nephew/niece;"
231
+ "maternal nephew/niece-aunt/uncle;"
232
+ "paternal aunt/uncle-nephew/niece;"
233
+ "paternal nephew/niece-aunt/uncle;"
234
+ "maternal grandparent-grandchild;"
235
+ "maternal grandchild-grandparent;"
236
+ "paternal grandparent-grandchild;"
237
+ "paternal grandchild-grandparent;"
238
+ "maternal half-siblings;"
239
+ "paternal half-siblings"
240
+ ),
241
+ "3": (
242
+ "half aunt/uncle-half nephew/niece;"
243
+ "half nephew/niece-half aunt/uncle;"
244
+ "greatgrandparent-greatgrandchild;"
245
+ "greatgrandchild-greatgrandparent;"
246
+ "grandaunt/granduncle-grandnephew/grandniece;"
247
+ "grandnephew/grandniece-grandaunt/granduncle;"
248
+ "first cousins"
249
+ ),
250
+ }
251
+
252
+ def fill_constraints(row: pd.Series) -> pd.Series:
253
+ if not row["constraints"]:
254
+ constraints = self._DEFAULT_CONSTRAINTS[row["degree"]]
255
+ return pd.Series(
256
+ {
257
+ "id1": row["id1"],
258
+ "id2": row["id2"],
259
+ "degree": row["degree"],
260
+ "constraints": constraints,
261
+ "force_constraints": row["force_constraints"],
262
+ }
263
+ )
264
+ return row
265
+
266
+ self._relation_data = self._relation_data.apply(fill_constraints, axis=1)
267
+
268
+ self._first_degree_relations = self._relation_data[self._relation_data["degree"] == "1"].reset_index(drop=True)
269
+ self._second_degree_relations = self._relation_data[self._relation_data["degree"] == "2"].reset_index(drop=True)
270
+ self._third_degree_relations = self._relation_data[self._relation_data["degree"] == "3"].reset_index(drop=True)
271
+ self._first_and_second_degree_relations = pd.concat(
272
+ [self._first_degree_relations, self._second_degree_relations]
273
+ ).reset_index(drop=True)
274
+ self._all_relations = pd.concat(
275
+ [self._first_degree_relations, self._second_degree_relations, self._third_degree_relations]
276
+ ).reset_index(drop=True)
277
+
278
+ def _shuffle_relations(self) -> None:
279
+ """
280
+ Shuffle relation DataFrames (when we want to restart the algorithm).
281
+ """
282
+ self._first_degree_relations = self._first_degree_relations.sample(
283
+ frac=1, random_state=self._random_seed
284
+ ).reset_index(drop=True)
285
+ self._second_degree_relations = self._second_degree_relations.sample(
286
+ frac=1, random_state=self._random_seed
287
+ ).reset_index(drop=True)
288
+ self._third_degree_relations = self._third_degree_relations.sample(
289
+ frac=1, random_state=self._random_seed
290
+ ).reset_index(drop=True)
291
+ self._first_and_second_degree_relations = pd.concat(
292
+ [self._first_degree_relations, self._second_degree_relations]
293
+ ).reset_index(drop=True)
294
+ self._all_relations = pd.concat(
295
+ [self._first_degree_relations, self._second_degree_relations, self._third_degree_relations]
296
+ ).reset_index(drop=True)
297
+
298
+ def _get_initial_pedigree(self):
299
+ """
300
+ Create the initial pedigree and add all nodes.
301
+ """
302
+ initial_pedigree = Pedigree()
303
+ for (
304
+ node_id,
305
+ sex,
306
+ y_haplogroup,
307
+ mt_haplogroup,
308
+ can_have_children,
309
+ can_be_inbred,
310
+ years_before_present,
311
+ ) in self._node_data.itertuples(index=False):
312
+ initial_pedigree.add_node(
313
+ node_id, sex, y_haplogroup, mt_haplogroup, can_have_children, can_be_inbred, years_before_present
314
+ )
315
+ return initial_pedigree
316
+
317
+ def find_best_pedigree(self) -> Pedigree:
318
+ """
319
+ Finds the configuration of relations that yields the "best" pedigree.
320
+ Writes to output_dir the set of relations with the least changes from the original input data.
321
+ """
322
+ for _ in range(self._MAX_RUNS):
323
+ progress_bar = tqdm(
324
+ self._first_and_second_degree_relations.iterrows(),
325
+ total=self._first_and_second_degree_relations.shape[0],
326
+ smoothing=0.5,
327
+ bar_format="{l_bar}{bar} | {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
328
+ )
329
+ for idx, row in progress_bar:
330
+ node1, node2, degree, constraints, force_constraints = row
331
+ logger.info(f"Current relation: {node1}, {node2}, {degree}")
332
+ progress_bar.set_description(f"Processing relation {{{node1}, {node2}, {degree}}}")
333
+ self._add_relation(
334
+ node1, node2, degree=degree, constraints=constraints, force_constraints=force_constraints
335
+ )
336
+ self._clean_relation_dicts()
337
+
338
+ processed_relations = self._all_relations.iloc[: idx + 1]
339
+ pair_to_relations_so_far = self._get_pair_to_relations_so_far(processed_relations)
340
+ if degree == "1" and len(processed_relations) < len(self._first_and_second_degree_relations):
341
+ self._prune_pedigrees(pair_to_relations_so_far, check_half_siblings=False)
342
+ else:
343
+ self._prune_pedigrees(pair_to_relations_so_far, check_half_siblings=True)
344
+ logger.info(
345
+ f"Remaining pedigrees after pruning: {len(self._candidate_pedigrees)}"
346
+ "\t\tElapsed: "
347
+ f"{round(time.time() - self._start_time, 1)} s\n"
348
+ )
349
+
350
+ if not self._final_pedigrees:
351
+ logger.warning("No valid pedigree found. Shuffling relations and restarting algorithm.\n")
352
+ self._candidate_pedigrees = [self._get_initial_pedigree()]
353
+ self._shuffle_relations()
354
+ else:
355
+ break
356
+
357
+ if not self._final_pedigrees:
358
+ logger.error(f"No valid pedigree found after {self._MAX_RUNS} runs. Exiting.")
359
+ raise RuntimeError(f"No valid pedigree found after {self._MAX_RUNS} runs.")
360
+
361
+ # Plot and write outputs of sample pedigree
362
+ sample_idx = random.randint(0, len(self._final_pedigrees) - 1)
363
+ self._sample_pedigree = self._final_pedigrees[sample_idx]
364
+ self._sample_strike_count = self._final_strike_counts[sample_idx]
365
+ self._sample_strike_log = self._final_strike_logs[sample_idx]
366
+ self._write_corrected_input_relations(
367
+ self._sample_strike_count,
368
+ self._sample_strike_log,
369
+ os.path.join(self._outputs_dir, "corrected_input_relations.csv"),
370
+ )
371
+ self._sample_pedigree.write_exact_relations(
372
+ os.path.join(self._outputs_dir, "reconstructed_exact_relations.csv")
373
+ )
374
+ if self._plot:
375
+ try:
376
+ self._sample_pedigree.plot(os.path.join(self._outputs_dir, "reconstructed_pedigree.png"))
377
+ pygraphviz_found = True
378
+ except ImportError:
379
+ logger.warning(
380
+ "PyGraphviz (https://pygraphviz.github.io/) must be installed to plot pedigrees. "
381
+ "Skipping plotting of reconstructed pedigree(s)."
382
+ )
383
+ pygraphviz_found = False
384
+
385
+ # Plot and write outputs of alternate pedigrees
386
+ if self._write_alternate_pedigrees:
387
+ os.makedirs(os.path.join(self._outputs_dir, "alternate_pedigrees"), exist_ok=True)
388
+ for idx, (pedigree, strike_count, strike_log) in enumerate(
389
+ zip(self._final_pedigrees, self._final_strike_counts, self._final_strike_logs)
390
+ ):
391
+ self._write_corrected_input_relations(
392
+ strike_count,
393
+ strike_log,
394
+ os.path.join(
395
+ self._outputs_dir, "alternate_pedigrees", f"pedigree_{idx}_corrected_input_relations.csv"
396
+ ),
397
+ )
398
+ pedigree.write_exact_relations(
399
+ os.path.join(self._outputs_dir, "alternate_pedigrees", f"pedigree_{idx}_exact_relations.csv")
400
+ )
401
+ if self._plot and pygraphviz_found:
402
+ pedigree.plot(os.path.join(self._outputs_dir, "alternate_pedigrees", f"pedigree_{idx}.png"))
403
+ return self._sample_pedigree
404
+
405
+ def _add_relation(self, node1: str, node2: str, degree: str, constraints: str, force_constraints: bool) -> None:
406
+ """
407
+ Connects two nodes in every pedigree.
408
+ """
409
+ assert degree in ["1", "2"]
410
+
411
+ new_pedigrees: list[Pedigree] = []
412
+ for pedigree in self._candidate_pedigrees:
413
+ if degree == "1":
414
+ if not force_constraints:
415
+ new_pedigrees.extend(
416
+ PedigreeReconstructor._connect_first_degree_relation(
417
+ pedigree, node1, node2, constraints=self._DEFAULT_CONSTRAINTS["1"]
418
+ )
419
+ )
420
+ new_pedigrees.extend(
421
+ PedigreeReconstructor._connect_second_degree_relation(
422
+ pedigree, node1, node2, constraints=self._DEFAULT_CONSTRAINTS["2"]
423
+ )
424
+ )
425
+ else:
426
+ new_pedigrees.extend(
427
+ PedigreeReconstructor._connect_first_degree_relation(
428
+ pedigree, node1, node2, constraints=constraints
429
+ )
430
+ )
431
+
432
+ elif degree == "2":
433
+ if not force_constraints:
434
+ new_pedigrees.append(pedigree) # No relation (i.e. false positive)
435
+ new_pedigrees.extend(
436
+ PedigreeReconstructor._connect_first_degree_relation(
437
+ pedigree, node1, node2, constraints=self._DEFAULT_CONSTRAINTS["1"]
438
+ )
439
+ )
440
+ new_pedigrees.extend(
441
+ PedigreeReconstructor._connect_second_degree_relation(
442
+ pedigree, node1, node2, constraints=self._DEFAULT_CONSTRAINTS["2"]
443
+ )
444
+ )
445
+ else:
446
+ new_pedigrees.extend(
447
+ PedigreeReconstructor._connect_second_degree_relation(
448
+ pedigree, node1, node2, constraints=constraints
449
+ )
450
+ )
451
+ self._candidate_pedigrees = new_pedigrees
452
+
453
+ @staticmethod
454
+ def _connect_first_degree_relation(pedigree: Pedigree, node1: str, node2: str, constraints: str) -> None:
455
+ """
456
+ Update pedigree with a first-degree relation.
457
+ """
458
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
459
+
460
+ new_pedigrees: list[Pedigree] = []
461
+ possible_relations: list[str] = constraints.split(";")
462
+
463
+ for relation in possible_relations:
464
+ if relation == "parent-child":
465
+ new_pedigrees.extend(PedigreeReconstructor._connect_parent_relation(pedigree, node1, node2))
466
+ if relation == "child-parent":
467
+ new_pedigrees.extend(PedigreeReconstructor._connect_parent_relation(pedigree, node2, node1))
468
+ if relation == "siblings":
469
+ new_pedigrees.extend(PedigreeReconstructor._connect_sibling_relation(pedigree, node1, node2))
470
+ return new_pedigrees
471
+
472
+ @staticmethod
473
+ def _connect_second_degree_relation(pedigree: Pedigree, node1: str, node2: str, constraints: str) -> None:
474
+ """
475
+ Update pedigree with a second-degree relation.
476
+ """
477
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
478
+
479
+ new_pedigrees: list[Pedigree] = []
480
+ possible_relations: list[str] = constraints.split(";")
481
+
482
+ for relation in possible_relations:
483
+ if relation == "maternal aunt/uncle-nephew/niece":
484
+ new_pedigrees.extend(
485
+ PedigreeReconstructor._connect_aunt_uncle_relation(pedigree, node1, node2, shared_relative_sex="F")
486
+ )
487
+ if relation == "maternal nephew/niece-aunt/uncle":
488
+ new_pedigrees.extend(
489
+ PedigreeReconstructor._connect_aunt_uncle_relation(pedigree, node2, node1, shared_relative_sex="F")
490
+ )
491
+ if relation == "paternal aunt/uncle-nephew/niece":
492
+ new_pedigrees.extend(
493
+ PedigreeReconstructor._connect_aunt_uncle_relation(pedigree, node1, node2, shared_relative_sex="M")
494
+ )
495
+ if relation == "paternal nephew/niece-aunt/uncle":
496
+ new_pedigrees.extend(
497
+ PedigreeReconstructor._connect_aunt_uncle_relation(pedigree, node2, node1, shared_relative_sex="M")
498
+ )
499
+
500
+ if relation == "maternal grandparent-grandchild":
501
+ new_pedigrees.extend(
502
+ PedigreeReconstructor._connect_grandparent_relation(pedigree, node1, node2, shared_relative_sex="F")
503
+ )
504
+ if relation == "maternal grandchild-grandparent":
505
+ new_pedigrees.extend(
506
+ PedigreeReconstructor._connect_grandparent_relation(pedigree, node2, node1, shared_relative_sex="F")
507
+ )
508
+ if relation == "paternal grandparent-grandchild":
509
+ new_pedigrees.extend(
510
+ PedigreeReconstructor._connect_grandparent_relation(pedigree, node1, node2, shared_relative_sex="M")
511
+ )
512
+ if relation == "paternal grandchild-grandparent":
513
+ new_pedigrees.extend(
514
+ PedigreeReconstructor._connect_grandparent_relation(pedigree, node2, node1, shared_relative_sex="M")
515
+ )
516
+
517
+ if relation == "maternal half-siblings":
518
+ new_pedigrees.extend(
519
+ PedigreeReconstructor._connect_half_sibling_relation(
520
+ pedigree, node1, node2, shared_relative_sex="F"
521
+ )
522
+ )
523
+ if relation == "paternal half-siblings":
524
+ new_pedigrees.extend(
525
+ PedigreeReconstructor._connect_half_sibling_relation(
526
+ pedigree, node1, node2, shared_relative_sex="M"
527
+ )
528
+ )
529
+ return new_pedigrees
530
+
531
+ @staticmethod
532
+ def _connect_parent_relation(pedigree: Pedigree, node1: str, node2: str) -> list[Pedigree]:
533
+ """
534
+ Adds a parent-child relation and merges nodes appropriately.
535
+ Returns a list containing the resulting Pedigree, if successful.
536
+ """
537
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
538
+
539
+ ret: list[Pedigree] = []
540
+ new_pedigree = copy.deepcopy(pedigree)
541
+ new_pedigree.fill_node_parents(node2)
542
+ original_parent: str
543
+ if new_pedigree.node_to_data[node1]["sex"] == "M":
544
+ original_parent = new_pedigree.node_to_father[node2]
545
+ else:
546
+ original_parent = new_pedigree.node_to_mother[node2]
547
+
548
+ if not new_pedigree.check_cycles_if_merged(node1, original_parent):
549
+ new_pedigree.merge_nodes(node1, original_parent)
550
+ ret.append(new_pedigree)
551
+ return ret
552
+
553
+ @staticmethod
554
+ def _connect_sibling_relation(pedigree: Pedigree, node1: str, node2: str) -> list[Pedigree]:
555
+ """
556
+ Adds a sibling relation and merges nodes appropriately.
557
+ Returns a list containing the resulting Pedigree, if successful.
558
+ """
559
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
560
+
561
+ ret: list[Pedigree] = []
562
+ new_pedigree = copy.deepcopy(pedigree)
563
+ new_pedigree.fill_node_parents(node1)
564
+ new_pedigree.fill_node_parents(node2)
565
+
566
+ father1 = new_pedigree.node_to_father[node1]
567
+ father2 = new_pedigree.node_to_father[node2]
568
+ if not new_pedigree.check_cycles_if_merged(father1, father2):
569
+ new_pedigree.merge_nodes(father1, father2)
570
+ mother1 = new_pedigree.node_to_mother[node1]
571
+ mother2 = new_pedigree.node_to_mother[node2]
572
+ if not new_pedigree.check_cycles_if_merged(mother1, mother2):
573
+ new_pedigree.merge_nodes(mother1, mother2)
574
+ new_pedigree.add_sibling_relation(node1, node2)
575
+ ret.append(new_pedigree)
576
+ return ret
577
+
578
+ @staticmethod
579
+ def _connect_aunt_uncle_relation(
580
+ pedigree: Pedigree, node1: str, node2: str, shared_relative_sex: str | None = None
581
+ ) -> list[Pedigree]:
582
+ """
583
+ Adds an aunt/uncle-nephew/niece relation and merges nodes appropriately.
584
+ Returns a list containing the resulting Pedigree(s), if successful.
585
+ """
586
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
587
+ assert shared_relative_sex in ["M", "F", None]
588
+
589
+ ret: list[Pedigree] = []
590
+ new_pedigree = copy.deepcopy(pedigree)
591
+ new_pedigree.fill_node_parents(node2)
592
+
593
+ node2_parents: list[str]
594
+ if shared_relative_sex == "M":
595
+ node2_parents = [new_pedigree.node_to_father[node2]]
596
+ elif shared_relative_sex == "F":
597
+ node2_parents = [new_pedigree.node_to_mother[node2]]
598
+ else:
599
+ node2_parents = [new_pedigree.node_to_father[node2], new_pedigree.node_to_mother[node2]]
600
+
601
+ for node2_parent in node2_parents:
602
+ if node1 != node2_parent:
603
+ ret.extend(PedigreeReconstructor._connect_sibling_relation(new_pedigree, node1, node2_parent))
604
+ return ret
605
+
606
+ @staticmethod
607
+ def _connect_grandparent_relation(
608
+ pedigree: Pedigree, node1: str, node2: str, shared_relative_sex: str | None = None
609
+ ) -> list[Pedigree]:
610
+ """
611
+ Adds a grandparent-grandchild relation and merges nodes appropriately.
612
+ Returns a list containing the resulting Pedigree(s), if successful.
613
+ """
614
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
615
+ assert shared_relative_sex in ["M", "F", None]
616
+
617
+ ret: list[Pedigree] = []
618
+ new_pedigree = copy.deepcopy(pedigree)
619
+ new_pedigree.fill_node_parents(node2)
620
+
621
+ node2_parents: list[str]
622
+ if shared_relative_sex == "M":
623
+ node2_parents = [new_pedigree.node_to_father[node2]]
624
+ elif shared_relative_sex == "F":
625
+ node2_parents = [new_pedigree.node_to_mother[node2]]
626
+ else:
627
+ node2_parents = [new_pedigree.node_to_father[node2], new_pedigree.node_to_mother[node2]]
628
+
629
+ for node2_parent in node2_parents:
630
+ if node1 != node2_parent:
631
+ ret.extend(PedigreeReconstructor._connect_parent_relation(new_pedigree, node1, node2_parent))
632
+ return ret
633
+
634
+ @staticmethod
635
+ def _connect_half_sibling_relation(
636
+ pedigree: Pedigree, node1: str, node2: str, shared_relative_sex: str | None = None
637
+ ) -> list[Pedigree]:
638
+ """
639
+ Adds a half-sibling relation and merges nodes appropriately.
640
+ Returns a list containing the resulting Pedigree(s), if successful.
641
+ """
642
+ assert node1 in pedigree.node_to_data and node2 in pedigree.node_to_data
643
+
644
+ ret: list[Pedigree] = []
645
+ new_pedigree = copy.deepcopy(pedigree)
646
+ new_pedigree.fill_node_parents(node1)
647
+ new_pedigree.fill_node_parents(node2)
648
+
649
+ node1_parents: list[str]
650
+ node2_parents: list[str]
651
+ if shared_relative_sex == "M":
652
+ node1_parents = [new_pedigree.node_to_father[node1]]
653
+ node2_parents = [new_pedigree.node_to_father[node2]]
654
+ elif shared_relative_sex == "F":
655
+ node1_parents = [new_pedigree.node_to_mother[node1]]
656
+ node2_parents = [new_pedigree.node_to_mother[node2]]
657
+ else:
658
+ node1_parents = [new_pedigree.node_to_father[node1], new_pedigree.node_to_mother[node1]]
659
+ node2_parents = [new_pedigree.node_to_father[node2], new_pedigree.node_to_mother[node2]]
660
+
661
+ # Node 1 and Node 2 are half-siblings via one of Node 1's parents
662
+ for node1_parent in node1_parents:
663
+ if node1_parent != node2:
664
+ ret.extend(PedigreeReconstructor._connect_parent_relation(new_pedigree, node1_parent, node2))
665
+ # Node 1 and Node 2 are half-siblings via one of Node 2's parents
666
+ for node2_parent in node2_parents:
667
+ if node2_parent != node1:
668
+ ret.extend(PedigreeReconstructor._connect_parent_relation(new_pedigree, node2_parent, node1))
669
+ return ret
670
+
671
+ def _clean_relation_dicts(self) -> None:
672
+ """
673
+ Remove unnecessary entries in Pedigree dicts.
674
+ """
675
+ for pedigree in self._candidate_pedigrees:
676
+ pedigree.clean_up_relations()
677
+
678
+ def _get_pair_to_constraints(self) -> defaultdict[tuple[str, str], list[tuple[str, ...]]]:
679
+ """
680
+ Turn DataFrame of relations/constraints into dict(s) of {node pairs: list of possible relations}.
681
+ Dict values are lists of tuples (as opposed to just tuples)
682
+ because a pair of nodes can share more than 1 relation.
683
+ """
684
+ pair_to_constraints: defaultdict[tuple[str, str], list[tuple[str, ...]]] = defaultdict(list)
685
+ for node1, node2, _, constraints, _ in self._all_relations.itertuples(index=False):
686
+ pair_to_constraints[(node1, node2)].append(tuple(constraints.split(";")))
687
+ for node_pair in pair_to_constraints:
688
+ # Sort by number of constraints so specific constraints are checked first when pruning
689
+ pair_to_constraints[node_pair].sort(key=lambda x: len(x))
690
+ return pair_to_constraints
691
+
692
+ def _get_pair_to_relations_so_far(
693
+ self, processed_relations: pd.DataFrame
694
+ ) -> defaultdict[tuple[str, str], list[tuple[str, str, bool]]]:
695
+ """
696
+ Turn DataFrame of relations/constraints processed so far
697
+ into dict(s) of {node pairs: list of (degree, constraints) tuples}.
698
+ """
699
+ pair_to_relations_so_far: defaultdict[tuple[str, str], list[tuple[str, str, bool]]] = defaultdict(list)
700
+ for node1, node2, degree, constraints, force_constraints in processed_relations.itertuples(index=False):
701
+ pair_to_relations_so_far[(node1, node2)].append((degree, constraints, force_constraints))
702
+ return pair_to_relations_so_far
703
+
704
+ def _prune_pedigrees(
705
+ self,
706
+ pair_to_relations_so_far: defaultdict[tuple[str, str], list[tuple[str, str, bool]]],
707
+ check_half_siblings: bool,
708
+ ) -> None:
709
+ """
710
+ Remove pedigrees with inconsistencies.
711
+ """
712
+ seen_topologies = set()
713
+ new_potential_pedigrees = []
714
+ for pedigree in self._candidate_pedigrees:
715
+ if (
716
+ pedigree.validate_members(set(self._node_data["id"]))
717
+ and pedigree.validate_can_have_children()
718
+ and pedigree.validate_inbreeding()
719
+ and pedigree.validate_years_before_present()
720
+ and pedigree.validate_forced_constraints(pair_to_relations_so_far)
721
+ ):
722
+ pedigree.update_haplogroups()
723
+ if pedigree.validate_haplogroups():
724
+ topology = pedigree.get_topo_sort()
725
+ if topology not in seen_topologies:
726
+ new_potential_pedigrees.append(pedigree)
727
+ seen_topologies.add(topology)
728
+ # Shuffle to avoid ordering bias in epsilon-greedy sampling
729
+ random.shuffle(new_potential_pedigrees)
730
+
731
+ strikes = []
732
+ third_degree_strikes = []
733
+ counts = defaultdict(int)
734
+ for pedigree in new_potential_pedigrees:
735
+ num_strikes, _ = pedigree.count_inconsistencies(
736
+ self._pair_to_constraints, pair_to_relations_so_far, check_half_siblings
737
+ )
738
+ num_third_degree_strikes = pedigree.count_third_degree_inconcistencies(self._pair_to_constraints)
739
+ strikes.append(num_strikes)
740
+ third_degree_strikes.append(num_third_degree_strikes)
741
+ counts[num_strikes] += 1
742
+ logger.info(f"Strike counts before pruning: {str(dict(sorted(counts.items())))}")
743
+
744
+ def epsilon_greedy_sample(
745
+ pedigrees: list[Pedigree],
746
+ strikes: list[int],
747
+ third_degree_strikes: list[int],
748
+ epsilon: float,
749
+ max_candidate_pedigrees: int,
750
+ ) -> list[Pedigree]:
751
+ assert len(pedigrees) == len(strikes)
752
+ if len(pedigrees) <= max_candidate_pedigrees:
753
+ return pedigrees
754
+
755
+ sorted_pedigrees = [
756
+ pedigree
757
+ for pedigree, _, _ in sorted(zip(pedigrees, strikes, third_degree_strikes), key=lambda x: (x[1], x[2]))
758
+ ]
759
+ exploitation_max_candidate_pedigrees = int((1 - epsilon) * max_candidate_pedigrees)
760
+ exploration_max_candidate_pedigrees = max_candidate_pedigrees - exploitation_max_candidate_pedigrees
761
+
762
+ exploitation_pedigrees = sorted_pedigrees[:exploitation_max_candidate_pedigrees]
763
+ exploration_pedigrees = random.sample(
764
+ sorted_pedigrees[exploitation_max_candidate_pedigrees:], exploration_max_candidate_pedigrees
765
+ )
766
+ return exploitation_pedigrees + exploration_pedigrees
767
+
768
+ num_processed_relations = sum(len(relations) for relations in pair_to_relations_so_far.values())
769
+ if num_processed_relations < len(self._first_and_second_degree_relations):
770
+ self._candidate_pedigrees = epsilon_greedy_sample(
771
+ new_potential_pedigrees,
772
+ strikes,
773
+ third_degree_strikes,
774
+ epsilon=self._epsilon,
775
+ max_candidate_pedigrees=self._max_candidate_pedigrees,
776
+ )
777
+ else:
778
+ # Final iteration
779
+ best_pedigrees = [
780
+ pedigree
781
+ for pedigree, num_strikes in zip(new_potential_pedigrees, strikes)
782
+ if num_strikes == min(strikes)
783
+ ]
784
+ # Use 3rd-degree strikes as tiebreaker
785
+ third_degree_strikes = [
786
+ pedigree.count_third_degree_inconcistencies(self._pair_to_constraints) for pedigree in best_pedigrees
787
+ ]
788
+
789
+ self._final_pedigrees = [
790
+ pedigree
791
+ for pedigree, num_strikes in zip(best_pedigrees, third_degree_strikes)
792
+ if num_strikes == min(third_degree_strikes)
793
+ ]
794
+ self._final_strike_counts = []
795
+ self._final_strike_logs = []
796
+ for pedigree in self._final_pedigrees:
797
+ strike_count, strike_log = pedigree.count_inconsistencies(
798
+ self._pair_to_constraints, pair_to_relations_so_far, check_half_siblings=True
799
+ )
800
+ self._final_strike_counts.append(strike_count)
801
+ self._final_strike_logs.append(strike_log)
802
+ pedigree.clean_up_relations()
803
+
804
+ def _write_corrected_input_relations(
805
+ self, strike_count: int, strike_log: list[tuple[str, str, str]], path: str
806
+ ) -> None:
807
+ """
808
+ Write corrected input relations to file. Includes information about added/removed/changed input relations.
809
+ """
810
+ added_relations = []
811
+ removed_relations = []
812
+ for node1, node2, degree, constraints in strike_log:
813
+ if degree[0] == "+":
814
+ added_relations.append((node1, node2, degree[1], constraints))
815
+ else:
816
+ removed_relations.append((node1, node2, degree[1], constraints))
817
+ removed_relations_set = set(removed_relations)
818
+
819
+ # Separate out *changed* relations (added relation + removed relation pair, e.g., 1st-degree -> 2nd-degree)
820
+ changed_node_pairs = set()
821
+ for add_node1, add_node2, _, _ in added_relations:
822
+ for remove_node1, remove_node2, _, _ in removed_relations:
823
+ if (add_node1 == remove_node1 and add_node2 == remove_node2) or (
824
+ add_node2 == remove_node1 and add_node1 == remove_node2
825
+ ):
826
+ changed_node_pairs.add((add_node1, add_node2))
827
+
828
+ with open(path, "w") as file:
829
+ file.write("id1,id2,degree,constraints\n") # Header line
830
+ file.write(f"# Final strike count: {strike_count}\n")
831
+
832
+ def write_relations_line(node1, node2, degree, constraints, commented=False):
833
+ if constraints == self._DEFAULT_CONSTRAINTS[degree]:
834
+ # Don't write default constraints to file
835
+ constraints = ""
836
+ if commented:
837
+ file.write("# ")
838
+ file.write(f"{node1},{node2},{degree},{constraints}\n")
839
+
840
+ file.write("# Added relations\n")
841
+ # Sort for consistency
842
+ for node1, node2, degree, constraints in sorted(added_relations):
843
+ if (node1, node2) not in changed_node_pairs and (node2, node1) not in changed_node_pairs:
844
+ write_relations_line(node1, node2, degree, constraints)
845
+
846
+ file.write("\n# Removed relations\n")
847
+ for node1, node2, degree, constraints in sorted(removed_relations):
848
+ if (node1, node2) not in changed_node_pairs and (node2, node1) not in changed_node_pairs:
849
+ write_relations_line(node1, node2, degree, constraints, commented=True)
850
+
851
+ file.write("\n# Changed relations\n")
852
+ # Pair up changed relations (add + remove)
853
+ for node1, node2 in sorted(changed_node_pairs):
854
+ # We want to write the two nodes in the correct (original) order
855
+ node1_to_write = None
856
+ node2_to_write = None
857
+ for node1_remove, node2_remove, degree_remove, constraints_remove in removed_relations:
858
+ if (node1_remove, node2_remove) == (node1, node2) or (node2_remove, node1_remove) == (node1, node2):
859
+ write_relations_line(
860
+ node1_remove, node2_remove, degree_remove, constraints_remove, commented=True
861
+ )
862
+ # The removed nodes follow the original input order
863
+ node1_to_write = node1_remove
864
+ node2_to_write = node2_remove
865
+ for node1_add, node2_add, degree_add, constraints_add in added_relations:
866
+ if (node1_add, node2_add) == (node1, node2) or (node2_add, node1_add) == (node1, node2):
867
+ assert node1_to_write and node2_to_write
868
+ write_relations_line(node1_to_write, node2_to_write, degree_add, constraints_add)
869
+
870
+ file.write("\n# Unchanged relations\n")
871
+ for node1, node2, degree, constraints, _ in self._all_relations.itertuples(index=False):
872
+ if (node1, node2, degree, constraints) not in removed_relations_set:
873
+ assert (node2, node1, degree, constraints) not in removed_relations_set
874
+ write_relations_line(node1, node2, degree, constraints)