speconsense 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,669 @@
1
+ """HAC clustering and variant selection for speconsense-summarize.
2
+
3
+ Provides hierarchical agglomerative clustering to separate specimens from variants
4
+ and variant selection strategies.
5
+ """
6
+
7
+ import itertools
8
+ import logging
9
+ from typing import List, Dict, Set, Tuple, Optional
10
+ from collections import defaultdict
11
+
12
+ from tqdm import tqdm
13
+
14
+ from speconsense.types import ConsensusInfo
15
+ from speconsense.scalability import (
16
+ VsearchCandidateFinder,
17
+ ScalablePairwiseOperation,
18
+ ScalabilityConfig,
19
+ )
20
+
21
+ from .iupac import (
22
+ primers_are_same,
23
+ calculate_adjusted_identity_distance,
24
+ calculate_overlap_aware_distance,
25
+ create_variant_summary,
26
+ )
27
+
28
+
29
+ def _complete_linkage_subset(
30
+ indices: List[int],
31
+ seq_distances: Dict[Tuple[int, int], float],
32
+ distance_threshold: float,
33
+ seq_adjacency: Dict[int, Set[int]]
34
+ ) -> List[List[int]]:
35
+ """Run complete linkage HAC on a subset of sequences.
36
+
37
+ First partitions the subset into connected components (based on seq_adjacency),
38
+ then runs HAC within each component. This matches the behavior of the original
39
+ complete linkage code.
40
+
41
+ Args:
42
+ indices: List of sequence indices to cluster
43
+ seq_distances: Precomputed distances between sequence pairs
44
+ distance_threshold: Maximum distance for merging (1.0 - identity)
45
+ seq_adjacency: Adjacency dict showing which sequences have edges
46
+
47
+ Returns:
48
+ List of clusters, where each cluster is a list of original indices
49
+ """
50
+ if len(indices) <= 1:
51
+ return [indices]
52
+
53
+ component_set = set(indices)
54
+
55
+ # First, partition into connected components using union-find
56
+ # This matches the original complete linkage behavior
57
+ parent: Dict[int, int] = {i: i for i in indices}
58
+
59
+ def find(x: int) -> int:
60
+ if parent[x] != x:
61
+ parent[x] = find(parent[x])
62
+ return parent[x]
63
+
64
+ def union(x: int, y: int) -> None:
65
+ px, py = find(x), find(y)
66
+ if px != py:
67
+ parent[px] = py
68
+
69
+ for i in indices:
70
+ for j in seq_adjacency.get(i, set()):
71
+ if j in component_set and i < j:
72
+ union(i, j)
73
+
74
+ # Group indices by component
75
+ components: Dict[int, List[int]] = defaultdict(list)
76
+ for i in indices:
77
+ components[find(i)].append(i)
78
+
79
+ # Run HAC within each connected component
80
+ all_clusters: List[List[int]] = []
81
+ for component_indices in components.values():
82
+ if len(component_indices) == 1:
83
+ all_clusters.append(component_indices)
84
+ continue
85
+
86
+ # Run HAC on this component
87
+ component_clusters = _run_hac_on_component(
88
+ component_indices, seq_distances, distance_threshold, seq_adjacency
89
+ )
90
+ all_clusters.extend(component_clusters)
91
+
92
+ return all_clusters
93
+
94
+
95
+ def _run_hac_on_component(
96
+ indices: List[int],
97
+ seq_distances: Dict[Tuple[int, int], float],
98
+ distance_threshold: float,
99
+ seq_adjacency: Dict[int, Set[int]]
100
+ ) -> List[List[int]]:
101
+ """Run HAC on a single connected component.
102
+
103
+ This is the inner HAC loop, separated from component partitioning.
104
+ """
105
+ if len(indices) <= 1:
106
+ return [indices]
107
+
108
+ component_set = set(indices)
109
+
110
+ # Build local adjacency for this subset
111
+ local_adjacency: Dict[int, Set[int]] = defaultdict(set)
112
+ for i in indices:
113
+ for j in seq_adjacency.get(i, set()):
114
+ if j in component_set:
115
+ local_adjacency[i].add(j)
116
+
117
+ # Initialize each sequence as its own cluster
118
+ seq_to_cluster: Dict[int, int] = {i: i for i in indices}
119
+ cluster_map: Dict[int, List[int]] = {i: [i] for i in indices}
120
+
121
+ def get_cluster_adjacency() -> Set[Tuple[int, int]]:
122
+ adjacent_pairs: Set[Tuple[int, int]] = set()
123
+ for seq_i in indices:
124
+ cluster_i = seq_to_cluster[seq_i]
125
+ for seq_j in local_adjacency[seq_i]:
126
+ cluster_j = seq_to_cluster[seq_j]
127
+ if cluster_i != cluster_j:
128
+ pair = (min(cluster_i, cluster_j), max(cluster_i, cluster_j))
129
+ adjacent_pairs.add(pair)
130
+ return adjacent_pairs
131
+
132
+ def cluster_distance(cluster1: List[int], cluster2: List[int]) -> float:
133
+ # Complete linkage: max distance, early exit on missing edge or threshold
134
+ max_dist = 0.0
135
+ for i in cluster1:
136
+ for j in cluster2:
137
+ if i == j:
138
+ continue
139
+ if j not in local_adjacency[i]:
140
+ return 1.0 # Missing edge = max distance
141
+ key = (i, j) if (i, j) in seq_distances else (j, i)
142
+ dist = seq_distances.get(key, 1.0)
143
+ if dist >= distance_threshold:
144
+ return 1.0 # Early exit
145
+ max_dist = max(max_dist, dist)
146
+ return max_dist
147
+
148
+ # HAC merging loop
149
+ while len(cluster_map) > 1:
150
+ adjacent_pairs = get_cluster_adjacency()
151
+ if not adjacent_pairs:
152
+ break
153
+
154
+ min_distance = float('inf')
155
+ merge_pair = None
156
+
157
+ for cluster_i, cluster_j in adjacent_pairs:
158
+ if cluster_i not in cluster_map or cluster_j not in cluster_map:
159
+ continue
160
+ dist = cluster_distance(cluster_map[cluster_i], cluster_map[cluster_j])
161
+ if dist < min_distance:
162
+ min_distance = dist
163
+ merge_pair = (cluster_i, cluster_j)
164
+
165
+ if min_distance >= distance_threshold or merge_pair is None:
166
+ break
167
+
168
+ ci, cj = merge_pair
169
+ merged = cluster_map[ci] + cluster_map[cj]
170
+ for seq_idx in cluster_map[cj]:
171
+ seq_to_cluster[seq_idx] = ci
172
+ cluster_map[ci] = merged
173
+ del cluster_map[cj]
174
+
175
+ return list(cluster_map.values())
176
+
177
+
178
+ def perform_hac_clustering(consensus_list: List[ConsensusInfo],
179
+ variant_group_identity: float,
180
+ min_overlap_bp: int = 0,
181
+ scalability_config: Optional[ScalabilityConfig] = None,
182
+ output_dir: str = ".") -> Dict[int, List[ConsensusInfo]]:
183
+ """
184
+ Perform Hierarchical Agglomerative Clustering.
185
+ Separates specimens from variants based on identity threshold.
186
+ Returns groups of consensus sequences.
187
+
188
+ Linkage strategy:
189
+ - When min_overlap_bp > 0 (overlap mode): Uses HYBRID linkage:
190
+ - Phase 1: COMPLETE linkage within each primer set (prevents chaining)
191
+ - Phase 2: SINGLE linkage across primer sets (allows ITS1+full+ITS2 merging)
192
+ This prevents sequences with the same primers from chaining through
193
+ intermediates while still allowing different-primer sequences to merge
194
+ via overlap regions.
195
+ - When min_overlap_bp == 0 (standard mode): Uses COMPLETE linkage, which
196
+ requires ALL pairs to be within threshold. More conservative for same-length
197
+ sequences.
198
+
199
+ When min_overlap_bp > 0, also uses overlap-aware distance calculation that
200
+ allows sequences of different lengths to be grouped together if they
201
+ share sufficient overlap with good identity.
202
+ """
203
+ if len(consensus_list) <= 1:
204
+ return {0: consensus_list}
205
+
206
+ # Determine linkage strategy based on overlap mode
207
+ use_hybrid_linkage = min_overlap_bp > 0
208
+ linkage_type = "hybrid" if use_hybrid_linkage else "complete"
209
+
210
+ if min_overlap_bp > 0:
211
+ logging.debug(f"Performing HAC clustering with {variant_group_identity} identity threshold "
212
+ f"({linkage_type} linkage, overlap-aware mode, min_overlap={min_overlap_bp}bp)")
213
+ else:
214
+ logging.debug(f"Performing HAC clustering with {variant_group_identity} identity threshold "
215
+ f"({linkage_type} linkage)")
216
+
217
+ n = len(consensus_list)
218
+ logging.debug(f"perform_hac_clustering: {n} sequences, threshold={variant_group_identity}")
219
+ distance_threshold = 1.0 - variant_group_identity
220
+
221
+ # Initialize each sequence as its own cluster
222
+ clusters = [[i] for i in range(n)]
223
+
224
+ # Build initial distance matrix between individual sequences
225
+ seq_distances = {}
226
+
227
+ # Use scalability if enabled and we have enough sequences
228
+ use_scalable = (
229
+ scalability_config is not None and
230
+ scalability_config.enabled and
231
+ n >= scalability_config.activation_threshold and
232
+ n > 50
233
+ )
234
+ logging.debug(f"perform_hac_clustering: use_scalable={use_scalable}")
235
+
236
+ if use_scalable:
237
+ # Build sequence dict with index keys
238
+ sequences = {str(i): consensus_list[i].sequence for i in range(n)}
239
+
240
+ # Build primers lookup by ID for the scoring function
241
+ primers_lookup = {str(i): consensus_list[i].primers for i in range(n)}
242
+
243
+ # Create scoring function that returns similarity (1 - distance)
244
+ # Use overlap-aware distance when min_overlap_bp > 0
245
+ if min_overlap_bp > 0:
246
+ def score_func(seq1: str, seq2: str, id1: str, id2: str) -> float:
247
+ # Check if primers match - same primers require global distance
248
+ p1, p2 = primers_lookup.get(id1), primers_lookup.get(id2)
249
+ if primers_are_same(p1, p2):
250
+ # Same primers -> global distance (no overlap merging)
251
+ return 1.0 - calculate_adjusted_identity_distance(seq1, seq2)
252
+ else:
253
+ # Different primers -> overlap-aware distance
254
+ return 1.0 - calculate_overlap_aware_distance(seq1, seq2, min_overlap_bp)
255
+ else:
256
+ def score_func(seq1: str, seq2: str, id1: str, id2: str) -> float:
257
+ return 1.0 - calculate_adjusted_identity_distance(seq1, seq2)
258
+
259
+ candidate_finder = VsearchCandidateFinder(num_threads=scalability_config.max_threads)
260
+ if candidate_finder.is_available:
261
+ try:
262
+ operation = ScalablePairwiseOperation(
263
+ candidate_finder=candidate_finder,
264
+ scoring_function=score_func,
265
+ config=scalability_config
266
+ )
267
+ distances = operation.compute_distance_matrix(sequences, output_dir, variant_group_identity)
268
+
269
+ # Convert to integer-keyed distances
270
+ for (id1, id2), dist in distances.items():
271
+ i, j = int(id1), int(id2)
272
+ seq_distances[(i, j)] = dist
273
+ seq_distances[(j, i)] = dist
274
+ finally:
275
+ candidate_finder.cleanup()
276
+ else:
277
+ logging.warning("Scalability enabled but vsearch not available. Using brute-force.")
278
+ use_scalable = False
279
+
280
+ if not use_scalable:
281
+ # Standard brute-force calculation
282
+ for i, j in itertools.combinations(range(n), 2):
283
+ if min_overlap_bp > 0:
284
+ # Check if primers match - same primers require global distance
285
+ p1, p2 = consensus_list[i].primers, consensus_list[j].primers
286
+ if primers_are_same(p1, p2):
287
+ # Same primers -> global distance (no overlap merging)
288
+ dist = calculate_adjusted_identity_distance(
289
+ consensus_list[i].sequence,
290
+ consensus_list[j].sequence
291
+ )
292
+ else:
293
+ # Different primers -> overlap-aware distance for primer pool scenarios
294
+ dist = calculate_overlap_aware_distance(
295
+ consensus_list[i].sequence,
296
+ consensus_list[j].sequence,
297
+ min_overlap_bp
298
+ )
299
+ else:
300
+ # Use standard global distance
301
+ dist = calculate_adjusted_identity_distance(
302
+ consensus_list[i].sequence,
303
+ consensus_list[j].sequence
304
+ )
305
+ seq_distances[(i, j)] = dist
306
+ seq_distances[(j, i)] = dist
307
+
308
+ # Build sequence adjacency from computed distances (works for both paths)
309
+ # Only include edges where distance < 1.0 (excludes failed alignments and non-candidates)
310
+ seq_adjacency: Dict[int, Set[int]] = defaultdict(set)
311
+ for (i, j), dist in seq_distances.items():
312
+ if dist < 1.0 and i != j:
313
+ seq_adjacency[i].add(j)
314
+ seq_adjacency[j].add(i)
315
+
316
+ logging.debug(f"Built adjacency: {len(seq_adjacency)} sequences with edges, "
317
+ f"{sum(len(v) for v in seq_adjacency.values()) // 2} unique edges")
318
+
319
+ # Union-find helper functions
320
+ parent: Dict[int, int] = {i: i for i in range(n)}
321
+
322
+ def find(x: int) -> int:
323
+ if parent[x] != x:
324
+ parent[x] = find(parent[x]) # Path compression
325
+ return parent[x]
326
+
327
+ def union(x: int, y: int) -> None:
328
+ px, py = find(x), find(y)
329
+ if px != py:
330
+ parent[px] = py
331
+
332
+ if use_hybrid_linkage:
333
+ # HYBRID LINKAGE: Complete within primer sets, single across primer sets
334
+ # This prevents chaining within same-primer sequences while allowing
335
+ # different-primer sequences (e.g., ITS1 + full ITS + ITS2) to merge.
336
+ logging.debug("Hybrid linkage: complete within primer sets, single across")
337
+
338
+ # Phase 1: Group sequences by primer set
339
+ primer_groups: Dict[Tuple[str, ...], List[int]] = defaultdict(list)
340
+ for i, cons in enumerate(consensus_list):
341
+ primer_key = tuple(sorted(cons.primers)) if cons.primers else ('_none_',)
342
+ primer_groups[primer_key].append(i)
343
+
344
+ logging.debug(f"Found {len(primer_groups)} distinct primer sets")
345
+
346
+ # Run complete linkage HAC within each primer group
347
+ primer_coherent_clusters: List[Tuple[Tuple[str, ...], List[int]]] = []
348
+
349
+ # Log info about the work to be done
350
+ max_group_size = max(len(indices) for indices in primer_groups.values())
351
+ if max_group_size > 1000:
352
+ logging.info(f"Running HAC on {len(primer_groups)} primer groups "
353
+ f"(largest has {max_group_size} sequences, this may take several minutes)")
354
+
355
+ for primer_key, indices in primer_groups.items():
356
+ if len(indices) == 1:
357
+ primer_coherent_clusters.append((primer_key, indices))
358
+ else:
359
+ # Run complete linkage on this primer subset
360
+ sub_clusters = _complete_linkage_subset(
361
+ indices, seq_distances, distance_threshold, seq_adjacency
362
+ )
363
+ for cluster in sub_clusters:
364
+ primer_coherent_clusters.append((primer_key, cluster))
365
+
366
+ logging.debug(f"Phase 1 complete: {len(primer_coherent_clusters)} primer-coherent clusters")
367
+
368
+ # Phase 2: Connect clusters with different primers using single linkage
369
+ # BUT: prevent transitive chaining that would connect same-primer clusters
370
+ # via different-primer intermediates
371
+ n_clusters = len(primer_coherent_clusters)
372
+
373
+ # Track which clusters are in each group (list of cluster indices per group)
374
+ groups: List[Set[int]] = [set([i]) for i in range(n_clusters)]
375
+ cluster_to_group: Dict[int, int] = {i: i for i in range(n_clusters)}
376
+
377
+ def get_group_primers(group_idx: int) -> Dict[Tuple[str, ...], List[int]]:
378
+ """Get all primer keys and their cluster indices in a group."""
379
+ result: Dict[Tuple[str, ...], List[int]] = defaultdict(list)
380
+ for cluster_idx in groups[group_idx]:
381
+ primer_key = primer_coherent_clusters[cluster_idx][0]
382
+ result[primer_key].append(cluster_idx)
383
+ return result
384
+
385
+ def can_merge_groups(group_a: int, group_b: int) -> bool:
386
+ """Check if merging would violate complete linkage for same-primer clusters."""
387
+ # Get all primer->clusters mappings for the merged group
388
+ primers_a = get_group_primers(group_a)
389
+ primers_b = get_group_primers(group_b)
390
+
391
+ # Check each primer key that appears in both groups
392
+ for primer_key in primers_a:
393
+ if primer_key in primers_b:
394
+ # Same primer key in both groups - need complete linkage check
395
+ clusters_a = [primer_coherent_clusters[i][1] for i in primers_a[primer_key]]
396
+ clusters_b = [primer_coherent_clusters[i][1] for i in primers_b[primer_key]]
397
+
398
+ # All pairs between clusters_a and clusters_b must satisfy complete linkage
399
+ for ca in clusters_a:
400
+ for cb in clusters_b:
401
+ # Complete linkage: max distance must be < threshold
402
+ max_dist = 0.0
403
+ for si in ca:
404
+ for sj in cb:
405
+ dist = seq_distances.get((si, sj), seq_distances.get((sj, si), 1.0))
406
+ max_dist = max(max_dist, dist)
407
+ if max_dist >= distance_threshold:
408
+ return False # Would violate complete linkage
409
+ if max_dist >= distance_threshold:
410
+ return False
411
+ return True
412
+
413
+ def merge_groups(group_a: int, group_b: int) -> None:
414
+ """Merge group_b into group_a."""
415
+ if group_a == group_b:
416
+ return
417
+ for cluster_idx in groups[group_b]:
418
+ cluster_to_group[cluster_idx] = group_a
419
+ groups[group_a].add(cluster_idx)
420
+ groups[group_b] = set()
421
+
422
+ cross_primer_edges = 0
423
+ cross_primer_blocked = 0
424
+ for i in range(n_clusters):
425
+ for j in range(i + 1, n_clusters):
426
+ primer_i, cluster_i = primer_coherent_clusters[i]
427
+ primer_j, cluster_j = primer_coherent_clusters[j]
428
+
429
+ # Skip same-primer pairs (already handled in phase 1)
430
+ if primer_i == primer_j:
431
+ continue
432
+
433
+ # Different primers: check single linkage distance
434
+ min_dist = 1.0
435
+ for si in cluster_i:
436
+ for sj in cluster_j:
437
+ dist = seq_distances.get((si, sj), seq_distances.get((sj, si), 1.0))
438
+ min_dist = min(min_dist, dist)
439
+ if min_dist < distance_threshold:
440
+ break
441
+ if min_dist < distance_threshold:
442
+ break
443
+
444
+ if min_dist < distance_threshold:
445
+ group_i = cluster_to_group[i]
446
+ group_j = cluster_to_group[j]
447
+ if group_i != group_j:
448
+ # Check if merge would create invalid same-primer connections
449
+ if can_merge_groups(group_i, group_j):
450
+ merge_groups(group_i, group_j)
451
+ cross_primer_edges += 1
452
+ else:
453
+ cross_primer_blocked += 1
454
+
455
+ logging.debug(f"Phase 2: {cross_primer_edges} cross-primer connections, "
456
+ f"{cross_primer_blocked} blocked by complete linkage constraint")
457
+
458
+ # Collect final groups using the new group structure
459
+ final_groups: Dict[int, List[int]] = defaultdict(list)
460
+ for i, (_, cluster) in enumerate(primer_coherent_clusters):
461
+ group_idx = cluster_to_group[i]
462
+ final_groups[group_idx].extend(cluster)
463
+
464
+ clusters = list(final_groups.values())
465
+ logging.info(f"Found {len(clusters)} sequence groups (hybrid linkage)")
466
+
467
+ else:
468
+ # Complete linkage: partition by connected components first
469
+ # Clusters from different components can never merge (missing edge = dist 1.0)
470
+ logging.debug("Complete linkage: partitioning into connected components")
471
+
472
+ for i in range(n):
473
+ for j in seq_adjacency[i]:
474
+ if i < j:
475
+ union(i, j)
476
+
477
+ # Group sequences by component
478
+ components: Dict[int, List[int]] = defaultdict(list)
479
+ for i in range(n):
480
+ components[find(i)].append(i)
481
+
482
+ # Count singletons vs multi-sequence components
483
+ singletons = sum(1 for c in components.values() if len(c) == 1)
484
+ multi_seq = len(components) - singletons
485
+ logging.info(f"Found {len(components)} sequence groups "
486
+ f"({singletons} single-sequence, {multi_seq} multi-sequence)")
487
+
488
+ # Run HAC within each component
489
+ clusters: List[List[int]] = []
490
+
491
+ for component_seqs in tqdm(components.values(), desc="HAC per component"):
492
+ if len(component_seqs) == 1:
493
+ clusters.append(component_seqs)
494
+ continue
495
+
496
+ # Convert to set for O(1) membership lookup
497
+ component_set = set(component_seqs)
498
+
499
+ # Build local adjacency for this component
500
+ local_adjacency: Dict[int, Set[int]] = defaultdict(set)
501
+ for i in component_seqs:
502
+ for j in seq_adjacency[i]:
503
+ if j in component_set:
504
+ local_adjacency[i].add(j)
505
+
506
+ # Initialize clusters for this component
507
+ seq_to_cluster: Dict[int, int] = {i: i for i in component_seqs}
508
+ cluster_map: Dict[int, List[int]] = {i: [i] for i in component_seqs}
509
+
510
+ def get_cluster_adjacency() -> Set[Tuple[int, int]]:
511
+ adjacent_pairs: Set[Tuple[int, int]] = set()
512
+ for seq_i in component_seqs:
513
+ cluster_i = seq_to_cluster[seq_i]
514
+ for seq_j in local_adjacency[seq_i]:
515
+ cluster_j = seq_to_cluster[seq_j]
516
+ if cluster_i != cluster_j:
517
+ pair = (min(cluster_i, cluster_j), max(cluster_i, cluster_j))
518
+ adjacent_pairs.add(pair)
519
+ return adjacent_pairs
520
+
521
+ def cluster_distance(cluster1: List[int], cluster2: List[int]) -> float:
522
+ # Complete linkage: max distance, early exit on missing edge or threshold
523
+ max_dist = 0.0
524
+ for i in cluster1:
525
+ for j in cluster2:
526
+ if i == j:
527
+ continue
528
+ if j not in local_adjacency[i]:
529
+ return 1.0 # Missing edge = max distance
530
+ key = (i, j) if (i, j) in seq_distances else (j, i)
531
+ dist = seq_distances.get(key, 1.0)
532
+ if dist >= distance_threshold:
533
+ return 1.0 # Early exit
534
+ max_dist = max(max_dist, dist)
535
+ return max_dist
536
+
537
+ # HAC within component
538
+ while len(cluster_map) > 1:
539
+ adjacent_pairs = get_cluster_adjacency()
540
+ if not adjacent_pairs:
541
+ break
542
+
543
+ min_distance = float('inf')
544
+ merge_pair = None
545
+
546
+ for cluster_i, cluster_j in adjacent_pairs:
547
+ if cluster_i not in cluster_map or cluster_j not in cluster_map:
548
+ continue
549
+ dist = cluster_distance(cluster_map[cluster_i], cluster_map[cluster_j])
550
+ if dist < min_distance:
551
+ min_distance = dist
552
+ merge_pair = (cluster_i, cluster_j)
553
+
554
+ if min_distance >= distance_threshold or merge_pair is None:
555
+ break
556
+
557
+ ci, cj = merge_pair
558
+ merged = cluster_map[ci] + cluster_map[cj]
559
+ for seq_idx in cluster_map[cj]:
560
+ seq_to_cluster[seq_idx] = ci
561
+ cluster_map[ci] = merged
562
+ del cluster_map[cj]
563
+
564
+ clusters.extend(cluster_map.values())
565
+
566
+ # Convert clusters to groups of ConsensusInfo
567
+ groups = {}
568
+ for group_id, cluster_indices in enumerate(clusters):
569
+ group_members = [consensus_list[idx] for idx in cluster_indices]
570
+ groups[group_id] = group_members
571
+
572
+ logging.debug(f"HAC clustering created {len(groups)} groups")
573
+ for group_id, group_members in groups.items():
574
+ member_names = [m.sample_name for m in group_members]
575
+ # Convert group_id to final naming (group 0 -> 1, group 1 -> 2, etc.)
576
+ final_group_name = group_id + 1
577
+ logging.debug(f"Group {final_group_name}: {member_names}")
578
+
579
+ return groups
580
+
581
+
582
+ def select_variants(group: List[ConsensusInfo],
583
+ max_variants: int,
584
+ variant_selection: str,
585
+ group_number: int = None) -> List[ConsensusInfo]:
586
+ """
587
+ Select variants from a group based on the specified strategy.
588
+ Always includes the largest variant first.
589
+ max_variants of 0 or -1 means no limit (return all variants).
590
+
591
+ Logs variant summaries for ALL variants in the group, including those
592
+ that will be skipped in the final output.
593
+
594
+ Args:
595
+ group: List of ConsensusInfo to select from
596
+ max_variants: Maximum total variants per group (0 or -1 for no limit)
597
+ variant_selection: Selection strategy ("size" or "diversity")
598
+ group_number: Group number for logging prefix (optional)
599
+ """
600
+ # Sort by size, largest first
601
+ sorted_group = sorted(group, key=lambda x: x.size, reverse=True)
602
+
603
+ if not sorted_group:
604
+ return []
605
+
606
+ # The primary variant is always the largest
607
+ primary_variant = sorted_group[0]
608
+
609
+ # Build prefix for logging
610
+ prefix = f"Group {group_number}: " if group_number is not None else ""
611
+
612
+ # Only log Primary when there are other variants to compare against
613
+ if len(sorted_group) > 1:
614
+ logging.info(f"{prefix}Primary: {primary_variant.sample_name} (size={primary_variant.size}, ric={primary_variant.ric})")
615
+
616
+ # Handle no limit case (0 or -1 means unlimited)
617
+ if max_variants <= 0:
618
+ selected = sorted_group
619
+ elif len(group) <= max_variants:
620
+ selected = sorted_group
621
+ else:
622
+ # Always include the largest (main) variant
623
+ selected = [primary_variant]
624
+ candidates = sorted_group[1:]
625
+
626
+ if variant_selection == "size":
627
+ # Select by size (max_variants - 1 because we already have primary)
628
+ selected.extend(candidates[:max_variants - 1])
629
+ else: # diversity
630
+ # Select by diversity (maximum distance from already selected)
631
+ while len(selected) < max_variants and candidates:
632
+ best_candidate = None
633
+ best_min_distance = -1
634
+
635
+ for candidate in candidates:
636
+ # Calculate minimum distance to all selected variants
637
+ min_distance = min(
638
+ calculate_adjusted_identity_distance(candidate.sequence, sel.sequence)
639
+ for sel in selected
640
+ )
641
+
642
+ if min_distance > best_min_distance:
643
+ best_min_distance = min_distance
644
+ best_candidate = candidate
645
+
646
+ if best_candidate:
647
+ selected.append(best_candidate)
648
+ candidates.remove(best_candidate)
649
+
650
+ # Now generate variant summaries, showing selected variants first in their final order
651
+ # Then show skipped variants
652
+
653
+ # Log selected variants first (excluding primary, which is already logged)
654
+ selected_secondary = selected[1:] # Exclude primary variant
655
+ for i, variant in enumerate(selected_secondary, 1):
656
+ variant_summary = create_variant_summary(primary_variant.sequence, variant.sequence)
657
+ logging.info(f"{prefix}Variant {i}: (size={variant.size}, ric={variant.ric}) - {variant_summary}")
658
+
659
+ # Log skipped variants
660
+ selected_names = {variant.sample_name for variant in selected}
661
+ skipped_variants = [v for v in sorted_group[1:] if v.sample_name not in selected_names]
662
+
663
+ for i, variant in enumerate(skipped_variants):
664
+ # Calculate what the variant number would have been in the original sorted order
665
+ original_position = next(j for j, v in enumerate(sorted_group) if v.sample_name == variant.sample_name)
666
+ variant_summary = create_variant_summary(primary_variant.sequence, variant.sequence)
667
+ logging.info(f"{prefix}Variant {original_position}: (size={variant.size}, ric={variant.ric}) - {variant_summary} - skipping")
668
+
669
+ return selected