speconsense 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1565 @@
1
+ """Main SpecimenClusterer class for clustering and consensus generation."""
2
+
3
+ from collections import defaultdict
4
+ import json
5
+ import logging
6
+ import os
7
+ import statistics
8
+ import subprocess
9
+ import tempfile
10
+ from datetime import datetime
11
+ from typing import Dict, List, Optional, Set, Tuple
12
+
13
+ import edlib
14
+ from adjusted_identity import score_alignment, AdjustmentParams, ScoringFormat
15
+ from Bio import SeqIO
16
+ from Bio.Seq import reverse_complement
17
+ from tqdm import tqdm
18
+
19
+ try:
20
+ from speconsense import __version__
21
+ except ImportError:
22
+ __version__ = "dev"
23
+
24
+ from speconsense.msa import ReadAlignment
25
+ from speconsense.scalability import (
26
+ VsearchCandidateFinder,
27
+ ScalablePairwiseOperation,
28
+ ScalabilityConfig,
29
+ )
30
+
31
+ from .workers import (
32
+ ClusterProcessingConfig,
33
+ ConsensusGenerationConfig,
34
+ _run_spoa_worker,
35
+ _process_cluster_worker,
36
+ _generate_cluster_consensus_worker,
37
+ _trim_primers_standalone,
38
+ _phase_reads_by_variants_standalone,
39
+ )
40
+
41
+
42
+ class SpecimenClusterer:
43
+ def __init__(self, min_identity: float = 0.9,
44
+ inflation: float = 4.0,
45
+ min_size: int = 5,
46
+ min_cluster_ratio: float = 0.2,
47
+ max_sample_size: int = 100,
48
+ presample_size: int = 1000,
49
+ k_nearest_neighbors: int = 20,
50
+ sample_name: str = "sample",
51
+ disable_homopolymer_equivalence: bool = False,
52
+ disable_cluster_merging: bool = False,
53
+ output_dir: str = "clusters",
54
+ outlier_identity_threshold: Optional[float] = None,
55
+ enable_secondpass_phasing: bool = True,
56
+ min_variant_frequency: float = 0.10,
57
+ min_variant_count: int = 5,
58
+ min_ambiguity_frequency: float = 0.10,
59
+ min_ambiguity_count: int = 3,
60
+ enable_iupac_calling: bool = True,
61
+ scale_threshold: int = 1001,
62
+ max_threads: int = 1,
63
+ early_filter: bool = False,
64
+ collect_discards: bool = False):
65
+ self.min_identity = min_identity
66
+ self.inflation = inflation
67
+ self.min_size = min_size
68
+ self.min_cluster_ratio = min_cluster_ratio
69
+ self.max_sample_size = max_sample_size
70
+ self.presample_size = presample_size
71
+ self.k_nearest_neighbors = k_nearest_neighbors
72
+ self.sample_name = sample_name
73
+ self.disable_homopolymer_equivalence = disable_homopolymer_equivalence
74
+ self.disable_cluster_merging = disable_cluster_merging
75
+ self.output_dir = output_dir
76
+
77
+ # Auto-calculate outlier identity threshold if not provided
78
+ # Logic: min_identity accounts for 2×error (read-to-read comparison)
79
+ # outlier_identity_threshold accounts for 1×error (read-to-consensus)
80
+ # Therefore: outlier_identity_threshold = (1 + min_identity) / 2
81
+ if outlier_identity_threshold is None:
82
+ self.outlier_identity_threshold = (1.0 + min_identity) / 2.0
83
+ else:
84
+ self.outlier_identity_threshold = outlier_identity_threshold
85
+
86
+ self.enable_secondpass_phasing = enable_secondpass_phasing
87
+ self.min_variant_frequency = min_variant_frequency
88
+ self.min_variant_count = min_variant_count
89
+ self.min_ambiguity_frequency = min_ambiguity_frequency
90
+ self.min_ambiguity_count = min_ambiguity_count
91
+ self.enable_iupac_calling = enable_iupac_calling
92
+ self.scale_threshold = scale_threshold
93
+ self.max_threads = max_threads
94
+ self.early_filter = early_filter
95
+ self.collect_discards = collect_discards
96
+ self.discarded_read_ids: Set[str] = set() # Track all discarded reads (outliers + filtered)
97
+
98
+ # Initialize scalability configuration
99
+ # scale_threshold: 0=disabled, N>0=enabled for datasets >= N sequences
100
+ self.scalability_config = ScalabilityConfig(
101
+ enabled=scale_threshold > 0,
102
+ activation_threshold=scale_threshold,
103
+ max_threads=max_threads
104
+ )
105
+ self._candidate_finder = None
106
+ if scale_threshold > 0:
107
+ finder = VsearchCandidateFinder(num_threads=max_threads)
108
+ if finder.is_available:
109
+ self._candidate_finder = finder
110
+
111
+ self.sequences = {} # id -> sequence string
112
+ self.records = {} # id -> SeqRecord object
113
+ self.id_map = {} # short_id -> original_id
114
+ self.rev_id_map = {} # original_id -> short_id
115
+
116
+ # Create output directory and debug subdirectory
117
+ os.makedirs(self.output_dir, exist_ok=True)
118
+ self.debug_dir = os.path.join(self.output_dir, "cluster_debug")
119
+ os.makedirs(self.debug_dir, exist_ok=True)
120
+
121
+ # Initialize attributes that may be set later
122
+ self.input_file = None
123
+ self.augment_input = None
124
+ self.algorithm = None
125
+ self.orient_mode = None
126
+ self.primers_file = None
127
+
128
+ def write_metadata(self) -> None:
129
+ """Write run metadata to JSON file for use by post-processing tools."""
130
+ metadata = {
131
+ "version": __version__,
132
+ "timestamp": datetime.now().isoformat(),
133
+ "sample_name": self.sample_name,
134
+ "parameters": {
135
+ "algorithm": self.algorithm,
136
+ "min_identity": self.min_identity,
137
+ "inflation": self.inflation,
138
+ "min_size": self.min_size,
139
+ "min_cluster_ratio": self.min_cluster_ratio,
140
+ "max_sample_size": self.max_sample_size,
141
+ "presample_size": self.presample_size,
142
+ "k_nearest_neighbors": self.k_nearest_neighbors,
143
+ "disable_homopolymer_equivalence": self.disable_homopolymer_equivalence,
144
+ "disable_cluster_merging": self.disable_cluster_merging,
145
+ "outlier_identity_threshold": self.outlier_identity_threshold,
146
+ "enable_secondpass_phasing": self.enable_secondpass_phasing,
147
+ "min_variant_frequency": self.min_variant_frequency,
148
+ "min_variant_count": self.min_variant_count,
149
+ "min_ambiguity_frequency": self.min_ambiguity_frequency,
150
+ "min_ambiguity_count": self.min_ambiguity_count,
151
+ "enable_iupac_calling": self.enable_iupac_calling,
152
+ "scale_threshold": self.scale_threshold,
153
+ "max_threads": self.max_threads,
154
+ "orient_mode": self.orient_mode,
155
+ },
156
+ "input_file": self.input_file,
157
+ "augment_input": self.augment_input,
158
+ }
159
+
160
+ # Add primer information if loaded
161
+ if hasattr(self, 'primers') and self.primers:
162
+ metadata["primers_file"] = self.primers_file
163
+ metadata["primers"] = {}
164
+
165
+ # Store primer sequences (avoid duplicates from RC versions)
166
+ seen_primers = set()
167
+ for primer_name, primer_seq in self.primers:
168
+ # Skip RC versions (they end with _RC)
169
+ if not primer_name.endswith('_RC') and primer_name not in seen_primers:
170
+ metadata["primers"][primer_name] = primer_seq
171
+ seen_primers.add(primer_name)
172
+
173
+ # Write metadata file
174
+ metadata_file = os.path.join(self.debug_dir, f"{self.sample_name}-metadata.json")
175
+ with open(metadata_file, 'w') as f:
176
+ json.dump(metadata, f, indent=2)
177
+
178
+ logging.debug(f"Wrote run metadata to {metadata_file}")
179
+
180
+ def write_phasing_stats(self, initial_clusters_count: int, after_prephasing_merge_count: int,
181
+ subclusters_count: int, merged_count: int, final_count: int,
182
+ clusters_with_ambiguities: int = 0,
183
+ total_ambiguity_positions: int = 0) -> None:
184
+ """Write phasing statistics to JSON file after clustering completes.
185
+
186
+ Args:
187
+ initial_clusters_count: Number of clusters from initial clustering
188
+ after_prephasing_merge_count: Number of clusters after pre-phasing merge
189
+ subclusters_count: Number of sub-clusters after phasing
190
+ merged_count: Number of clusters after post-phasing merge
191
+ final_count: Number of final clusters after filtering
192
+ clusters_with_ambiguities: Number of clusters with at least one ambiguity code
193
+ total_ambiguity_positions: Total number of ambiguity positions across all clusters
194
+ """
195
+ phasing_stats = {
196
+ "phasing_enabled": self.enable_secondpass_phasing,
197
+ "initial_clusters": initial_clusters_count,
198
+ "after_prephasing_merge": after_prephasing_merge_count,
199
+ "phased_subclusters": subclusters_count,
200
+ "after_postphasing_merge": merged_count,
201
+ "after_filtering": final_count,
202
+ "prephasing_clusters_merged": after_prephasing_merge_count < initial_clusters_count,
203
+ "clusters_split": subclusters_count > after_prephasing_merge_count,
204
+ "postphasing_clusters_merged": merged_count < subclusters_count,
205
+ "net_change": final_count - initial_clusters_count,
206
+ "ambiguity_calling_enabled": self.enable_iupac_calling,
207
+ "clusters_with_ambiguities": clusters_with_ambiguities,
208
+ "total_ambiguity_positions": total_ambiguity_positions
209
+ }
210
+
211
+ # Write phasing stats to separate JSON file
212
+ stats_file = os.path.join(self.debug_dir, f"{self.sample_name}-phasing_stats.json")
213
+ with open(stats_file, 'w') as f:
214
+ json.dump(phasing_stats, f, indent=2)
215
+
216
+ logging.debug(f"Wrote phasing statistics to {stats_file}")
217
+
218
+ def add_sequences(self, records: List[SeqIO.SeqRecord],
219
+ augment_records: Optional[List[SeqIO.SeqRecord]] = None) -> None:
220
+ """Add sequences to be clustered, with optional presampling."""
221
+ all_records = records.copy() # Start with primary records
222
+
223
+ # Track the source of each record for potential logging/debugging
224
+ primary_count = len(records)
225
+ augment_count = 0
226
+
227
+ # Add augmented records if provided
228
+ if augment_records:
229
+ augment_count = len(augment_records)
230
+ all_records.extend(augment_records)
231
+
232
+ if self.presample_size and len(all_records) > self.presample_size:
233
+ logging.info(f"Presampling {self.presample_size} sequences from {len(all_records)} total "
234
+ f"({primary_count} primary, {augment_count} augmented)")
235
+
236
+ # First, sort primary sequences by quality and take as many as possible
237
+ primary_sorted = sorted(
238
+ records,
239
+ key=lambda r: -statistics.mean(r.letter_annotations["phred_quality"])
240
+ )
241
+
242
+ # Determine how many primary sequences to include (all if possible)
243
+ primary_to_include = min(len(primary_sorted), self.presample_size)
244
+ presampled = primary_sorted[:primary_to_include]
245
+
246
+ # If we still have room, add augmented sequences sorted by quality
247
+ remaining_slots = self.presample_size - primary_to_include
248
+ if remaining_slots > 0 and augment_records:
249
+ augment_sorted = sorted(
250
+ augment_records,
251
+ key=lambda r: -statistics.mean(r.letter_annotations["phred_quality"])
252
+ )
253
+ presampled.extend(augment_sorted[:remaining_slots])
254
+
255
+ logging.info(f"Presampled {len(presampled)} sequences "
256
+ f"({primary_to_include} primary, {len(presampled) - primary_to_include} augmented)")
257
+ all_records = presampled
258
+
259
+ # Add all selected records to internal storage
260
+ for record in all_records:
261
+ self.sequences[record.id] = str(record.seq)
262
+ self.records[record.id] = record
263
+
264
+ # Log scalability mode status for large datasets
265
+ if len(self.sequences) >= self.scale_threshold and self.scale_threshold > 0:
266
+ if self._candidate_finder is not None:
267
+ logging.info(f"Scalability mode active for {len(self.sequences)} sequences (threshold: {self.scale_threshold})")
268
+ else:
269
+ logging.warning(f"Dataset has {len(self.sequences)} sequences (>= threshold {self.scale_threshold}) "
270
+ "but vsearch not found. Using brute-force.")
271
+
272
+ def _get_scalable_operation(self) -> ScalablePairwiseOperation:
273
+ """Get a ScalablePairwiseOperation for pairwise comparisons."""
274
+ # Wrap calculate_similarity to match expected signature (seq1, seq2, id1, id2)
275
+ # IDs are unused in core.py - only needed for primer-aware scoring in summarize.py
276
+ return ScalablePairwiseOperation(
277
+ candidate_finder=self._candidate_finder,
278
+ scoring_function=lambda seq1, seq2, id1, id2: self.calculate_similarity(seq1, seq2),
279
+ config=self.scalability_config
280
+ )
281
+
282
+ def write_mcl_input(self, output_file: str) -> None:
283
+ """Write similarity matrix in MCL input format using k-nearest neighbors approach."""
284
+ self._create_id_mapping()
285
+
286
+ n = len(self.sequences)
287
+ k = min(self.k_nearest_neighbors, n - 1) # Connect to at most k neighbors
288
+
289
+ # Use scalable operation to compute K-NN edges
290
+ operation = self._get_scalable_operation()
291
+ knn_edges = operation.compute_top_k_neighbors(
292
+ sequences=self.sequences,
293
+ k=k,
294
+ min_identity=self.min_identity,
295
+ output_dir=self.output_dir,
296
+ min_edges_per_node=3
297
+ )
298
+
299
+ # Write edges to MCL input file
300
+ with open(output_file, 'w') as f:
301
+ for id1, neighbors in sorted(knn_edges.items()):
302
+ short_id1 = self.rev_id_map[id1]
303
+ for id2, sim in neighbors:
304
+ short_id2 = self.rev_id_map[id2]
305
+ # Transform similarity to emphasize differences
306
+ transformed_sim = sim ** 2 # Square the similarity
307
+ f.write(f"{short_id1}\t{short_id2}\t{transformed_sim:.6f}\n")
308
+
309
+ def run_mcl(self, input_file: str, output_file: str) -> None:
310
+ """Run MCL clustering algorithm with optimized parameters."""
311
+ cmd = [
312
+ "mcl",
313
+ input_file,
314
+ "--abc", # Input is in ABC format (node1 node2 weight)
315
+ "-I", str(self.inflation), # Inflation parameter
316
+ "-scheme", "7", # More advanced flow simulation
317
+ "-pct", "50", # Prune weakest 50% of connections during iterations
318
+ "-te", str(self.max_threads), # Number of threads
319
+ "-o", output_file # Output file
320
+ ]
321
+
322
+ try:
323
+ result = subprocess.run(
324
+ cmd,
325
+ capture_output=True,
326
+ text=True,
327
+ check=True
328
+ )
329
+ logging.debug(f"MCL stdout: {result.stdout}")
330
+ logging.debug(f"MCL stderr: {result.stderr}")
331
+
332
+ except subprocess.CalledProcessError as e:
333
+ logging.error(f"MCL failed with return code {e.returncode}")
334
+ logging.error(f"Command: {' '.join(cmd)}")
335
+ logging.error(f"Stderr: {e.stderr}")
336
+ raise
337
+
338
+ def merge_similar_clusters(self, clusters: List[Dict], phase_name: str = "Post-phasing") -> List[Dict]:
339
+ """
340
+ Merge clusters whose consensus sequences are identical or homopolymer-equivalent.
341
+ Preserves provenance metadata through the merging process.
342
+
343
+ This function is used for both pre-phasing merge (combining initial clusters before
344
+ variant detection) and post-phasing merge (combining subclusters after phasing).
345
+
346
+ Note: Primer trimming is performed before comparison to ensure clusters that differ
347
+ only in primer regions are properly merged. Trimmed consensuses are used only for
348
+ comparison and are discarded after merging.
349
+
350
+ Args:
351
+ clusters: List of cluster dictionaries with 'read_ids' and provenance fields
352
+ phase_name: Name of the merge phase for logging (e.g., "Pre-phasing", "Post-phasing")
353
+
354
+ Returns:
355
+ List of merged cluster dictionaries with combined provenance
356
+ """
357
+ if not clusters:
358
+ return []
359
+
360
+ # Sort clusters by size, largest first
361
+ clusters = sorted(clusters, key=lambda c: len(c['read_ids']), reverse=True)
362
+
363
+ # Generate a consensus sequence for each cluster
364
+ logging.debug(f"{phase_name} merge: Generating consensus sequences...")
365
+ consensuses = []
366
+ cluster_to_consensus = {} # Map from cluster index to its consensus
367
+
368
+ # First pass: prepare sampled sequences and handle single-read clusters
369
+ clusters_needing_spoa = [] # (cluster_index, sampled_seqs)
370
+
371
+ for i, cluster_dict in enumerate(clusters):
372
+ cluster_reads = cluster_dict['read_ids']
373
+
374
+ # Skip empty clusters
375
+ if not cluster_reads:
376
+ logging.debug(f"Cluster {i} is empty, skipping")
377
+ continue
378
+
379
+ # Single-read clusters don't need SPOA - use the read directly
380
+ if len(cluster_reads) == 1:
381
+ seq_id = next(iter(cluster_reads))
382
+ consensus = self.sequences[seq_id]
383
+ # Trim primers before comparison
384
+ if hasattr(self, 'primers'):
385
+ consensus, _ = self.trim_primers(consensus)
386
+ consensuses.append(consensus)
387
+ cluster_to_consensus[i] = consensus
388
+ continue
389
+
390
+ # Sample from larger clusters to speed up consensus generation
391
+ if len(cluster_reads) > self.max_sample_size:
392
+ # Sample by quality
393
+ qualities = []
394
+ for seq_id in cluster_reads:
395
+ record = self.records[seq_id]
396
+ mean_quality = statistics.mean(record.letter_annotations["phred_quality"])
397
+ qualities.append((mean_quality, seq_id))
398
+
399
+ # Sort by quality (descending), then by read ID (ascending) for deterministic tie-breaking
400
+ sampled_ids = [seq_id for _, seq_id in
401
+ sorted(qualities, key=lambda x: (-x[0], x[1]))[:self.max_sample_size]]
402
+ sampled_seqs = {seq_id: self.sequences[seq_id] for seq_id in sampled_ids}
403
+ else:
404
+ # Sort all reads by quality for optimal SPOA ordering
405
+ qualities = []
406
+ for seq_id in cluster_reads:
407
+ record = self.records[seq_id]
408
+ mean_quality = statistics.mean(record.letter_annotations["phred_quality"])
409
+ qualities.append((mean_quality, seq_id))
410
+ sorted_ids = [seq_id for _, seq_id in
411
+ sorted(qualities, key=lambda x: (-x[0], x[1]))]
412
+ sampled_seqs = {seq_id: self.sequences[seq_id] for seq_id in sorted_ids}
413
+
414
+ clusters_needing_spoa.append((i, sampled_seqs))
415
+
416
+ # Run SPOA for multi-read clusters
417
+ if clusters_needing_spoa:
418
+ if self.max_threads > 1 and len(clusters_needing_spoa) > 10:
419
+ # Parallel SPOA execution using ProcessPoolExecutor
420
+ from concurrent.futures import ProcessPoolExecutor
421
+
422
+ # Prepare work packages with config
423
+ work_packages = [
424
+ (cluster_idx, sampled_seqs, self.disable_homopolymer_equivalence)
425
+ for cluster_idx, sampled_seqs in clusters_needing_spoa
426
+ ]
427
+
428
+ with ProcessPoolExecutor(max_workers=self.max_threads) as executor:
429
+ results = list(tqdm(
430
+ executor.map(_run_spoa_worker, work_packages),
431
+ total=len(work_packages),
432
+ desc=f"{phase_name} consensus generation"
433
+ ))
434
+
435
+ for cluster_idx, result in results:
436
+ if result is None:
437
+ logging.warning(f"Cluster {cluster_idx} produced no consensus, skipping")
438
+ continue
439
+ consensus = result.consensus
440
+ if hasattr(self, 'primers'):
441
+ consensus, _ = self.trim_primers(consensus)
442
+ consensuses.append(consensus)
443
+ cluster_to_consensus[cluster_idx] = consensus
444
+ else:
445
+ # Sequential SPOA execution using same worker function as parallel path
446
+ for cluster_idx, sampled_seqs in clusters_needing_spoa:
447
+ _, result = _run_spoa_worker((cluster_idx, sampled_seqs, self.disable_homopolymer_equivalence))
448
+ if result is None:
449
+ logging.warning(f"Cluster {cluster_idx} produced no consensus, skipping")
450
+ continue
451
+ consensus = result.consensus
452
+ if hasattr(self, 'primers'):
453
+ consensus, _ = self.trim_primers(consensus)
454
+ consensuses.append(consensus)
455
+ cluster_to_consensus[cluster_idx] = consensus
456
+
457
+ consensus_to_clusters = defaultdict(list)
458
+
459
+ # IMPORTANT: Use cluster_to_consensus.items() instead of enumerate(consensuses)
460
+ # because the feature branch processes single-read and multi-read clusters separately,
461
+ # which changes the order in the consensuses list. The cluster_to_consensus dict
462
+ # maintains the correct mapping from cluster index to consensus.
463
+
464
+ if self.disable_homopolymer_equivalence:
465
+ # Only merge exactly identical sequences
466
+ # Sort by cluster index to match main branch iteration order
467
+ for cluster_idx, consensus in sorted(cluster_to_consensus.items()):
468
+ consensus_to_clusters[consensus].append(cluster_idx)
469
+ else:
470
+ # Group by homopolymer-equivalent sequences
471
+ # Use scalable method when enabled and there are many clusters
472
+ use_scalable = (
473
+ self.scale_threshold > 0 and
474
+ self._candidate_finder is not None and
475
+ self._candidate_finder.is_available and
476
+ len(cluster_to_consensus) > 50
477
+ )
478
+
479
+ if use_scalable:
480
+ # Map cluster indices to string IDs for scalability module
481
+ str_to_index = {str(i): i for i in cluster_to_consensus.keys()}
482
+ consensus_seq_dict = {str(i): seq for i, seq in cluster_to_consensus.items()}
483
+
484
+ # Use scalable equivalence grouping
485
+ operation = self._get_scalable_operation()
486
+ equivalence_groups = operation.compute_equivalence_groups(
487
+ sequences=consensus_seq_dict,
488
+ equivalence_fn=self.are_homopolymer_equivalent,
489
+ output_dir=self.output_dir,
490
+ min_candidate_identity=0.95
491
+ )
492
+
493
+ # Convert groups back to indices and populate consensus_to_clusters
494
+ for group in equivalence_groups:
495
+ if group:
496
+ representative = group[0]
497
+ repr_consensus = consensus_seq_dict[representative]
498
+ for str_id in group:
499
+ consensus_to_clusters[repr_consensus].append(str_to_index[str_id])
500
+ else:
501
+ # Original O(n²) approach for small sets
502
+ # Sort by cluster index to match main branch iteration order
503
+ # (main branch iterates via enumerate(consensuses) where consensuses list
504
+ # order matches clusters order; our dict may have different insertion order
505
+ # due to single-read vs multi-read separation)
506
+ for cluster_idx, consensus in sorted(cluster_to_consensus.items()):
507
+ # Find if this consensus is homopolymer-equivalent to any existing group
508
+ found_group = False
509
+ for existing_consensus in consensus_to_clusters.keys():
510
+ if self.are_homopolymer_equivalent(consensus, existing_consensus):
511
+ consensus_to_clusters[existing_consensus].append(cluster_idx)
512
+ found_group = True
513
+ break
514
+
515
+ if not found_group:
516
+ consensus_to_clusters[consensus].append(cluster_idx)
517
+
518
+ merged = []
519
+ merged_indices = set()
520
+
521
+ # Determine merge type for logging
522
+ merge_type = "identical" if self.disable_homopolymer_equivalence else "homopolymer-equivalent"
523
+
524
+ # Handle clusters with equivalent consensus sequences
525
+ for equivalent_clusters in consensus_to_clusters.values():
526
+ if len(equivalent_clusters) > 1:
527
+ # Merge clusters with equivalent consensus
528
+ merged_read_ids = set()
529
+ merged_from_list = []
530
+
531
+ # Check if we're merging phased subclusters from the same initial cluster
532
+ initial_clusters_involved = set()
533
+ phased_subclusters_merged = []
534
+
535
+ for idx in equivalent_clusters:
536
+ merged_read_ids.update(clusters[idx]['read_ids'])
537
+ merged_indices.add(idx)
538
+
539
+ # Track what we're merging from
540
+ cluster_info = {
541
+ 'initial_cluster_num': clusters[idx]['initial_cluster_num'],
542
+ 'allele_combo': clusters[idx].get('allele_combo'),
543
+ 'size': len(clusters[idx]['read_ids'])
544
+ }
545
+ merged_from_list.append(cluster_info)
546
+
547
+ # Track if phased subclusters are being merged
548
+ if clusters[idx].get('allele_combo') is not None:
549
+ phased_subclusters_merged.append(cluster_info)
550
+ initial_clusters_involved.add(clusters[idx]['initial_cluster_num'])
551
+
552
+ # Log if we're merging phased subclusters that came from the same initial cluster
553
+ # This can happen when SPOA consensus generation doesn't preserve variant differences
554
+ # that were detected during phasing (e.g., due to homopolymer normalization differences)
555
+ if len(phased_subclusters_merged) > 1 and len(initial_clusters_involved) == 1:
556
+ initial_cluster = list(initial_clusters_involved)[0]
557
+ logging.debug(
558
+ f"Merging {len(phased_subclusters_merged)} phased subclusters from initial cluster {initial_cluster} "
559
+ f"back together (consensus sequences are {merge_type})"
560
+ )
561
+ for info in phased_subclusters_merged:
562
+ logging.debug(f" Subcluster: allele_combo='{info['allele_combo']}', size={info['size']}")
563
+
564
+ # Create merged cluster with provenance
565
+ merged_cluster = {
566
+ 'read_ids': merged_read_ids,
567
+ 'initial_cluster_num': None, # Multiple sources
568
+ 'allele_combo': None, # Multiple alleles merged
569
+ 'merged_from': merged_from_list # Track merge provenance
570
+ }
571
+ merged.append(merged_cluster)
572
+
573
+ # Add remaining unmerged clusters
574
+ for i, cluster_dict in enumerate(clusters):
575
+ if i not in merged_indices:
576
+ merged.append(cluster_dict)
577
+
578
+ if len(merged) < len(clusters):
579
+ logging.info(f"{phase_name} merge: Combined {len(clusters)} clusters into {len(merged)} "
580
+ f"({len(clusters) - len(merged)} merged due to {merge_type} consensus)")
581
+ else:
582
+ logging.info(f"{phase_name} merge: No clusters merged (no {merge_type} consensus found)")
583
+
584
+ return merged
585
+
586
+ def _find_root(self, merged_to: List[int], i: int) -> int:
587
+ """Find the root index of a merged cluster using path compression."""
588
+ if merged_to[i] != i:
589
+ merged_to[i] = self._find_root(merged_to, merged_to[i])
590
+ return merged_to[i]
591
+
592
+ def write_cluster_files(self, cluster_num: int, cluster: Set[str],
593
+ consensus: str, trimmed_consensus: Optional[str] = None,
594
+ found_primers: Optional[List[str]] = None,
595
+ rid: Optional[float] = None,
596
+ rid_min: Optional[float] = None,
597
+ actual_size: Optional[int] = None,
598
+ consensus_fasta_handle = None,
599
+ sampled_ids: Optional[Set[str]] = None,
600
+ msa: Optional[str] = None,
601
+ sorted_cluster_ids: Optional[List[str]] = None,
602
+ sorted_sampled_ids: Optional[List[str]] = None,
603
+ iupac_count: int = 0) -> None:
604
+ """Write cluster files: reads FASTQ, MSA, and consensus FASTA.
605
+
606
+ Read identity metrics measure internal cluster consistency (not accuracy vs. ground truth):
607
+ - rid: Mean read identity - measures average agreement between reads and consensus
608
+ - rid_min: Minimum read identity - captures worst-case outlier reads
609
+
610
+ High identity values indicate homogeneous clusters with consistent reads.
611
+ Low values may indicate heterogeneity, outliers, or poor consensus (especially at low RiC).
612
+ """
613
+ cluster_size = len(cluster)
614
+ ric_size = min(actual_size or cluster_size, self.max_sample_size)
615
+
616
+ # Create info string with size first
617
+ info_parts = [f"size={cluster_size}", f"ric={ric_size}"]
618
+
619
+ # Add read identity metrics (as percentages for readability)
620
+ if rid is not None:
621
+ info_parts.append(f"rid={rid*100:.1f}")
622
+ if rid_min is not None:
623
+ info_parts.append(f"rid_min={rid_min*100:.1f}")
624
+
625
+ if found_primers:
626
+ info_parts.append(f"primers={','.join(found_primers)}")
627
+ if iupac_count > 0:
628
+ info_parts.append(f"ambig={iupac_count}")
629
+ info_str = " ".join(info_parts)
630
+
631
+ # Write reads FASTQ to debug directory with new naming convention
632
+ # Use sorted order (by quality descending) if available, matching MSA order
633
+ reads_file = os.path.join(self.debug_dir, f"{self.sample_name}-c{cluster_num}-RiC{ric_size}-reads.fastq")
634
+ with open(reads_file, 'w') as f:
635
+ read_ids_to_write = sorted_cluster_ids if sorted_cluster_ids is not None else cluster
636
+ for seq_id in read_ids_to_write:
637
+ SeqIO.write(self.records[seq_id], f, "fastq")
638
+
639
+ # Write sampled reads FASTQ (only sequences used for consensus generation)
640
+ # Use sorted order (by quality descending) if available, matching MSA order
641
+ if sampled_ids is not None or sorted_sampled_ids is not None:
642
+ sampled_file = os.path.join(self.debug_dir, f"{self.sample_name}-c{cluster_num}-RiC{ric_size}-sampled.fastq")
643
+ with open(sampled_file, 'w') as f:
644
+ sampled_to_write = sorted_sampled_ids if sorted_sampled_ids is not None else sampled_ids
645
+ for seq_id in sampled_to_write:
646
+ SeqIO.write(self.records[seq_id], f, "fastq")
647
+
648
+ # Write MSA (multiple sequence alignment) to debug directory
649
+ if msa is not None:
650
+ msa_file = os.path.join(self.debug_dir, f"{self.sample_name}-c{cluster_num}-RiC{ric_size}-msa.fasta")
651
+ with open(msa_file, 'w') as f:
652
+ f.write(msa)
653
+
654
+ # Write untrimmed consensus to debug directory
655
+ with open(os.path.join(self.debug_dir, f"{self.sample_name}-c{cluster_num}-RiC{ric_size}-untrimmed.fasta"),
656
+ 'w') as f:
657
+ f.write(f">{self.sample_name}-c{cluster_num} {info_str}\n")
658
+ f.write(consensus + "\n")
659
+
660
+ # Write consensus to main output file if handle is provided
661
+ if consensus_fasta_handle:
662
+ final_consensus = trimmed_consensus if trimmed_consensus else consensus
663
+ consensus_fasta_handle.write(f">{self.sample_name}-c{cluster_num} {info_str}\n")
664
+ consensus_fasta_handle.write(final_consensus + "\n")
665
+
666
+ def run_mcl_clustering(self, temp_dir: str) -> List[Set[str]]:
667
+ """Run MCL clustering algorithm and return the clusters.
668
+
669
+ Args:
670
+ temp_dir: Path to temporary directory for intermediate files
671
+
672
+ Returns:
673
+ List of clusters, where each cluster is a set of sequence IDs
674
+ """
675
+ mcl_input = os.path.join(temp_dir, "input.abc")
676
+ mcl_output = os.path.join(temp_dir, "output.cls")
677
+
678
+ self.write_mcl_input(mcl_input)
679
+
680
+ logging.info(f"Running MCL algorithm with inflation {self.inflation}...")
681
+ self.run_mcl(mcl_input, mcl_output)
682
+ return self.parse_mcl_output(mcl_output)
683
+
684
+ def run_greedy_clustering(self, temp_dir: str) -> List[Set[str]]:
685
+ """Run greedy clustering algorithm and return the clusters.
686
+
687
+ This algorithm iteratively finds the sequence with the most connections above
688
+ the similarity threshold and forms a cluster around it.
689
+
690
+ Args:
691
+ temp_dir: Path to temporary directory for intermediate files
692
+
693
+ Returns:
694
+ List of clusters, where each cluster is a set of sequence IDs
695
+ """
696
+ logging.info("Running greedy clustering algorithm...")
697
+
698
+ # Build similarity matrix if not already built
699
+ if not hasattr(self, 'alignments'):
700
+ self.alignments = defaultdict(dict)
701
+ self.build_similarity_matrix()
702
+
703
+ # Initial clustering
704
+ clusters = []
705
+ available_ids = set(self.sequences.keys())
706
+
707
+ cluster_count = 0
708
+
709
+ while available_ids:
710
+ center, members = self.find_cluster_center(available_ids)
711
+ available_ids -= members
712
+
713
+ clusters.append(members)
714
+ cluster_count += 1
715
+
716
+ return clusters
717
+
718
+ def build_similarity_matrix(self) -> None:
719
+ """Calculate all pairwise similarities between sequences."""
720
+ logging.info("Calculating pairwise sequence similarities...")
721
+
722
+ # Sort for deterministic order
723
+ seq_ids = sorted(self.sequences.keys())
724
+ total = len(seq_ids) * (len(seq_ids) - 1) // 2
725
+
726
+ with tqdm(total=total, desc="Building similarity matrix") as pbar:
727
+ for i, id1 in enumerate(seq_ids):
728
+ for id2 in seq_ids[i + 1:]:
729
+ sim = self.calculate_similarity(
730
+ self.sequences[id1],
731
+ self.sequences[id2]
732
+ )
733
+
734
+ if sim >= self.min_identity:
735
+ self.alignments[id1][id2] = sim
736
+ self.alignments[id2][id1] = sim
737
+
738
+ pbar.update(1)
739
+
740
+ def find_cluster_center(self, available_ids: Set[str]) -> Tuple[str, Set[str]]:
741
+ """
742
+ Find the sequence with most similar sequences above threshold,
743
+ and return its ID and the IDs of its cluster members.
744
+ """
745
+ best_center = None
746
+ best_members = set()
747
+ best_count = -1
748
+
749
+ # Sort for deterministic iteration (important for tie-breaking)
750
+ for seq_id in sorted(available_ids):
751
+ # Get all sequences that align with current sequence
752
+ members = {other_id for other_id in self.alignments.get(seq_id, {})
753
+ if other_id in available_ids}
754
+
755
+ # Use > (not >=) so first alphabetically wins ties
756
+ if len(members) > best_count:
757
+ best_count = len(members)
758
+ best_center = seq_id
759
+ best_members = members
760
+
761
+ if best_center is None:
762
+ # No alignments found, create singleton cluster with smallest ID
763
+ singleton_id = min(available_ids)
764
+ return singleton_id, {singleton_id}
765
+
766
+ best_members.add(best_center) # Include center in cluster
767
+ return best_center, best_members
768
+
769
+
770
+ # ========================================================================
771
+ # Clustering Phase Helper Methods
772
+ # ========================================================================
773
+
774
+ def _run_initial_clustering(self, temp_dir: str, algorithm: str) -> List[Set[str]]:
775
+ """Phase 1: Run initial clustering algorithm.
776
+
777
+ Args:
778
+ temp_dir: Temporary directory for intermediate files
779
+ algorithm: 'graph' for MCL or 'greedy' for greedy clustering
780
+
781
+ Returns:
782
+ List of clusters (sets of read IDs), sorted by size (largest first)
783
+ """
784
+ if algorithm == "graph":
785
+ try:
786
+ initial_clusters = self.run_mcl_clustering(temp_dir)
787
+ except (subprocess.SubprocessError, FileNotFoundError) as e:
788
+ logging.error(f"MCL clustering failed: {str(e)}")
789
+ logging.error("You may need to install MCL: https://micans.org/mcl/")
790
+ logging.error("Falling back to greedy clustering algorithm...")
791
+ initial_clusters = self.run_greedy_clustering(temp_dir)
792
+ elif algorithm == "greedy":
793
+ initial_clusters = self.run_greedy_clustering(temp_dir)
794
+ else:
795
+ raise ValueError(f"Unknown clustering algorithm: {algorithm}")
796
+
797
+ # Sort initial clusters by size (largest first)
798
+ initial_clusters.sort(key=lambda c: len(c), reverse=True)
799
+ logging.info(f"Initial clustering produced {len(initial_clusters)} clusters")
800
+ return initial_clusters
801
+
802
+ def _run_prephasing_merge(self, initial_clusters: List[Set[str]]) -> List[Set[str]]:
803
+ """Phase 2: Merge initial clusters with HP-equivalent consensus.
804
+
805
+ Maximizes read depth for variant detection in the phasing phase.
806
+
807
+ Args:
808
+ initial_clusters: List of initial clusters from Phase 1
809
+
810
+ Returns:
811
+ List of merged clusters (sets of read IDs)
812
+ """
813
+ if self.disable_cluster_merging:
814
+ logging.info("Cluster merging disabled, skipping pre-phasing merge")
815
+ return initial_clusters
816
+
817
+ # Convert initial clusters to dict format for merge_similar_clusters
818
+ initial_cluster_dicts = [
819
+ {'read_ids': cluster, 'initial_cluster_num': i, 'allele_combo': None}
820
+ for i, cluster in enumerate(initial_clusters, 1)
821
+ ]
822
+ merged_dicts = self.merge_similar_clusters(initial_cluster_dicts, phase_name="Pre-phasing")
823
+ # Extract back to sets for Phase 3
824
+ return [d['read_ids'] for d in merged_dicts]
825
+
826
+ def _apply_early_filter(self, clusters: List[Set[str]]) -> Tuple[List[Set[str]], List[Set[str]]]:
827
+ """Apply early size filtering after pre-phasing merge.
828
+
829
+ Uses the same logic as _run_size_filtering() but operates before
830
+ variant phasing to avoid expensive processing of small clusters.
831
+
832
+ Args:
833
+ clusters: List of merged clusters from Phase 2
834
+
835
+ Returns:
836
+ Tuple of (clusters_to_process, filtered_clusters)
837
+ """
838
+ if not self.early_filter:
839
+ return clusters, []
840
+
841
+ # Get size of each cluster for filtering
842
+ cluster_sizes = [(c, len(c)) for c in clusters]
843
+
844
+ # Find largest cluster size for ratio filtering
845
+ if not cluster_sizes:
846
+ return [], []
847
+ largest_size = max(size for _, size in cluster_sizes)
848
+
849
+ keep_clusters = []
850
+ filtered_clusters = []
851
+
852
+ for cluster, size in cluster_sizes:
853
+ # Apply min_size filter
854
+ if size < self.min_size:
855
+ filtered_clusters.append(cluster)
856
+ continue
857
+
858
+ # Apply min_cluster_ratio filter
859
+ if self.min_cluster_ratio > 0 and size / largest_size < self.min_cluster_ratio:
860
+ filtered_clusters.append(cluster)
861
+ continue
862
+
863
+ keep_clusters.append(cluster)
864
+
865
+ if filtered_clusters:
866
+ # Collect discarded read IDs
867
+ discarded_count = 0
868
+ for cluster in filtered_clusters:
869
+ self.discarded_read_ids.update(cluster)
870
+ discarded_count += len(cluster)
871
+
872
+ logging.info(f"Early filter: {len(filtered_clusters)} clusters ({discarded_count} reads) "
873
+ f"below threshold, {len(keep_clusters)} clusters proceeding to phasing")
874
+
875
+ return keep_clusters, filtered_clusters
876
+
877
+ def _run_variant_phasing(self, merged_clusters: List[Set[str]]) -> List[Dict]:
878
+ """Phase 3: Detect variants and phase reads into haplotypes.
879
+
880
+ For each merged cluster:
881
+ 1. Sample reads if needed
882
+ 2. Generate consensus and MSA
883
+ 3. Optionally remove outliers
884
+ 4. Detect variant positions
885
+ 5. Phase reads by their alleles at variant positions
886
+
887
+ Args:
888
+ merged_clusters: List of merged clusters from Phase 2
889
+
890
+ Returns:
891
+ List of subclusters with provenance info (dicts with read_ids,
892
+ initial_cluster_num, allele_combo)
893
+ """
894
+ all_subclusters = []
895
+ all_discarded = set()
896
+ split_count = 0
897
+ logging.debug("Processing clusters for variant detection and phasing...")
898
+
899
+ indexed_clusters = list(enumerate(merged_clusters, 1))
900
+
901
+ # Create config object for workers (used by both parallel and sequential paths)
902
+ config = ClusterProcessingConfig(
903
+ outlier_identity_threshold=self.outlier_identity_threshold,
904
+ enable_secondpass_phasing=self.enable_secondpass_phasing,
905
+ disable_homopolymer_equivalence=self.disable_homopolymer_equivalence,
906
+ min_variant_frequency=self.min_variant_frequency,
907
+ min_variant_count=self.min_variant_count
908
+ )
909
+
910
+ # Build work packages with per-cluster data
911
+ work_packages = []
912
+ for initial_idx, cluster in indexed_clusters:
913
+ cluster_seqs = {sid: self.sequences[sid] for sid in cluster}
914
+ cluster_quals = {
915
+ sid: statistics.mean(self.records[sid].letter_annotations["phred_quality"])
916
+ for sid in cluster
917
+ }
918
+ work_packages.append((initial_idx, cluster, cluster_seqs, cluster_quals, config))
919
+
920
+ if self.max_threads > 1 and len(merged_clusters) > 10:
921
+ # Parallel processing with ProcessPoolExecutor for true parallelism
922
+ from concurrent.futures import ProcessPoolExecutor
923
+
924
+ with ProcessPoolExecutor(max_workers=self.max_threads) as executor:
925
+ from tqdm import tqdm
926
+ results = list(tqdm(
927
+ executor.map(_process_cluster_worker, work_packages),
928
+ total=len(work_packages),
929
+ desc="Processing clusters"
930
+ ))
931
+
932
+ # Collect results maintaining order
933
+ for subclusters, discarded_ids in results:
934
+ if len(subclusters) > 1:
935
+ split_count += 1
936
+ all_subclusters.extend(subclusters)
937
+ all_discarded.update(discarded_ids)
938
+ else:
939
+ # Sequential processing using same worker function as parallel path
940
+ for work_package in work_packages:
941
+ subclusters, discarded_ids = _process_cluster_worker(work_package)
942
+ if len(subclusters) > 1:
943
+ split_count += 1
944
+ all_subclusters.extend(subclusters)
945
+ all_discarded.update(discarded_ids)
946
+
947
+ # Update shared state after all processing complete
948
+ self.discarded_read_ids.update(all_discarded)
949
+
950
+ split_info = f" ({split_count} split)" if split_count > 0 else ""
951
+ logging.info(f"After phasing, created {len(all_subclusters)} sub-clusters from {len(merged_clusters)} merged clusters{split_info}")
952
+ return all_subclusters
953
+
954
+ def _run_postphasing_merge(self, subclusters: List[Dict]) -> List[Dict]:
955
+ """Phase 4: Merge subclusters with HP-equivalent consensus.
956
+
957
+ Args:
958
+ subclusters: List of subclusters from Phase 3
959
+
960
+ Returns:
961
+ List of merged subclusters
962
+ """
963
+ if self.disable_cluster_merging:
964
+ logging.info("Cluster merging disabled, skipping post-phasing merge")
965
+ return subclusters
966
+
967
+ return self.merge_similar_clusters(subclusters, phase_name="Post-phasing")
968
+
969
+ def _run_size_filtering(self, subclusters: List[Dict]) -> List[Dict]:
970
+ """Phase 5: Filter clusters by size and ratio thresholds.
971
+
972
+ Args:
973
+ subclusters: List of subclusters from Phase 4
974
+
975
+ Returns:
976
+ List of filtered clusters, sorted by size (largest first)
977
+ """
978
+ # Filter by absolute size
979
+ large_clusters = [c for c in subclusters if len(c['read_ids']) >= self.min_size]
980
+ small_clusters = [c for c in subclusters if len(c['read_ids']) < self.min_size]
981
+
982
+ if small_clusters:
983
+ filtered_count = len(small_clusters)
984
+ logging.info(f"Filtered {filtered_count} clusters below minimum size ({self.min_size})")
985
+ # Track discarded reads from size-filtered clusters
986
+ for cluster in small_clusters:
987
+ self.discarded_read_ids.update(cluster['read_ids'])
988
+
989
+ # Filter by relative size ratio
990
+ if large_clusters and self.min_cluster_ratio > 0:
991
+ largest_size = max(len(c['read_ids']) for c in large_clusters)
992
+ before_ratio_filter = len(large_clusters)
993
+ passing_ratio = [c for c in large_clusters
994
+ if len(c['read_ids']) / largest_size >= self.min_cluster_ratio]
995
+ failing_ratio = [c for c in large_clusters
996
+ if len(c['read_ids']) / largest_size < self.min_cluster_ratio]
997
+
998
+ if failing_ratio:
999
+ filtered_count = len(failing_ratio)
1000
+ logging.info(f"Filtered {filtered_count} clusters below minimum ratio ({self.min_cluster_ratio})")
1001
+ # Track discarded reads from ratio-filtered clusters
1002
+ for cluster in failing_ratio:
1003
+ self.discarded_read_ids.update(cluster['read_ids'])
1004
+
1005
+ large_clusters = passing_ratio
1006
+
1007
+ # Sort by size and renumber as c1, c2, c3...
1008
+ large_clusters.sort(key=lambda c: len(c['read_ids']), reverse=True)
1009
+
1010
+ total_sequences = len(self.sequences)
1011
+ sequences_covered = sum(len(c['read_ids']) for c in large_clusters)
1012
+
1013
+ if total_sequences > 0:
1014
+ logging.info(f"Final: {len(large_clusters)} clusters covering {sequences_covered} sequences "
1015
+ f"({sequences_covered / total_sequences:.1%} of total)")
1016
+ else:
1017
+ logging.info(f"Final: {len(large_clusters)} clusters (no sequences to cluster)")
1018
+
1019
+ return large_clusters
1020
+
1021
+ def _write_cluster_outputs(self, clusters: List[Dict], output_file: str) -> Tuple[int, int]:
1022
+ """Phase 6: Generate final consensus and write output files.
1023
+
1024
+ Args:
1025
+ clusters: List of filtered clusters from Phase 5
1026
+ output_file: Path to the output FASTA file
1027
+
1028
+ Returns:
1029
+ Tuple of (clusters_with_ambiguities, total_ambiguity_positions)
1030
+ """
1031
+ total_ambiguity_positions = 0
1032
+ clusters_with_ambiguities = 0
1033
+
1034
+ # Create config for consensus generation workers
1035
+ primers = getattr(self, 'primers', None)
1036
+ config = ConsensusGenerationConfig(
1037
+ max_sample_size=self.max_sample_size,
1038
+ enable_iupac_calling=self.enable_iupac_calling,
1039
+ min_ambiguity_frequency=self.min_ambiguity_frequency,
1040
+ min_ambiguity_count=self.min_ambiguity_count,
1041
+ disable_homopolymer_equivalence=self.disable_homopolymer_equivalence,
1042
+ primers=primers
1043
+ )
1044
+
1045
+ # Build work packages for each cluster
1046
+ work_packages = []
1047
+ for final_idx, cluster_dict in enumerate(clusters, 1):
1048
+ cluster = cluster_dict['read_ids']
1049
+ # Pre-compute quality means for each read
1050
+ qualities = {}
1051
+ for seq_id in cluster:
1052
+ record = self.records[seq_id]
1053
+ qualities[seq_id] = statistics.mean(record.letter_annotations["phred_quality"])
1054
+ # Extract sequences for this cluster
1055
+ sequences = {seq_id: self.sequences[seq_id] for seq_id in cluster}
1056
+ work_packages.append((final_idx, cluster, sequences, qualities, config))
1057
+
1058
+ # Run consensus generation (parallel or sequential based on settings)
1059
+ if self.max_threads > 1 and len(clusters) > 4:
1060
+ # Parallel execution with ProcessPoolExecutor
1061
+ from concurrent.futures import ProcessPoolExecutor
1062
+
1063
+ with ProcessPoolExecutor(max_workers=self.max_threads) as executor:
1064
+ results = list(tqdm(
1065
+ executor.map(_generate_cluster_consensus_worker, work_packages),
1066
+ total=len(work_packages),
1067
+ desc="Final consensus generation"
1068
+ ))
1069
+ else:
1070
+ # Sequential execution using same worker function
1071
+ results = []
1072
+ for work_package in work_packages:
1073
+ result = _generate_cluster_consensus_worker(work_package)
1074
+ results.append(result)
1075
+
1076
+ # Sort results by final_idx to ensure correct order
1077
+ results.sort(key=lambda r: r['final_idx'])
1078
+
1079
+ # Write output files sequentially (I/O bound, must preserve order)
1080
+ with open(output_file, 'w') as consensus_fasta_handle:
1081
+ for result in results:
1082
+ final_idx = result['final_idx']
1083
+ cluster = result['cluster']
1084
+ actual_size = result['actual_size']
1085
+
1086
+ # Log sampling info for large clusters
1087
+ if len(cluster) > self.max_sample_size:
1088
+ logging.debug(f"Cluster {final_idx}: Sampling {self.max_sample_size} from {len(cluster)} reads for final consensus")
1089
+
1090
+ consensus = result['consensus']
1091
+ iupac_count = result['iupac_count']
1092
+
1093
+ if consensus:
1094
+ if iupac_count > 0:
1095
+ logging.debug(f"Cluster {final_idx}: Called {iupac_count} IUPAC ambiguity position(s)")
1096
+ total_ambiguity_positions += iupac_count
1097
+ clusters_with_ambiguities += 1
1098
+
1099
+ # Write output files
1100
+ self.write_cluster_files(
1101
+ cluster_num=final_idx,
1102
+ cluster=cluster,
1103
+ consensus=consensus,
1104
+ trimmed_consensus=result['trimmed_consensus'],
1105
+ found_primers=result['found_primers'],
1106
+ rid=result['rid'],
1107
+ rid_min=result['rid_min'],
1108
+ actual_size=actual_size,
1109
+ consensus_fasta_handle=consensus_fasta_handle,
1110
+ sampled_ids=result['sampled_ids'],
1111
+ msa=result['msa'],
1112
+ sorted_cluster_ids=result['sorted_cluster_ids'],
1113
+ sorted_sampled_ids=result['sorted_sampled_ids'],
1114
+ iupac_count=iupac_count
1115
+ )
1116
+
1117
+ return clusters_with_ambiguities, total_ambiguity_positions
1118
+
1119
+ def _write_discarded_reads(self) -> None:
1120
+ """Write discarded reads to a FASTQ file for inspection.
1121
+
1122
+ Discards include:
1123
+ - Outlier reads removed during variant phasing
1124
+ - Reads from clusters filtered out by early filtering (Phase 2b)
1125
+ - Reads from clusters filtered out by size/ratio thresholds (Phase 5)
1126
+ - Reads filtered during orientation (when --orient-mode filter-failed)
1127
+
1128
+ Output: cluster_debug/{sample_name}-discards.fastq
1129
+ """
1130
+ if not self.discarded_read_ids:
1131
+ return
1132
+
1133
+ discards_file = os.path.join(self.debug_dir, f"{self.sample_name}-discards.fastq")
1134
+ with open(discards_file, 'w') as f:
1135
+ for seq_id in sorted(self.discarded_read_ids):
1136
+ if seq_id in self.records:
1137
+ SeqIO.write(self.records[seq_id], f, "fastq")
1138
+
1139
+ logging.info(f"Wrote {len(self.discarded_read_ids)} discarded reads to {discards_file}")
1140
+
1141
+ def cluster(self, algorithm: str = "graph") -> None:
1142
+ """Perform complete clustering process with variant phasing and write output files.
1143
+
1144
+ Pipeline:
1145
+ 1. Initial clustering (MCL or greedy)
1146
+ 2. Pre-phasing merge (combine HP-equivalent initial clusters)
1147
+ 2b. Early filtering (optional, skip small clusters before expensive phasing)
1148
+ 3. Variant detection + phasing (split clusters by haplotype)
1149
+ 4. Post-phasing merge (combine HP-equivalent subclusters)
1150
+ 5. Filtering (size and ratio thresholds)
1151
+ 6. Output generation
1152
+ 7. Write discarded reads (optional)
1153
+
1154
+ Args:
1155
+ algorithm: Clustering algorithm to use ('graph' for MCL or 'greedy')
1156
+ """
1157
+ with tempfile.TemporaryDirectory() as temp_dir:
1158
+ # Phase 1: Initial clustering
1159
+ initial_clusters = self._run_initial_clustering(temp_dir, algorithm)
1160
+
1161
+ # Phase 2: Pre-phasing merge
1162
+ merged_clusters = self._run_prephasing_merge(initial_clusters)
1163
+
1164
+ # Phase 2b: Early filtering (optional)
1165
+ clusters_to_phase, early_filtered = self._apply_early_filter(merged_clusters)
1166
+
1167
+ # Phase 3: Variant detection + phasing
1168
+ all_subclusters = self._run_variant_phasing(clusters_to_phase)
1169
+
1170
+ # Phase 4: Post-phasing merge
1171
+ merged_subclusters = self._run_postphasing_merge(all_subclusters)
1172
+
1173
+ # Phase 5: Size filtering
1174
+ filtered_clusters = self._run_size_filtering(merged_subclusters)
1175
+
1176
+ # Phase 6: Output generation
1177
+ consensus_output_file = os.path.join(self.output_dir, f"{self.sample_name}-all.fasta")
1178
+ clusters_with_ambiguities, total_ambiguity_positions = self._write_cluster_outputs(
1179
+ filtered_clusters, consensus_output_file
1180
+ )
1181
+
1182
+ # Phase 7: Write discarded reads (optional)
1183
+ if self.collect_discards and self.discarded_read_ids:
1184
+ self._write_discarded_reads()
1185
+
1186
+ # Write phasing statistics
1187
+ self.write_phasing_stats(
1188
+ initial_clusters_count=len(initial_clusters),
1189
+ after_prephasing_merge_count=len(merged_clusters),
1190
+ subclusters_count=len(all_subclusters),
1191
+ merged_count=len(merged_subclusters),
1192
+ final_count=len(filtered_clusters),
1193
+ clusters_with_ambiguities=clusters_with_ambiguities,
1194
+ total_ambiguity_positions=total_ambiguity_positions
1195
+ )
1196
+
1197
+ def _create_id_mapping(self) -> None:
1198
+ """Create short numeric IDs for all sequences."""
1199
+ for i, seq_id in enumerate(self.sequences.keys()):
1200
+ short_id = str(i)
1201
+ self.id_map[short_id] = seq_id
1202
+ self.rev_id_map[seq_id] = short_id
1203
+
1204
+ def calculate_similarity(self, seq1: str, seq2: str) -> float:
1205
+ """Calculate sequence similarity using edlib alignment."""
1206
+ if len(seq1) == 0 or len(seq2) == 0:
1207
+ return 0.0
1208
+
1209
+ max_dist = int((1 - self.min_identity) * max(len(seq1), len(seq2)))
1210
+ result = edlib.align(seq1, seq2, task="distance", k=max_dist)
1211
+
1212
+ if result["editDistance"] == -1:
1213
+ return 0.0
1214
+
1215
+ return 1.0 - (result["editDistance"] / max(len(seq1), len(seq2)))
1216
+
1217
+ def phase_reads_by_variants(
1218
+ self,
1219
+ msa_string: str,
1220
+ consensus_seq: str,
1221
+ cluster_read_ids: Set[str],
1222
+ variant_positions: List[Dict],
1223
+ alignments: Optional[List[ReadAlignment]] = None
1224
+ ) -> List[Tuple[str, Set[str]]]:
1225
+ """Phase reads into haplotypes. Wrapper around standalone function.
1226
+
1227
+ This method is provided for backward compatibility and testing.
1228
+ Internal processing uses _phase_reads_by_variants_standalone directly.
1229
+ """
1230
+ if not variant_positions:
1231
+ return [(None, cluster_read_ids)]
1232
+
1233
+ # Build sequences dict from self.sequences
1234
+ read_sequences = {rid: self.sequences[rid] for rid in cluster_read_ids if rid in self.sequences}
1235
+
1236
+ if not read_sequences:
1237
+ logging.warning("No sequences found for cluster reads")
1238
+ return [(None, cluster_read_ids)]
1239
+
1240
+ config = ClusterProcessingConfig(
1241
+ outlier_identity_threshold=self.outlier_identity_threshold,
1242
+ enable_secondpass_phasing=self.enable_secondpass_phasing,
1243
+ disable_homopolymer_equivalence=self.disable_homopolymer_equivalence,
1244
+ min_variant_frequency=self.min_variant_frequency,
1245
+ min_variant_count=self.min_variant_count
1246
+ )
1247
+
1248
+ return _phase_reads_by_variants_standalone(
1249
+ cluster_read_ids, self.sequences, variant_positions, config
1250
+ )
1251
+
1252
+ def load_primers(self, primer_file: str) -> None:
1253
+ """Load primers from FASTA file with position awareness."""
1254
+ # Store primers in separate lists by position
1255
+ self.forward_primers = []
1256
+ self.reverse_primers = []
1257
+ self.forward_primers_rc = [] # RC of forward primers
1258
+ self.reverse_primers_rc = [] # RC of reverse primers
1259
+
1260
+ # For backward compatibility with trim_primers
1261
+ self.primers = [] # Will be populated with all primers for existing code
1262
+
1263
+ try:
1264
+ primer_count = {'forward': 0, 'reverse': 0, 'unknown': 0}
1265
+
1266
+ for record in SeqIO.parse(primer_file, "fasta"):
1267
+ sequence = str(record.seq)
1268
+ sequence_rc = str(reverse_complement(sequence))
1269
+
1270
+ # Parse position from header
1271
+ if "position=forward" in record.description:
1272
+ self.forward_primers.append((record.id, sequence))
1273
+ self.forward_primers_rc.append((f"{record.id}_RC", sequence_rc))
1274
+ primer_count['forward'] += 1
1275
+ elif "position=reverse" in record.description:
1276
+ self.reverse_primers.append((record.id, sequence))
1277
+ self.reverse_primers_rc.append((f"{record.id}_RC", sequence_rc))
1278
+ primer_count['reverse'] += 1
1279
+ else:
1280
+ # For primers without position info, add to both lists
1281
+ logging.warning(f"Primer {record.id} has no position annotation, treating as bidirectional")
1282
+ self.forward_primers.append((record.id, sequence))
1283
+ self.forward_primers_rc.append((f"{record.id}_RC", sequence_rc))
1284
+ self.reverse_primers.append((record.id, sequence))
1285
+ self.reverse_primers_rc.append((f"{record.id}_RC", sequence_rc))
1286
+ primer_count['unknown'] += 1
1287
+
1288
+ # Maintain backward compatibility
1289
+ self.primers.append((record.id, sequence))
1290
+ self.primers.append((f"{record.id}_RC", sequence_rc))
1291
+
1292
+ total_primers = sum(primer_count.values())
1293
+ if total_primers == 0:
1294
+ logging.warning("No primers were loaded. Primer trimming will be disabled.")
1295
+ else:
1296
+ logging.debug(f"Loaded {total_primers} primers: {primer_count['forward']} forward, "
1297
+ f"{primer_count['reverse']} reverse, {primer_count['unknown']} unknown")
1298
+ except Exception as e:
1299
+ logging.error(f"Error loading primers: {str(e)}")
1300
+ raise
1301
+
1302
+ def orient_sequences(self) -> set:
1303
+ """Normalize sequence orientations based on primer matches.
1304
+
1305
+ Scoring system:
1306
+ - +1 point if a forward primer is found at the expected position
1307
+ - +1 point if a reverse primer is found at the expected position
1308
+ - Maximum score: 2 (both primers found)
1309
+
1310
+ Decision logic:
1311
+ - If one orientation scores >0 and the other scores 0: use the non-zero orientation
1312
+ - If both score 0 or both score >0: keep original orientation (ambiguous/failed)
1313
+ """
1314
+ if not hasattr(self, 'forward_primers') or not hasattr(self, 'reverse_primers'):
1315
+ logging.warning("No positioned primers loaded, skipping orientation")
1316
+ return set()
1317
+
1318
+ if len(self.forward_primers) == 0 and len(self.reverse_primers) == 0:
1319
+ logging.warning("No positioned primers available, skipping orientation")
1320
+ return set()
1321
+
1322
+ logging.info("Starting sequence orientation based on primer positions...")
1323
+
1324
+ oriented_count = 0
1325
+ already_correct = 0
1326
+ failed_count = 0
1327
+ failed_sequences = set() # Track which sequences failed orientation
1328
+
1329
+ # Process each sequence
1330
+ for seq_id in tqdm(self.sequences, desc="Orienting sequences"):
1331
+ sequence = self.sequences[seq_id]
1332
+
1333
+ # Test both orientations (scores will be 0, 1, or 2)
1334
+ forward_score = self._score_orientation(sequence, "forward")
1335
+ reverse_score = self._score_orientation(sequence, "reverse")
1336
+
1337
+ # Decision logic
1338
+ if forward_score > 0 and reverse_score == 0:
1339
+ # Clear forward orientation
1340
+ already_correct += 1
1341
+ logging.debug(f"Kept {seq_id} as-is: forward_score={forward_score}, reverse_score={reverse_score}")
1342
+ elif reverse_score > 0 and forward_score == 0:
1343
+ # Clear reverse orientation - needs to be flipped
1344
+ self.sequences[seq_id] = str(reverse_complement(sequence))
1345
+
1346
+ # Also update the record if it exists
1347
+ if seq_id in self.records:
1348
+ record = self.records[seq_id]
1349
+ record.seq = reverse_complement(record.seq)
1350
+ # Reverse quality scores too if they exist
1351
+ if 'phred_quality' in record.letter_annotations:
1352
+ record.letter_annotations['phred_quality'] = \
1353
+ record.letter_annotations['phred_quality'][::-1]
1354
+
1355
+ oriented_count += 1
1356
+ logging.debug(f"Reoriented {seq_id}: forward_score={forward_score}, reverse_score={reverse_score}")
1357
+ else:
1358
+ # Both zero (no primers) or both non-zero (ambiguous) - orientation failed
1359
+ failed_count += 1
1360
+ failed_sequences.add(seq_id) # Track this sequence as failed
1361
+ if forward_score == 0 and reverse_score == 0:
1362
+ logging.debug(f"No primer matches for {seq_id}")
1363
+ else:
1364
+ logging.debug(f"Ambiguous orientation for {seq_id}: forward_score={forward_score}, reverse_score={reverse_score}")
1365
+
1366
+ logging.info(f"Orientation complete: {already_correct} kept as-is, "
1367
+ f"{oriented_count} reverse-complemented, {failed_count} orientation failed")
1368
+
1369
+ # Return set of failed sequence IDs for potential filtering
1370
+ return failed_sequences
1371
+
1372
+ def _score_orientation(self, sequence: str, orientation: str) -> int:
1373
+ """Score how well primers match in the given orientation.
1374
+
1375
+ Simple binary scoring:
1376
+ - +1 if a forward primer is found at the expected position
1377
+ - +1 if a reverse primer is found at the expected position
1378
+
1379
+ Args:
1380
+ sequence: The sequence to test
1381
+ orientation: Either "forward" or "reverse"
1382
+
1383
+ Returns:
1384
+ Score from 0-2 (integer)
1385
+ """
1386
+ score = 0
1387
+
1388
+ if orientation == "forward":
1389
+ # Forward orientation:
1390
+ # - Check for forward primers at 5' end (as-is)
1391
+ # - Check for RC of reverse primers at 3' end
1392
+ if self._has_primer_match(sequence, self.forward_primers, "start"):
1393
+ score += 1
1394
+ if self._has_primer_match(sequence, self.reverse_primers_rc, "end"):
1395
+ score += 1
1396
+ else:
1397
+ # Reverse orientation:
1398
+ # - Check for reverse primers at 5' end (as-is)
1399
+ # - Check for RC of forward primers at 3' end
1400
+ if self._has_primer_match(sequence, self.reverse_primers, "start"):
1401
+ score += 1
1402
+ if self._has_primer_match(sequence, self.forward_primers_rc, "end"):
1403
+ score += 1
1404
+
1405
+ return score
1406
+
1407
+ def _has_primer_match(self, sequence: str, primers: List[Tuple[str, str]], end: str) -> bool:
1408
+ """Check if any primer matches at the specified end of sequence.
1409
+
1410
+ Args:
1411
+ sequence: The sequence to search in
1412
+ primers: List of (name, sequence) tuples to search for
1413
+ end: Either "start" or "end"
1414
+
1415
+ Returns:
1416
+ True if any primer has a good match, False otherwise
1417
+ """
1418
+ if not primers or not sequence:
1419
+ return False
1420
+
1421
+ # Determine search region
1422
+ max_primer_len = max(len(p[1]) for p in primers) if primers else 50
1423
+ if end == "start":
1424
+ search_region = sequence[:min(max_primer_len * 2, len(sequence))]
1425
+ else:
1426
+ search_region = sequence[-min(max_primer_len * 2, len(sequence)):]
1427
+
1428
+ for primer_name, primer_seq in primers:
1429
+ # Allow up to 25% errors
1430
+ k = len(primer_seq) // 4
1431
+
1432
+ # Use edlib to find best match
1433
+ result = edlib.align(primer_seq, search_region, task="distance", mode="HW", k=k)
1434
+
1435
+ if result["editDistance"] != -1:
1436
+ # Consider it a match if identity is >= 75%
1437
+ identity = 1.0 - (result["editDistance"] / len(primer_seq))
1438
+ if identity >= 0.75:
1439
+ logging.debug(f"Found {primer_name} at {end} with identity {identity:.2%} "
1440
+ f"(edit_dist={result['editDistance']}, len={len(primer_seq)})")
1441
+ return True
1442
+
1443
+ return False
1444
+
1445
+ def trim_primers(self, sequence: str) -> Tuple[str, List[str]]:
1446
+ """Trim primers from start and end of sequence. Wrapper around standalone function."""
1447
+ primers = getattr(self, 'primers', None)
1448
+ return _trim_primers_standalone(sequence, primers)
1449
+
1450
+ def calculate_consensus_distance(self, seq1: str, seq2: str, require_merge_compatible: bool = False) -> int:
1451
+ """Calculate distance between two consensus sequences using adjusted identity.
1452
+
1453
+ Uses custom adjustment parameters that enable only homopolymer normalization:
1454
+ - Homopolymer differences (e.g., AAA vs AAAAA) are treated as identical
1455
+ - Regular substitutions count as mismatches
1456
+ - Non-homopolymer indels optionally prevent merging
1457
+
1458
+ Args:
1459
+ seq1: First consensus sequence
1460
+ seq2: Second consensus sequence
1461
+ require_merge_compatible: If True, return -1 when sequences have variations
1462
+ that cannot be represented in IUPAC consensus (indels)
1463
+
1464
+ Returns:
1465
+ Distance between sequences (substitutions only), or -1 if require_merge_compatible=True
1466
+ and sequences contain non-homopolymer indels
1467
+ """
1468
+ if not seq1 or not seq2:
1469
+ return max(len(seq1), len(seq2))
1470
+
1471
+ # Get alignment from edlib (uses global NW alignment by default)
1472
+ result = edlib.align(seq1, seq2, task="path")
1473
+ if result["editDistance"] == -1:
1474
+ # Alignment failed, return maximum possible distance
1475
+ return max(len(seq1), len(seq2))
1476
+
1477
+ # Get nice alignment for adjusted identity scoring
1478
+ alignment = edlib.getNiceAlignment(result, seq1, seq2)
1479
+ if not alignment or not alignment.get('query_aligned') or not alignment.get('target_aligned'):
1480
+ # Fall back to edit distance if alignment extraction fails
1481
+ return result["editDistance"]
1482
+
1483
+ # Configure custom adjustment parameters for homopolymer normalization only
1484
+ # Use max_repeat_motif_length=1 to be consistent with variant detection
1485
+ # (extract_alignments_from_msa also uses length=1)
1486
+ custom_params = AdjustmentParams(
1487
+ normalize_homopolymers=True, # Enable homopolymer normalization
1488
+ handle_iupac_overlap=False, # Disable IUPAC overlap handling
1489
+ normalize_indels=False, # Disable indel normalization
1490
+ end_skip_distance=0, # Disable end trimming
1491
+ max_repeat_motif_length=1 # Single-base repeats only (consistent with variant detection)
1492
+ )
1493
+
1494
+ # Create custom scoring format to distinguish indels from substitutions
1495
+ custom_format = ScoringFormat(
1496
+ match='|',
1497
+ substitution='X', # Distinct code for substitutions
1498
+ indel_start='I', # Distinct code for indels
1499
+ indel_extension='-',
1500
+ homopolymer_extension='=',
1501
+ end_trimmed='.'
1502
+ )
1503
+
1504
+ # Calculate adjusted identity with custom format
1505
+ score_result = score_alignment(
1506
+ alignment['query_aligned'],
1507
+ alignment['target_aligned'],
1508
+ adjustment_params=custom_params,
1509
+ scoring_format=custom_format
1510
+ )
1511
+
1512
+ # Check for merge compatibility if requested
1513
+ # Both non-homopolymer indels ('I') and terminal overhangs ('.') prevent merging
1514
+ if require_merge_compatible:
1515
+ if 'I' in score_result.score_aligned:
1516
+ # logging.debug(f"Non-homopolymer indel detected, sequences not merge-compatible")
1517
+ return -1 # Signal that merging should not occur
1518
+ if '.' in score_result.score_aligned:
1519
+ # logging.debug(f"Terminal overhang detected, sequences not merge-compatible")
1520
+ return -1 # Signal that merging should not occur
1521
+
1522
+ # Count only substitutions (not homopolymer adjustments or indels)
1523
+ # Note: mismatches includes both substitutions and non-homopolymer indels
1524
+ # For accurate distance when indels are present, we use the mismatches count
1525
+ distance = score_result.mismatches
1526
+
1527
+ # Log details about the variations found
1528
+ substitutions = score_result.score_aligned.count('X')
1529
+ indels = score_result.score_aligned.count('I')
1530
+ homopolymers = score_result.score_aligned.count('=')
1531
+
1532
+ # logging.debug(f"Consensus distance: {distance} total mismatches "
1533
+ # f"({substitutions} substitutions, {indels} indels, "
1534
+ # f"{homopolymers} homopolymer adjustments)")
1535
+
1536
+ return distance
1537
+
1538
+ def are_homopolymer_equivalent(self, seq1: str, seq2: str) -> bool:
1539
+ """Check if two sequences are equivalent when considering only homopolymer differences.
1540
+
1541
+ Uses adjusted-identity scoring with global alignment. Terminal overhangs (marked as '.')
1542
+ and non-homopolymer indels (marked as 'I') prevent merging, ensuring truncated sequences
1543
+ don't merge with full-length sequences.
1544
+ """
1545
+ if not seq1 or not seq2:
1546
+ return seq1 == seq2
1547
+
1548
+ # Use calculate_consensus_distance with merge compatibility check
1549
+ # Global alignment ensures terminal gaps are counted as indels
1550
+ # Returns: -1 (non-homopolymer indels), 0 (homopolymer-equivalent), >0 (substitutions)
1551
+ # Only distance == 0 means truly homopolymer-equivalent
1552
+ distance = self.calculate_consensus_distance(seq1, seq2, require_merge_compatible=True)
1553
+ return distance == 0
1554
+
1555
+ def parse_mcl_output(self, mcl_output_file: str) -> List[Set[str]]:
1556
+ """Parse MCL output file into clusters of original sequence IDs."""
1557
+ clusters = []
1558
+ with open(mcl_output_file) as f:
1559
+ for line in f:
1560
+ # Each line is a tab-separated list of cluster members
1561
+ short_ids = line.strip().split('\t')
1562
+ # Map short IDs back to original sequence IDs
1563
+ cluster = {self.id_map[short_id] for short_id in short_ids}
1564
+ clusters.append(cluster)
1565
+ return clusters