py-gbcms 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gbcms/counter.py ADDED
@@ -0,0 +1,1074 @@
1
+ """
2
+ Base counting algorithms for variants - Pure Python Implementation.
3
+
4
+ This module provides the standard (non-optimized) base counting implementation.
5
+ It works directly with pysam objects and is easier to debug and modify.
6
+
7
+ **When to use this module:**
8
+ - Small datasets (<10K variants)
9
+ - Development and debugging
10
+ - When Numba is not available
11
+ - When you need to modify counting logic
12
+
13
+ **Performance:** Baseline (1x)
14
+
15
+ **Alternative:** For production workloads with large datasets, see `numba_counter.py`
16
+ which provides 50-100x speedup through JIT compilation.
17
+
18
+ **Key Classes:**
19
+ - BaseCounter: Main counting class with methods for SNP, DNP, and indel variants
20
+
21
+ **Usage:**
22
+ from gbcms.counter import BaseCounter
23
+ from gbcms.config import Config
24
+
25
+ config = Config(...)
26
+ counter = BaseCounter(config)
27
+ counter.count_bases_snp(variant, alignments, sample_name)
28
+ """
29
+
30
+ import logging
31
+ from collections import defaultdict
32
+
33
+ import numpy as np
34
+ import pysam
35
+
36
+ from .config import Config, CountType
37
+ from .variant import VariantEntry
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class BaseCounter:
43
+ """
44
+ Performs base counting for variants using pure Python.
45
+
46
+ This is the standard (non-optimized) implementation that works directly
47
+ with pysam AlignedSegment objects. It's flexible and easy to debug but
48
+ slower than the Numba-optimized version.
49
+
50
+ **Performance Characteristics:**
51
+ - Speed: Baseline (1x)
52
+ - Memory: Low
53
+ - Flexibility: High (easy to modify)
54
+ - Debugging: Easy (pure Python)
55
+
56
+ **For better performance on large datasets, see:**
57
+ - `numba_counter.py` for 50-100x speedup
58
+ - Use with `parallel.py` for multi-core processing
59
+
60
+ **Attributes:**
61
+ config: Configuration object with quality thresholds and filters
62
+ warning_counts: Track warnings to avoid spam
63
+ """
64
+
65
+ def __init__(self, config: Config):
66
+ """
67
+ Initialize base counter.
68
+
69
+ Args:
70
+ config: Configuration object with quality filters and thresholds
71
+ """
72
+ self.config = config
73
+ self.warning_counts: dict[str, int] = defaultdict(int)
74
+
75
+ def filter_alignment(self, aln: pysam.AlignedSegment) -> bool:
76
+ """
77
+ Check if alignment should be filtered.
78
+
79
+ Args:
80
+ aln: BAM alignment
81
+
82
+ Returns:
83
+ True if alignment should be filtered (excluded)
84
+ """
85
+ if self.config.filter_duplicate and aln.is_duplicate:
86
+ return True
87
+ if self.config.filter_improper_pair and not aln.is_proper_pair:
88
+ return True
89
+ if self.config.filter_qc_failed and aln.is_qcfail:
90
+ return True
91
+ if self.config.filter_non_primary and aln.is_secondary:
92
+ return True
93
+ if self.config.filter_non_primary and aln.is_supplementary:
94
+ return True
95
+ if aln.mapping_quality < self.config.mapping_quality_threshold:
96
+ return True
97
+ if self.config.filter_indel and self._has_indel(aln):
98
+ return True
99
+ return False
100
+
101
+ @staticmethod
102
+ def _has_indel(aln: pysam.AlignedSegment) -> bool:
103
+ """Check if alignment has indels."""
104
+ if aln.cigartuples is None:
105
+ return False
106
+ for op, _length in aln.cigartuples:
107
+ if op in (1, 2): # Insertion or deletion
108
+ return True
109
+ return False
110
+
111
+ def count_bases_snp(
112
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
113
+ ) -> None:
114
+ """
115
+ Count bases for SNP variants.
116
+
117
+ Args:
118
+ variant: Variant entry to count
119
+ alignments: List of alignments overlapping the variant
120
+ sample_name: Sample name for storing counts
121
+ """
122
+ counts = np.zeros(len(CountType), dtype=np.float32)
123
+
124
+ # Fragment tracking for fragment counts
125
+ dpf_map: dict[str, dict[int, int]] = {}
126
+ rdf_map: dict[str, dict[int, int]] = {}
127
+ adf_map: dict[str, dict[int, int]] = {}
128
+
129
+ for aln in alignments:
130
+ # Check if alignment overlaps variant position
131
+ if (aln.reference_start is not None and aln.reference_start > variant.pos) or (
132
+ aln.reference_end is not None and aln.reference_end <= variant.pos
133
+ ):
134
+ continue
135
+
136
+ # Get the base at variant position
137
+ read_pos = None
138
+ for read_idx, ref_idx in aln.get_aligned_pairs(matches_only=False):
139
+ if ref_idx == variant.pos:
140
+ read_pos = read_idx
141
+ break
142
+
143
+ if read_pos is None:
144
+ continue # Variant position is in deletion
145
+
146
+ # Check if query sequence and qualities are available
147
+ if aln.query_sequence is None or aln.query_qualities is None:
148
+ continue
149
+
150
+ # Get base and quality
151
+ base = aln.query_sequence[read_pos].upper()
152
+ qual = aln.query_qualities[read_pos]
153
+
154
+ if qual < self.config.base_quality_threshold:
155
+ continue
156
+
157
+ # Count total depth
158
+ counts[CountType.DP] += 1
159
+ if not aln.is_reverse:
160
+ counts[CountType.DPP] += 1
161
+
162
+ # Track fragment
163
+ end_no = 1 if aln.is_read1 else 2
164
+ if self.config.output_fragment_count:
165
+ if aln.query_name is not None:
166
+ if aln.query_name not in dpf_map:
167
+ dpf_map[aln.query_name] = {}
168
+ dpf_map[aln.query_name][end_no] = dpf_map[aln.query_name].get(end_no, 0) + 1
169
+
170
+ # Count ref/alt
171
+ if base == variant.ref:
172
+ counts[CountType.RD] += 1
173
+ if not aln.is_reverse:
174
+ counts[CountType.RDP] += 1
175
+ if self.config.output_fragment_count and aln.query_name is not None:
176
+ if aln.query_name not in rdf_map:
177
+ rdf_map[aln.query_name] = {}
178
+ rdf_map[aln.query_name][end_no] = rdf_map[aln.query_name].get(end_no, 0) + 1
179
+ elif base == variant.alt:
180
+ counts[CountType.AD] += 1
181
+ if not aln.is_reverse:
182
+ counts[CountType.ADP] += 1
183
+ if self.config.output_fragment_count and aln.query_name is not None:
184
+ if aln.query_name not in adf_map:
185
+ adf_map[aln.query_name] = {}
186
+ adf_map[aln.query_name][end_no] = adf_map[aln.query_name].get(end_no, 0) + 1
187
+
188
+ # Calculate strand bias for this sample
189
+ ref_forward = int(counts[CountType.RDP])
190
+ ref_reverse = int(counts[CountType.RD]) - ref_forward
191
+ alt_forward = int(counts[CountType.ADP])
192
+ alt_reverse = int(counts[CountType.AD]) - alt_forward
193
+
194
+ strand_bias_pval, strand_bias_or, strand_bias_dir = self.calculate_strand_bias(
195
+ ref_forward, ref_reverse, alt_forward, alt_reverse
196
+ )
197
+
198
+ # Note: Strand bias is calculated on-the-fly during output, not stored here
199
+
200
+ # Calculate fragment counts
201
+ if self.config.output_fragment_count:
202
+ counts[CountType.DPF] = len(dpf_map)
203
+
204
+ fragment_ref_weight = 0.5 if self.config.fragment_fractional_weight else 0
205
+ fragment_alt_weight = 0.5 if self.config.fragment_fractional_weight else 0
206
+
207
+ for frag_name, end_counts in dpf_map.items():
208
+ # Check for overlapping multimapped reads
209
+ if any(count > 1 for count in end_counts.values()):
210
+ if (
211
+ self.warning_counts["overlapping_multimap"]
212
+ < self.config.max_warning_per_type
213
+ ):
214
+ logger.warning(
215
+ f"Fragment {frag_name} has overlapping multiple mapped alignment "
216
+ f"at site: {variant.chrom}:{variant.pos + 1}, and will not be used"
217
+ )
218
+ self.warning_counts["overlapping_multimap"] += 1
219
+ continue
220
+
221
+ has_ref = frag_name in rdf_map
222
+ has_alt = frag_name in adf_map
223
+
224
+ if has_ref and has_alt:
225
+ counts[CountType.RDF] += fragment_ref_weight
226
+ counts[CountType.ADF] += fragment_alt_weight
227
+ elif has_ref:
228
+ counts[CountType.RDF] += 1
229
+ elif has_alt:
230
+ counts[CountType.ADF] += 1
231
+
232
+ # Store counts
233
+ if sample_name not in variant.base_count:
234
+ variant.base_count[sample_name] = counts
235
+ else:
236
+ variant.base_count[sample_name] += counts
237
+
238
+ def count_bases_dnp(
239
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
240
+ ) -> None:
241
+ """
242
+ Count bases for DNP (di-nucleotide polymorphism) variants.
243
+
244
+ Args:
245
+ variant: Variant entry to count
246
+ alignments: List of alignments overlapping the variant
247
+ sample_name: Sample name for storing counts
248
+ """
249
+ counts = np.zeros(len(CountType), dtype=np.float32)
250
+
251
+ dpf_map: dict[str, dict[int, int]] = {}
252
+ rdf_map: dict[str, dict[int, int]] = {}
253
+ adf_map: dict[str, dict[int, int]] = {}
254
+
255
+ for aln in alignments:
256
+ # Check if alignment fully covers the DNP
257
+ if (aln.reference_start is not None and aln.reference_start > variant.pos) or (
258
+ aln.reference_end is not None
259
+ and aln.reference_end <= variant.pos + variant.dnp_len - 1
260
+ ):
261
+ continue
262
+
263
+ # Find the read positions corresponding to the DNP
264
+ read_bases = []
265
+ for read_idx, ref_idx in aln.get_aligned_pairs(matches_only=True):
266
+ if ref_idx is not None and variant.pos <= ref_idx < variant.pos + variant.dnp_len:
267
+ if aln.query_sequence is not None:
268
+ read_bases.append((read_idx, aln.query_sequence[read_idx]))
269
+
270
+ if len(read_bases) != variant.dnp_len:
271
+ continue # DNP not fully covered
272
+
273
+ # Check if query sequence and qualities are available
274
+ if aln.query_sequence is None or aln.query_qualities is None:
275
+ continue
276
+
277
+ # Get the DNP sequence and minimum quality
278
+ dnp_seq = "".join([base for _, base in read_bases]).upper()
279
+ min_qual = min([aln.query_qualities[idx] for idx, _ in read_bases])
280
+
281
+ if min_qual < self.config.base_quality_threshold:
282
+ continue
283
+
284
+ # Count total depth
285
+ counts[CountType.DP] += 1
286
+ if not aln.is_reverse:
287
+ counts[CountType.DPP] += 1
288
+
289
+ # Track fragment
290
+ end_no = 1 if aln.is_read1 else 2
291
+ if self.config.output_fragment_count:
292
+ if aln.query_name is not None:
293
+ if aln.query_name not in dpf_map:
294
+ dpf_map[aln.query_name] = {}
295
+ dpf_map[aln.query_name][end_no] = dpf_map[aln.query_name].get(end_no, 0) + 1
296
+
297
+ # Count ref/alt
298
+ if dnp_seq == variant.ref:
299
+ counts[CountType.RD] += 1
300
+ if not aln.is_reverse:
301
+ counts[CountType.RDP] += 1
302
+ if self.config.output_fragment_count and aln.query_name is not None:
303
+ if aln.query_name not in rdf_map:
304
+ rdf_map[aln.query_name] = {}
305
+ rdf_map[aln.query_name][end_no] = rdf_map[aln.query_name].get(end_no, 0) + 1
306
+ elif dnp_seq == variant.alt:
307
+ counts[CountType.AD] += 1
308
+ if not aln.is_reverse:
309
+ counts[CountType.ADP] += 1
310
+ if self.config.output_fragment_count and aln.query_name is not None:
311
+ if aln.query_name not in adf_map:
312
+ adf_map[aln.query_name] = {}
313
+ adf_map[aln.query_name][end_no] = adf_map[aln.query_name].get(end_no, 0) + 1
314
+
315
+ # Calculate fragment counts
316
+ if self.config.output_fragment_count:
317
+ counts[CountType.DPF] = len(dpf_map)
318
+
319
+ fragment_ref_weight = 0.5 if self.config.fragment_fractional_weight else 0
320
+ fragment_alt_weight = 0.5 if self.config.fragment_fractional_weight else 0
321
+
322
+ for frag_name, end_counts in dpf_map.items():
323
+ if any(count > 1 for count in end_counts.values()):
324
+ continue
325
+
326
+ has_ref = frag_name in rdf_map
327
+ has_alt = frag_name in adf_map
328
+
329
+ if has_ref and has_alt:
330
+ counts[CountType.RDF] += fragment_ref_weight
331
+ counts[CountType.ADF] += fragment_alt_weight
332
+ elif has_ref:
333
+ counts[CountType.RDF] += 1
334
+ elif has_alt:
335
+ counts[CountType.ADF] += 1
336
+
337
+ if sample_name not in variant.base_count:
338
+ variant.base_count[sample_name] = counts
339
+ else:
340
+ variant.base_count[sample_name] += counts
341
+
342
+ def count_bases_indel(
343
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
344
+ ) -> None:
345
+ """
346
+ Count bases for indel variants using DMP method.
347
+
348
+ Args:
349
+ variant: Variant entry to count
350
+ alignments: List of alignments overlapping the variant
351
+ sample_name: Sample name for storing counts
352
+ """
353
+ counts = np.zeros(len(CountType), dtype=np.float32)
354
+
355
+ dpf_map: dict[str, dict[int, int]] = {}
356
+ rdf_map: dict[str, dict[int, int]] = {}
357
+ adf_map: dict[str, dict[int, int]] = {}
358
+
359
+ for aln in alignments:
360
+ # Check if alignment overlaps the indel region
361
+ if (aln.reference_start is not None and aln.reference_start > variant.pos + 1) or (
362
+ aln.reference_end is not None and aln.reference_end <= variant.pos
363
+ ):
364
+ continue
365
+
366
+ # Parse CIGAR to find indels at the variant position
367
+ matched_indel = False
368
+ ref_pos = aln.reference_start
369
+ read_pos = 0
370
+
371
+ if aln.cigartuples is None:
372
+ continue
373
+
374
+ for i, (cigar_op, cigar_len) in enumerate(aln.cigartuples):
375
+ if ref_pos is not None and ref_pos > variant.pos + 1:
376
+ break
377
+
378
+ if cigar_op == 0: # Match/mismatch (M)
379
+ # Check if variant position is at the end of this match
380
+ if ref_pos is not None and ref_pos + cigar_len - 1 == variant.pos:
381
+ # Look ahead for insertion or deletion
382
+ if i + 1 < len(aln.cigartuples):
383
+ next_op, next_len = aln.cigartuples[i + 1]
384
+
385
+ if next_op == 1 and variant.insertion: # Insertion (I)
386
+ expected_ins_len = len(variant.alt) - len(variant.ref)
387
+ if next_len == expected_ins_len:
388
+ # Check if insertion sequence matches
389
+ if aln.query_sequence is not None:
390
+ ins_seq = aln.query_sequence[
391
+ read_pos + cigar_len : read_pos + cigar_len + next_len
392
+ ]
393
+ expected_ins_seq = variant.alt[len(variant.ref) :]
394
+ if ins_seq == expected_ins_seq:
395
+ matched_indel = True
396
+ elif next_op == 2 and variant.deletion: # Deletion (D)
397
+ expected_del_len = len(variant.ref) - len(variant.alt)
398
+ if next_len == expected_del_len:
399
+ matched_indel = True
400
+
401
+ # Check if we can count depth at pos+1
402
+ if ref_pos is not None and ref_pos <= variant.pos + 1 < ref_pos + cigar_len:
403
+ # Check if query qualities are available
404
+ if aln.query_qualities is None:
405
+ continue
406
+
407
+ # Get base quality at pos+1
408
+ offset = variant.pos + 1 - ref_pos
409
+ qual = aln.query_qualities[read_pos + offset]
410
+
411
+ if qual >= self.config.base_quality_threshold:
412
+ # Count total depth
413
+ counts[CountType.DP] += 1
414
+ if not aln.is_reverse:
415
+ counts[CountType.DPP] += 1
416
+
417
+ # Track fragment
418
+ end_no = 1 if aln.is_read1 else 2
419
+ if self.config.output_fragment_count:
420
+ if aln.query_name is not None:
421
+ if aln.query_name not in dpf_map:
422
+ dpf_map[aln.query_name] = {}
423
+ dpf_map[aln.query_name][end_no] = (
424
+ dpf_map[aln.query_name].get(end_no, 0) + 1
425
+ )
426
+
427
+ # Count ref/alt based on matched indel
428
+ if matched_indel:
429
+ counts[CountType.AD] += 1
430
+ if not aln.is_reverse:
431
+ counts[CountType.ADP] += 1
432
+ if self.config.output_fragment_count and aln.query_name is not None:
433
+ if aln.query_name not in adf_map:
434
+ adf_map[aln.query_name] = {}
435
+ adf_map[aln.query_name][end_no] = (
436
+ adf_map[aln.query_name].get(end_no, 0) + 1
437
+ )
438
+ else:
439
+ counts[CountType.RD] += 1
440
+ if not aln.is_reverse:
441
+ counts[CountType.RDP] += 1
442
+ if self.config.output_fragment_count and aln.query_name is not None:
443
+ if aln.query_name not in rdf_map:
444
+ rdf_map[aln.query_name] = {}
445
+ rdf_map[aln.query_name][end_no] = (
446
+ rdf_map[aln.query_name].get(end_no, 0) + 1
447
+ )
448
+
449
+ if ref_pos is not None:
450
+ ref_pos += cigar_len
451
+ read_pos += cigar_len
452
+ elif cigar_op == 1: # Insertion (I)
453
+ read_pos += cigar_len
454
+ elif cigar_op == 2: # Deletion (D)
455
+ if ref_pos is not None:
456
+ ref_pos += cigar_len
457
+ elif cigar_op == 4: # Soft clip (S)
458
+ read_pos += cigar_len
459
+ elif cigar_op == 5: # Hard clip (H)
460
+ pass
461
+ elif cigar_op == 3: # Skipped region (N)
462
+ if ref_pos is not None:
463
+ ref_pos += cigar_len
464
+
465
+ # Calculate fragment counts
466
+ if self.config.output_fragment_count:
467
+ counts[CountType.DPF] = len(dpf_map)
468
+
469
+ fragment_ref_weight = 0.5 if self.config.fragment_fractional_weight else 0
470
+ fragment_alt_weight = 0.5 if self.config.fragment_fractional_weight else 0
471
+
472
+ for frag_name, end_counts in dpf_map.items():
473
+ if any(count > 1 for count in end_counts.values()):
474
+ continue
475
+
476
+ has_ref = frag_name in rdf_map
477
+ has_alt = frag_name in adf_map
478
+
479
+ if has_ref and has_alt:
480
+ counts[CountType.RDF] += fragment_ref_weight
481
+ counts[CountType.ADF] += fragment_alt_weight
482
+ elif has_ref:
483
+ counts[CountType.RDF] += 1
484
+ elif has_alt:
485
+ counts[CountType.ADF] += 1
486
+
487
+ # Store counts
488
+ if sample_name not in variant.base_count:
489
+ variant.base_count[sample_name] = counts
490
+ else:
491
+ variant.base_count[sample_name] += counts
492
+
493
+ def count_bases_generic(
494
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
495
+ ) -> None:
496
+ """
497
+ Generic counting algorithm that works for all variant types.
498
+
499
+ This algorithm extracts the alignment allele by parsing CIGAR and comparing
500
+ directly to ref/alt. Works better for complex variants but may give slightly
501
+ different results than the specialized counting methods.
502
+
503
+ This is equivalent to the C++ baseCountGENERIC function.
504
+
505
+ Args:
506
+ variant: Variant to count
507
+ alignments: List of alignments overlapping the variant
508
+ sample_name: Sample name
509
+ """
510
+ counts = np.zeros(len(CountType), dtype=np.float32)
511
+ dpf_map: dict[str, dict[int, int]] = {}
512
+ rdf_map: dict[str, dict[int, int]] = {}
513
+ adf_map: dict[str, dict[int, int]] = {}
514
+
515
+ for aln in alignments:
516
+ if self.filter_alignment(aln):
517
+ continue
518
+
519
+ # Check if alignment overlaps variant region
520
+ if aln.reference_end is not None and (
521
+ aln.reference_end <= variant.pos or aln.reference_start > variant.end_pos
522
+ ):
523
+ continue
524
+
525
+ # Extract alignment allele by parsing CIGAR
526
+ alignment_allele = ""
527
+ cur_bq = float("inf")
528
+ partially_cover = False
529
+
530
+ if aln.reference_start > variant.pos or (
531
+ aln.reference_end is not None and aln.reference_end < variant.end_pos
532
+ ):
533
+ partially_cover = True
534
+
535
+ # Check if query sequence and qualities are available
536
+ if aln.query_sequence is None or aln.query_qualities is None:
537
+ continue
538
+
539
+ # Parse CIGAR to extract allele
540
+ ref_pos = aln.reference_start
541
+ read_pos = 0
542
+ additional_insertion = False
543
+
544
+ if aln.cigartuples is None:
545
+ continue
546
+
547
+ for i, (op, length) in enumerate(aln.cigartuples):
548
+ if (
549
+ aln.reference_end is not None
550
+ and ref_pos > variant.end_pos
551
+ and not additional_insertion
552
+ ):
553
+ break
554
+
555
+ if op == 0: # M (match/mismatch)
556
+ if ref_pos is not None and ref_pos + length - 1 >= variant.pos:
557
+ start_idx = read_pos + max(variant.pos, ref_pos) - ref_pos
558
+ str_len = min(
559
+ length,
560
+ min(variant.end_pos, ref_pos + length - 1)
561
+ + 1
562
+ - max(variant.pos, ref_pos),
563
+ )
564
+ alignment_allele += aln.query_sequence[start_idx : start_idx + str_len]
565
+
566
+ # Get minimum base quality
567
+ for bq_idx in range(str_len):
568
+ cur_bq = min(cur_bq, aln.query_qualities[start_idx + bq_idx])
569
+
570
+ if ref_pos is not None:
571
+ ref_pos += length
572
+ read_pos += length
573
+
574
+ # Allow additional insertion if M falls at variant end
575
+ if ref_pos is not None and ref_pos == variant.end_pos + 1:
576
+ if i + 1 < len(aln.cigartuples) and aln.cigartuples[i + 1][0] == 1:
577
+ additional_insertion = True
578
+
579
+ elif op == 1: # I (insertion)
580
+ if ref_pos is not None and ref_pos >= variant.pos:
581
+ alignment_allele += aln.query_sequence[read_pos : read_pos + length]
582
+ for bq_idx in range(length):
583
+ cur_bq = min(cur_bq, aln.query_qualities[read_pos + bq_idx])
584
+ read_pos += length
585
+ additional_insertion = False
586
+
587
+ elif op == 4: # S (soft clip)
588
+ read_pos += length
589
+
590
+ elif op in [2, 3]: # D or N (deletion/skip)
591
+ if (
592
+ aln.reference_end is not None
593
+ and ref_pos is not None
594
+ and ref_pos + length - 1 > variant.end_pos
595
+ ):
596
+ alignment_allele = "U" # Unmatched deletion
597
+ if ref_pos is not None:
598
+ ref_pos += length
599
+
600
+ # Allow additional insertion if D/N falls at variant end
601
+ if ref_pos is not None and ref_pos == variant.end_pos + 1:
602
+ if i + 1 < len(aln.cigartuples) and aln.cigartuples[i + 1][0] == 1:
603
+ additional_insertion = True
604
+
605
+ # Check base quality threshold
606
+ if cur_bq < self.config.base_quality_threshold:
607
+ continue
608
+
609
+ # Count depth
610
+ counts[CountType.DP] += 1
611
+ if not aln.is_reverse:
612
+ counts[CountType.DPP] += 1
613
+
614
+ # Track fragment
615
+ end_no = 1 if aln.is_read1 else 2
616
+ frag_name = aln.query_name
617
+
618
+ if self.config.output_fragment_count:
619
+ if frag_name is not None:
620
+ if frag_name not in dpf_map:
621
+ dpf_map[frag_name] = {}
622
+ if end_no not in dpf_map[frag_name]:
623
+ dpf_map[frag_name][end_no] = 0
624
+ dpf_map[frag_name][end_no] += 1
625
+
626
+ # Count ref/alt (skip if partially covered)
627
+ if not partially_cover:
628
+ if alignment_allele == variant.ref:
629
+ counts[CountType.RD] += 1
630
+ if not aln.is_reverse:
631
+ counts[CountType.RDP] += 1
632
+
633
+ if self.config.output_fragment_count and frag_name is not None:
634
+ if frag_name not in rdf_map:
635
+ rdf_map[frag_name] = {}
636
+ if end_no not in rdf_map[frag_name]:
637
+ rdf_map[frag_name][end_no] = 0
638
+ rdf_map[frag_name][end_no] += 1
639
+
640
+ elif alignment_allele == variant.alt:
641
+ counts[CountType.AD] += 1
642
+ if not aln.is_reverse:
643
+ counts[CountType.ADP] += 1
644
+
645
+ if self.config.output_fragment_count and frag_name is not None:
646
+ if frag_name not in adf_map:
647
+ adf_map[frag_name] = {}
648
+ if end_no not in adf_map[frag_name]:
649
+ adf_map[frag_name][end_no] = 0
650
+ adf_map[frag_name][end_no] += 1
651
+
652
+ # Calculate fragment counts
653
+ if self.config.output_fragment_count:
654
+ counts[CountType.DPF] = len(dpf_map)
655
+
656
+ fragment_ref_weight = 0.5 if self.config.fragment_fractional_weight else 0
657
+ fragment_alt_weight = 0.5 if self.config.fragment_fractional_weight else 0
658
+
659
+ for frag_name, end_counts in dpf_map.items():
660
+ # Check for overlapping multimapped reads
661
+ overlap_multimap = False
662
+ for count in end_counts.values():
663
+ if count > 1:
664
+ if (
665
+ self.warning_counts["overlapping_multimap"]
666
+ < self.config.max_warning_per_type
667
+ ):
668
+ logger.warning(
669
+ f"Fragment {frag_name} has overlapping multiple mapped alignment "
670
+ f"at site: {variant.chrom}:{variant.pos}"
671
+ )
672
+ self.warning_counts["overlapping_multimap"] += 1
673
+ overlap_multimap = True
674
+ break
675
+
676
+ if overlap_multimap:
677
+ continue
678
+
679
+ # Count fragment ref/alt
680
+ has_ref = frag_name in rdf_map
681
+ has_alt = frag_name in adf_map
682
+
683
+ if has_ref and has_alt:
684
+ # Both ref and alt in fragment
685
+ counts[CountType.RDF] += fragment_ref_weight
686
+ counts[CountType.ADF] += fragment_alt_weight
687
+ elif has_ref:
688
+ counts[CountType.RDF] += 1
689
+ elif has_alt:
690
+ counts[CountType.ADF] += 1
691
+
692
+ def count_bases_snp_numba(
693
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
694
+ ) -> None:
695
+ """
696
+ Count SNP bases using fast numba_counter for simple SNPs.
697
+
698
+ This method converts pysam alignments to NumPy arrays and uses
699
+ the optimized numba_counter for maximum speed on simple SNPs.
700
+
701
+ Args:
702
+ variant: Variant entry to count
703
+ alignments: List of alignments overlapping the variant
704
+ sample_name: Sample name for storing counts
705
+ """
706
+ try:
707
+ from .numba_counter import count_snp_base
708
+
709
+ # Convert alignments to NumPy arrays for numba_counter
710
+ query_bases = []
711
+ query_qualities = []
712
+ reference_positions = []
713
+ is_reverse_flags = []
714
+
715
+ for aln in alignments:
716
+ # Find the base at variant position
717
+ for read_idx, ref_idx in aln.get_aligned_pairs(matches_only=False):
718
+ if ref_idx == variant.pos:
719
+ if (
720
+ aln.query_sequence is not None
721
+ and aln.query_qualities is not None
722
+ and read_idx is not None
723
+ ):
724
+ query_bases.append(aln.query_sequence[read_idx])
725
+ query_qualities.append(aln.query_qualities[read_idx])
726
+ reference_positions.append(ref_idx)
727
+ is_reverse_flags.append(aln.is_reverse)
728
+ break
729
+
730
+ if not query_bases:
731
+ # No valid bases found, fallback to regular counting
732
+ raise ValueError("No valid bases found at variant position")
733
+
734
+ # Convert to NumPy arrays
735
+ import numpy as np
736
+
737
+ bases_array = np.array(query_bases, dtype="U1") # Unicode string array
738
+ quals_array = np.array(query_qualities, dtype=np.uint8)
739
+ pos_array = np.array(reference_positions, dtype=np.int32)
740
+ reverse_array = np.array(is_reverse_flags, dtype=bool)
741
+
742
+ # Use numba_counter for fast counting
743
+ dp, rd, ad, dpp, rdp, adp = count_snp_base(
744
+ bases_array,
745
+ quals_array,
746
+ pos_array,
747
+ reverse_array,
748
+ variant.pos,
749
+ variant.ref,
750
+ variant.alt,
751
+ self.config.base_quality_threshold,
752
+ )
753
+
754
+ # Convert to our count format
755
+ counts = np.zeros(len(CountType), dtype=np.float32)
756
+ counts[CountType.DP] = dp
757
+ counts[CountType.RD] = rd
758
+ counts[CountType.AD] = ad
759
+ counts[CountType.DPP] = dpp
760
+ counts[CountType.RDP] = rdp
761
+ counts[CountType.ADP] = adp
762
+
763
+ # Handle fragment counting if enabled
764
+ if self.config.output_fragment_count:
765
+ # Calculate fragment counts from alignments
766
+ fragment_counts = self._calculate_fragment_counts_numba(alignments, variant)
767
+ counts[CountType.DPF] = fragment_counts["dpf"]
768
+ counts[CountType.RDF] = fragment_counts["rdf"]
769
+ counts[CountType.ADF] = fragment_counts["adf"]
770
+
771
+ # Calculate strand bias for this sample (on-the-fly during output)
772
+ # Note: We don't store strand bias in variant object anymore, calculate during output
773
+
774
+ # Calculate fragment strand bias if fragment counting is enabled
775
+ if self.config.output_fragment_count:
776
+ # Note: Fragment strand bias uses same forward/reverse as normal strand bias
777
+ # since fragments inherit strand orientation from reads
778
+ pass
779
+ if sample_name not in variant.base_count:
780
+ variant.base_count[sample_name] = counts
781
+ else:
782
+ variant.base_count[sample_name] += counts
783
+
784
+ except ImportError:
785
+ # numba not available, fallback to regular counting
786
+ logger.warning("numba_counter not available, falling back to regular SNP counting")
787
+ self.count_bases_snp(variant, alignments, sample_name)
788
+ except Exception as e:
789
+ # Any other error, fallback to regular counting
790
+ logger.warning(f"numba_counter failed: {e}, falling back to regular SNP counting")
791
+ self.count_bases_snp(variant, alignments, sample_name)
792
+
793
+ def _calculate_fragment_counts_numba(
794
+ self, alignments: list[pysam.AlignedSegment], variant: VariantEntry
795
+ ) -> dict:
796
+ """
797
+ Calculate fragment counts for numba_counter results.
798
+
799
+ Args:
800
+ alignments: Alignments for fragment counting
801
+ variant: Variant for context
802
+
803
+ Returns:
804
+ Dictionary with fragment counts
805
+ """
806
+ dpf_map: dict[str, dict[int, int]] = {}
807
+ rdf_map: dict[str, dict[int, int]] = {}
808
+ adf_map: dict[str, dict[int, int]] = {}
809
+
810
+ # Group by fragment and track ref/alt
811
+ for aln in alignments:
812
+ if aln.query_name is None:
813
+ continue
814
+
815
+ # Find base at variant position for this alignment
816
+ base = None
817
+ for read_idx, ref_idx in aln.get_aligned_pairs(matches_only=False):
818
+ if ref_idx == variant.pos and read_idx is not None:
819
+ if aln.query_sequence is not None:
820
+ base = aln.query_sequence[read_idx].upper()
821
+ break
822
+
823
+ if base is None:
824
+ continue
825
+
826
+ end_no = 1 if aln.is_read1 else 2
827
+ frag_name = aln.query_name
828
+
829
+ # Track depth per fragment
830
+ if frag_name not in dpf_map:
831
+ dpf_map[frag_name] = {}
832
+ dpf_map[frag_name][end_no] = dpf_map[frag_name].get(end_no, 0) + 1
833
+
834
+ # Track ref/alt per fragment
835
+ if base == variant.ref:
836
+ if frag_name not in rdf_map:
837
+ rdf_map[frag_name] = {}
838
+ rdf_map[frag_name][end_no] = rdf_map[frag_name].get(end_no, 0) + 1
839
+ elif base == variant.alt:
840
+ if frag_name not in adf_map:
841
+ adf_map[frag_name] = {}
842
+ adf_map[frag_name][end_no] = adf_map[frag_name].get(end_no, 0) + 1
843
+
844
+ # Calculate final fragment counts
845
+ dpf = len(dpf_map)
846
+
847
+ fragment_ref_weight = 0.5 if self.config.fragment_fractional_weight else 1.0
848
+ fragment_alt_weight = 0.5 if self.config.fragment_fractional_weight else 1.0
849
+
850
+ rdf = 0.0
851
+ adf = 0.0
852
+
853
+ for frag_name, end_counts in dpf_map.items():
854
+ # Check for overlapping multimapped reads
855
+ if any(count > 1 for count in end_counts.values()):
856
+ continue
857
+
858
+ has_ref = frag_name in rdf_map
859
+ has_alt = frag_name in adf_map
860
+
861
+ if has_ref and has_alt:
862
+ rdf += fragment_ref_weight
863
+ adf += fragment_alt_weight
864
+ elif has_ref:
865
+ rdf += 1.0
866
+ elif has_alt:
867
+ adf += 1.0
868
+
869
+ return {"dpf": dpf, "rdf": rdf, "adf": adf}
870
+
871
+ def calculate_strand_bias(
872
+ self,
873
+ ref_forward: int,
874
+ ref_reverse: int,
875
+ alt_forward: int,
876
+ alt_reverse: int,
877
+ min_depth: int = 10,
878
+ ) -> tuple[float, float, str]:
879
+ """
880
+ Calculate strand bias using Fisher's exact test.
881
+
882
+ Args:
883
+ ref_forward: Reference allele count on forward strand
884
+ ref_reverse: Reference allele count on reverse strand
885
+ alt_forward: Alternate allele count on forward strand
886
+ alt_reverse: Alternate allele count on reverse strand
887
+ min_depth: Minimum total depth to calculate bias
888
+
889
+ Returns:
890
+ Tuple of (p_value, odds_ratio, bias_direction)
891
+ """
892
+ try:
893
+ import numpy as np
894
+ from scipy.stats import fisher_exact
895
+
896
+ # Check minimum depth requirement
897
+ total_depth = ref_forward + ref_reverse + alt_forward + alt_reverse
898
+ if total_depth < min_depth:
899
+ return 1.0, 1.0, "insufficient_depth"
900
+
901
+ # Create 2x2 contingency table
902
+ # [[ref_forward, ref_reverse],
903
+ # [alt_forward, alt_reverse]]
904
+ table = np.array([[ref_forward, ref_reverse], [alt_forward, alt_reverse]])
905
+
906
+ # Fisher's exact test
907
+ odds_ratio, p_value = fisher_exact(table, alternative="two-sided")
908
+
909
+ # Determine bias direction
910
+ total_forward = ref_forward + alt_forward
911
+ total_reverse = ref_reverse + alt_reverse
912
+
913
+ if total_forward > 0 and total_reverse > 0:
914
+ forward_ratio = ref_forward / total_forward if total_forward > 0 else 0
915
+ reverse_ratio = ref_reverse / total_reverse if total_reverse > 0 else 0
916
+
917
+ if forward_ratio > reverse_ratio + 0.1: # 10% threshold
918
+ bias_direction = "forward"
919
+ elif reverse_ratio > forward_ratio + 0.1:
920
+ bias_direction = "reverse"
921
+ else:
922
+ bias_direction = "none"
923
+ else:
924
+ bias_direction = "none"
925
+
926
+ return p_value, odds_ratio, bias_direction
927
+
928
+ except ImportError:
929
+ logger.warning("scipy not available for strand bias calculation")
930
+ return 1.0, 1.0, "scipy_unavailable"
931
+ except Exception as e:
932
+ logger.warning(f"Error calculating strand bias: {e}")
933
+ return 1.0, 1.0, "error"
934
+
935
+ def get_strand_counts_for_sample(
936
+ self, variant: VariantEntry, sample_name: str
937
+ ) -> tuple[int, int, int, int]:
938
+ """
939
+ Get strand-specific counts for a sample.
940
+
941
+ Args:
942
+ variant: Variant with counts
943
+ sample_name: Sample name
944
+
945
+ Returns:
946
+ Tuple of (ref_forward, ref_reverse, alt_forward, alt_reverse)
947
+ """
948
+ # Get strand-specific counts
949
+ ref_forward = int(variant.get_count(sample_name, CountType.RDP))
950
+ alt_forward = int(variant.get_count(sample_name, CountType.ADP))
951
+
952
+ # Calculate reverse strand counts (total - forward)
953
+ ref_total = int(variant.get_count(sample_name, CountType.RD))
954
+ alt_total = int(variant.get_count(sample_name, CountType.AD))
955
+
956
+ ref_reverse = max(0, ref_total - ref_forward)
957
+ alt_reverse = max(0, alt_total - alt_forward)
958
+
959
+ return ref_forward, ref_reverse, alt_forward, alt_reverse
960
+
961
+ def is_simple_snp(self, variant: VariantEntry) -> bool:
962
+ """
963
+ Determine if variant is a simple SNP that can use fast numba_counter.
964
+
965
+ Simple SNPs are single-base substitutions that don't require complex
966
+ CIGAR parsing for accurate counting.
967
+
968
+ Args:
969
+ variant: Variant to analyze
970
+
971
+ Returns:
972
+ True if variant can use numba_counter, False if needs counter.py
973
+ """
974
+ # Must be a SNP (not insertion, deletion, or DNP)
975
+ if not variant.snp:
976
+ return False
977
+
978
+ # Must have single-base ref and alt alleles
979
+ if len(variant.ref) != 1 or len(variant.alt) != 1:
980
+ return False
981
+
982
+ # Check if we have alignments to analyze
983
+ if not hasattr(variant, "alignments") or not variant.alignments:
984
+ return False
985
+
986
+ # For simple SNPs, numba_counter should work well
987
+ # We'll validate this assumption with position checking
988
+ return True
989
+
990
+ def validate_variant_position(self, variant: VariantEntry, alignments: list) -> bool:
991
+ """
992
+ Validate that variant position makes sense for counting.
993
+
994
+ Args:
995
+ variant: Variant to validate
996
+ alignments: Alignments overlapping variant
997
+
998
+ Returns:
999
+ True if position is valid for counting
1000
+ """
1001
+ if not alignments:
1002
+ return False
1003
+
1004
+ # Check if any alignment actually overlaps the variant position
1005
+ variant_covered = False
1006
+ for aln in alignments:
1007
+ if aln.reference_end is not None and aln.reference_start is not None:
1008
+ if aln.reference_start <= variant.pos <= aln.reference_end:
1009
+ variant_covered = True
1010
+ break
1011
+
1012
+ return variant_covered
1013
+
1014
+ def smart_count_variant(
1015
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
1016
+ ) -> None:
1017
+ """
1018
+ Smart counting strategy that chooses optimal algorithm based on variant complexity.
1019
+
1020
+ Strategy:
1021
+ - Simple SNPs: Use numba_counter for speed (50-100x faster)
1022
+ - Complex variants: Use counter.py for accuracy (better CIGAR handling)
1023
+
1024
+ Args:
1025
+ variant: Variant entry to count
1026
+ alignments: List of alignments overlapping the variant
1027
+ sample_name: Sample name for storing counts
1028
+ """
1029
+ # Validate variant position first
1030
+ if not self.validate_variant_position(variant, alignments):
1031
+ logger.warning(
1032
+ f"No alignments cover variant position for {variant.chrom}:{variant.pos + 1}"
1033
+ )
1034
+ return
1035
+
1036
+ # Choose counting strategy based on variant complexity and user preference
1037
+ if self.config.generic_counting:
1038
+ # User explicitly requested generic counting for all variants
1039
+ self.count_bases_generic(variant, alignments, sample_name)
1040
+ elif self.is_simple_snp(variant):
1041
+ try:
1042
+ # Try numba_counter for simple SNPs
1043
+ self.count_bases_snp_numba(variant, alignments, sample_name)
1044
+ except Exception as e:
1045
+ logger.warning(
1046
+ f"numba_counter failed for {variant.chrom}:{variant.pos + 1}, falling back to counter.py: {e}"
1047
+ )
1048
+ # Fallback to counter.py
1049
+ self.count_bases_snp(variant, alignments, sample_name)
1050
+ else:
1051
+ # Use counter.py for complex variants
1052
+ if variant.dnp:
1053
+ self.count_bases_dnp(variant, alignments, sample_name)
1054
+ elif variant.insertion or variant.deletion:
1055
+ self.count_bases_indel(variant, alignments, sample_name)
1056
+ else:
1057
+ # Default to generic counting for unknown variant types
1058
+ self.count_bases_generic(variant, alignments, sample_name)
1059
+
1060
+ def count_variant(
1061
+ self, variant: VariantEntry, alignments: list[pysam.AlignedSegment], sample_name: str
1062
+ ) -> None:
1063
+ """
1064
+ Count bases for a variant (dispatches to appropriate method).
1065
+
1066
+ Now uses smart counting strategy by default.
1067
+
1068
+ Args:
1069
+ variant: Variant to count
1070
+ alignments: List of alignments overlapping the variant
1071
+ sample_name: Sample name
1072
+ """
1073
+ # Use smart counting strategy for optimal performance/accuracy balance
1074
+ self.smart_count_variant(variant, alignments, sample_name)